PostgreSQL ExclusionConstraint (django/contrib/postgres/constraints.py)

Description

class ExclusionConstraint(*, name, expressions, index_type=None, condition=None)¶

Creates an exclusion constraint in the database.

Internally, PostgreSQL implements exclusion constraints using indexes. The default index type is GiST. To use them, you need to activate the btree_gist extension on PostgreSQL.

You can install it using the BtreeGistExtension migration operation.

If you attempt to insert a new row that conflicts with an existing row, an IntegrityError is raised. Similarly, when update conflicts with an existing row.

name

ExclusionConstraint. name

The name of the constraint.

expressions

ExclusionConstraint. expressions

An iterable of 2-tuples. The first element is an expression or string. The second element is a SQL operator represented as a string. To avoid typos, you may use RangeOperators which maps the operators with strings. For example:

expressions=[
    ('timespan', RangeOperators.ADJACENT_TO),
    (F('room'), RangeOperators.EQUAL),
]

Restrictions on operators.

Only commutative operators can be used in exclusion constraints.

index_type

ExclusionConstraint. index_type

The index type of the constraint. Accepted values are GIST or SPGIST . Matching is case insensitive. If not provided, the default index type is GIST .

condition

ExclusionConstraint. condition

A Q object that specifies the condition to restrict a constraint to a subset of rows. For example, condition=Q(cancelled=False) .

These conditions have the same database restrictions as django.db.models.Index.condition .

Examples

The following example restricts overlapping reservations in the same room, not taking canceled reservations into account:

from django.contrib.postgres.constraints import ExclusionConstraint
from django.contrib.postgres.fields import DateTimeRangeField, RangeOperators
from django.db import models
from django.db.models import Q

class Room(models.Model):
    number = models.IntegerField()


class Reservation(models.Model):
    room = models.ForeignKey('Room', on_delete=models.CASCADE)
    timespan = DateTimeRangeField()
    cancelled = models.BooleanField(default=False)

    class Meta:
        constraints = [
            ExclusionConstraint(
                name='exclude_overlapping_reservations',
                expressions=[
                    ('timespan', RangeOperators.OVERLAPS),
                    ('room', RangeOperators.EQUAL),
                ],
                condition=Q(cancelled=False),
            ),
        ]

In case your model defines a range using two fields, instead of the native PostgreSQL range types, you should write an expression that uses the equivalent function (e.g. TsTzRange() ), and use the delimiters for the field. Most often, the delimiters will be '[)' , meaning that the lower bound is inclusive and the upper bound is exclusive. You may use the RangeBoundary that provides an expression mapping for the range boundaries . For example

 1 from django.contrib.postgres.constraints import ExclusionConstraint
 2 from django.contrib.postgres.fields import (
 3     DateTimeRangeField,
 4     RangeBoundary,
 5     RangeOperators,
 6 )
 7 from django.db import models
 8 from django.db.models import Func, Q
 9
10
11 class TsTzRange(Func):
12     function = 'TSTZRANGE'
13     output_field = DateTimeRangeField()
14
15
16 class Reservation(models.Model):
17     room = models.ForeignKey('Room', on_delete=models.CASCADE)
18     start = models.DateTimeField()
19     end = models.DateTimeField()
20     cancelled = models.BooleanField(default=False)
21
22     class Meta:
23         constraints = [
24             ExclusionConstraint(
25                 name='exclude_overlapping_reservations',
26                 expressions=(
27                     (TsTzRange('start', 'end', RangeBoundary()), RangeOperators.OVERLAPS),
28                     ('room', RangeOperators.EQUAL),
29                 ),
30                 condition=Q(cancelled=False),
31             ),
32         ]

The following example restricts overlapping reservations in the same room, not taking canceled reservations into account:

django/contrib/postgres/constraints.py

  1from django.db.backends.ddl_references import Statement
  2from django.db.backends.ddl_references import Table
  3from django.db.models import F
  4from django.db.models import Q
  5from django.db.models.constraints import BaseConstraint
  6from django.db.models.sql import Query
  7
  8__all__ = ["ExclusionConstraint"]
  9
 10
 11class ExclusionConstraint(BaseConstraint):
 12    template = (
 13        "CONSTRAINT %(name)s EXCLUDE USING %(index_type)s (%(expressions)s)%(where)s"
 14    )
 15
 16    def __init__(self, *, name, expressions, index_type=None, condition=None):
 17        if index_type and index_type.lower() not in {"gist", "spgist"}:
 18            raise ValueError(
 19                "Exclusion constraints only support GiST or SP-GiST indexes."
 20            )
 21        if not expressions:
 22            raise ValueError(
 23                "At least one expression is required to define an exclusion "
 24                "constraint."
 25            )
 26        if not all(
 27            isinstance(expr, (list, tuple)) and len(expr) == 2 for expr in expressions
 28        ):
 29            raise ValueError("The expressions must be a list of 2-tuples.")
 30        if not isinstance(condition, (type(None), Q)):
 31            raise ValueError("ExclusionConstraint.condition must be a Q instance.")
 32        self.expressions = expressions
 33        self.index_type = index_type or "GIST"
 34        self.condition = condition
 35        super().__init__(name=name)
 36
 37    def _get_expression_sql(self, compiler, connection, query):
 38        expressions = []
 39        for expression, operator in self.expressions:
 40            if isinstance(expression, str):
 41                expression = F(expression)
 42            if isinstance(expression, F):
 43                expression = expression.resolve_expression(query=query, simple_col=True)
 44            else:
 45                expression = expression.resolve_expression(query=query)
 46            sql, params = expression.as_sql(compiler, connection)
 47            expressions.append("%s WITH %s" % (sql % params, operator))
 48        return expressions
 49
 50    def _get_condition_sql(self, compiler, schema_editor, query):
 51        if self.condition is None:
 52            return None
 53        where = query.build_where(self.condition)
 54        sql, params = where.as_sql(compiler, schema_editor.connection)
 55        return sql % tuple(schema_editor.quote_value(p) for p in params)
 56
 57    def constraint_sql(self, model, schema_editor):
 58        query = Query(model)
 59        compiler = query.get_compiler(connection=schema_editor.connection)
 60        expressions = self._get_expression_sql(
 61            compiler, schema_editor.connection, query
 62        )
 63        condition = self._get_condition_sql(compiler, schema_editor, query)
 64        return self.template % {
 65            "name": schema_editor.quote_name(self.name),
 66            "index_type": self.index_type,
 67            "expressions": ", ".join(expressions),
 68            "where": " WHERE (%s)" % condition if condition else "",
 69        }
 70
 71    def create_sql(self, model, schema_editor):
 72        return Statement(
 73            "ALTER TABLE %(table)s ADD %(constraint)s",
 74            table=Table(model._meta.db_table, schema_editor.quote_name),
 75            constraint=self.constraint_sql(model, schema_editor),
 76        )
 77
 78    def remove_sql(self, model, schema_editor):
 79        return schema_editor._delete_constraint_sql(
 80            schema_editor.sql_delete_check,
 81            model,
 82            schema_editor.quote_name(self.name),
 83        )
 84
 85    def deconstruct(self):
 86        path, args, kwargs = super().deconstruct()
 87        kwargs["expressions"] = self.expressions
 88        if self.condition is not None:
 89            kwargs["condition"] = self.condition
 90        if self.index_type.lower() != "gist":
 91            kwargs["index_type"] = self.index_type
 92        return path, args, kwargs
 93
 94    def __eq__(self, other):
 95        return (
 96            isinstance(other, self.__class__)
 97            and self.name == other.name
 98            and self.index_type == other.index_type
 99            and self.expressions == other.expressions
100            and self.condition == other.condition
101        )
102
103    def __repr__(self):
104        return "<%s: index_type=%s, expressions=%s%s>" % (
105            self.__class__.__qualname__,
106            self.index_type,
107            self.expressions,
108            "" if self.condition is None else ", condition=%s" % self.condition,
109        )

tests/postgres_tests/__init__.py

 1import unittest
 2
 3from django.db import connection
 4from django.test import modify_settings
 5from django.test import SimpleTestCase
 6from django.test import TestCase
 7from forms_tests.widget_tests.base import WidgetTest
 8
 9
10@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
11class PostgreSQLSimpleTestCase(SimpleTestCase):
12    pass
13
14
15@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
16class PostgreSQLTestCase(TestCase):
17    pass
18
19
20@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
21# To locate the widget's template.
22@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
23class PostgreSQLWidgetTestCase(WidgetTest, PostgreSQLSimpleTestCase):
24    pass

tests/postgres_tests/fields.py

 1"""
 2Indirection layer for PostgreSQL-specific fields, so the tests don't fail when
 3run with a backend other than PostgreSQL.
 4"""
 5import enum
 6
 7from django.db import models
 8
 9try:
10    from django.contrib.postgres.fields import (
11        ArrayField,
12        BigIntegerRangeField,
13        CICharField,
14        CIEmailField,
15        CITextField,
16        DateRangeField,
17        DateTimeRangeField,
18        DecimalRangeField,
19        HStoreField,
20        IntegerRangeField,
21        JSONField,
22    )
23    from django.contrib.postgres.search import SearchVectorField
24except ImportError:
25
26    class DummyArrayField(models.Field):
27        def __init__(self, base_field, size=None, **kwargs):
28            super().__init__(**kwargs)
29
30        def deconstruct(self):
31            name, path, args, kwargs = super().deconstruct()
32            kwargs.update(
33                {
34                    "base_field": "",
35                    "size": 1,
36                }
37            )
38            return name, path, args, kwargs
39
40    class DummyJSONField(models.Field):
41        def __init__(self, encoder=None, **kwargs):
42            super().__init__(**kwargs)
43
44    ArrayField = DummyArrayField
45    BigIntegerRangeField = models.Field
46    CICharField = models.Field
47    CIEmailField = models.Field
48    CITextField = models.Field
49    DateRangeField = models.Field
50    DateTimeRangeField = models.Field
51    DecimalRangeField = models.Field
52    HStoreField = models.Field
53    IntegerRangeField = models.Field
54    JSONField = DummyJSONField
55    SearchVectorField = models.Field
56
57
58class EnumField(models.CharField):
59    def get_prep_value(self, value):
60        return value.value if isinstance(value, enum.Enum) else value

tests/postgres_tests/models.py

  1from django.core.serializers.json import DjangoJSONEncoder
  2from django.db import models
  3
  4from .fields import ArrayField
  5from .fields import BigIntegerRangeField
  6from .fields import CICharField
  7from .fields import CIEmailField
  8from .fields import CITextField
  9from .fields import DateRangeField
 10from .fields import DateTimeRangeField
 11from .fields import DecimalRangeField
 12from .fields import EnumField
 13from .fields import HStoreField
 14from .fields import IntegerRangeField
 15from .fields import JSONField
 16from .fields import SearchVectorField
 17
 18
 19class Tag:
 20    def __init__(self, tag_id):
 21        self.tag_id = tag_id
 22
 23    def __eq__(self, other):
 24        return isinstance(other, Tag) and self.tag_id == other.tag_id
 25
 26
 27class TagField(models.SmallIntegerField):
 28    def from_db_value(self, value, expression, connection):
 29        if value is None:
 30            return value
 31        return Tag(int(value))
 32
 33    def to_python(self, value):
 34        if isinstance(value, Tag):
 35            return value
 36        if value is None:
 37            return value
 38        return Tag(int(value))
 39
 40    def get_prep_value(self, value):
 41        return value.tag_id
 42
 43
 44class PostgreSQLModel(models.Model):
 45    class Meta:
 46        abstract = True
 47        required_db_vendor = "postgresql"
 48
 49
 50class IntegerArrayModel(PostgreSQLModel):
 51    field = ArrayField(models.IntegerField(), default=list, blank=True)
 52
 53
 54class NullableIntegerArrayModel(PostgreSQLModel):
 55    field = ArrayField(models.IntegerField(), blank=True, null=True)
 56    field_nested = ArrayField(ArrayField(models.IntegerField(null=True)), null=True)
 57
 58
 59class CharArrayModel(PostgreSQLModel):
 60    field = ArrayField(models.CharField(max_length=10))
 61
 62
 63class DateTimeArrayModel(PostgreSQLModel):
 64    datetimes = ArrayField(models.DateTimeField())
 65    dates = ArrayField(models.DateField())
 66    times = ArrayField(models.TimeField())
 67
 68
 69class NestedIntegerArrayModel(PostgreSQLModel):
 70    field = ArrayField(ArrayField(models.IntegerField()))
 71
 72
 73class OtherTypesArrayModel(PostgreSQLModel):
 74    ips = ArrayField(models.GenericIPAddressField(), default=list)
 75    uuids = ArrayField(models.UUIDField(), default=list)
 76    decimals = ArrayField(
 77        models.DecimalField(max_digits=5, decimal_places=2), default=list
 78    )
 79    tags = ArrayField(TagField(), blank=True, null=True)
 80    json = ArrayField(JSONField(default=dict), default=list)
 81    int_ranges = ArrayField(IntegerRangeField(), blank=True, null=True)
 82    bigint_ranges = ArrayField(BigIntegerRangeField(), blank=True, null=True)
 83
 84
 85class HStoreModel(PostgreSQLModel):
 86    field = HStoreField(blank=True, null=True)
 87    array_field = ArrayField(HStoreField(), null=True)
 88
 89
 90class ArrayEnumModel(PostgreSQLModel):
 91    array_of_enums = ArrayField(EnumField(max_length=20))
 92
 93
 94class CharFieldModel(models.Model):
 95    field = models.CharField(max_length=16)
 96
 97
 98class TextFieldModel(models.Model):
 99    field = models.TextField()
100
101    def __str__(self):
102        return self.field
103
104
105# Scene/Character/Line models are used to test full text search. They're
106# populated with content from Monty Python and the Holy Grail.
107class Scene(models.Model):
108    scene = models.CharField(max_length=255)
109    setting = models.CharField(max_length=255)
110
111    def __str__(self):
112        return self.scene
113
114
115class Character(models.Model):
116    name = models.CharField(max_length=255)
117
118    def __str__(self):
119        return self.name
120
121
122class CITestModel(PostgreSQLModel):
123    name = CICharField(primary_key=True, max_length=255)
124    email = CIEmailField()
125    description = CITextField()
126    array_field = ArrayField(CITextField(), null=True)
127
128    def __str__(self):
129        return self.name
130
131
132class Line(PostgreSQLModel):
133    scene = models.ForeignKey("Scene", models.CASCADE)
134    character = models.ForeignKey("Character", models.CASCADE)
135    dialogue = models.TextField(blank=True, null=True)
136    dialogue_search_vector = SearchVectorField(blank=True, null=True)
137    dialogue_config = models.CharField(max_length=100, blank=True, null=True)
138
139    def __str__(self):
140        return self.dialogue or ""
141
142
143class RangesModel(PostgreSQLModel):
144    ints = IntegerRangeField(blank=True, null=True)
145    bigints = BigIntegerRangeField(blank=True, null=True)
146    decimals = DecimalRangeField(blank=True, null=True)
147    timestamps = DateTimeRangeField(blank=True, null=True)
148    timestamps_inner = DateTimeRangeField(blank=True, null=True)
149    dates = DateRangeField(blank=True, null=True)
150    dates_inner = DateRangeField(blank=True, null=True)
151
152
153class RangeLookupsModel(PostgreSQLModel):
154    parent = models.ForeignKey(RangesModel, models.SET_NULL, blank=True, null=True)
155    integer = models.IntegerField(blank=True, null=True)
156    big_integer = models.BigIntegerField(blank=True, null=True)
157    float = models.FloatField(blank=True, null=True)
158    timestamp = models.DateTimeField(blank=True, null=True)
159    date = models.DateField(blank=True, null=True)
160
161
162class JSONModel(PostgreSQLModel):
163    field = JSONField(blank=True, null=True)
164    field_custom = JSONField(blank=True, null=True, encoder=DjangoJSONEncoder)
165
166
167class ArrayFieldSubclass(ArrayField):
168    def __init__(self, *args, **kwargs):
169        super().__init__(models.IntegerField())
170
171
172class AggregateTestModel(models.Model):
173    """
174    To test postgres-specific general aggregation functions
175    """
176
177    char_field = models.CharField(max_length=30, blank=True)
178    integer_field = models.IntegerField(null=True)
179    boolean_field = models.BooleanField(null=True)
180
181
182class StatTestModel(models.Model):
183    """
184    To test postgres-specific aggregation functions for statistics
185    """
186
187    int1 = models.IntegerField()
188    int2 = models.IntegerField()
189    related_field = models.ForeignKey(AggregateTestModel, models.SET_NULL, null=True)
190
191
192class NowTestModel(models.Model):
193    when = models.DateTimeField(null=True, default=None)
194
195
196class UUIDTestModel(models.Model):
197    uuid = models.UUIDField(default=None, null=True)
198
199
200class Room(models.Model):
201    number = models.IntegerField(unique=True)
202
203
204class HotelReservation(PostgreSQLModel):
205    room = models.ForeignKey("Room", on_delete=models.CASCADE)
206    datespan = DateRangeField()
207    start = models.DateTimeField()
208    end = models.DateTimeField()
209    cancelled = models.BooleanField(default=False)

tests/postgres_tests/test_constraints.py

  1import datetime
  2from unittest import mock
  3
  4from django.db import connection
  5from django.db import transaction
  6from django.db.models import F
  7from django.db.models import Func
  8from django.db.models import Q
  9from django.db.models.constraints import CheckConstraint
 10from django.db.utils import IntegrityError
 11from django.utils import timezone
 12
 13from . import PostgreSQLTestCase
 14from .models import HotelReservation
 15from .models import RangesModel
 16from .models import Room
 17
 18try:
 19    from django.contrib.postgres.constraints import ExclusionConstraint
 20    from django.contrib.postgres.fields import (
 21        DateTimeRangeField,
 22        RangeBoundary,
 23        RangeOperators,
 24    )
 25
 26    from psycopg2.extras import DateRange, NumericRange
 27except ImportError:
 28    pass
 29
 30
 31class SchemaTests(PostgreSQLTestCase):
 32    def get_constraints(self, table):
 33        """Get the constraints on the table using a new cursor."""
 34        with connection.cursor() as cursor:
 35            return connection.introspection.get_constraints(cursor, table)
 36
 37    def test_check_constraint_range_value(self):
 38        constraint_name = "ints_between"
 39        self.assertNotIn(
 40            constraint_name, self.get_constraints(RangesModel._meta.db_table)
 41        )
 42        constraint = CheckConstraint(
 43            check=Q(ints__contained_by=NumericRange(10, 30)),
 44            name=constraint_name,
 45        )
 46        with connection.schema_editor() as editor:
 47            editor.add_constraint(RangesModel, constraint)
 48        self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
 49        with self.assertRaises(IntegrityError), transaction.atomic():
 50            RangesModel.objects.create(ints=(20, 50))
 51        RangesModel.objects.create(ints=(10, 30))
 52
 53    def test_check_constraint_daterange_contains(self):
 54        constraint_name = "dates_contains"
 55        self.assertNotIn(
 56            constraint_name, self.get_constraints(RangesModel._meta.db_table)
 57        )
 58        constraint = CheckConstraint(
 59            check=Q(dates__contains=F("dates_inner")),
 60            name=constraint_name,
 61        )
 62        with connection.schema_editor() as editor:
 63            editor.add_constraint(RangesModel, constraint)
 64        self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
 65        date_1 = datetime.date(2016, 1, 1)
 66        date_2 = datetime.date(2016, 1, 4)
 67        with self.assertRaises(IntegrityError), transaction.atomic():
 68            RangesModel.objects.create(
 69                dates=(date_1, date_2),
 70                dates_inner=(date_1, date_2.replace(day=5)),
 71            )
 72        RangesModel.objects.create(
 73            dates=(date_1, date_2),
 74            dates_inner=(date_1, date_2),
 75        )
 76
 77    def test_check_constraint_datetimerange_contains(self):
 78        constraint_name = "timestamps_contains"
 79        self.assertNotIn(
 80            constraint_name, self.get_constraints(RangesModel._meta.db_table)
 81        )
 82        constraint = CheckConstraint(
 83            check=Q(timestamps__contains=F("timestamps_inner")),
 84            name=constraint_name,
 85        )
 86        with connection.schema_editor() as editor:
 87            editor.add_constraint(RangesModel, constraint)
 88        self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
 89        datetime_1 = datetime.datetime(2016, 1, 1)
 90        datetime_2 = datetime.datetime(2016, 1, 2, 12)
 91        with self.assertRaises(IntegrityError), transaction.atomic():
 92            RangesModel.objects.create(
 93                timestamps=(datetime_1, datetime_2),
 94                timestamps_inner=(datetime_1, datetime_2.replace(hour=13)),
 95            )
 96        RangesModel.objects.create(
 97            timestamps=(datetime_1, datetime_2),
 98            timestamps_inner=(datetime_1, datetime_2),
 99        )
100
101
102class ExclusionConstraintTests(PostgreSQLTestCase):
103    def get_constraints(self, table):
104        """Get the constraints on the table using a new cursor."""
105        with connection.cursor() as cursor:
106            return connection.introspection.get_constraints(cursor, table)
107
108    def test_invalid_condition(self):
109        msg = "ExclusionConstraint.condition must be a Q instance."
110        with self.assertRaisesMessage(ValueError, msg):
111            ExclusionConstraint(
112                index_type="GIST",
113                name="exclude_invalid_condition",
114                expressions=[(F("datespan"), RangeOperators.OVERLAPS)],
115                condition=F("invalid"),
116            )
117
118    def test_invalid_index_type(self):
119        msg = "Exclusion constraints only support GiST or SP-GiST indexes."
120        with self.assertRaisesMessage(ValueError, msg):
121            ExclusionConstraint(
122                index_type="gin",
123                name="exclude_invalid_index_type",
124                expressions=[(F("datespan"), RangeOperators.OVERLAPS)],
125            )
126
127    def test_invalid_expressions(self):
128        msg = "The expressions must be a list of 2-tuples."
129        for expressions in (["foo"], [("foo")], [("foo_1", "foo_2", "foo_3")]):
130            with self.subTest(expressions), self.assertRaisesMessage(ValueError, msg):
131                ExclusionConstraint(
132                    index_type="GIST",
133                    name="exclude_invalid_expressions",
134                    expressions=expressions,
135                )
136
137    def test_empty_expressions(self):
138        msg = "At least one expression is required to define an exclusion constraint."
139        for empty_expressions in (None, []):
140            with self.subTest(empty_expressions), self.assertRaisesMessage(
141                ValueError, msg
142            ):
143                ExclusionConstraint(
144                    index_type="GIST",
145                    name="exclude_empty_expressions",
146                    expressions=empty_expressions,
147                )
148
149    def test_repr(self):
150        constraint = ExclusionConstraint(
151            name="exclude_overlapping",
152            expressions=[
153                (F("datespan"), RangeOperators.OVERLAPS),
154                (F("room"), RangeOperators.EQUAL),
155            ],
156        )
157        self.assertEqual(
158            repr(constraint),
159            "<ExclusionConstraint: index_type=GIST, expressions=["
160            "(F(datespan), '&&'), (F(room), '=')]>",
161        )
162        constraint = ExclusionConstraint(
163            name="exclude_overlapping",
164            expressions=[(F("datespan"), RangeOperators.ADJACENT_TO)],
165            condition=Q(cancelled=False),
166            index_type="SPGiST",
167        )
168        self.assertEqual(
169            repr(constraint),
170            "<ExclusionConstraint: index_type=SPGiST, expressions=["
171            "(F(datespan), '-|-')], condition=(AND: ('cancelled', False))>",
172        )
173
174    def test_eq(self):
175        constraint_1 = ExclusionConstraint(
176            name="exclude_overlapping",
177            expressions=[
178                (F("datespan"), RangeOperators.OVERLAPS),
179                (F("room"), RangeOperators.EQUAL),
180            ],
181            condition=Q(cancelled=False),
182        )
183        constraint_2 = ExclusionConstraint(
184            name="exclude_overlapping",
185            expressions=[
186                ("datespan", RangeOperators.OVERLAPS),
187                ("room", RangeOperators.EQUAL),
188            ],
189        )
190        constraint_3 = ExclusionConstraint(
191            name="exclude_overlapping",
192            expressions=[("datespan", RangeOperators.OVERLAPS)],
193            condition=Q(cancelled=False),
194        )
195        self.assertEqual(constraint_1, constraint_1)
196        self.assertEqual(constraint_1, mock.ANY)
197        self.assertNotEqual(constraint_1, constraint_2)
198        self.assertNotEqual(constraint_1, constraint_3)
199        self.assertNotEqual(constraint_2, constraint_3)
200        self.assertNotEqual(constraint_1, object())
201
202    def test_deconstruct(self):
203        constraint = ExclusionConstraint(
204            name="exclude_overlapping",
205            expressions=[
206                ("datespan", RangeOperators.OVERLAPS),
207                ("room", RangeOperators.EQUAL),
208            ],
209        )
210        path, args, kwargs = constraint.deconstruct()
211        self.assertEqual(
212            path, "django.contrib.postgres.constraints.ExclusionConstraint"
213        )
214        self.assertEqual(args, ())
215        self.assertEqual(
216            kwargs,
217            {
218                "name": "exclude_overlapping",
219                "expressions": [
220                    ("datespan", RangeOperators.OVERLAPS),
221                    ("room", RangeOperators.EQUAL),
222                ],
223            },
224        )
225
226    def test_deconstruct_index_type(self):
227        constraint = ExclusionConstraint(
228            name="exclude_overlapping",
229            index_type="SPGIST",
230            expressions=[
231                ("datespan", RangeOperators.OVERLAPS),
232                ("room", RangeOperators.EQUAL),
233            ],
234        )
235        path, args, kwargs = constraint.deconstruct()
236        self.assertEqual(
237            path, "django.contrib.postgres.constraints.ExclusionConstraint"
238        )
239        self.assertEqual(args, ())
240        self.assertEqual(
241            kwargs,
242            {
243                "name": "exclude_overlapping",
244                "index_type": "SPGIST",
245                "expressions": [
246                    ("datespan", RangeOperators.OVERLAPS),
247                    ("room", RangeOperators.EQUAL),
248                ],
249            },
250        )
251
252    def test_deconstruct_condition(self):
253        constraint = ExclusionConstraint(
254            name="exclude_overlapping",
255            expressions=[
256                ("datespan", RangeOperators.OVERLAPS),
257                ("room", RangeOperators.EQUAL),
258            ],
259            condition=Q(cancelled=False),
260        )
261        path, args, kwargs = constraint.deconstruct()
262        self.assertEqual(
263            path, "django.contrib.postgres.constraints.ExclusionConstraint"
264        )
265        self.assertEqual(args, ())
266        self.assertEqual(
267            kwargs,
268            {
269                "name": "exclude_overlapping",
270                "expressions": [
271                    ("datespan", RangeOperators.OVERLAPS),
272                    ("room", RangeOperators.EQUAL),
273                ],
274                "condition": Q(cancelled=False),
275            },
276        )
277
278    def _test_range_overlaps(self, constraint):
279        # Create exclusion constraint.
280        self.assertNotIn(
281            constraint.name, self.get_constraints(HotelReservation._meta.db_table)
282        )
283        with connection.schema_editor() as editor:
284            editor.add_constraint(HotelReservation, constraint)
285        self.assertIn(
286            constraint.name, self.get_constraints(HotelReservation._meta.db_table)
287        )
288        # Add initial reservations.
289        room101 = Room.objects.create(number=101)
290        room102 = Room.objects.create(number=102)
291        datetimes = [
292            timezone.datetime(2018, 6, 20),
293            timezone.datetime(2018, 6, 24),
294            timezone.datetime(2018, 6, 26),
295            timezone.datetime(2018, 6, 28),
296            timezone.datetime(2018, 6, 29),
297        ]
298        HotelReservation.objects.create(
299            datespan=DateRange(datetimes[0].date(), datetimes[1].date()),
300            start=datetimes[0],
301            end=datetimes[1],
302            room=room102,
303        )
304        HotelReservation.objects.create(
305            datespan=DateRange(datetimes[1].date(), datetimes[3].date()),
306            start=datetimes[1],
307            end=datetimes[3],
308            room=room102,
309        )
310        # Overlap dates.
311        with self.assertRaises(IntegrityError), transaction.atomic():
312            reservation = HotelReservation(
313                datespan=(datetimes[1].date(), datetimes[2].date()),
314                start=datetimes[1],
315                end=datetimes[2],
316                room=room102,
317            )
318            reservation.save()
319        # Valid range.
320        HotelReservation.objects.bulk_create(
321            [
322                # Other room.
323                HotelReservation(
324                    datespan=(datetimes[1].date(), datetimes[2].date()),
325                    start=datetimes[1],
326                    end=datetimes[2],
327                    room=room101,
328                ),
329                # Cancelled reservation.
330                HotelReservation(
331                    datespan=(datetimes[1].date(), datetimes[1].date()),
332                    start=datetimes[1],
333                    end=datetimes[2],
334                    room=room102,
335                    cancelled=True,
336                ),
337                # Other adjacent dates.
338                HotelReservation(
339                    datespan=(datetimes[3].date(), datetimes[4].date()),
340                    start=datetimes[3],
341                    end=datetimes[4],
342                    room=room102,
343                ),
344            ]
345        )
346
347    def test_range_overlaps_custom(self):
348        class TsTzRange(Func):
349            function = "TSTZRANGE"
350            output_field = DateTimeRangeField()
351
352        constraint = ExclusionConstraint(
353            name="exclude_overlapping_reservations_custom",
354            expressions=[
355                (TsTzRange("start", "end", RangeBoundary()), RangeOperators.OVERLAPS),
356                ("room", RangeOperators.EQUAL),
357            ],
358            condition=Q(cancelled=False),
359        )
360        self._test_range_overlaps(constraint)
361
362    def test_range_overlaps(self):
363        constraint = ExclusionConstraint(
364            name="exclude_overlapping_reservations",
365            expressions=[
366                (F("datespan"), RangeOperators.OVERLAPS),
367                ("room", RangeOperators.EQUAL),
368            ],
369            condition=Q(cancelled=False),
370        )
371        self._test_range_overlaps(constraint)
372
373    def test_range_adjacent(self):
374        constraint_name = "ints_adjacent"
375        self.assertNotIn(
376            constraint_name, self.get_constraints(RangesModel._meta.db_table)
377        )
378        constraint = ExclusionConstraint(
379            name=constraint_name,
380            expressions=[("ints", RangeOperators.ADJACENT_TO)],
381        )
382        with connection.schema_editor() as editor:
383            editor.add_constraint(RangesModel, constraint)
384        self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
385        RangesModel.objects.create(ints=(20, 50))
386        with self.assertRaises(IntegrityError), transaction.atomic():
387            RangesModel.objects.create(ints=(10, 20))
388        RangesModel.objects.create(ints=(10, 19))
389        RangesModel.objects.create(ints=(51, 60))