Tests RangeField

__init__.py

 1import unittest
 2
 3from django.db import connection
 4from django.test import modify_settings
 5from django.test import SimpleTestCase
 6from django.test import TestCase
 7from forms_tests.widget_tests.base import WidgetTest
 8
 9
10@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
11class PostgreSQLSimpleTestCase(SimpleTestCase):
12    pass
13
14
15@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
16class PostgreSQLTestCase(TestCase):
17    pass
18
19
20@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
21# To locate the widget's template.
22@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
23class PostgreSQLWidgetTestCase(WidgetTest, PostgreSQLSimpleTestCase):
24    pass

fields.py

 1"""
 2Indirection layer for PostgreSQL-specific fields, so the tests don't fail when
 3run with a backend other than PostgreSQL.
 4"""
 5import enum
 6
 7from django.db import models
 8
 9try:
10    from django.contrib.postgres.fields import (
11        ArrayField,
12        BigIntegerRangeField,
13        CICharField,
14        CIEmailField,
15        CITextField,
16        DateRangeField,
17        DateTimeRangeField,
18        DecimalRangeField,
19        HStoreField,
20        IntegerRangeField,
21        JSONField,
22    )
23    from django.contrib.postgres.search import SearchVectorField
24except ImportError:
25
26    class DummyArrayField(models.Field):
27        def __init__(self, base_field, size=None, **kwargs):
28            super().__init__(**kwargs)
29
30        def deconstruct(self):
31            name, path, args, kwargs = super().deconstruct()
32            kwargs.update(
33                {
34                    "base_field": "",
35                    "size": 1,
36                }
37            )
38            return name, path, args, kwargs
39
40    class DummyJSONField(models.Field):
41        def __init__(self, encoder=None, **kwargs):
42            super().__init__(**kwargs)
43
44    ArrayField = DummyArrayField
45    BigIntegerRangeField = models.Field
46    CICharField = models.Field
47    CIEmailField = models.Field
48    CITextField = models.Field
49    DateRangeField = models.Field
50    DateTimeRangeField = models.Field
51    DecimalRangeField = models.Field
52    HStoreField = models.Field
53    IntegerRangeField = models.Field
54    JSONField = DummyJSONField
55    SearchVectorField = models.Field
56
57
58class EnumField(models.CharField):
59    def get_prep_value(self, value):
60        return value.value if isinstance(value, enum.Enum) else value

models.py

  1from django.core.serializers.json import DjangoJSONEncoder
  2from django.db import models
  3
  4from .fields import ArrayField
  5from .fields import BigIntegerRangeField
  6from .fields import CICharField
  7from .fields import CIEmailField
  8from .fields import CITextField
  9from .fields import DateRangeField
 10from .fields import DateTimeRangeField
 11from .fields import DecimalRangeField
 12from .fields import EnumField
 13from .fields import HStoreField
 14from .fields import IntegerRangeField
 15from .fields import JSONField
 16from .fields import SearchVectorField
 17
 18
 19class Tag:
 20    def __init__(self, tag_id):
 21        self.tag_id = tag_id
 22
 23    def __eq__(self, other):
 24        return isinstance(other, Tag) and self.tag_id == other.tag_id
 25
 26
 27class TagField(models.SmallIntegerField):
 28    def from_db_value(self, value, expression, connection):
 29        if value is None:
 30            return value
 31        return Tag(int(value))
 32
 33    def to_python(self, value):
 34        if isinstance(value, Tag):
 35            return value
 36        if value is None:
 37            return value
 38        return Tag(int(value))
 39
 40    def get_prep_value(self, value):
 41        return value.tag_id
 42
 43
 44class PostgreSQLModel(models.Model):
 45    class Meta:
 46        abstract = True
 47        required_db_vendor = "postgresql"
 48
 49
 50class IntegerArrayModel(PostgreSQLModel):
 51    field = ArrayField(models.IntegerField(), default=list, blank=True)
 52
 53
 54class NullableIntegerArrayModel(PostgreSQLModel):
 55    field = ArrayField(models.IntegerField(), blank=True, null=True)
 56    field_nested = ArrayField(ArrayField(models.IntegerField(null=True)), null=True)
 57
 58
 59class CharArrayModel(PostgreSQLModel):
 60    field = ArrayField(models.CharField(max_length=10))
 61
 62
 63class DateTimeArrayModel(PostgreSQLModel):
 64    datetimes = ArrayField(models.DateTimeField())
 65    dates = ArrayField(models.DateField())
 66    times = ArrayField(models.TimeField())
 67
 68
 69class NestedIntegerArrayModel(PostgreSQLModel):
 70    field = ArrayField(ArrayField(models.IntegerField()))
 71
 72
 73class OtherTypesArrayModel(PostgreSQLModel):
 74    ips = ArrayField(models.GenericIPAddressField(), default=list)
 75    uuids = ArrayField(models.UUIDField(), default=list)
 76    decimals = ArrayField(
 77        models.DecimalField(max_digits=5, decimal_places=2), default=list
 78    )
 79    tags = ArrayField(TagField(), blank=True, null=True)
 80    json = ArrayField(JSONField(default=dict), default=list)
 81    int_ranges = ArrayField(IntegerRangeField(), blank=True, null=True)
 82    bigint_ranges = ArrayField(BigIntegerRangeField(), blank=True, null=True)
 83
 84
 85class HStoreModel(PostgreSQLModel):
 86    field = HStoreField(blank=True, null=True)
 87    array_field = ArrayField(HStoreField(), null=True)
 88
 89
 90class ArrayEnumModel(PostgreSQLModel):
 91    array_of_enums = ArrayField(EnumField(max_length=20))
 92
 93
 94class CharFieldModel(models.Model):
 95    field = models.CharField(max_length=16)
 96
 97
 98class TextFieldModel(models.Model):
 99    field = models.TextField()
100
101    def __str__(self):
102        return self.field
103
104
105class SmallAutoFieldModel(models.Model):
106    id = models.SmallAutoField(primary_key=True)
107
108
109class BigAutoFieldModel(models.Model):
110    id = models.BigAutoField(primary_key=True)
111
112
113# Scene/Character/Line models are used to test full text search. They're
114# populated with content from Monty Python and the Holy Grail.
115class Scene(models.Model):
116    scene = models.CharField(max_length=255)
117    setting = models.CharField(max_length=255)
118
119    def __str__(self):
120        return self.scene
121
122
123class Character(models.Model):
124    name = models.CharField(max_length=255)
125
126    def __str__(self):
127        return self.name
128
129
130class CITestModel(PostgreSQLModel):
131    name = CICharField(primary_key=True, max_length=255)
132    email = CIEmailField()
133    description = CITextField()
134    array_field = ArrayField(CITextField(), null=True)
135
136    def __str__(self):
137        return self.name
138
139
140class Line(PostgreSQLModel):
141    scene = models.ForeignKey("Scene", models.CASCADE)
142    character = models.ForeignKey("Character", models.CASCADE)
143    dialogue = models.TextField(blank=True, null=True)
144    dialogue_search_vector = SearchVectorField(blank=True, null=True)
145    dialogue_config = models.CharField(max_length=100, blank=True, null=True)
146
147    def __str__(self):
148        return self.dialogue or ""
149
150
151class RangesModel(PostgreSQLModel):
152    ints = IntegerRangeField(blank=True, null=True)
153    bigints = BigIntegerRangeField(blank=True, null=True)
154    decimals = DecimalRangeField(blank=True, null=True)
155    timestamps = DateTimeRangeField(blank=True, null=True)
156    timestamps_inner = DateTimeRangeField(blank=True, null=True)
157    dates = DateRangeField(blank=True, null=True)
158    dates_inner = DateRangeField(blank=True, null=True)
159
160
161class RangeLookupsModel(PostgreSQLModel):
162    parent = models.ForeignKey(RangesModel, models.SET_NULL, blank=True, null=True)
163    integer = models.IntegerField(blank=True, null=True)
164    big_integer = models.BigIntegerField(blank=True, null=True)
165    float = models.FloatField(blank=True, null=True)
166    timestamp = models.DateTimeField(blank=True, null=True)
167    date = models.DateField(blank=True, null=True)
168    small_integer = models.SmallIntegerField(blank=True, null=True)
169    decimal_field = models.DecimalField(
170        max_digits=5, decimal_places=2, blank=True, null=True
171    )
172
173
174class JSONModel(PostgreSQLModel):
175    field = JSONField(blank=True, null=True)
176    field_custom = JSONField(blank=True, null=True, encoder=DjangoJSONEncoder)
177
178
179class ArrayFieldSubclass(ArrayField):
180    def __init__(self, *args, **kwargs):
181        super().__init__(models.IntegerField())
182
183
184class AggregateTestModel(models.Model):
185    """
186    To test postgres-specific general aggregation functions
187    """
188
189    char_field = models.CharField(max_length=30, blank=True)
190    integer_field = models.IntegerField(null=True)
191    boolean_field = models.BooleanField(null=True)
192
193
194class StatTestModel(models.Model):
195    """
196    To test postgres-specific aggregation functions for statistics
197    """
198
199    int1 = models.IntegerField()
200    int2 = models.IntegerField()
201    related_field = models.ForeignKey(AggregateTestModel, models.SET_NULL, null=True)
202
203
204class NowTestModel(models.Model):
205    when = models.DateTimeField(null=True, default=None)
206
207
208class UUIDTestModel(models.Model):
209    uuid = models.UUIDField(default=None, null=True)
210
211
212class Room(models.Model):
213    number = models.IntegerField(unique=True)
214
215
216class HotelReservation(PostgreSQLModel):
217    room = models.ForeignKey("Room", on_delete=models.CASCADE)
218    datespan = DateRangeField()
219    start = models.DateTimeField()
220    end = models.DateTimeField()
221    cancelled = models.BooleanField(default=False)

test_ranges.py

   1import datetime
   2import json
   3from decimal import Decimal
   4
   5from django import forms
   6from django.core import exceptions
   7from django.core import serializers
   8from django.db.models import DateField
   9from django.db.models import DateTimeField
  10from django.db.models import F
  11from django.db.models import Func
  12from django.db.models import Value
  13from django.http import QueryDict
  14from django.test import override_settings
  15from django.test.utils import isolate_apps
  16from django.utils import timezone
  17
  18from . import PostgreSQLSimpleTestCase
  19from . import PostgreSQLTestCase
  20from .models import BigAutoFieldModel
  21from .models import PostgreSQLModel
  22from .models import RangeLookupsModel
  23from .models import RangesModel
  24from .models import SmallAutoFieldModel
  25
  26try:
  27    from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange
  28    from django.contrib.postgres import fields as pg_fields, forms as pg_forms
  29    from django.contrib.postgres.validators import (
  30        RangeMaxValueValidator,
  31        RangeMinValueValidator,
  32    )
  33except ImportError:
  34    pass
  35
  36
  37@isolate_apps("postgres_tests")
  38class BasicTests(PostgreSQLSimpleTestCase):
  39    def test_get_field_display(self):
  40        class Model(PostgreSQLModel):
  41            field = pg_fields.IntegerRangeField(
  42                choices=[
  43                    ["1-50", [((1, 25), "1-25"), ([26, 50], "26-50")]],
  44                    ((51, 100), "51-100"),
  45                ],
  46            )
  47
  48        tests = (
  49            ((1, 25), "1-25"),
  50            ([26, 50], "26-50"),
  51            ((51, 100), "51-100"),
  52            ((1, 2), "(1, 2)"),
  53            ([1, 2], "[1, 2]"),
  54        )
  55        for value, display in tests:
  56            with self.subTest(value=value, display=display):
  57                instance = Model(field=value)
  58                self.assertEqual(instance.get_field_display(), display)
  59
  60
  61class TestSaveLoad(PostgreSQLTestCase):
  62    def test_all_fields(self):
  63        now = timezone.now()
  64        instance = RangesModel(
  65            ints=NumericRange(0, 10),
  66            bigints=NumericRange(10, 20),
  67            decimals=NumericRange(20, 30),
  68            timestamps=DateTimeTZRange(now - datetime.timedelta(hours=1), now),
  69            dates=DateRange(now.date() - datetime.timedelta(days=1), now.date()),
  70        )
  71        instance.save()
  72        loaded = RangesModel.objects.get()
  73        self.assertEqual(instance.ints, loaded.ints)
  74        self.assertEqual(instance.bigints, loaded.bigints)
  75        self.assertEqual(instance.decimals, loaded.decimals)
  76        self.assertEqual(instance.timestamps, loaded.timestamps)
  77        self.assertEqual(instance.dates, loaded.dates)
  78
  79    def test_range_object(self):
  80        r = NumericRange(0, 10)
  81        instance = RangesModel(ints=r)
  82        instance.save()
  83        loaded = RangesModel.objects.get()
  84        self.assertEqual(r, loaded.ints)
  85
  86    def test_tuple(self):
  87        instance = RangesModel(ints=(0, 10))
  88        instance.save()
  89        loaded = RangesModel.objects.get()
  90        self.assertEqual(NumericRange(0, 10), loaded.ints)
  91
  92    def test_range_object_boundaries(self):
  93        r = NumericRange(0, 10, "[]")
  94        instance = RangesModel(decimals=r)
  95        instance.save()
  96        loaded = RangesModel.objects.get()
  97        self.assertEqual(r, loaded.decimals)
  98        self.assertIn(10, loaded.decimals)
  99
 100    def test_unbounded(self):
 101        r = NumericRange(None, None, "()")
 102        instance = RangesModel(decimals=r)
 103        instance.save()
 104        loaded = RangesModel.objects.get()
 105        self.assertEqual(r, loaded.decimals)
 106
 107    def test_empty(self):
 108        r = NumericRange(empty=True)
 109        instance = RangesModel(ints=r)
 110        instance.save()
 111        loaded = RangesModel.objects.get()
 112        self.assertEqual(r, loaded.ints)
 113
 114    def test_null(self):
 115        instance = RangesModel(ints=None)
 116        instance.save()
 117        loaded = RangesModel.objects.get()
 118        self.assertIsNone(loaded.ints)
 119
 120    def test_model_set_on_base_field(self):
 121        instance = RangesModel()
 122        field = instance._meta.get_field("ints")
 123        self.assertEqual(field.model, RangesModel)
 124        self.assertEqual(field.base_field.model, RangesModel)
 125
 126
 127class TestRangeContainsLookup(PostgreSQLTestCase):
 128    @classmethod
 129    def setUpTestData(cls):
 130        cls.timestamps = [
 131            datetime.datetime(year=2016, month=1, day=1),
 132            datetime.datetime(year=2016, month=1, day=2, hour=1),
 133            datetime.datetime(year=2016, month=1, day=2, hour=12),
 134            datetime.datetime(year=2016, month=1, day=3),
 135            datetime.datetime(year=2016, month=1, day=3, hour=1),
 136            datetime.datetime(year=2016, month=2, day=2),
 137        ]
 138        cls.aware_timestamps = [
 139            timezone.make_aware(timestamp) for timestamp in cls.timestamps
 140        ]
 141        cls.dates = [
 142            datetime.date(year=2016, month=1, day=1),
 143            datetime.date(year=2016, month=1, day=2),
 144            datetime.date(year=2016, month=1, day=3),
 145            datetime.date(year=2016, month=1, day=4),
 146            datetime.date(year=2016, month=2, day=2),
 147            datetime.date(year=2016, month=2, day=3),
 148        ]
 149        cls.obj = RangesModel.objects.create(
 150            dates=(cls.dates[0], cls.dates[3]),
 151            dates_inner=(cls.dates[1], cls.dates[2]),
 152            timestamps=(cls.timestamps[0], cls.timestamps[3]),
 153            timestamps_inner=(cls.timestamps[1], cls.timestamps[2]),
 154        )
 155        cls.aware_obj = RangesModel.objects.create(
 156            dates=(cls.dates[0], cls.dates[3]),
 157            dates_inner=(cls.dates[1], cls.dates[2]),
 158            timestamps=(cls.aware_timestamps[0], cls.aware_timestamps[3]),
 159            timestamps_inner=(cls.timestamps[1], cls.timestamps[2]),
 160        )
 161        # Objects that don't match any queries.
 162        for i in range(3, 4):
 163            RangesModel.objects.create(
 164                dates=(cls.dates[i], cls.dates[i + 1]),
 165                timestamps=(cls.timestamps[i], cls.timestamps[i + 1]),
 166            )
 167            RangesModel.objects.create(
 168                dates=(cls.dates[i], cls.dates[i + 1]),
 169                timestamps=(cls.aware_timestamps[i], cls.aware_timestamps[i + 1]),
 170            )
 171
 172    def test_datetime_range_contains(self):
 173        filter_args = (
 174            self.timestamps[1],
 175            self.aware_timestamps[1],
 176            (self.timestamps[1], self.timestamps[2]),
 177            (self.aware_timestamps[1], self.aware_timestamps[2]),
 178            Value(self.dates[0], output_field=DateTimeField()),
 179            Func(F("dates"), function="lower", output_field=DateTimeField()),
 180            F("timestamps_inner"),
 181        )
 182        for filter_arg in filter_args:
 183            with self.subTest(filter_arg=filter_arg):
 184                self.assertCountEqual(
 185                    RangesModel.objects.filter(**{"timestamps__contains": filter_arg}),
 186                    [self.obj, self.aware_obj],
 187                )
 188
 189    def test_date_range_contains(self):
 190        filter_args = (
 191            self.timestamps[1],
 192            (self.dates[1], self.dates[2]),
 193            Value(self.dates[0], output_field=DateField()),
 194            Func(F("timestamps"), function="lower", output_field=DateField()),
 195            F("dates_inner"),
 196        )
 197        for filter_arg in filter_args:
 198            with self.subTest(filter_arg=filter_arg):
 199                self.assertCountEqual(
 200                    RangesModel.objects.filter(**{"dates__contains": filter_arg}),
 201                    [self.obj, self.aware_obj],
 202                )
 203
 204
 205class TestQuerying(PostgreSQLTestCase):
 206    @classmethod
 207    def setUpTestData(cls):
 208        cls.objs = RangesModel.objects.bulk_create(
 209            [
 210                RangesModel(ints=NumericRange(0, 10)),
 211                RangesModel(ints=NumericRange(5, 15)),
 212                RangesModel(ints=NumericRange(None, 0)),
 213                RangesModel(ints=NumericRange(empty=True)),
 214                RangesModel(ints=None),
 215            ]
 216        )
 217
 218    def test_exact(self):
 219        self.assertSequenceEqual(
 220            RangesModel.objects.filter(ints__exact=NumericRange(0, 10)),
 221            [self.objs[0]],
 222        )
 223
 224    def test_isnull(self):
 225        self.assertSequenceEqual(
 226            RangesModel.objects.filter(ints__isnull=True),
 227            [self.objs[4]],
 228        )
 229
 230    def test_isempty(self):
 231        self.assertSequenceEqual(
 232            RangesModel.objects.filter(ints__isempty=True),
 233            [self.objs[3]],
 234        )
 235
 236    def test_contains(self):
 237        self.assertSequenceEqual(
 238            RangesModel.objects.filter(ints__contains=8),
 239            [self.objs[0], self.objs[1]],
 240        )
 241
 242    def test_contains_range(self):
 243        self.assertSequenceEqual(
 244            RangesModel.objects.filter(ints__contains=NumericRange(3, 8)),
 245            [self.objs[0]],
 246        )
 247
 248    def test_contained_by(self):
 249        self.assertSequenceEqual(
 250            RangesModel.objects.filter(ints__contained_by=NumericRange(0, 20)),
 251            [self.objs[0], self.objs[1], self.objs[3]],
 252        )
 253
 254    def test_overlap(self):
 255        self.assertSequenceEqual(
 256            RangesModel.objects.filter(ints__overlap=NumericRange(3, 8)),
 257            [self.objs[0], self.objs[1]],
 258        )
 259
 260    def test_fully_lt(self):
 261        self.assertSequenceEqual(
 262            RangesModel.objects.filter(ints__fully_lt=NumericRange(5, 10)),
 263            [self.objs[2]],
 264        )
 265
 266    def test_fully_gt(self):
 267        self.assertSequenceEqual(
 268            RangesModel.objects.filter(ints__fully_gt=NumericRange(5, 10)),
 269            [],
 270        )
 271
 272    def test_not_lt(self):
 273        self.assertSequenceEqual(
 274            RangesModel.objects.filter(ints__not_lt=NumericRange(5, 10)),
 275            [self.objs[1]],
 276        )
 277
 278    def test_not_gt(self):
 279        self.assertSequenceEqual(
 280            RangesModel.objects.filter(ints__not_gt=NumericRange(5, 10)),
 281            [self.objs[0], self.objs[2]],
 282        )
 283
 284    def test_adjacent_to(self):
 285        self.assertSequenceEqual(
 286            RangesModel.objects.filter(ints__adjacent_to=NumericRange(0, 5)),
 287            [self.objs[1], self.objs[2]],
 288        )
 289
 290    def test_startswith(self):
 291        self.assertSequenceEqual(
 292            RangesModel.objects.filter(ints__startswith=0),
 293            [self.objs[0]],
 294        )
 295
 296    def test_endswith(self):
 297        self.assertSequenceEqual(
 298            RangesModel.objects.filter(ints__endswith=0),
 299            [self.objs[2]],
 300        )
 301
 302    def test_startswith_chaining(self):
 303        self.assertSequenceEqual(
 304            RangesModel.objects.filter(ints__startswith__gte=0),
 305            [self.objs[0], self.objs[1]],
 306        )
 307
 308    def test_bound_type(self):
 309        decimals = RangesModel.objects.bulk_create(
 310            [
 311                RangesModel(decimals=NumericRange(None, 10)),
 312                RangesModel(decimals=NumericRange(10, None)),
 313                RangesModel(decimals=NumericRange(5, 15)),
 314                RangesModel(decimals=NumericRange(5, 15, "(]")),
 315            ]
 316        )
 317        tests = [
 318            ("lower_inc", True, [decimals[1], decimals[2]]),
 319            ("lower_inc", False, [decimals[0], decimals[3]]),
 320            ("lower_inf", True, [decimals[0]]),
 321            ("lower_inf", False, [decimals[1], decimals[2], decimals[3]]),
 322            ("upper_inc", True, [decimals[3]]),
 323            ("upper_inc", False, [decimals[0], decimals[1], decimals[2]]),
 324            ("upper_inf", True, [decimals[1]]),
 325            ("upper_inf", False, [decimals[0], decimals[2], decimals[3]]),
 326        ]
 327        for lookup, filter_arg, excepted_result in tests:
 328            with self.subTest(lookup=lookup, filter_arg=filter_arg):
 329                self.assertSequenceEqual(
 330                    RangesModel.objects.filter(**{"decimals__%s" % lookup: filter_arg}),
 331                    excepted_result,
 332                )
 333
 334
 335class TestQueryingWithRanges(PostgreSQLTestCase):
 336    def test_date_range(self):
 337        objs = [
 338            RangeLookupsModel.objects.create(date="2015-01-01"),
 339            RangeLookupsModel.objects.create(date="2015-05-05"),
 340        ]
 341        self.assertSequenceEqual(
 342            RangeLookupsModel.objects.filter(
 343                date__contained_by=DateRange("2015-01-01", "2015-05-04")
 344            ),
 345            [objs[0]],
 346        )
 347
 348    def test_date_range_datetime_field(self):
 349        objs = [
 350            RangeLookupsModel.objects.create(timestamp="2015-01-01"),
 351            RangeLookupsModel.objects.create(timestamp="2015-05-05"),
 352        ]
 353        self.assertSequenceEqual(
 354            RangeLookupsModel.objects.filter(
 355                timestamp__date__contained_by=DateRange("2015-01-01", "2015-05-04")
 356            ),
 357            [objs[0]],
 358        )
 359
 360    def test_datetime_range(self):
 361        objs = [
 362            RangeLookupsModel.objects.create(timestamp="2015-01-01T09:00:00"),
 363            RangeLookupsModel.objects.create(timestamp="2015-05-05T17:00:00"),
 364        ]
 365        self.assertSequenceEqual(
 366            RangeLookupsModel.objects.filter(
 367                timestamp__contained_by=DateTimeTZRange(
 368                    "2015-01-01T09:00", "2015-05-04T23:55"
 369                )
 370            ),
 371            [objs[0]],
 372        )
 373
 374    def test_small_integer_field_contained_by(self):
 375        objs = [
 376            RangeLookupsModel.objects.create(small_integer=8),
 377            RangeLookupsModel.objects.create(small_integer=4),
 378            RangeLookupsModel.objects.create(small_integer=-1),
 379        ]
 380        self.assertSequenceEqual(
 381            RangeLookupsModel.objects.filter(
 382                small_integer__contained_by=NumericRange(4, 6)
 383            ),
 384            [objs[1]],
 385        )
 386
 387    def test_integer_range(self):
 388        objs = [
 389            RangeLookupsModel.objects.create(integer=5),
 390            RangeLookupsModel.objects.create(integer=99),
 391            RangeLookupsModel.objects.create(integer=-1),
 392        ]
 393        self.assertSequenceEqual(
 394            RangeLookupsModel.objects.filter(integer__contained_by=NumericRange(1, 98)),
 395            [objs[0]],
 396        )
 397
 398    def test_biginteger_range(self):
 399        objs = [
 400            RangeLookupsModel.objects.create(big_integer=5),
 401            RangeLookupsModel.objects.create(big_integer=99),
 402            RangeLookupsModel.objects.create(big_integer=-1),
 403        ]
 404        self.assertSequenceEqual(
 405            RangeLookupsModel.objects.filter(
 406                big_integer__contained_by=NumericRange(1, 98)
 407            ),
 408            [objs[0]],
 409        )
 410
 411    def test_decimal_field_contained_by(self):
 412        objs = [
 413            RangeLookupsModel.objects.create(decimal_field=Decimal("1.33")),
 414            RangeLookupsModel.objects.create(decimal_field=Decimal("2.88")),
 415            RangeLookupsModel.objects.create(decimal_field=Decimal("99.17")),
 416        ]
 417        self.assertSequenceEqual(
 418            RangeLookupsModel.objects.filter(
 419                decimal_field__contained_by=NumericRange(
 420                    Decimal("1.89"), Decimal("7.91")
 421                ),
 422            ),
 423            [objs[1]],
 424        )
 425
 426    def test_float_range(self):
 427        objs = [
 428            RangeLookupsModel.objects.create(float=5),
 429            RangeLookupsModel.objects.create(float=99),
 430            RangeLookupsModel.objects.create(float=-1),
 431        ]
 432        self.assertSequenceEqual(
 433            RangeLookupsModel.objects.filter(float__contained_by=NumericRange(1, 98)),
 434            [objs[0]],
 435        )
 436
 437    def test_small_auto_field_contained_by(self):
 438        objs = SmallAutoFieldModel.objects.bulk_create(
 439            [SmallAutoFieldModel() for i in range(1, 5)]
 440        )
 441        self.assertSequenceEqual(
 442            SmallAutoFieldModel.objects.filter(
 443                id__contained_by=NumericRange(objs[1].pk, objs[3].pk),
 444            ),
 445            objs[1:3],
 446        )
 447
 448    def test_auto_field_contained_by(self):
 449        objs = RangeLookupsModel.objects.bulk_create(
 450            [RangeLookupsModel() for i in range(1, 5)]
 451        )
 452        self.assertSequenceEqual(
 453            RangeLookupsModel.objects.filter(
 454                id__contained_by=NumericRange(objs[1].pk, objs[3].pk),
 455            ),
 456            objs[1:3],
 457        )
 458
 459    def test_big_auto_field_contained_by(self):
 460        objs = BigAutoFieldModel.objects.bulk_create(
 461            [BigAutoFieldModel() for i in range(1, 5)]
 462        )
 463        self.assertSequenceEqual(
 464            BigAutoFieldModel.objects.filter(
 465                id__contained_by=NumericRange(objs[1].pk, objs[3].pk),
 466            ),
 467            objs[1:3],
 468        )
 469
 470    def test_f_ranges(self):
 471        parent = RangesModel.objects.create(decimals=NumericRange(0, 10))
 472        objs = [
 473            RangeLookupsModel.objects.create(float=5, parent=parent),
 474            RangeLookupsModel.objects.create(float=99, parent=parent),
 475        ]
 476        self.assertSequenceEqual(
 477            RangeLookupsModel.objects.filter(float__contained_by=F("parent__decimals")),
 478            [objs[0]],
 479        )
 480
 481    def test_exclude(self):
 482        objs = [
 483            RangeLookupsModel.objects.create(float=5),
 484            RangeLookupsModel.objects.create(float=99),
 485            RangeLookupsModel.objects.create(float=-1),
 486        ]
 487        self.assertSequenceEqual(
 488            RangeLookupsModel.objects.exclude(float__contained_by=NumericRange(0, 100)),
 489            [objs[2]],
 490        )
 491
 492
 493class TestSerialization(PostgreSQLSimpleTestCase):
 494    test_data = (
 495        '[{"fields": {"ints": "{\\"upper\\": \\"10\\", \\"lower\\": \\"0\\", '
 496        '\\"bounds\\": \\"[)\\"}", "decimals": "{\\"empty\\": true}", '
 497        '"bigints": null, "timestamps": "{\\"upper\\": \\"2014-02-02T12:12:12+00:00\\", '
 498        '\\"lower\\": \\"2014-01-01T00:00:00+00:00\\", \\"bounds\\": \\"[)\\"}", '
 499        '"timestamps_inner": null, '
 500        '"dates": "{\\"upper\\": \\"2014-02-02\\", \\"lower\\": \\"2014-01-01\\", \\"bounds\\": \\"[)\\"}", '
 501        '"dates_inner": null }, '
 502        '"model": "postgres_tests.rangesmodel", "pk": null}]'
 503    )
 504
 505    lower_date = datetime.date(2014, 1, 1)
 506    upper_date = datetime.date(2014, 2, 2)
 507    lower_dt = datetime.datetime(2014, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
 508    upper_dt = datetime.datetime(2014, 2, 2, 12, 12, 12, tzinfo=timezone.utc)
 509
 510    def test_dumping(self):
 511        instance = RangesModel(
 512            ints=NumericRange(0, 10),
 513            decimals=NumericRange(empty=True),
 514            timestamps=DateTimeTZRange(self.lower_dt, self.upper_dt),
 515            dates=DateRange(self.lower_date, self.upper_date),
 516        )
 517        data = serializers.serialize("json", [instance])
 518        dumped = json.loads(data)
 519        for field in ("ints", "dates", "timestamps"):
 520            dumped[0]["fields"][field] = json.loads(dumped[0]["fields"][field])
 521        check = json.loads(self.test_data)
 522        for field in ("ints", "dates", "timestamps"):
 523            check[0]["fields"][field] = json.loads(check[0]["fields"][field])
 524        self.assertEqual(dumped, check)
 525
 526    def test_loading(self):
 527        instance = list(serializers.deserialize("json", self.test_data))[0].object
 528        self.assertEqual(instance.ints, NumericRange(0, 10))
 529        self.assertEqual(instance.decimals, NumericRange(empty=True))
 530        self.assertIsNone(instance.bigints)
 531        self.assertEqual(instance.dates, DateRange(self.lower_date, self.upper_date))
 532        self.assertEqual(
 533            instance.timestamps, DateTimeTZRange(self.lower_dt, self.upper_dt)
 534        )
 535
 536    def test_serialize_range_with_null(self):
 537        instance = RangesModel(ints=NumericRange(None, 10))
 538        data = serializers.serialize("json", [instance])
 539        new_instance = list(serializers.deserialize("json", data))[0].object
 540        self.assertEqual(new_instance.ints, NumericRange(None, 10))
 541
 542        instance = RangesModel(ints=NumericRange(10, None))
 543        data = serializers.serialize("json", [instance])
 544        new_instance = list(serializers.deserialize("json", data))[0].object
 545        self.assertEqual(new_instance.ints, NumericRange(10, None))
 546
 547
 548class TestChecks(PostgreSQLSimpleTestCase):
 549    def test_choices_tuple_list(self):
 550        class Model(PostgreSQLModel):
 551            field = pg_fields.IntegerRangeField(
 552                choices=[
 553                    ["1-50", [((1, 25), "1-25"), ([26, 50], "26-50")]],
 554                    ((51, 100), "51-100"),
 555                ],
 556            )
 557
 558        self.assertEqual(Model._meta.get_field("field").check(), [])
 559
 560
 561class TestValidators(PostgreSQLSimpleTestCase):
 562    def test_max(self):
 563        validator = RangeMaxValueValidator(5)
 564        validator(NumericRange(0, 5))
 565        msg = "Ensure that this range is completely less than or equal to 5."
 566        with self.assertRaises(exceptions.ValidationError) as cm:
 567            validator(NumericRange(0, 10))
 568        self.assertEqual(cm.exception.messages[0], msg)
 569        self.assertEqual(cm.exception.code, "max_value")
 570        with self.assertRaisesMessage(exceptions.ValidationError, msg):
 571            validator(NumericRange(0, None))  # an unbound range
 572
 573    def test_min(self):
 574        validator = RangeMinValueValidator(5)
 575        validator(NumericRange(10, 15))
 576        msg = "Ensure that this range is completely greater than or equal to 5."
 577        with self.assertRaises(exceptions.ValidationError) as cm:
 578            validator(NumericRange(0, 10))
 579        self.assertEqual(cm.exception.messages[0], msg)
 580        self.assertEqual(cm.exception.code, "min_value")
 581        with self.assertRaisesMessage(exceptions.ValidationError, msg):
 582            validator(NumericRange(None, 10))  # an unbound range
 583
 584
 585class TestFormField(PostgreSQLSimpleTestCase):
 586    def test_valid_integer(self):
 587        field = pg_forms.IntegerRangeField()
 588        value = field.clean(["1", "2"])
 589        self.assertEqual(value, NumericRange(1, 2))
 590
 591    def test_valid_decimal(self):
 592        field = pg_forms.DecimalRangeField()
 593        value = field.clean(["1.12345", "2.001"])
 594        self.assertEqual(value, NumericRange(Decimal("1.12345"), Decimal("2.001")))
 595
 596    def test_valid_timestamps(self):
 597        field = pg_forms.DateTimeRangeField()
 598        value = field.clean(["01/01/2014 00:00:00", "02/02/2014 12:12:12"])
 599        lower = datetime.datetime(2014, 1, 1, 0, 0, 0)
 600        upper = datetime.datetime(2014, 2, 2, 12, 12, 12)
 601        self.assertEqual(value, DateTimeTZRange(lower, upper))
 602
 603    def test_valid_dates(self):
 604        field = pg_forms.DateRangeField()
 605        value = field.clean(["01/01/2014", "02/02/2014"])
 606        lower = datetime.date(2014, 1, 1)
 607        upper = datetime.date(2014, 2, 2)
 608        self.assertEqual(value, DateRange(lower, upper))
 609
 610    def test_using_split_datetime_widget(self):
 611        class SplitDateTimeRangeField(pg_forms.DateTimeRangeField):
 612            base_field = forms.SplitDateTimeField
 613
 614        class SplitForm(forms.Form):
 615            field = SplitDateTimeRangeField()
 616
 617        form = SplitForm()
 618        self.assertHTMLEqual(
 619            str(form),
 620            """
 621            <tr>
 622                <th>
 623                <label for="id_field_0">Field:</label>
 624                </th>
 625                <td>
 626                    <input id="id_field_0_0" name="field_0_0" type="text">
 627                    <input id="id_field_0_1" name="field_0_1" type="text">
 628                    <input id="id_field_1_0" name="field_1_0" type="text">
 629                    <input id="id_field_1_1" name="field_1_1" type="text">
 630                </td>
 631            </tr>
 632        """,
 633        )
 634        form = SplitForm(
 635            {
 636                "field_0_0": "01/01/2014",
 637                "field_0_1": "00:00:00",
 638                "field_1_0": "02/02/2014",
 639                "field_1_1": "12:12:12",
 640            }
 641        )
 642        self.assertTrue(form.is_valid())
 643        lower = datetime.datetime(2014, 1, 1, 0, 0, 0)
 644        upper = datetime.datetime(2014, 2, 2, 12, 12, 12)
 645        self.assertEqual(form.cleaned_data["field"], DateTimeTZRange(lower, upper))
 646
 647    def test_none(self):
 648        field = pg_forms.IntegerRangeField(required=False)
 649        value = field.clean(["", ""])
 650        self.assertIsNone(value)
 651
 652    def test_datetime_form_as_table(self):
 653        class DateTimeRangeForm(forms.Form):
 654            datetime_field = pg_forms.DateTimeRangeField(show_hidden_initial=True)
 655
 656        form = DateTimeRangeForm()
 657        self.assertHTMLEqual(
 658            form.as_table(),
 659            """
 660            <tr><th>
 661            <label for="id_datetime_field_0">Datetime field:</label>
 662            </th><td>
 663            <input type="text" name="datetime_field_0" id="id_datetime_field_0">
 664            <input type="text" name="datetime_field_1" id="id_datetime_field_1">
 665            <input type="hidden" name="initial-datetime_field_0" id="initial-id_datetime_field_0">
 666            <input type="hidden" name="initial-datetime_field_1" id="initial-id_datetime_field_1">
 667            </td></tr>
 668            """,
 669        )
 670        form = DateTimeRangeForm(
 671            {
 672                "datetime_field_0": "2010-01-01 11:13:00",
 673                "datetime_field_1": "2020-12-12 16:59:00",
 674            }
 675        )
 676        self.assertHTMLEqual(
 677            form.as_table(),
 678            """
 679            <tr><th>
 680            <label for="id_datetime_field_0">Datetime field:</label>
 681            </th><td>
 682            <input type="text" name="datetime_field_0"
 683            value="2010-01-01 11:13:00" id="id_datetime_field_0">
 684            <input type="text" name="datetime_field_1"
 685            value="2020-12-12 16:59:00" id="id_datetime_field_1">
 686            <input type="hidden" name="initial-datetime_field_0" value="2010-01-01 11:13:00"
 687            id="initial-id_datetime_field_0">
 688            <input type="hidden" name="initial-datetime_field_1" value="2020-12-12 16:59:00"
 689            id="initial-id_datetime_field_1"></td></tr>
 690            """,
 691        )
 692
 693    def test_datetime_form_initial_data(self):
 694        class DateTimeRangeForm(forms.Form):
 695            datetime_field = pg_forms.DateTimeRangeField(show_hidden_initial=True)
 696
 697        data = QueryDict(mutable=True)
 698        data.update(
 699            {
 700                "datetime_field_0": "2010-01-01 11:13:00",
 701                "datetime_field_1": "",
 702                "initial-datetime_field_0": "2010-01-01 10:12:00",
 703                "initial-datetime_field_1": "",
 704            }
 705        )
 706        form = DateTimeRangeForm(data=data)
 707        self.assertTrue(form.has_changed())
 708
 709        data["initial-datetime_field_0"] = "2010-01-01 11:13:00"
 710        form = DateTimeRangeForm(data=data)
 711        self.assertFalse(form.has_changed())
 712
 713    def test_rendering(self):
 714        class RangeForm(forms.Form):
 715            ints = pg_forms.IntegerRangeField()
 716
 717        self.assertHTMLEqual(
 718            str(RangeForm()),
 719            """
 720        <tr>
 721            <th><label for="id_ints_0">Ints:</label></th>
 722            <td>
 723                <input id="id_ints_0" name="ints_0" type="number">
 724                <input id="id_ints_1" name="ints_1" type="number">
 725            </td>
 726        </tr>
 727        """,
 728        )
 729
 730    def test_integer_lower_bound_higher(self):
 731        field = pg_forms.IntegerRangeField()
 732        with self.assertRaises(exceptions.ValidationError) as cm:
 733            field.clean(["10", "2"])
 734        self.assertEqual(
 735            cm.exception.messages[0],
 736            "The start of the range must not exceed the end of the range.",
 737        )
 738        self.assertEqual(cm.exception.code, "bound_ordering")
 739
 740    def test_integer_open(self):
 741        field = pg_forms.IntegerRangeField()
 742        value = field.clean(["", "0"])
 743        self.assertEqual(value, NumericRange(None, 0))
 744
 745    def test_integer_incorrect_data_type(self):
 746        field = pg_forms.IntegerRangeField()
 747        with self.assertRaises(exceptions.ValidationError) as cm:
 748            field.clean("1")
 749        self.assertEqual(cm.exception.messages[0], "Enter two whole numbers.")
 750        self.assertEqual(cm.exception.code, "invalid")
 751
 752    def test_integer_invalid_lower(self):
 753        field = pg_forms.IntegerRangeField()
 754        with self.assertRaises(exceptions.ValidationError) as cm:
 755            field.clean(["a", "2"])
 756        self.assertEqual(cm.exception.messages[0], "Enter a whole number.")
 757
 758    def test_integer_invalid_upper(self):
 759        field = pg_forms.IntegerRangeField()
 760        with self.assertRaises(exceptions.ValidationError) as cm:
 761            field.clean(["1", "b"])
 762        self.assertEqual(cm.exception.messages[0], "Enter a whole number.")
 763
 764    def test_integer_required(self):
 765        field = pg_forms.IntegerRangeField(required=True)
 766        with self.assertRaises(exceptions.ValidationError) as cm:
 767            field.clean(["", ""])
 768        self.assertEqual(cm.exception.messages[0], "This field is required.")
 769        value = field.clean([1, ""])
 770        self.assertEqual(value, NumericRange(1, None))
 771
 772    def test_decimal_lower_bound_higher(self):
 773        field = pg_forms.DecimalRangeField()
 774        with self.assertRaises(exceptions.ValidationError) as cm:
 775            field.clean(["1.8", "1.6"])
 776        self.assertEqual(
 777            cm.exception.messages[0],
 778            "The start of the range must not exceed the end of the range.",
 779        )
 780        self.assertEqual(cm.exception.code, "bound_ordering")
 781
 782    def test_decimal_open(self):
 783        field = pg_forms.DecimalRangeField()
 784        value = field.clean(["", "3.1415926"])
 785        self.assertEqual(value, NumericRange(None, Decimal("3.1415926")))
 786
 787    def test_decimal_incorrect_data_type(self):
 788        field = pg_forms.DecimalRangeField()
 789        with self.assertRaises(exceptions.ValidationError) as cm:
 790            field.clean("1.6")
 791        self.assertEqual(cm.exception.messages[0], "Enter two numbers.")
 792        self.assertEqual(cm.exception.code, "invalid")
 793
 794    def test_decimal_invalid_lower(self):
 795        field = pg_forms.DecimalRangeField()
 796        with self.assertRaises(exceptions.ValidationError) as cm:
 797            field.clean(["a", "3.1415926"])
 798        self.assertEqual(cm.exception.messages[0], "Enter a number.")
 799
 800    def test_decimal_invalid_upper(self):
 801        field = pg_forms.DecimalRangeField()
 802        with self.assertRaises(exceptions.ValidationError) as cm:
 803            field.clean(["1.61803399", "b"])
 804        self.assertEqual(cm.exception.messages[0], "Enter a number.")
 805
 806    def test_decimal_required(self):
 807        field = pg_forms.DecimalRangeField(required=True)
 808        with self.assertRaises(exceptions.ValidationError) as cm:
 809            field.clean(["", ""])
 810        self.assertEqual(cm.exception.messages[0], "This field is required.")
 811        value = field.clean(["1.61803399", ""])
 812        self.assertEqual(value, NumericRange(Decimal("1.61803399"), None))
 813
 814    def test_date_lower_bound_higher(self):
 815        field = pg_forms.DateRangeField()
 816        with self.assertRaises(exceptions.ValidationError) as cm:
 817            field.clean(["2013-04-09", "1976-04-16"])
 818        self.assertEqual(
 819            cm.exception.messages[0],
 820            "The start of the range must not exceed the end of the range.",
 821        )
 822        self.assertEqual(cm.exception.code, "bound_ordering")
 823
 824    def test_date_open(self):
 825        field = pg_forms.DateRangeField()
 826        value = field.clean(["", "2013-04-09"])
 827        self.assertEqual(value, DateRange(None, datetime.date(2013, 4, 9)))
 828
 829    def test_date_incorrect_data_type(self):
 830        field = pg_forms.DateRangeField()
 831        with self.assertRaises(exceptions.ValidationError) as cm:
 832            field.clean("1")
 833        self.assertEqual(cm.exception.messages[0], "Enter two valid dates.")
 834        self.assertEqual(cm.exception.code, "invalid")
 835
 836    def test_date_invalid_lower(self):
 837        field = pg_forms.DateRangeField()
 838        with self.assertRaises(exceptions.ValidationError) as cm:
 839            field.clean(["a", "2013-04-09"])
 840        self.assertEqual(cm.exception.messages[0], "Enter a valid date.")
 841
 842    def test_date_invalid_upper(self):
 843        field = pg_forms.DateRangeField()
 844        with self.assertRaises(exceptions.ValidationError) as cm:
 845            field.clean(["2013-04-09", "b"])
 846        self.assertEqual(cm.exception.messages[0], "Enter a valid date.")
 847
 848    def test_date_required(self):
 849        field = pg_forms.DateRangeField(required=True)
 850        with self.assertRaises(exceptions.ValidationError) as cm:
 851            field.clean(["", ""])
 852        self.assertEqual(cm.exception.messages[0], "This field is required.")
 853        value = field.clean(["1976-04-16", ""])
 854        self.assertEqual(value, DateRange(datetime.date(1976, 4, 16), None))
 855
 856    def test_date_has_changed_first(self):
 857        self.assertTrue(
 858            pg_forms.DateRangeField().has_changed(
 859                ["2010-01-01", "2020-12-12"],
 860                ["2010-01-31", "2020-12-12"],
 861            )
 862        )
 863
 864    def test_date_has_changed_last(self):
 865        self.assertTrue(
 866            pg_forms.DateRangeField().has_changed(
 867                ["2010-01-01", "2020-12-12"],
 868                ["2010-01-01", "2020-12-31"],
 869            )
 870        )
 871
 872    def test_datetime_lower_bound_higher(self):
 873        field = pg_forms.DateTimeRangeField()
 874        with self.assertRaises(exceptions.ValidationError) as cm:
 875            field.clean(["2006-10-25 14:59", "2006-10-25 14:58"])
 876        self.assertEqual(
 877            cm.exception.messages[0],
 878            "The start of the range must not exceed the end of the range.",
 879        )
 880        self.assertEqual(cm.exception.code, "bound_ordering")
 881
 882    def test_datetime_open(self):
 883        field = pg_forms.DateTimeRangeField()
 884        value = field.clean(["", "2013-04-09 11:45"])
 885        self.assertEqual(
 886            value, DateTimeTZRange(None, datetime.datetime(2013, 4, 9, 11, 45))
 887        )
 888
 889    def test_datetime_incorrect_data_type(self):
 890        field = pg_forms.DateTimeRangeField()
 891        with self.assertRaises(exceptions.ValidationError) as cm:
 892            field.clean("2013-04-09 11:45")
 893        self.assertEqual(cm.exception.messages[0], "Enter two valid date/times.")
 894        self.assertEqual(cm.exception.code, "invalid")
 895
 896    def test_datetime_invalid_lower(self):
 897        field = pg_forms.DateTimeRangeField()
 898        with self.assertRaises(exceptions.ValidationError) as cm:
 899            field.clean(["45", "2013-04-09 11:45"])
 900        self.assertEqual(cm.exception.messages[0], "Enter a valid date/time.")
 901
 902    def test_datetime_invalid_upper(self):
 903        field = pg_forms.DateTimeRangeField()
 904        with self.assertRaises(exceptions.ValidationError) as cm:
 905            field.clean(["2013-04-09 11:45", "sweet pickles"])
 906        self.assertEqual(cm.exception.messages[0], "Enter a valid date/time.")
 907
 908    def test_datetime_required(self):
 909        field = pg_forms.DateTimeRangeField(required=True)
 910        with self.assertRaises(exceptions.ValidationError) as cm:
 911            field.clean(["", ""])
 912        self.assertEqual(cm.exception.messages[0], "This field is required.")
 913        value = field.clean(["2013-04-09 11:45", ""])
 914        self.assertEqual(
 915            value, DateTimeTZRange(datetime.datetime(2013, 4, 9, 11, 45), None)
 916        )
 917
 918    @override_settings(USE_TZ=True, TIME_ZONE="Africa/Johannesburg")
 919    def test_datetime_prepare_value(self):
 920        field = pg_forms.DateTimeRangeField()
 921        value = field.prepare_value(
 922            DateTimeTZRange(
 923                datetime.datetime(2015, 5, 22, 16, 6, 33, tzinfo=timezone.utc), None
 924            )
 925        )
 926        self.assertEqual(value, [datetime.datetime(2015, 5, 22, 18, 6, 33), None])
 927
 928    def test_datetime_has_changed_first(self):
 929        self.assertTrue(
 930            pg_forms.DateTimeRangeField().has_changed(
 931                ["2010-01-01 00:00", "2020-12-12 00:00"],
 932                ["2010-01-31 23:00", "2020-12-12 00:00"],
 933            )
 934        )
 935
 936    def test_datetime_has_changed_last(self):
 937        self.assertTrue(
 938            pg_forms.DateTimeRangeField().has_changed(
 939                ["2010-01-01 00:00", "2020-12-12 00:00"],
 940                ["2010-01-01 00:00", "2020-12-31 23:00"],
 941            )
 942        )
 943
 944    def test_model_field_formfield_integer(self):
 945        model_field = pg_fields.IntegerRangeField()
 946        form_field = model_field.formfield()
 947        self.assertIsInstance(form_field, pg_forms.IntegerRangeField)
 948
 949    def test_model_field_formfield_biginteger(self):
 950        model_field = pg_fields.BigIntegerRangeField()
 951        form_field = model_field.formfield()
 952        self.assertIsInstance(form_field, pg_forms.IntegerRangeField)
 953
 954    def test_model_field_formfield_float(self):
 955        model_field = pg_fields.DecimalRangeField()
 956        form_field = model_field.formfield()
 957        self.assertIsInstance(form_field, pg_forms.DecimalRangeField)
 958
 959    def test_model_field_formfield_date(self):
 960        model_field = pg_fields.DateRangeField()
 961        form_field = model_field.formfield()
 962        self.assertIsInstance(form_field, pg_forms.DateRangeField)
 963
 964    def test_model_field_formfield_datetime(self):
 965        model_field = pg_fields.DateTimeRangeField()
 966        form_field = model_field.formfield()
 967        self.assertIsInstance(form_field, pg_forms.DateTimeRangeField)
 968
 969    def test_has_changed(self):
 970        for field, value in (
 971            (pg_forms.DateRangeField(), ["2010-01-01", "2020-12-12"]),
 972            (pg_forms.DateTimeRangeField(), ["2010-01-01 11:13", "2020-12-12 14:52"]),
 973            (pg_forms.IntegerRangeField(), [1, 2]),
 974            (pg_forms.DecimalRangeField(), ["1.12345", "2.001"]),
 975        ):
 976            with self.subTest(field=field.__class__.__name__):
 977                self.assertTrue(field.has_changed(None, value))
 978                self.assertTrue(field.has_changed([value[0], ""], value))
 979                self.assertTrue(field.has_changed(["", value[1]], value))
 980                self.assertFalse(field.has_changed(value, value))
 981
 982
 983class TestWidget(PostgreSQLSimpleTestCase):
 984    def test_range_widget(self):
 985        f = pg_forms.ranges.DateTimeRangeField()
 986        self.assertHTMLEqual(
 987            f.widget.render("datetimerange", ""),
 988            '<input type="text" name="datetimerange_0"><input type="text" name="datetimerange_1">',
 989        )
 990        self.assertHTMLEqual(
 991            f.widget.render("datetimerange", None),
 992            '<input type="text" name="datetimerange_0"><input type="text" name="datetimerange_1">',
 993        )
 994        dt_range = DateTimeTZRange(
 995            datetime.datetime(2006, 1, 10, 7, 30), datetime.datetime(2006, 2, 12, 9, 50)
 996        )
 997        self.assertHTMLEqual(
 998            f.widget.render("datetimerange", dt_range),
 999            '<input type="text" name="datetimerange_0" value="2006-01-10 07:30:00">'
1000            '<input type="text" name="datetimerange_1" value="2006-02-12 09:50:00">',
1001        )

test_constraints.py

  1import datetime
  2from unittest import mock
  3
  4from django.db import connection
  5from django.db import IntegrityError
  6from django.db import transaction
  7from django.db.models import CheckConstraint
  8from django.db.models import F
  9from django.db.models import Func
 10from django.db.models import Q
 11from django.utils import timezone
 12
 13from . import PostgreSQLTestCase
 14from .models import HotelReservation
 15from .models import RangesModel
 16from .models import Room
 17
 18try:
 19    from django.contrib.postgres.constraints import ExclusionConstraint
 20    from django.contrib.postgres.fields import (
 21        DateTimeRangeField,
 22        RangeBoundary,
 23        RangeOperators,
 24    )
 25
 26    from psycopg2.extras import DateRange, NumericRange
 27except ImportError:
 28    pass
 29
 30
 31class SchemaTests(PostgreSQLTestCase):
 32    def get_constraints(self, table):
 33        """Get the constraints on the table using a new cursor."""
 34        with connection.cursor() as cursor:
 35            return connection.introspection.get_constraints(cursor, table)
 36
 37    def test_check_constraint_range_value(self):
 38        constraint_name = "ints_between"
 39        self.assertNotIn(
 40            constraint_name, self.get_constraints(RangesModel._meta.db_table)
 41        )
 42        constraint = CheckConstraint(
 43            check=Q(ints__contained_by=NumericRange(10, 30)),
 44            name=constraint_name,
 45        )
 46        with connection.schema_editor() as editor:
 47            editor.add_constraint(RangesModel, constraint)
 48        self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
 49        with self.assertRaises(IntegrityError), transaction.atomic():
 50            RangesModel.objects.create(ints=(20, 50))
 51        RangesModel.objects.create(ints=(10, 30))
 52
 53    def test_check_constraint_daterange_contains(self):
 54        constraint_name = "dates_contains"
 55        self.assertNotIn(
 56            constraint_name, self.get_constraints(RangesModel._meta.db_table)
 57        )
 58        constraint = CheckConstraint(
 59            check=Q(dates__contains=F("dates_inner")),
 60            name=constraint_name,
 61        )
 62        with connection.schema_editor() as editor:
 63            editor.add_constraint(RangesModel, constraint)
 64        self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
 65        date_1 = datetime.date(2016, 1, 1)
 66        date_2 = datetime.date(2016, 1, 4)
 67        with self.assertRaises(IntegrityError), transaction.atomic():
 68            RangesModel.objects.create(
 69                dates=(date_1, date_2),
 70                dates_inner=(date_1, date_2.replace(day=5)),
 71            )
 72        RangesModel.objects.create(
 73            dates=(date_1, date_2),
 74            dates_inner=(date_1, date_2),
 75        )
 76
 77    def test_check_constraint_datetimerange_contains(self):
 78        constraint_name = "timestamps_contains"
 79        self.assertNotIn(
 80            constraint_name, self.get_constraints(RangesModel._meta.db_table)
 81        )
 82        constraint = CheckConstraint(
 83            check=Q(timestamps__contains=F("timestamps_inner")),
 84            name=constraint_name,
 85        )
 86        with connection.schema_editor() as editor:
 87            editor.add_constraint(RangesModel, constraint)
 88        self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
 89        datetime_1 = datetime.datetime(2016, 1, 1)
 90        datetime_2 = datetime.datetime(2016, 1, 2, 12)
 91        with self.assertRaises(IntegrityError), transaction.atomic():
 92            RangesModel.objects.create(
 93                timestamps=(datetime_1, datetime_2),
 94                timestamps_inner=(datetime_1, datetime_2.replace(hour=13)),
 95            )
 96        RangesModel.objects.create(
 97            timestamps=(datetime_1, datetime_2),
 98            timestamps_inner=(datetime_1, datetime_2),
 99        )
100
101
102class ExclusionConstraintTests(PostgreSQLTestCase):
103    def get_constraints(self, table):
104        """Get the constraints on the table using a new cursor."""
105        with connection.cursor() as cursor:
106            return connection.introspection.get_constraints(cursor, table)
107
108    def test_invalid_condition(self):
109        msg = "ExclusionConstraint.condition must be a Q instance."
110        with self.assertRaisesMessage(ValueError, msg):
111            ExclusionConstraint(
112                index_type="GIST",
113                name="exclude_invalid_condition",
114                expressions=[(F("datespan"), RangeOperators.OVERLAPS)],
115                condition=F("invalid"),
116            )
117
118    def test_invalid_index_type(self):
119        msg = "Exclusion constraints only support GiST or SP-GiST indexes."
120        with self.assertRaisesMessage(ValueError, msg):
121            ExclusionConstraint(
122                index_type="gin",
123                name="exclude_invalid_index_type",
124                expressions=[(F("datespan"), RangeOperators.OVERLAPS)],
125            )
126
127    def test_invalid_expressions(self):
128        msg = "The expressions must be a list of 2-tuples."
129        for expressions in (["foo"], [("foo")], [("foo_1", "foo_2", "foo_3")]):
130            with self.subTest(expressions), self.assertRaisesMessage(ValueError, msg):
131                ExclusionConstraint(
132                    index_type="GIST",
133                    name="exclude_invalid_expressions",
134                    expressions=expressions,
135                )
136
137    def test_empty_expressions(self):
138        msg = "At least one expression is required to define an exclusion constraint."
139        for empty_expressions in (None, []):
140            with self.subTest(empty_expressions), self.assertRaisesMessage(
141                ValueError, msg
142            ):
143                ExclusionConstraint(
144                    index_type="GIST",
145                    name="exclude_empty_expressions",
146                    expressions=empty_expressions,
147                )
148
149    def test_repr(self):
150        constraint = ExclusionConstraint(
151            name="exclude_overlapping",
152            expressions=[
153                (F("datespan"), RangeOperators.OVERLAPS),
154                (F("room"), RangeOperators.EQUAL),
155            ],
156        )
157        self.assertEqual(
158            repr(constraint),
159            "<ExclusionConstraint: index_type=GIST, expressions=["
160            "(F(datespan), '&&'), (F(room), '=')]>",
161        )
162        constraint = ExclusionConstraint(
163            name="exclude_overlapping",
164            expressions=[(F("datespan"), RangeOperators.ADJACENT_TO)],
165            condition=Q(cancelled=False),
166            index_type="SPGiST",
167        )
168        self.assertEqual(
169            repr(constraint),
170            "<ExclusionConstraint: index_type=SPGiST, expressions=["
171            "(F(datespan), '-|-')], condition=(AND: ('cancelled', False))>",
172        )
173
174    def test_eq(self):
175        constraint_1 = ExclusionConstraint(
176            name="exclude_overlapping",
177            expressions=[
178                (F("datespan"), RangeOperators.OVERLAPS),
179                (F("room"), RangeOperators.EQUAL),
180            ],
181            condition=Q(cancelled=False),
182        )
183        constraint_2 = ExclusionConstraint(
184            name="exclude_overlapping",
185            expressions=[
186                ("datespan", RangeOperators.OVERLAPS),
187                ("room", RangeOperators.EQUAL),
188            ],
189        )
190        constraint_3 = ExclusionConstraint(
191            name="exclude_overlapping",
192            expressions=[("datespan", RangeOperators.OVERLAPS)],
193            condition=Q(cancelled=False),
194        )
195        self.assertEqual(constraint_1, constraint_1)
196        self.assertEqual(constraint_1, mock.ANY)
197        self.assertNotEqual(constraint_1, constraint_2)
198        self.assertNotEqual(constraint_1, constraint_3)
199        self.assertNotEqual(constraint_2, constraint_3)
200        self.assertNotEqual(constraint_1, object())
201
202    def test_deconstruct(self):
203        constraint = ExclusionConstraint(
204            name="exclude_overlapping",
205            expressions=[
206                ("datespan", RangeOperators.OVERLAPS),
207                ("room", RangeOperators.EQUAL),
208            ],
209        )
210        path, args, kwargs = constraint.deconstruct()
211        self.assertEqual(
212            path, "django.contrib.postgres.constraints.ExclusionConstraint"
213        )
214        self.assertEqual(args, ())
215        self.assertEqual(
216            kwargs,
217            {
218                "name": "exclude_overlapping",
219                "expressions": [
220                    ("datespan", RangeOperators.OVERLAPS),
221                    ("room", RangeOperators.EQUAL),
222                ],
223            },
224        )
225
226    def test_deconstruct_index_type(self):
227        constraint = ExclusionConstraint(
228            name="exclude_overlapping",
229            index_type="SPGIST",
230            expressions=[
231                ("datespan", RangeOperators.OVERLAPS),
232                ("room", RangeOperators.EQUAL),
233            ],
234        )
235        path, args, kwargs = constraint.deconstruct()
236        self.assertEqual(
237            path, "django.contrib.postgres.constraints.ExclusionConstraint"
238        )
239        self.assertEqual(args, ())
240        self.assertEqual(
241            kwargs,
242            {
243                "name": "exclude_overlapping",
244                "index_type": "SPGIST",
245                "expressions": [
246                    ("datespan", RangeOperators.OVERLAPS),
247                    ("room", RangeOperators.EQUAL),
248                ],
249            },
250        )
251
252    def test_deconstruct_condition(self):
253        constraint = ExclusionConstraint(
254            name="exclude_overlapping",
255            expressions=[
256                ("datespan", RangeOperators.OVERLAPS),
257                ("room", RangeOperators.EQUAL),
258            ],
259            condition=Q(cancelled=False),
260        )
261        path, args, kwargs = constraint.deconstruct()
262        self.assertEqual(
263            path, "django.contrib.postgres.constraints.ExclusionConstraint"
264        )
265        self.assertEqual(args, ())
266        self.assertEqual(
267            kwargs,
268            {
269                "name": "exclude_overlapping",
270                "expressions": [
271                    ("datespan", RangeOperators.OVERLAPS),
272                    ("room", RangeOperators.EQUAL),
273                ],
274                "condition": Q(cancelled=False),
275            },
276        )
277
278    def _test_range_overlaps(self, constraint):
279        # Create exclusion constraint.
280        self.assertNotIn(
281            constraint.name, self.get_constraints(HotelReservation._meta.db_table)
282        )
283        with connection.schema_editor() as editor:
284            editor.add_constraint(HotelReservation, constraint)
285        self.assertIn(
286            constraint.name, self.get_constraints(HotelReservation._meta.db_table)
287        )
288        # Add initial reservations.
289        room101 = Room.objects.create(number=101)
290        room102 = Room.objects.create(number=102)
291        datetimes = [
292            timezone.datetime(2018, 6, 20),
293            timezone.datetime(2018, 6, 24),
294            timezone.datetime(2018, 6, 26),
295            timezone.datetime(2018, 6, 28),
296            timezone.datetime(2018, 6, 29),
297        ]
298        HotelReservation.objects.create(
299            datespan=DateRange(datetimes[0].date(), datetimes[1].date()),
300            start=datetimes[0],
301            end=datetimes[1],
302            room=room102,
303        )
304        HotelReservation.objects.create(
305            datespan=DateRange(datetimes[1].date(), datetimes[3].date()),
306            start=datetimes[1],
307            end=datetimes[3],
308            room=room102,
309        )
310        # Overlap dates.
311        with self.assertRaises(IntegrityError), transaction.atomic():
312            reservation = HotelReservation(
313                datespan=(datetimes[1].date(), datetimes[2].date()),
314                start=datetimes[1],
315                end=datetimes[2],
316                room=room102,
317            )
318            reservation.save()
319        # Valid range.
320        HotelReservation.objects.bulk_create(
321            [
322                # Other room.
323                HotelReservation(
324                    datespan=(datetimes[1].date(), datetimes[2].date()),
325                    start=datetimes[1],
326                    end=datetimes[2],
327                    room=room101,
328                ),
329                # Cancelled reservation.
330                HotelReservation(
331                    datespan=(datetimes[1].date(), datetimes[1].date()),
332                    start=datetimes[1],
333                    end=datetimes[2],
334                    room=room102,
335                    cancelled=True,
336                ),
337                # Other adjacent dates.
338                HotelReservation(
339                    datespan=(datetimes[3].date(), datetimes[4].date()),
340                    start=datetimes[3],
341                    end=datetimes[4],
342                    room=room102,
343                ),
344            ]
345        )
346
347    def test_range_overlaps_custom(self):
348        class TsTzRange(Func):
349            function = "TSTZRANGE"
350            output_field = DateTimeRangeField()
351
352        constraint = ExclusionConstraint(
353            name="exclude_overlapping_reservations_custom",
354            expressions=[
355                (TsTzRange("start", "end", RangeBoundary()), RangeOperators.OVERLAPS),
356                ("room", RangeOperators.EQUAL),
357            ],
358            condition=Q(cancelled=False),
359        )
360        self._test_range_overlaps(constraint)
361
362    def test_range_overlaps(self):
363        constraint = ExclusionConstraint(
364            name="exclude_overlapping_reservations",
365            expressions=[
366                (F("datespan"), RangeOperators.OVERLAPS),
367                ("room", RangeOperators.EQUAL),
368            ],
369            condition=Q(cancelled=False),
370        )
371        self._test_range_overlaps(constraint)
372
373    def test_range_adjacent(self):
374        constraint_name = "ints_adjacent"
375        self.assertNotIn(
376            constraint_name, self.get_constraints(RangesModel._meta.db_table)
377        )
378        constraint = ExclusionConstraint(
379            name=constraint_name,
380            expressions=[("ints", RangeOperators.ADJACENT_TO)],
381        )
382        with connection.schema_editor() as editor:
383            editor.add_constraint(RangesModel, constraint)
384        self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
385        RangesModel.objects.create(ints=(20, 50))
386        with self.assertRaises(IntegrityError), transaction.atomic():
387            RangesModel.objects.create(ints=(10, 20))
388        RangesModel.objects.create(ints=(10, 19))
389        RangesModel.objects.create(ints=(51, 60))