PostgreSQL RangeOperators

Description

PostgreSQL provides a set of SQL operators that can be used together with the range data types (see the PostgreSQL documentation for the full details of range operators).

This class is meant as a convenient method to avoid typos.

The operator names overlap with the names of corresponding lookups

class RangeOperators:
    EQUAL = '='
    NOT_EQUAL = '<>'
    CONTAINS = '@>'
    CONTAINED_BY = '<@'
    OVERLAPS = '&&'
    FULLY_LT = '<<'
    FULLY_GT = '>>'
    NOT_LT = '&>'
    NOT_GT = '&<'
    ADJACENT_TO = '-|-'

https://twitter.com/l_avrot and https://twitter.com/be_haki

A reader from Reddit just pointed out that in #PostgreSQL to filter a date range you can use a range type:

SELECT * FROM sales WHERE created <@ daterange(date '2019-01-01', date '2020-01-01', '[)');

Nice!

../../../../../../_images/l_avrot_operators.png

I call @> and <@ the bird operators because they look like birds and they’re so cute!

django/contrib/postgres/fields/ranges.py

  1import datetime
  2import json
  3
  4from django.contrib.postgres import forms
  5from django.contrib.postgres import lookups
  6from django.db import models
  7from psycopg2.extras import DateRange
  8from psycopg2.extras import DateTimeTZRange
  9from psycopg2.extras import NumericRange
 10from psycopg2.extras import Range
 11
 12from .utils import AttributeSetter
 13
 14__all__ = [
 15    "RangeField",
 16    "IntegerRangeField",
 17    "BigIntegerRangeField",
 18    "DecimalRangeField",
 19    "DateTimeRangeField",
 20    "DateRangeField",
 21    "RangeBoundary",
 22    "RangeOperators",
 23]
 24
 25
 26class RangeBoundary(models.Expression):
 27    """A class that represents range boundaries."""
 28
 29    def __init__(self, inclusive_lower=True, inclusive_upper=False):
 30        self.lower = "[" if inclusive_lower else "("
 31        self.upper = "]" if inclusive_upper else ")"
 32
 33    def as_sql(self, compiler, connection):
 34        return "'%s%s'" % (self.lower, self.upper), []
 35
 36
 37class RangeOperators:
 38    # https://www.postgresql.org/docs/current/functions-range.html#RANGE-OPERATORS-TABLE
 39    EQUAL = "="
 40    NOT_EQUAL = "<>"
 41    CONTAINS = "@>"
 42    CONTAINED_BY = "<@"
 43    OVERLAPS = "&&"
 44    FULLY_LT = "<<"
 45    FULLY_GT = ">>"
 46    NOT_LT = "&>"
 47    NOT_GT = "&<"
 48    ADJACENT_TO = "-|-"
 49
 50
 51class RangeField(models.Field):
 52    empty_strings_allowed = False
 53
 54    def __init__(self, *args, **kwargs):
 55        # Initializing base_field here ensures that its model matches the model for self.
 56        if hasattr(self, "base_field"):
 57            self.base_field = self.base_field()
 58        super().__init__(*args, **kwargs)
 59
 60    @property
 61    def model(self):
 62        try:
 63            return self.__dict__["model"]
 64        except KeyError:
 65            raise AttributeError(
 66                "'%s' object has no attribute 'model'" % self.__class__.__name__
 67            )
 68
 69    @model.setter
 70    def model(self, model):
 71        self.__dict__["model"] = model
 72        self.base_field.model = model
 73
 74    @classmethod
 75    def _choices_is_value(cls, value):
 76        return isinstance(value, (list, tuple)) or super()._choices_is_value(value)
 77
 78    def get_prep_value(self, value):
 79        if value is None:
 80            return None
 81        elif isinstance(value, Range):
 82            return value
 83        elif isinstance(value, (list, tuple)):
 84            return self.range_type(value[0], value[1])
 85        return value
 86
 87    def to_python(self, value):
 88        if isinstance(value, str):
 89            # Assume we're deserializing
 90            vals = json.loads(value)
 91            for end in ("lower", "upper"):
 92                if end in vals:
 93                    vals[end] = self.base_field.to_python(vals[end])
 94            value = self.range_type(**vals)
 95        elif isinstance(value, (list, tuple)):
 96            value = self.range_type(value[0], value[1])
 97        return value
 98
 99    def set_attributes_from_name(self, name):
100        super().set_attributes_from_name(name)
101        self.base_field.set_attributes_from_name(name)
102
103    def value_to_string(self, obj):
104        value = self.value_from_object(obj)
105        if value is None:
106            return None
107        if value.isempty:
108            return json.dumps({"empty": True})
109        base_field = self.base_field
110        result = {"bounds": value._bounds}
111        for end in ("lower", "upper"):
112            val = getattr(value, end)
113            if val is None:
114                result[end] = None
115            else:
116                obj = AttributeSetter(base_field.attname, val)
117                result[end] = base_field.value_to_string(obj)
118        return json.dumps(result)
119
120    def formfield(self, **kwargs):
121        kwargs.setdefault("form_class", self.form_field)
122        return super().formfield(**kwargs)
123
124
125class IntegerRangeField(RangeField):
126    base_field = models.IntegerField
127    range_type = NumericRange
128    form_field = forms.IntegerRangeField
129
130    def db_type(self, connection):
131        return "int4range"
132
133
134class BigIntegerRangeField(RangeField):
135    base_field = models.BigIntegerField
136    range_type = NumericRange
137    form_field = forms.IntegerRangeField
138
139    def db_type(self, connection):
140        return "int8range"
141
142
143class DecimalRangeField(RangeField):
144    base_field = models.DecimalField
145    range_type = NumericRange
146    form_field = forms.DecimalRangeField
147
148    def db_type(self, connection):
149        return "numrange"
150
151
152class DateTimeRangeField(RangeField):
153    base_field = models.DateTimeField
154    range_type = DateTimeTZRange
155    form_field = forms.DateTimeRangeField
156
157    def db_type(self, connection):
158        return "tstzrange"
159
160
161class DateRangeField(RangeField):
162    base_field = models.DateField
163    range_type = DateRange
164    form_field = forms.DateRangeField
165
166    def db_type(self, connection):
167        return "daterange"
168
169
170RangeField.register_lookup(lookups.DataContains)
171RangeField.register_lookup(lookups.ContainedBy)
172RangeField.register_lookup(lookups.Overlap)
173
174
175class DateTimeRangeContains(lookups.PostgresSimpleLookup):
176    """
177    Lookup for Date/DateTimeRange containment to cast the rhs to the correct
178    type.
179    """
180
181    lookup_name = "contains"
182    operator = RangeOperators.CONTAINS
183
184    def process_rhs(self, compiler, connection):
185        # Transform rhs value for db lookup.
186        if isinstance(self.rhs, datetime.date):
187            output_field = (
188                models.DateTimeField()
189                if isinstance(self.rhs, datetime.datetime)
190                else models.DateField()
191            )
192            value = models.Value(self.rhs, output_field=output_field)
193            self.rhs = value.resolve_expression(compiler.query)
194        return super().process_rhs(compiler, connection)
195
196    def as_sql(self, compiler, connection):
197        sql, params = super().as_sql(compiler, connection)
198        # Cast the rhs if needed.
199        cast_sql = ""
200        if (
201            isinstance(self.rhs, models.Expression)
202            and self.rhs._output_field_or_none
203            and
204            # Skip cast if rhs has a matching range type.
205            not isinstance(
206                self.rhs._output_field_or_none, self.lhs.output_field.__class__
207            )
208        ):
209            cast_internal_type = self.lhs.output_field.base_field.get_internal_type()
210            cast_sql = "::{}".format(connection.data_types.get(cast_internal_type))
211        return "%s%s" % (sql, cast_sql), params
212
213
214DateRangeField.register_lookup(DateTimeRangeContains)
215DateTimeRangeField.register_lookup(DateTimeRangeContains)
216
217
218class RangeContainedBy(lookups.PostgresSimpleLookup):
219    lookup_name = "contained_by"
220    type_mapping = {
221        "integer": "int4range",
222        "bigint": "int8range",
223        "double precision": "numrange",
224        "date": "daterange",
225        "timestamp with time zone": "tstzrange",
226    }
227    operator = RangeOperators.CONTAINED_BY
228
229    def process_rhs(self, compiler, connection):
230        rhs, rhs_params = super().process_rhs(compiler, connection)
231        cast_type = self.type_mapping[self.lhs.output_field.db_type(connection)]
232        return "%s::%s" % (rhs, cast_type), rhs_params
233
234    def process_lhs(self, compiler, connection):
235        lhs, lhs_params = super().process_lhs(compiler, connection)
236        if isinstance(self.lhs.output_field, models.FloatField):
237            lhs = "%s::numeric" % lhs
238        return lhs, lhs_params
239
240    def get_prep_lookup(self):
241        return RangeField().get_prep_value(self.rhs)
242
243
244models.DateField.register_lookup(RangeContainedBy)
245models.DateTimeField.register_lookup(RangeContainedBy)
246models.IntegerField.register_lookup(RangeContainedBy)
247models.BigIntegerField.register_lookup(RangeContainedBy)
248models.FloatField.register_lookup(RangeContainedBy)
249
250
251@RangeField.register_lookup
252class FullyLessThan(lookups.PostgresSimpleLookup):
253    lookup_name = "fully_lt"
254    operator = RangeOperators.FULLY_LT
255
256
257@RangeField.register_lookup
258class FullGreaterThan(lookups.PostgresSimpleLookup):
259    lookup_name = "fully_gt"
260    operator = RangeOperators.FULLY_GT
261
262
263@RangeField.register_lookup
264class NotLessThan(lookups.PostgresSimpleLookup):
265    lookup_name = "not_lt"
266    operator = RangeOperators.NOT_LT
267
268
269@RangeField.register_lookup
270class NotGreaterThan(lookups.PostgresSimpleLookup):
271    lookup_name = "not_gt"
272    operator = RangeOperators.NOT_GT
273
274
275@RangeField.register_lookup
276class AdjacentToLookup(lookups.PostgresSimpleLookup):
277    lookup_name = "adjacent_to"
278    operator = RangeOperators.ADJACENT_TO
279
280
281@RangeField.register_lookup
282class RangeStartsWith(models.Transform):
283    lookup_name = "startswith"
284    function = "lower"
285
286    @property
287    def output_field(self):
288        return self.lhs.output_field.base_field
289
290
291@RangeField.register_lookup
292class RangeEndsWith(models.Transform):
293    lookup_name = "endswith"
294    function = "upper"
295
296    @property
297    def output_field(self):
298        return self.lhs.output_field.base_field
299
300
301@RangeField.register_lookup
302class IsEmpty(models.Transform):
303    lookup_name = "isempty"
304    function = "isempty"
305    output_field = models.BooleanField()
306
307
308@RangeField.register_lookup
309class LowerInclusive(models.Transform):
310    lookup_name = "lower_inc"
311    function = "LOWER_INC"
312    output_field = models.BooleanField()
313
314
315@RangeField.register_lookup
316class LowerInfinite(models.Transform):
317    lookup_name = "lower_inf"
318    function = "LOWER_INF"
319    output_field = models.BooleanField()
320
321
322@RangeField.register_lookup
323class UpperInclusive(models.Transform):
324    lookup_name = "upper_inc"
325    function = "UPPER_INC"
326    output_field = models.BooleanField()
327
328
329@RangeField.register_lookup
330class UpperInfinite(models.Transform):
331    lookup_name = "upper_inf"
332    function = "UPPER_INF"
333    output_field = models.BooleanField()

tests/postgres_tests/models.py

  1from django.core.serializers.json import DjangoJSONEncoder
  2from django.db import models
  3
  4from .fields import ArrayField
  5from .fields import BigIntegerRangeField
  6from .fields import CICharField
  7from .fields import CIEmailField
  8from .fields import CITextField
  9from .fields import DateRangeField
 10from .fields import DateTimeRangeField
 11from .fields import DecimalRangeField
 12from .fields import EnumField
 13from .fields import HStoreField
 14from .fields import IntegerRangeField
 15from .fields import JSONField
 16from .fields import SearchVectorField
 17
 18
 19class Tag:
 20    def __init__(self, tag_id):
 21        self.tag_id = tag_id
 22
 23    def __eq__(self, other):
 24        return isinstance(other, Tag) and self.tag_id == other.tag_id
 25
 26
 27class TagField(models.SmallIntegerField):
 28    def from_db_value(self, value, expression, connection):
 29        if value is None:
 30            return value
 31        return Tag(int(value))
 32
 33    def to_python(self, value):
 34        if isinstance(value, Tag):
 35            return value
 36        if value is None:
 37            return value
 38        return Tag(int(value))
 39
 40    def get_prep_value(self, value):
 41        return value.tag_id
 42
 43
 44class PostgreSQLModel(models.Model):
 45    class Meta:
 46        abstract = True
 47        required_db_vendor = "postgresql"
 48
 49
 50class IntegerArrayModel(PostgreSQLModel):
 51    field = ArrayField(models.IntegerField(), default=list, blank=True)
 52
 53
 54class NullableIntegerArrayModel(PostgreSQLModel):
 55    field = ArrayField(models.IntegerField(), blank=True, null=True)
 56    field_nested = ArrayField(ArrayField(models.IntegerField(null=True)), null=True)
 57
 58
 59class CharArrayModel(PostgreSQLModel):
 60    field = ArrayField(models.CharField(max_length=10))
 61
 62
 63class DateTimeArrayModel(PostgreSQLModel):
 64    datetimes = ArrayField(models.DateTimeField())
 65    dates = ArrayField(models.DateField())
 66    times = ArrayField(models.TimeField())
 67
 68
 69class NestedIntegerArrayModel(PostgreSQLModel):
 70    field = ArrayField(ArrayField(models.IntegerField()))
 71
 72
 73class OtherTypesArrayModel(PostgreSQLModel):
 74    ips = ArrayField(models.GenericIPAddressField(), default=list)
 75    uuids = ArrayField(models.UUIDField(), default=list)
 76    decimals = ArrayField(
 77        models.DecimalField(max_digits=5, decimal_places=2), default=list
 78    )
 79    tags = ArrayField(TagField(), blank=True, null=True)
 80    json = ArrayField(JSONField(default=dict), default=list)
 81    int_ranges = ArrayField(IntegerRangeField(), blank=True, null=True)
 82    bigint_ranges = ArrayField(BigIntegerRangeField(), blank=True, null=True)
 83
 84
 85class HStoreModel(PostgreSQLModel):
 86    field = HStoreField(blank=True, null=True)
 87    array_field = ArrayField(HStoreField(), null=True)
 88
 89
 90class ArrayEnumModel(PostgreSQLModel):
 91    array_of_enums = ArrayField(EnumField(max_length=20))
 92
 93
 94class CharFieldModel(models.Model):
 95    field = models.CharField(max_length=16)
 96
 97
 98class TextFieldModel(models.Model):
 99    field = models.TextField()
100
101    def __str__(self):
102        return self.field
103
104
105# Scene/Character/Line models are used to test full text search. They're
106# populated with content from Monty Python and the Holy Grail.
107class Scene(models.Model):
108    scene = models.CharField(max_length=255)
109    setting = models.CharField(max_length=255)
110
111    def __str__(self):
112        return self.scene
113
114
115class Character(models.Model):
116    name = models.CharField(max_length=255)
117
118    def __str__(self):
119        return self.name
120
121
122class CITestModel(PostgreSQLModel):
123    name = CICharField(primary_key=True, max_length=255)
124    email = CIEmailField()
125    description = CITextField()
126    array_field = ArrayField(CITextField(), null=True)
127
128    def __str__(self):
129        return self.name
130
131
132class Line(PostgreSQLModel):
133    scene = models.ForeignKey("Scene", models.CASCADE)
134    character = models.ForeignKey("Character", models.CASCADE)
135    dialogue = models.TextField(blank=True, null=True)
136    dialogue_search_vector = SearchVectorField(blank=True, null=True)
137    dialogue_config = models.CharField(max_length=100, blank=True, null=True)
138
139    def __str__(self):
140        return self.dialogue or ""
141
142
143class RangesModel(PostgreSQLModel):
144    ints = IntegerRangeField(blank=True, null=True)
145    bigints = BigIntegerRangeField(blank=True, null=True)
146    decimals = DecimalRangeField(blank=True, null=True)
147    timestamps = DateTimeRangeField(blank=True, null=True)
148    timestamps_inner = DateTimeRangeField(blank=True, null=True)
149    dates = DateRangeField(blank=True, null=True)
150    dates_inner = DateRangeField(blank=True, null=True)
151
152
153class RangeLookupsModel(PostgreSQLModel):
154    parent = models.ForeignKey(RangesModel, models.SET_NULL, blank=True, null=True)
155    integer = models.IntegerField(blank=True, null=True)
156    big_integer = models.BigIntegerField(blank=True, null=True)
157    float = models.FloatField(blank=True, null=True)
158    timestamp = models.DateTimeField(blank=True, null=True)
159    date = models.DateField(blank=True, null=True)
160
161
162class JSONModel(PostgreSQLModel):
163    field = JSONField(blank=True, null=True)
164    field_custom = JSONField(blank=True, null=True, encoder=DjangoJSONEncoder)
165
166
167class ArrayFieldSubclass(ArrayField):
168    def __init__(self, *args, **kwargs):
169        super().__init__(models.IntegerField())
170
171
172class AggregateTestModel(models.Model):
173    """
174    To test postgres-specific general aggregation functions
175    """
176
177    char_field = models.CharField(max_length=30, blank=True)
178    integer_field = models.IntegerField(null=True)
179    boolean_field = models.BooleanField(null=True)
180
181
182class StatTestModel(models.Model):
183    """
184    To test postgres-specific aggregation functions for statistics
185    """
186
187    int1 = models.IntegerField()
188    int2 = models.IntegerField()
189    related_field = models.ForeignKey(AggregateTestModel, models.SET_NULL, null=True)
190
191
192class NowTestModel(models.Model):
193    when = models.DateTimeField(null=True, default=None)
194
195
196class UUIDTestModel(models.Model):
197    uuid = models.UUIDField(default=None, null=True)
198
199
200class Room(models.Model):
201    number = models.IntegerField(unique=True)
202
203
204class HotelReservation(PostgreSQLModel):
205    room = models.ForeignKey("Room", on_delete=models.CASCADE)
206    datespan = DateRangeField()
207    start = models.DateTimeField()
208    end = models.DateTimeField()
209    cancelled = models.BooleanField(default=False)

tests/postgres_tests/test_constraints.py

  1import datetime
  2from unittest import mock
  3
  4from django.db import connection
  5from django.db import transaction
  6from django.db.models import F
  7from django.db.models import Func
  8from django.db.models import Q
  9from django.db.models.constraints import CheckConstraint
 10from django.db.utils import IntegrityError
 11from django.utils import timezone
 12
 13from . import PostgreSQLTestCase
 14from .models import HotelReservation
 15from .models import RangesModel
 16from .models import Room
 17
 18try:
 19    from django.contrib.postgres.constraints import ExclusionConstraint
 20    from django.contrib.postgres.fields import (
 21        DateTimeRangeField,
 22        RangeBoundary,
 23        RangeOperators,
 24    )
 25
 26    from psycopg2.extras import DateRange, NumericRange
 27except ImportError:
 28    pass
 29
 30
 31class SchemaTests(PostgreSQLTestCase):
 32    def get_constraints(self, table):
 33        """Get the constraints on the table using a new cursor."""
 34        with connection.cursor() as cursor:
 35            return connection.introspection.get_constraints(cursor, table)
 36
 37    def test_check_constraint_range_value(self):
 38        constraint_name = "ints_between"
 39        self.assertNotIn(
 40            constraint_name, self.get_constraints(RangesModel._meta.db_table)
 41        )
 42        constraint = CheckConstraint(
 43            check=Q(ints__contained_by=NumericRange(10, 30)),
 44            name=constraint_name,
 45        )
 46        with connection.schema_editor() as editor:
 47            editor.add_constraint(RangesModel, constraint)
 48        self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
 49        with self.assertRaises(IntegrityError), transaction.atomic():
 50            RangesModel.objects.create(ints=(20, 50))
 51        RangesModel.objects.create(ints=(10, 30))
 52
 53    def test_check_constraint_daterange_contains(self):
 54        constraint_name = "dates_contains"
 55        self.assertNotIn(
 56            constraint_name, self.get_constraints(RangesModel._meta.db_table)
 57        )
 58        constraint = CheckConstraint(
 59            check=Q(dates__contains=F("dates_inner")),
 60            name=constraint_name,
 61        )
 62        with connection.schema_editor() as editor:
 63            editor.add_constraint(RangesModel, constraint)
 64        self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
 65        date_1 = datetime.date(2016, 1, 1)
 66        date_2 = datetime.date(2016, 1, 4)
 67        with self.assertRaises(IntegrityError), transaction.atomic():
 68            RangesModel.objects.create(
 69                dates=(date_1, date_2),
 70                dates_inner=(date_1, date_2.replace(day=5)),
 71            )
 72        RangesModel.objects.create(
 73            dates=(date_1, date_2),
 74            dates_inner=(date_1, date_2),
 75        )
 76
 77    def test_check_constraint_datetimerange_contains(self):
 78        constraint_name = "timestamps_contains"
 79        self.assertNotIn(
 80            constraint_name, self.get_constraints(RangesModel._meta.db_table)
 81        )
 82        constraint = CheckConstraint(
 83            check=Q(timestamps__contains=F("timestamps_inner")),
 84            name=constraint_name,
 85        )
 86        with connection.schema_editor() as editor:
 87            editor.add_constraint(RangesModel, constraint)
 88        self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
 89        datetime_1 = datetime.datetime(2016, 1, 1)
 90        datetime_2 = datetime.datetime(2016, 1, 2, 12)
 91        with self.assertRaises(IntegrityError), transaction.atomic():
 92            RangesModel.objects.create(
 93                timestamps=(datetime_1, datetime_2),
 94                timestamps_inner=(datetime_1, datetime_2.replace(hour=13)),
 95            )
 96        RangesModel.objects.create(
 97            timestamps=(datetime_1, datetime_2),
 98            timestamps_inner=(datetime_1, datetime_2),
 99        )
100
101
102class ExclusionConstraintTests(PostgreSQLTestCase):
103    def get_constraints(self, table):
104        """Get the constraints on the table using a new cursor."""
105        with connection.cursor() as cursor:
106            return connection.introspection.get_constraints(cursor, table)
107
108    def test_invalid_condition(self):
109        msg = "ExclusionConstraint.condition must be a Q instance."
110        with self.assertRaisesMessage(ValueError, msg):
111            ExclusionConstraint(
112                index_type="GIST",
113                name="exclude_invalid_condition",
114                expressions=[(F("datespan"), RangeOperators.OVERLAPS)],
115                condition=F("invalid"),
116            )
117
118    def test_invalid_index_type(self):
119        msg = "Exclusion constraints only support GiST or SP-GiST indexes."
120        with self.assertRaisesMessage(ValueError, msg):
121            ExclusionConstraint(
122                index_type="gin",
123                name="exclude_invalid_index_type",
124                expressions=[(F("datespan"), RangeOperators.OVERLAPS)],
125            )
126
127    def test_invalid_expressions(self):
128        msg = "The expressions must be a list of 2-tuples."
129        for expressions in (["foo"], [("foo")], [("foo_1", "foo_2", "foo_3")]):
130            with self.subTest(expressions), self.assertRaisesMessage(ValueError, msg):
131                ExclusionConstraint(
132                    index_type="GIST",
133                    name="exclude_invalid_expressions",
134                    expressions=expressions,
135                )
136
137    def test_empty_expressions(self):
138        msg = "At least one expression is required to define an exclusion constraint."
139        for empty_expressions in (None, []):
140            with self.subTest(empty_expressions), self.assertRaisesMessage(
141                ValueError, msg
142            ):
143                ExclusionConstraint(
144                    index_type="GIST",
145                    name="exclude_empty_expressions",
146                    expressions=empty_expressions,
147                )
148
149    def test_repr(self):
150        constraint = ExclusionConstraint(
151            name="exclude_overlapping",
152            expressions=[
153                (F("datespan"), RangeOperators.OVERLAPS),
154                (F("room"), RangeOperators.EQUAL),
155            ],
156        )
157        self.assertEqual(
158            repr(constraint),
159            "<ExclusionConstraint: index_type=GIST, expressions=["
160            "(F(datespan), '&&'), (F(room), '=')]>",
161        )
162        constraint = ExclusionConstraint(
163            name="exclude_overlapping",
164            expressions=[(F("datespan"), RangeOperators.ADJACENT_TO)],
165            condition=Q(cancelled=False),
166            index_type="SPGiST",
167        )
168        self.assertEqual(
169            repr(constraint),
170            "<ExclusionConstraint: index_type=SPGiST, expressions=["
171            "(F(datespan), '-|-')], condition=(AND: ('cancelled', False))>",
172        )
173
174    def test_eq(self):
175        constraint_1 = ExclusionConstraint(
176            name="exclude_overlapping",
177            expressions=[
178                (F("datespan"), RangeOperators.OVERLAPS),
179                (F("room"), RangeOperators.EQUAL),
180            ],
181            condition=Q(cancelled=False),
182        )
183        constraint_2 = ExclusionConstraint(
184            name="exclude_overlapping",
185            expressions=[
186                ("datespan", RangeOperators.OVERLAPS),
187                ("room", RangeOperators.EQUAL),
188            ],
189        )
190        constraint_3 = ExclusionConstraint(
191            name="exclude_overlapping",
192            expressions=[("datespan", RangeOperators.OVERLAPS)],
193            condition=Q(cancelled=False),
194        )
195        self.assertEqual(constraint_1, constraint_1)
196        self.assertEqual(constraint_1, mock.ANY)
197        self.assertNotEqual(constraint_1, constraint_2)
198        self.assertNotEqual(constraint_1, constraint_3)
199        self.assertNotEqual(constraint_2, constraint_3)
200        self.assertNotEqual(constraint_1, object())
201
202    def test_deconstruct(self):
203        constraint = ExclusionConstraint(
204            name="exclude_overlapping",
205            expressions=[
206                ("datespan", RangeOperators.OVERLAPS),
207                ("room", RangeOperators.EQUAL),
208            ],
209        )
210        path, args, kwargs = constraint.deconstruct()
211        self.assertEqual(
212            path, "django.contrib.postgres.constraints.ExclusionConstraint"
213        )
214        self.assertEqual(args, ())
215        self.assertEqual(
216            kwargs,
217            {
218                "name": "exclude_overlapping",
219                "expressions": [
220                    ("datespan", RangeOperators.OVERLAPS),
221                    ("room", RangeOperators.EQUAL),
222                ],
223            },
224        )
225
226    def test_deconstruct_index_type(self):
227        constraint = ExclusionConstraint(
228            name="exclude_overlapping",
229            index_type="SPGIST",
230            expressions=[
231                ("datespan", RangeOperators.OVERLAPS),
232                ("room", RangeOperators.EQUAL),
233            ],
234        )
235        path, args, kwargs = constraint.deconstruct()
236        self.assertEqual(
237            path, "django.contrib.postgres.constraints.ExclusionConstraint"
238        )
239        self.assertEqual(args, ())
240        self.assertEqual(
241            kwargs,
242            {
243                "name": "exclude_overlapping",
244                "index_type": "SPGIST",
245                "expressions": [
246                    ("datespan", RangeOperators.OVERLAPS),
247                    ("room", RangeOperators.EQUAL),
248                ],
249            },
250        )
251
252    def test_deconstruct_condition(self):
253        constraint = ExclusionConstraint(
254            name="exclude_overlapping",
255            expressions=[
256                ("datespan", RangeOperators.OVERLAPS),
257                ("room", RangeOperators.EQUAL),
258            ],
259            condition=Q(cancelled=False),
260        )
261        path, args, kwargs = constraint.deconstruct()
262        self.assertEqual(
263            path, "django.contrib.postgres.constraints.ExclusionConstraint"
264        )
265        self.assertEqual(args, ())
266        self.assertEqual(
267            kwargs,
268            {
269                "name": "exclude_overlapping",
270                "expressions": [
271                    ("datespan", RangeOperators.OVERLAPS),
272                    ("room", RangeOperators.EQUAL),
273                ],
274                "condition": Q(cancelled=False),
275            },
276        )
277
278    def _test_range_overlaps(self, constraint):
279        # Create exclusion constraint.
280        self.assertNotIn(
281            constraint.name, self.get_constraints(HotelReservation._meta.db_table)
282        )
283        with connection.schema_editor() as editor:
284            editor.add_constraint(HotelReservation, constraint)
285        self.assertIn(
286            constraint.name, self.get_constraints(HotelReservation._meta.db_table)
287        )
288        # Add initial reservations.
289        room101 = Room.objects.create(number=101)
290        room102 = Room.objects.create(number=102)
291        datetimes = [
292            timezone.datetime(2018, 6, 20),
293            timezone.datetime(2018, 6, 24),
294            timezone.datetime(2018, 6, 26),
295            timezone.datetime(2018, 6, 28),
296            timezone.datetime(2018, 6, 29),
297        ]
298        HotelReservation.objects.create(
299            datespan=DateRange(datetimes[0].date(), datetimes[1].date()),
300            start=datetimes[0],
301            end=datetimes[1],
302            room=room102,
303        )
304        HotelReservation.objects.create(
305            datespan=DateRange(datetimes[1].date(), datetimes[3].date()),
306            start=datetimes[1],
307            end=datetimes[3],
308            room=room102,
309        )
310        # Overlap dates.
311        with self.assertRaises(IntegrityError), transaction.atomic():
312            reservation = HotelReservation(
313                datespan=(datetimes[1].date(), datetimes[2].date()),
314                start=datetimes[1],
315                end=datetimes[2],
316                room=room102,
317            )
318            reservation.save()
319        # Valid range.
320        HotelReservation.objects.bulk_create(
321            [
322                # Other room.
323                HotelReservation(
324                    datespan=(datetimes[1].date(), datetimes[2].date()),
325                    start=datetimes[1],
326                    end=datetimes[2],
327                    room=room101,
328                ),
329                # Cancelled reservation.
330                HotelReservation(
331                    datespan=(datetimes[1].date(), datetimes[1].date()),
332                    start=datetimes[1],
333                    end=datetimes[2],
334                    room=room102,
335                    cancelled=True,
336                ),
337                # Other adjacent dates.
338                HotelReservation(
339                    datespan=(datetimes[3].date(), datetimes[4].date()),
340                    start=datetimes[3],
341                    end=datetimes[4],
342                    room=room102,
343                ),
344            ]
345        )
346
347    def test_range_overlaps_custom(self):
348        class TsTzRange(Func):
349            function = "TSTZRANGE"
350            output_field = DateTimeRangeField()
351
352        constraint = ExclusionConstraint(
353            name="exclude_overlapping_reservations_custom",
354            expressions=[
355                (TsTzRange("start", "end", RangeBoundary()), RangeOperators.OVERLAPS),
356                ("room", RangeOperators.EQUAL),
357            ],
358            condition=Q(cancelled=False),
359        )
360        self._test_range_overlaps(constraint)
361
362    def test_range_overlaps(self):
363        constraint = ExclusionConstraint(
364            name="exclude_overlapping_reservations",
365            expressions=[
366                (F("datespan"), RangeOperators.OVERLAPS),
367                ("room", RangeOperators.EQUAL),
368            ],
369            condition=Q(cancelled=False),
370        )
371        self._test_range_overlaps(constraint)
372
373    def test_range_adjacent(self):
374        constraint_name = "ints_adjacent"
375        self.assertNotIn(
376            constraint_name, self.get_constraints(RangesModel._meta.db_table)
377        )
378        constraint = ExclusionConstraint(
379            name=constraint_name,
380            expressions=[("ints", RangeOperators.ADJACENT_TO)],
381        )
382        with connection.schema_editor() as editor:
383            editor.add_constraint(RangesModel, constraint)
384        self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
385        RangesModel.objects.create(ints=(20, 50))
386        with self.assertRaises(IntegrityError), transaction.atomic():
387            RangesModel.objects.create(ints=(10, 20))
388        RangesModel.objects.create(ints=(10, 19))
389        RangesModel.objects.create(ints=(51, 60))