PostgreSQL RangeOperators ¶
See also
Contents
Description ¶
PostgreSQL provides a set of SQL operators that can be used together with the range data types (see the PostgreSQL documentation for the full details of range operators).
This class is meant as a convenient method to avoid typos.
The operator names overlap with the names of corresponding lookups
class RangeOperators:
EQUAL = '='
NOT_EQUAL = '<>'
CONTAINS = '@>'
CONTAINED_BY = '<@'
OVERLAPS = '&&'
FULLY_LT = '<<'
FULLY_GT = '>>'
NOT_LT = '&>'
NOT_GT = '&<'
ADJACENT_TO = '-|-'
https://twitter.com/l_avrot and https://twitter.com/be_haki ¶
See also
A reader from Reddit just pointed out that in #PostgreSQL to filter a date range you can use a range type:
SELECT * FROM sales WHERE created <@ daterange(date '2019-01-01', date '2020-01-01', '[)');
Nice!

I call @> and <@ the bird operators because they look like birds and they’re so cute!
django/contrib/postgres/fields/ranges.py ¶
See also
1import datetime
2import json
3
4from django.contrib.postgres import forms
5from django.contrib.postgres import lookups
6from django.db import models
7from psycopg2.extras import DateRange
8from psycopg2.extras import DateTimeTZRange
9from psycopg2.extras import NumericRange
10from psycopg2.extras import Range
11
12from .utils import AttributeSetter
13
14__all__ = [
15 "RangeField",
16 "IntegerRangeField",
17 "BigIntegerRangeField",
18 "DecimalRangeField",
19 "DateTimeRangeField",
20 "DateRangeField",
21 "RangeBoundary",
22 "RangeOperators",
23]
24
25
26class RangeBoundary(models.Expression):
27 """A class that represents range boundaries."""
28
29 def __init__(self, inclusive_lower=True, inclusive_upper=False):
30 self.lower = "[" if inclusive_lower else "("
31 self.upper = "]" if inclusive_upper else ")"
32
33 def as_sql(self, compiler, connection):
34 return "'%s%s'" % (self.lower, self.upper), []
35
36
37class RangeOperators:
38 # https://www.postgresql.org/docs/current/functions-range.html#RANGE-OPERATORS-TABLE
39 EQUAL = "="
40 NOT_EQUAL = "<>"
41 CONTAINS = "@>"
42 CONTAINED_BY = "<@"
43 OVERLAPS = "&&"
44 FULLY_LT = "<<"
45 FULLY_GT = ">>"
46 NOT_LT = "&>"
47 NOT_GT = "&<"
48 ADJACENT_TO = "-|-"
49
50
51class RangeField(models.Field):
52 empty_strings_allowed = False
53
54 def __init__(self, *args, **kwargs):
55 # Initializing base_field here ensures that its model matches the model for self.
56 if hasattr(self, "base_field"):
57 self.base_field = self.base_field()
58 super().__init__(*args, **kwargs)
59
60 @property
61 def model(self):
62 try:
63 return self.__dict__["model"]
64 except KeyError:
65 raise AttributeError(
66 "'%s' object has no attribute 'model'" % self.__class__.__name__
67 )
68
69 @model.setter
70 def model(self, model):
71 self.__dict__["model"] = model
72 self.base_field.model = model
73
74 @classmethod
75 def _choices_is_value(cls, value):
76 return isinstance(value, (list, tuple)) or super()._choices_is_value(value)
77
78 def get_prep_value(self, value):
79 if value is None:
80 return None
81 elif isinstance(value, Range):
82 return value
83 elif isinstance(value, (list, tuple)):
84 return self.range_type(value[0], value[1])
85 return value
86
87 def to_python(self, value):
88 if isinstance(value, str):
89 # Assume we're deserializing
90 vals = json.loads(value)
91 for end in ("lower", "upper"):
92 if end in vals:
93 vals[end] = self.base_field.to_python(vals[end])
94 value = self.range_type(**vals)
95 elif isinstance(value, (list, tuple)):
96 value = self.range_type(value[0], value[1])
97 return value
98
99 def set_attributes_from_name(self, name):
100 super().set_attributes_from_name(name)
101 self.base_field.set_attributes_from_name(name)
102
103 def value_to_string(self, obj):
104 value = self.value_from_object(obj)
105 if value is None:
106 return None
107 if value.isempty:
108 return json.dumps({"empty": True})
109 base_field = self.base_field
110 result = {"bounds": value._bounds}
111 for end in ("lower", "upper"):
112 val = getattr(value, end)
113 if val is None:
114 result[end] = None
115 else:
116 obj = AttributeSetter(base_field.attname, val)
117 result[end] = base_field.value_to_string(obj)
118 return json.dumps(result)
119
120 def formfield(self, **kwargs):
121 kwargs.setdefault("form_class", self.form_field)
122 return super().formfield(**kwargs)
123
124
125class IntegerRangeField(RangeField):
126 base_field = models.IntegerField
127 range_type = NumericRange
128 form_field = forms.IntegerRangeField
129
130 def db_type(self, connection):
131 return "int4range"
132
133
134class BigIntegerRangeField(RangeField):
135 base_field = models.BigIntegerField
136 range_type = NumericRange
137 form_field = forms.IntegerRangeField
138
139 def db_type(self, connection):
140 return "int8range"
141
142
143class DecimalRangeField(RangeField):
144 base_field = models.DecimalField
145 range_type = NumericRange
146 form_field = forms.DecimalRangeField
147
148 def db_type(self, connection):
149 return "numrange"
150
151
152class DateTimeRangeField(RangeField):
153 base_field = models.DateTimeField
154 range_type = DateTimeTZRange
155 form_field = forms.DateTimeRangeField
156
157 def db_type(self, connection):
158 return "tstzrange"
159
160
161class DateRangeField(RangeField):
162 base_field = models.DateField
163 range_type = DateRange
164 form_field = forms.DateRangeField
165
166 def db_type(self, connection):
167 return "daterange"
168
169
170RangeField.register_lookup(lookups.DataContains)
171RangeField.register_lookup(lookups.ContainedBy)
172RangeField.register_lookup(lookups.Overlap)
173
174
175class DateTimeRangeContains(lookups.PostgresSimpleLookup):
176 """
177 Lookup for Date/DateTimeRange containment to cast the rhs to the correct
178 type.
179 """
180
181 lookup_name = "contains"
182 operator = RangeOperators.CONTAINS
183
184 def process_rhs(self, compiler, connection):
185 # Transform rhs value for db lookup.
186 if isinstance(self.rhs, datetime.date):
187 output_field = (
188 models.DateTimeField()
189 if isinstance(self.rhs, datetime.datetime)
190 else models.DateField()
191 )
192 value = models.Value(self.rhs, output_field=output_field)
193 self.rhs = value.resolve_expression(compiler.query)
194 return super().process_rhs(compiler, connection)
195
196 def as_sql(self, compiler, connection):
197 sql, params = super().as_sql(compiler, connection)
198 # Cast the rhs if needed.
199 cast_sql = ""
200 if (
201 isinstance(self.rhs, models.Expression)
202 and self.rhs._output_field_or_none
203 and
204 # Skip cast if rhs has a matching range type.
205 not isinstance(
206 self.rhs._output_field_or_none, self.lhs.output_field.__class__
207 )
208 ):
209 cast_internal_type = self.lhs.output_field.base_field.get_internal_type()
210 cast_sql = "::{}".format(connection.data_types.get(cast_internal_type))
211 return "%s%s" % (sql, cast_sql), params
212
213
214DateRangeField.register_lookup(DateTimeRangeContains)
215DateTimeRangeField.register_lookup(DateTimeRangeContains)
216
217
218class RangeContainedBy(lookups.PostgresSimpleLookup):
219 lookup_name = "contained_by"
220 type_mapping = {
221 "integer": "int4range",
222 "bigint": "int8range",
223 "double precision": "numrange",
224 "date": "daterange",
225 "timestamp with time zone": "tstzrange",
226 }
227 operator = RangeOperators.CONTAINED_BY
228
229 def process_rhs(self, compiler, connection):
230 rhs, rhs_params = super().process_rhs(compiler, connection)
231 cast_type = self.type_mapping[self.lhs.output_field.db_type(connection)]
232 return "%s::%s" % (rhs, cast_type), rhs_params
233
234 def process_lhs(self, compiler, connection):
235 lhs, lhs_params = super().process_lhs(compiler, connection)
236 if isinstance(self.lhs.output_field, models.FloatField):
237 lhs = "%s::numeric" % lhs
238 return lhs, lhs_params
239
240 def get_prep_lookup(self):
241 return RangeField().get_prep_value(self.rhs)
242
243
244models.DateField.register_lookup(RangeContainedBy)
245models.DateTimeField.register_lookup(RangeContainedBy)
246models.IntegerField.register_lookup(RangeContainedBy)
247models.BigIntegerField.register_lookup(RangeContainedBy)
248models.FloatField.register_lookup(RangeContainedBy)
249
250
251@RangeField.register_lookup
252class FullyLessThan(lookups.PostgresSimpleLookup):
253 lookup_name = "fully_lt"
254 operator = RangeOperators.FULLY_LT
255
256
257@RangeField.register_lookup
258class FullGreaterThan(lookups.PostgresSimpleLookup):
259 lookup_name = "fully_gt"
260 operator = RangeOperators.FULLY_GT
261
262
263@RangeField.register_lookup
264class NotLessThan(lookups.PostgresSimpleLookup):
265 lookup_name = "not_lt"
266 operator = RangeOperators.NOT_LT
267
268
269@RangeField.register_lookup
270class NotGreaterThan(lookups.PostgresSimpleLookup):
271 lookup_name = "not_gt"
272 operator = RangeOperators.NOT_GT
273
274
275@RangeField.register_lookup
276class AdjacentToLookup(lookups.PostgresSimpleLookup):
277 lookup_name = "adjacent_to"
278 operator = RangeOperators.ADJACENT_TO
279
280
281@RangeField.register_lookup
282class RangeStartsWith(models.Transform):
283 lookup_name = "startswith"
284 function = "lower"
285
286 @property
287 def output_field(self):
288 return self.lhs.output_field.base_field
289
290
291@RangeField.register_lookup
292class RangeEndsWith(models.Transform):
293 lookup_name = "endswith"
294 function = "upper"
295
296 @property
297 def output_field(self):
298 return self.lhs.output_field.base_field
299
300
301@RangeField.register_lookup
302class IsEmpty(models.Transform):
303 lookup_name = "isempty"
304 function = "isempty"
305 output_field = models.BooleanField()
306
307
308@RangeField.register_lookup
309class LowerInclusive(models.Transform):
310 lookup_name = "lower_inc"
311 function = "LOWER_INC"
312 output_field = models.BooleanField()
313
314
315@RangeField.register_lookup
316class LowerInfinite(models.Transform):
317 lookup_name = "lower_inf"
318 function = "LOWER_INF"
319 output_field = models.BooleanField()
320
321
322@RangeField.register_lookup
323class UpperInclusive(models.Transform):
324 lookup_name = "upper_inc"
325 function = "UPPER_INC"
326 output_field = models.BooleanField()
327
328
329@RangeField.register_lookup
330class UpperInfinite(models.Transform):
331 lookup_name = "upper_inf"
332 function = "UPPER_INF"
333 output_field = models.BooleanField()
tests/postgres_tests/models.py ¶
1from django.core.serializers.json import DjangoJSONEncoder
2from django.db import models
3
4from .fields import ArrayField
5from .fields import BigIntegerRangeField
6from .fields import CICharField
7from .fields import CIEmailField
8from .fields import CITextField
9from .fields import DateRangeField
10from .fields import DateTimeRangeField
11from .fields import DecimalRangeField
12from .fields import EnumField
13from .fields import HStoreField
14from .fields import IntegerRangeField
15from .fields import JSONField
16from .fields import SearchVectorField
17
18
19class Tag:
20 def __init__(self, tag_id):
21 self.tag_id = tag_id
22
23 def __eq__(self, other):
24 return isinstance(other, Tag) and self.tag_id == other.tag_id
25
26
27class TagField(models.SmallIntegerField):
28 def from_db_value(self, value, expression, connection):
29 if value is None:
30 return value
31 return Tag(int(value))
32
33 def to_python(self, value):
34 if isinstance(value, Tag):
35 return value
36 if value is None:
37 return value
38 return Tag(int(value))
39
40 def get_prep_value(self, value):
41 return value.tag_id
42
43
44class PostgreSQLModel(models.Model):
45 class Meta:
46 abstract = True
47 required_db_vendor = "postgresql"
48
49
50class IntegerArrayModel(PostgreSQLModel):
51 field = ArrayField(models.IntegerField(), default=list, blank=True)
52
53
54class NullableIntegerArrayModel(PostgreSQLModel):
55 field = ArrayField(models.IntegerField(), blank=True, null=True)
56 field_nested = ArrayField(ArrayField(models.IntegerField(null=True)), null=True)
57
58
59class CharArrayModel(PostgreSQLModel):
60 field = ArrayField(models.CharField(max_length=10))
61
62
63class DateTimeArrayModel(PostgreSQLModel):
64 datetimes = ArrayField(models.DateTimeField())
65 dates = ArrayField(models.DateField())
66 times = ArrayField(models.TimeField())
67
68
69class NestedIntegerArrayModel(PostgreSQLModel):
70 field = ArrayField(ArrayField(models.IntegerField()))
71
72
73class OtherTypesArrayModel(PostgreSQLModel):
74 ips = ArrayField(models.GenericIPAddressField(), default=list)
75 uuids = ArrayField(models.UUIDField(), default=list)
76 decimals = ArrayField(
77 models.DecimalField(max_digits=5, decimal_places=2), default=list
78 )
79 tags = ArrayField(TagField(), blank=True, null=True)
80 json = ArrayField(JSONField(default=dict), default=list)
81 int_ranges = ArrayField(IntegerRangeField(), blank=True, null=True)
82 bigint_ranges = ArrayField(BigIntegerRangeField(), blank=True, null=True)
83
84
85class HStoreModel(PostgreSQLModel):
86 field = HStoreField(blank=True, null=True)
87 array_field = ArrayField(HStoreField(), null=True)
88
89
90class ArrayEnumModel(PostgreSQLModel):
91 array_of_enums = ArrayField(EnumField(max_length=20))
92
93
94class CharFieldModel(models.Model):
95 field = models.CharField(max_length=16)
96
97
98class TextFieldModel(models.Model):
99 field = models.TextField()
100
101 def __str__(self):
102 return self.field
103
104
105# Scene/Character/Line models are used to test full text search. They're
106# populated with content from Monty Python and the Holy Grail.
107class Scene(models.Model):
108 scene = models.CharField(max_length=255)
109 setting = models.CharField(max_length=255)
110
111 def __str__(self):
112 return self.scene
113
114
115class Character(models.Model):
116 name = models.CharField(max_length=255)
117
118 def __str__(self):
119 return self.name
120
121
122class CITestModel(PostgreSQLModel):
123 name = CICharField(primary_key=True, max_length=255)
124 email = CIEmailField()
125 description = CITextField()
126 array_field = ArrayField(CITextField(), null=True)
127
128 def __str__(self):
129 return self.name
130
131
132class Line(PostgreSQLModel):
133 scene = models.ForeignKey("Scene", models.CASCADE)
134 character = models.ForeignKey("Character", models.CASCADE)
135 dialogue = models.TextField(blank=True, null=True)
136 dialogue_search_vector = SearchVectorField(blank=True, null=True)
137 dialogue_config = models.CharField(max_length=100, blank=True, null=True)
138
139 def __str__(self):
140 return self.dialogue or ""
141
142
143class RangesModel(PostgreSQLModel):
144 ints = IntegerRangeField(blank=True, null=True)
145 bigints = BigIntegerRangeField(blank=True, null=True)
146 decimals = DecimalRangeField(blank=True, null=True)
147 timestamps = DateTimeRangeField(blank=True, null=True)
148 timestamps_inner = DateTimeRangeField(blank=True, null=True)
149 dates = DateRangeField(blank=True, null=True)
150 dates_inner = DateRangeField(blank=True, null=True)
151
152
153class RangeLookupsModel(PostgreSQLModel):
154 parent = models.ForeignKey(RangesModel, models.SET_NULL, blank=True, null=True)
155 integer = models.IntegerField(blank=True, null=True)
156 big_integer = models.BigIntegerField(blank=True, null=True)
157 float = models.FloatField(blank=True, null=True)
158 timestamp = models.DateTimeField(blank=True, null=True)
159 date = models.DateField(blank=True, null=True)
160
161
162class JSONModel(PostgreSQLModel):
163 field = JSONField(blank=True, null=True)
164 field_custom = JSONField(blank=True, null=True, encoder=DjangoJSONEncoder)
165
166
167class ArrayFieldSubclass(ArrayField):
168 def __init__(self, *args, **kwargs):
169 super().__init__(models.IntegerField())
170
171
172class AggregateTestModel(models.Model):
173 """
174 To test postgres-specific general aggregation functions
175 """
176
177 char_field = models.CharField(max_length=30, blank=True)
178 integer_field = models.IntegerField(null=True)
179 boolean_field = models.BooleanField(null=True)
180
181
182class StatTestModel(models.Model):
183 """
184 To test postgres-specific aggregation functions for statistics
185 """
186
187 int1 = models.IntegerField()
188 int2 = models.IntegerField()
189 related_field = models.ForeignKey(AggregateTestModel, models.SET_NULL, null=True)
190
191
192class NowTestModel(models.Model):
193 when = models.DateTimeField(null=True, default=None)
194
195
196class UUIDTestModel(models.Model):
197 uuid = models.UUIDField(default=None, null=True)
198
199
200class Room(models.Model):
201 number = models.IntegerField(unique=True)
202
203
204class HotelReservation(PostgreSQLModel):
205 room = models.ForeignKey("Room", on_delete=models.CASCADE)
206 datespan = DateRangeField()
207 start = models.DateTimeField()
208 end = models.DateTimeField()
209 cancelled = models.BooleanField(default=False)
tests/postgres_tests/test_constraints.py ¶
1import datetime
2from unittest import mock
3
4from django.db import connection
5from django.db import transaction
6from django.db.models import F
7from django.db.models import Func
8from django.db.models import Q
9from django.db.models.constraints import CheckConstraint
10from django.db.utils import IntegrityError
11from django.utils import timezone
12
13from . import PostgreSQLTestCase
14from .models import HotelReservation
15from .models import RangesModel
16from .models import Room
17
18try:
19 from django.contrib.postgres.constraints import ExclusionConstraint
20 from django.contrib.postgres.fields import (
21 DateTimeRangeField,
22 RangeBoundary,
23 RangeOperators,
24 )
25
26 from psycopg2.extras import DateRange, NumericRange
27except ImportError:
28 pass
29
30
31class SchemaTests(PostgreSQLTestCase):
32 def get_constraints(self, table):
33 """Get the constraints on the table using a new cursor."""
34 with connection.cursor() as cursor:
35 return connection.introspection.get_constraints(cursor, table)
36
37 def test_check_constraint_range_value(self):
38 constraint_name = "ints_between"
39 self.assertNotIn(
40 constraint_name, self.get_constraints(RangesModel._meta.db_table)
41 )
42 constraint = CheckConstraint(
43 check=Q(ints__contained_by=NumericRange(10, 30)),
44 name=constraint_name,
45 )
46 with connection.schema_editor() as editor:
47 editor.add_constraint(RangesModel, constraint)
48 self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
49 with self.assertRaises(IntegrityError), transaction.atomic():
50 RangesModel.objects.create(ints=(20, 50))
51 RangesModel.objects.create(ints=(10, 30))
52
53 def test_check_constraint_daterange_contains(self):
54 constraint_name = "dates_contains"
55 self.assertNotIn(
56 constraint_name, self.get_constraints(RangesModel._meta.db_table)
57 )
58 constraint = CheckConstraint(
59 check=Q(dates__contains=F("dates_inner")),
60 name=constraint_name,
61 )
62 with connection.schema_editor() as editor:
63 editor.add_constraint(RangesModel, constraint)
64 self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
65 date_1 = datetime.date(2016, 1, 1)
66 date_2 = datetime.date(2016, 1, 4)
67 with self.assertRaises(IntegrityError), transaction.atomic():
68 RangesModel.objects.create(
69 dates=(date_1, date_2),
70 dates_inner=(date_1, date_2.replace(day=5)),
71 )
72 RangesModel.objects.create(
73 dates=(date_1, date_2),
74 dates_inner=(date_1, date_2),
75 )
76
77 def test_check_constraint_datetimerange_contains(self):
78 constraint_name = "timestamps_contains"
79 self.assertNotIn(
80 constraint_name, self.get_constraints(RangesModel._meta.db_table)
81 )
82 constraint = CheckConstraint(
83 check=Q(timestamps__contains=F("timestamps_inner")),
84 name=constraint_name,
85 )
86 with connection.schema_editor() as editor:
87 editor.add_constraint(RangesModel, constraint)
88 self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
89 datetime_1 = datetime.datetime(2016, 1, 1)
90 datetime_2 = datetime.datetime(2016, 1, 2, 12)
91 with self.assertRaises(IntegrityError), transaction.atomic():
92 RangesModel.objects.create(
93 timestamps=(datetime_1, datetime_2),
94 timestamps_inner=(datetime_1, datetime_2.replace(hour=13)),
95 )
96 RangesModel.objects.create(
97 timestamps=(datetime_1, datetime_2),
98 timestamps_inner=(datetime_1, datetime_2),
99 )
100
101
102class ExclusionConstraintTests(PostgreSQLTestCase):
103 def get_constraints(self, table):
104 """Get the constraints on the table using a new cursor."""
105 with connection.cursor() as cursor:
106 return connection.introspection.get_constraints(cursor, table)
107
108 def test_invalid_condition(self):
109 msg = "ExclusionConstraint.condition must be a Q instance."
110 with self.assertRaisesMessage(ValueError, msg):
111 ExclusionConstraint(
112 index_type="GIST",
113 name="exclude_invalid_condition",
114 expressions=[(F("datespan"), RangeOperators.OVERLAPS)],
115 condition=F("invalid"),
116 )
117
118 def test_invalid_index_type(self):
119 msg = "Exclusion constraints only support GiST or SP-GiST indexes."
120 with self.assertRaisesMessage(ValueError, msg):
121 ExclusionConstraint(
122 index_type="gin",
123 name="exclude_invalid_index_type",
124 expressions=[(F("datespan"), RangeOperators.OVERLAPS)],
125 )
126
127 def test_invalid_expressions(self):
128 msg = "The expressions must be a list of 2-tuples."
129 for expressions in (["foo"], [("foo")], [("foo_1", "foo_2", "foo_3")]):
130 with self.subTest(expressions), self.assertRaisesMessage(ValueError, msg):
131 ExclusionConstraint(
132 index_type="GIST",
133 name="exclude_invalid_expressions",
134 expressions=expressions,
135 )
136
137 def test_empty_expressions(self):
138 msg = "At least one expression is required to define an exclusion constraint."
139 for empty_expressions in (None, []):
140 with self.subTest(empty_expressions), self.assertRaisesMessage(
141 ValueError, msg
142 ):
143 ExclusionConstraint(
144 index_type="GIST",
145 name="exclude_empty_expressions",
146 expressions=empty_expressions,
147 )
148
149 def test_repr(self):
150 constraint = ExclusionConstraint(
151 name="exclude_overlapping",
152 expressions=[
153 (F("datespan"), RangeOperators.OVERLAPS),
154 (F("room"), RangeOperators.EQUAL),
155 ],
156 )
157 self.assertEqual(
158 repr(constraint),
159 "<ExclusionConstraint: index_type=GIST, expressions=["
160 "(F(datespan), '&&'), (F(room), '=')]>",
161 )
162 constraint = ExclusionConstraint(
163 name="exclude_overlapping",
164 expressions=[(F("datespan"), RangeOperators.ADJACENT_TO)],
165 condition=Q(cancelled=False),
166 index_type="SPGiST",
167 )
168 self.assertEqual(
169 repr(constraint),
170 "<ExclusionConstraint: index_type=SPGiST, expressions=["
171 "(F(datespan), '-|-')], condition=(AND: ('cancelled', False))>",
172 )
173
174 def test_eq(self):
175 constraint_1 = ExclusionConstraint(
176 name="exclude_overlapping",
177 expressions=[
178 (F("datespan"), RangeOperators.OVERLAPS),
179 (F("room"), RangeOperators.EQUAL),
180 ],
181 condition=Q(cancelled=False),
182 )
183 constraint_2 = ExclusionConstraint(
184 name="exclude_overlapping",
185 expressions=[
186 ("datespan", RangeOperators.OVERLAPS),
187 ("room", RangeOperators.EQUAL),
188 ],
189 )
190 constraint_3 = ExclusionConstraint(
191 name="exclude_overlapping",
192 expressions=[("datespan", RangeOperators.OVERLAPS)],
193 condition=Q(cancelled=False),
194 )
195 self.assertEqual(constraint_1, constraint_1)
196 self.assertEqual(constraint_1, mock.ANY)
197 self.assertNotEqual(constraint_1, constraint_2)
198 self.assertNotEqual(constraint_1, constraint_3)
199 self.assertNotEqual(constraint_2, constraint_3)
200 self.assertNotEqual(constraint_1, object())
201
202 def test_deconstruct(self):
203 constraint = ExclusionConstraint(
204 name="exclude_overlapping",
205 expressions=[
206 ("datespan", RangeOperators.OVERLAPS),
207 ("room", RangeOperators.EQUAL),
208 ],
209 )
210 path, args, kwargs = constraint.deconstruct()
211 self.assertEqual(
212 path, "django.contrib.postgres.constraints.ExclusionConstraint"
213 )
214 self.assertEqual(args, ())
215 self.assertEqual(
216 kwargs,
217 {
218 "name": "exclude_overlapping",
219 "expressions": [
220 ("datespan", RangeOperators.OVERLAPS),
221 ("room", RangeOperators.EQUAL),
222 ],
223 },
224 )
225
226 def test_deconstruct_index_type(self):
227 constraint = ExclusionConstraint(
228 name="exclude_overlapping",
229 index_type="SPGIST",
230 expressions=[
231 ("datespan", RangeOperators.OVERLAPS),
232 ("room", RangeOperators.EQUAL),
233 ],
234 )
235 path, args, kwargs = constraint.deconstruct()
236 self.assertEqual(
237 path, "django.contrib.postgres.constraints.ExclusionConstraint"
238 )
239 self.assertEqual(args, ())
240 self.assertEqual(
241 kwargs,
242 {
243 "name": "exclude_overlapping",
244 "index_type": "SPGIST",
245 "expressions": [
246 ("datespan", RangeOperators.OVERLAPS),
247 ("room", RangeOperators.EQUAL),
248 ],
249 },
250 )
251
252 def test_deconstruct_condition(self):
253 constraint = ExclusionConstraint(
254 name="exclude_overlapping",
255 expressions=[
256 ("datespan", RangeOperators.OVERLAPS),
257 ("room", RangeOperators.EQUAL),
258 ],
259 condition=Q(cancelled=False),
260 )
261 path, args, kwargs = constraint.deconstruct()
262 self.assertEqual(
263 path, "django.contrib.postgres.constraints.ExclusionConstraint"
264 )
265 self.assertEqual(args, ())
266 self.assertEqual(
267 kwargs,
268 {
269 "name": "exclude_overlapping",
270 "expressions": [
271 ("datespan", RangeOperators.OVERLAPS),
272 ("room", RangeOperators.EQUAL),
273 ],
274 "condition": Q(cancelled=False),
275 },
276 )
277
278 def _test_range_overlaps(self, constraint):
279 # Create exclusion constraint.
280 self.assertNotIn(
281 constraint.name, self.get_constraints(HotelReservation._meta.db_table)
282 )
283 with connection.schema_editor() as editor:
284 editor.add_constraint(HotelReservation, constraint)
285 self.assertIn(
286 constraint.name, self.get_constraints(HotelReservation._meta.db_table)
287 )
288 # Add initial reservations.
289 room101 = Room.objects.create(number=101)
290 room102 = Room.objects.create(number=102)
291 datetimes = [
292 timezone.datetime(2018, 6, 20),
293 timezone.datetime(2018, 6, 24),
294 timezone.datetime(2018, 6, 26),
295 timezone.datetime(2018, 6, 28),
296 timezone.datetime(2018, 6, 29),
297 ]
298 HotelReservation.objects.create(
299 datespan=DateRange(datetimes[0].date(), datetimes[1].date()),
300 start=datetimes[0],
301 end=datetimes[1],
302 room=room102,
303 )
304 HotelReservation.objects.create(
305 datespan=DateRange(datetimes[1].date(), datetimes[3].date()),
306 start=datetimes[1],
307 end=datetimes[3],
308 room=room102,
309 )
310 # Overlap dates.
311 with self.assertRaises(IntegrityError), transaction.atomic():
312 reservation = HotelReservation(
313 datespan=(datetimes[1].date(), datetimes[2].date()),
314 start=datetimes[1],
315 end=datetimes[2],
316 room=room102,
317 )
318 reservation.save()
319 # Valid range.
320 HotelReservation.objects.bulk_create(
321 [
322 # Other room.
323 HotelReservation(
324 datespan=(datetimes[1].date(), datetimes[2].date()),
325 start=datetimes[1],
326 end=datetimes[2],
327 room=room101,
328 ),
329 # Cancelled reservation.
330 HotelReservation(
331 datespan=(datetimes[1].date(), datetimes[1].date()),
332 start=datetimes[1],
333 end=datetimes[2],
334 room=room102,
335 cancelled=True,
336 ),
337 # Other adjacent dates.
338 HotelReservation(
339 datespan=(datetimes[3].date(), datetimes[4].date()),
340 start=datetimes[3],
341 end=datetimes[4],
342 room=room102,
343 ),
344 ]
345 )
346
347 def test_range_overlaps_custom(self):
348 class TsTzRange(Func):
349 function = "TSTZRANGE"
350 output_field = DateTimeRangeField()
351
352 constraint = ExclusionConstraint(
353 name="exclude_overlapping_reservations_custom",
354 expressions=[
355 (TsTzRange("start", "end", RangeBoundary()), RangeOperators.OVERLAPS),
356 ("room", RangeOperators.EQUAL),
357 ],
358 condition=Q(cancelled=False),
359 )
360 self._test_range_overlaps(constraint)
361
362 def test_range_overlaps(self):
363 constraint = ExclusionConstraint(
364 name="exclude_overlapping_reservations",
365 expressions=[
366 (F("datespan"), RangeOperators.OVERLAPS),
367 ("room", RangeOperators.EQUAL),
368 ],
369 condition=Q(cancelled=False),
370 )
371 self._test_range_overlaps(constraint)
372
373 def test_range_adjacent(self):
374 constraint_name = "ints_adjacent"
375 self.assertNotIn(
376 constraint_name, self.get_constraints(RangesModel._meta.db_table)
377 )
378 constraint = ExclusionConstraint(
379 name=constraint_name,
380 expressions=[("ints", RangeOperators.ADJACENT_TO)],
381 )
382 with connection.schema_editor() as editor:
383 editor.add_constraint(RangesModel, constraint)
384 self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
385 RangesModel.objects.create(ints=(20, 50))
386 with self.assertRaises(IntegrityError), transaction.atomic():
387 RangesModel.objects.create(ints=(10, 20))
388 RangesModel.objects.create(ints=(10, 19))
389 RangesModel.objects.create(ints=(51, 60))