Tests RangeField ¶
Contents
__init__.py ¶
1import unittest
2
3from django.db import connection
4from django.test import modify_settings
5from django.test import SimpleTestCase
6from django.test import TestCase
7from forms_tests.widget_tests.base import WidgetTest
8
9
10@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
11class PostgreSQLSimpleTestCase(SimpleTestCase):
12 pass
13
14
15@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
16class PostgreSQLTestCase(TestCase):
17 pass
18
19
20@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
21# To locate the widget's template.
22@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
23class PostgreSQLWidgetTestCase(WidgetTest, PostgreSQLSimpleTestCase):
24 pass
fields.py ¶
1"""
2Indirection layer for PostgreSQL-specific fields, so the tests don't fail when
3run with a backend other than PostgreSQL.
4"""
5import enum
6
7from django.db import models
8
9try:
10 from django.contrib.postgres.fields import (
11 ArrayField,
12 BigIntegerRangeField,
13 CICharField,
14 CIEmailField,
15 CITextField,
16 DateRangeField,
17 DateTimeRangeField,
18 DecimalRangeField,
19 HStoreField,
20 IntegerRangeField,
21 JSONField,
22 )
23 from django.contrib.postgres.search import SearchVectorField
24except ImportError:
25
26 class DummyArrayField(models.Field):
27 def __init__(self, base_field, size=None, **kwargs):
28 super().__init__(**kwargs)
29
30 def deconstruct(self):
31 name, path, args, kwargs = super().deconstruct()
32 kwargs.update(
33 {
34 "base_field": "",
35 "size": 1,
36 }
37 )
38 return name, path, args, kwargs
39
40 class DummyJSONField(models.Field):
41 def __init__(self, encoder=None, **kwargs):
42 super().__init__(**kwargs)
43
44 ArrayField = DummyArrayField
45 BigIntegerRangeField = models.Field
46 CICharField = models.Field
47 CIEmailField = models.Field
48 CITextField = models.Field
49 DateRangeField = models.Field
50 DateTimeRangeField = models.Field
51 DecimalRangeField = models.Field
52 HStoreField = models.Field
53 IntegerRangeField = models.Field
54 JSONField = DummyJSONField
55 SearchVectorField = models.Field
56
57
58class EnumField(models.CharField):
59 def get_prep_value(self, value):
60 return value.value if isinstance(value, enum.Enum) else value
models.py ¶
1from django.core.serializers.json import DjangoJSONEncoder
2from django.db import models
3
4from .fields import ArrayField
5from .fields import BigIntegerRangeField
6from .fields import CICharField
7from .fields import CIEmailField
8from .fields import CITextField
9from .fields import DateRangeField
10from .fields import DateTimeRangeField
11from .fields import DecimalRangeField
12from .fields import EnumField
13from .fields import HStoreField
14from .fields import IntegerRangeField
15from .fields import JSONField
16from .fields import SearchVectorField
17
18
19class Tag:
20 def __init__(self, tag_id):
21 self.tag_id = tag_id
22
23 def __eq__(self, other):
24 return isinstance(other, Tag) and self.tag_id == other.tag_id
25
26
27class TagField(models.SmallIntegerField):
28 def from_db_value(self, value, expression, connection):
29 if value is None:
30 return value
31 return Tag(int(value))
32
33 def to_python(self, value):
34 if isinstance(value, Tag):
35 return value
36 if value is None:
37 return value
38 return Tag(int(value))
39
40 def get_prep_value(self, value):
41 return value.tag_id
42
43
44class PostgreSQLModel(models.Model):
45 class Meta:
46 abstract = True
47 required_db_vendor = "postgresql"
48
49
50class IntegerArrayModel(PostgreSQLModel):
51 field = ArrayField(models.IntegerField(), default=list, blank=True)
52
53
54class NullableIntegerArrayModel(PostgreSQLModel):
55 field = ArrayField(models.IntegerField(), blank=True, null=True)
56 field_nested = ArrayField(ArrayField(models.IntegerField(null=True)), null=True)
57
58
59class CharArrayModel(PostgreSQLModel):
60 field = ArrayField(models.CharField(max_length=10))
61
62
63class DateTimeArrayModel(PostgreSQLModel):
64 datetimes = ArrayField(models.DateTimeField())
65 dates = ArrayField(models.DateField())
66 times = ArrayField(models.TimeField())
67
68
69class NestedIntegerArrayModel(PostgreSQLModel):
70 field = ArrayField(ArrayField(models.IntegerField()))
71
72
73class OtherTypesArrayModel(PostgreSQLModel):
74 ips = ArrayField(models.GenericIPAddressField(), default=list)
75 uuids = ArrayField(models.UUIDField(), default=list)
76 decimals = ArrayField(
77 models.DecimalField(max_digits=5, decimal_places=2), default=list
78 )
79 tags = ArrayField(TagField(), blank=True, null=True)
80 json = ArrayField(JSONField(default=dict), default=list)
81 int_ranges = ArrayField(IntegerRangeField(), blank=True, null=True)
82 bigint_ranges = ArrayField(BigIntegerRangeField(), blank=True, null=True)
83
84
85class HStoreModel(PostgreSQLModel):
86 field = HStoreField(blank=True, null=True)
87 array_field = ArrayField(HStoreField(), null=True)
88
89
90class ArrayEnumModel(PostgreSQLModel):
91 array_of_enums = ArrayField(EnumField(max_length=20))
92
93
94class CharFieldModel(models.Model):
95 field = models.CharField(max_length=16)
96
97
98class TextFieldModel(models.Model):
99 field = models.TextField()
100
101 def __str__(self):
102 return self.field
103
104
105class SmallAutoFieldModel(models.Model):
106 id = models.SmallAutoField(primary_key=True)
107
108
109class BigAutoFieldModel(models.Model):
110 id = models.BigAutoField(primary_key=True)
111
112
113# Scene/Character/Line models are used to test full text search. They're
114# populated with content from Monty Python and the Holy Grail.
115class Scene(models.Model):
116 scene = models.CharField(max_length=255)
117 setting = models.CharField(max_length=255)
118
119 def __str__(self):
120 return self.scene
121
122
123class Character(models.Model):
124 name = models.CharField(max_length=255)
125
126 def __str__(self):
127 return self.name
128
129
130class CITestModel(PostgreSQLModel):
131 name = CICharField(primary_key=True, max_length=255)
132 email = CIEmailField()
133 description = CITextField()
134 array_field = ArrayField(CITextField(), null=True)
135
136 def __str__(self):
137 return self.name
138
139
140class Line(PostgreSQLModel):
141 scene = models.ForeignKey("Scene", models.CASCADE)
142 character = models.ForeignKey("Character", models.CASCADE)
143 dialogue = models.TextField(blank=True, null=True)
144 dialogue_search_vector = SearchVectorField(blank=True, null=True)
145 dialogue_config = models.CharField(max_length=100, blank=True, null=True)
146
147 def __str__(self):
148 return self.dialogue or ""
149
150
151class RangesModel(PostgreSQLModel):
152 ints = IntegerRangeField(blank=True, null=True)
153 bigints = BigIntegerRangeField(blank=True, null=True)
154 decimals = DecimalRangeField(blank=True, null=True)
155 timestamps = DateTimeRangeField(blank=True, null=True)
156 timestamps_inner = DateTimeRangeField(blank=True, null=True)
157 dates = DateRangeField(blank=True, null=True)
158 dates_inner = DateRangeField(blank=True, null=True)
159
160
161class RangeLookupsModel(PostgreSQLModel):
162 parent = models.ForeignKey(RangesModel, models.SET_NULL, blank=True, null=True)
163 integer = models.IntegerField(blank=True, null=True)
164 big_integer = models.BigIntegerField(blank=True, null=True)
165 float = models.FloatField(blank=True, null=True)
166 timestamp = models.DateTimeField(blank=True, null=True)
167 date = models.DateField(blank=True, null=True)
168 small_integer = models.SmallIntegerField(blank=True, null=True)
169 decimal_field = models.DecimalField(
170 max_digits=5, decimal_places=2, blank=True, null=True
171 )
172
173
174class JSONModel(PostgreSQLModel):
175 field = JSONField(blank=True, null=True)
176 field_custom = JSONField(blank=True, null=True, encoder=DjangoJSONEncoder)
177
178
179class ArrayFieldSubclass(ArrayField):
180 def __init__(self, *args, **kwargs):
181 super().__init__(models.IntegerField())
182
183
184class AggregateTestModel(models.Model):
185 """
186 To test postgres-specific general aggregation functions
187 """
188
189 char_field = models.CharField(max_length=30, blank=True)
190 integer_field = models.IntegerField(null=True)
191 boolean_field = models.BooleanField(null=True)
192
193
194class StatTestModel(models.Model):
195 """
196 To test postgres-specific aggregation functions for statistics
197 """
198
199 int1 = models.IntegerField()
200 int2 = models.IntegerField()
201 related_field = models.ForeignKey(AggregateTestModel, models.SET_NULL, null=True)
202
203
204class NowTestModel(models.Model):
205 when = models.DateTimeField(null=True, default=None)
206
207
208class UUIDTestModel(models.Model):
209 uuid = models.UUIDField(default=None, null=True)
210
211
212class Room(models.Model):
213 number = models.IntegerField(unique=True)
214
215
216class HotelReservation(PostgreSQLModel):
217 room = models.ForeignKey("Room", on_delete=models.CASCADE)
218 datespan = DateRangeField()
219 start = models.DateTimeField()
220 end = models.DateTimeField()
221 cancelled = models.BooleanField(default=False)
test_ranges.py ¶
1import datetime
2import json
3from decimal import Decimal
4
5from django import forms
6from django.core import exceptions
7from django.core import serializers
8from django.db.models import DateField
9from django.db.models import DateTimeField
10from django.db.models import F
11from django.db.models import Func
12from django.db.models import Value
13from django.http import QueryDict
14from django.test import override_settings
15from django.test.utils import isolate_apps
16from django.utils import timezone
17
18from . import PostgreSQLSimpleTestCase
19from . import PostgreSQLTestCase
20from .models import BigAutoFieldModel
21from .models import PostgreSQLModel
22from .models import RangeLookupsModel
23from .models import RangesModel
24from .models import SmallAutoFieldModel
25
26try:
27 from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange
28 from django.contrib.postgres import fields as pg_fields, forms as pg_forms
29 from django.contrib.postgres.validators import (
30 RangeMaxValueValidator,
31 RangeMinValueValidator,
32 )
33except ImportError:
34 pass
35
36
37@isolate_apps("postgres_tests")
38class BasicTests(PostgreSQLSimpleTestCase):
39 def test_get_field_display(self):
40 class Model(PostgreSQLModel):
41 field = pg_fields.IntegerRangeField(
42 choices=[
43 ["1-50", [((1, 25), "1-25"), ([26, 50], "26-50")]],
44 ((51, 100), "51-100"),
45 ],
46 )
47
48 tests = (
49 ((1, 25), "1-25"),
50 ([26, 50], "26-50"),
51 ((51, 100), "51-100"),
52 ((1, 2), "(1, 2)"),
53 ([1, 2], "[1, 2]"),
54 )
55 for value, display in tests:
56 with self.subTest(value=value, display=display):
57 instance = Model(field=value)
58 self.assertEqual(instance.get_field_display(), display)
59
60
61class TestSaveLoad(PostgreSQLTestCase):
62 def test_all_fields(self):
63 now = timezone.now()
64 instance = RangesModel(
65 ints=NumericRange(0, 10),
66 bigints=NumericRange(10, 20),
67 decimals=NumericRange(20, 30),
68 timestamps=DateTimeTZRange(now - datetime.timedelta(hours=1), now),
69 dates=DateRange(now.date() - datetime.timedelta(days=1), now.date()),
70 )
71 instance.save()
72 loaded = RangesModel.objects.get()
73 self.assertEqual(instance.ints, loaded.ints)
74 self.assertEqual(instance.bigints, loaded.bigints)
75 self.assertEqual(instance.decimals, loaded.decimals)
76 self.assertEqual(instance.timestamps, loaded.timestamps)
77 self.assertEqual(instance.dates, loaded.dates)
78
79 def test_range_object(self):
80 r = NumericRange(0, 10)
81 instance = RangesModel(ints=r)
82 instance.save()
83 loaded = RangesModel.objects.get()
84 self.assertEqual(r, loaded.ints)
85
86 def test_tuple(self):
87 instance = RangesModel(ints=(0, 10))
88 instance.save()
89 loaded = RangesModel.objects.get()
90 self.assertEqual(NumericRange(0, 10), loaded.ints)
91
92 def test_range_object_boundaries(self):
93 r = NumericRange(0, 10, "[]")
94 instance = RangesModel(decimals=r)
95 instance.save()
96 loaded = RangesModel.objects.get()
97 self.assertEqual(r, loaded.decimals)
98 self.assertIn(10, loaded.decimals)
99
100 def test_unbounded(self):
101 r = NumericRange(None, None, "()")
102 instance = RangesModel(decimals=r)
103 instance.save()
104 loaded = RangesModel.objects.get()
105 self.assertEqual(r, loaded.decimals)
106
107 def test_empty(self):
108 r = NumericRange(empty=True)
109 instance = RangesModel(ints=r)
110 instance.save()
111 loaded = RangesModel.objects.get()
112 self.assertEqual(r, loaded.ints)
113
114 def test_null(self):
115 instance = RangesModel(ints=None)
116 instance.save()
117 loaded = RangesModel.objects.get()
118 self.assertIsNone(loaded.ints)
119
120 def test_model_set_on_base_field(self):
121 instance = RangesModel()
122 field = instance._meta.get_field("ints")
123 self.assertEqual(field.model, RangesModel)
124 self.assertEqual(field.base_field.model, RangesModel)
125
126
127class TestRangeContainsLookup(PostgreSQLTestCase):
128 @classmethod
129 def setUpTestData(cls):
130 cls.timestamps = [
131 datetime.datetime(year=2016, month=1, day=1),
132 datetime.datetime(year=2016, month=1, day=2, hour=1),
133 datetime.datetime(year=2016, month=1, day=2, hour=12),
134 datetime.datetime(year=2016, month=1, day=3),
135 datetime.datetime(year=2016, month=1, day=3, hour=1),
136 datetime.datetime(year=2016, month=2, day=2),
137 ]
138 cls.aware_timestamps = [
139 timezone.make_aware(timestamp) for timestamp in cls.timestamps
140 ]
141 cls.dates = [
142 datetime.date(year=2016, month=1, day=1),
143 datetime.date(year=2016, month=1, day=2),
144 datetime.date(year=2016, month=1, day=3),
145 datetime.date(year=2016, month=1, day=4),
146 datetime.date(year=2016, month=2, day=2),
147 datetime.date(year=2016, month=2, day=3),
148 ]
149 cls.obj = RangesModel.objects.create(
150 dates=(cls.dates[0], cls.dates[3]),
151 dates_inner=(cls.dates[1], cls.dates[2]),
152 timestamps=(cls.timestamps[0], cls.timestamps[3]),
153 timestamps_inner=(cls.timestamps[1], cls.timestamps[2]),
154 )
155 cls.aware_obj = RangesModel.objects.create(
156 dates=(cls.dates[0], cls.dates[3]),
157 dates_inner=(cls.dates[1], cls.dates[2]),
158 timestamps=(cls.aware_timestamps[0], cls.aware_timestamps[3]),
159 timestamps_inner=(cls.timestamps[1], cls.timestamps[2]),
160 )
161 # Objects that don't match any queries.
162 for i in range(3, 4):
163 RangesModel.objects.create(
164 dates=(cls.dates[i], cls.dates[i + 1]),
165 timestamps=(cls.timestamps[i], cls.timestamps[i + 1]),
166 )
167 RangesModel.objects.create(
168 dates=(cls.dates[i], cls.dates[i + 1]),
169 timestamps=(cls.aware_timestamps[i], cls.aware_timestamps[i + 1]),
170 )
171
172 def test_datetime_range_contains(self):
173 filter_args = (
174 self.timestamps[1],
175 self.aware_timestamps[1],
176 (self.timestamps[1], self.timestamps[2]),
177 (self.aware_timestamps[1], self.aware_timestamps[2]),
178 Value(self.dates[0], output_field=DateTimeField()),
179 Func(F("dates"), function="lower", output_field=DateTimeField()),
180 F("timestamps_inner"),
181 )
182 for filter_arg in filter_args:
183 with self.subTest(filter_arg=filter_arg):
184 self.assertCountEqual(
185 RangesModel.objects.filter(**{"timestamps__contains": filter_arg}),
186 [self.obj, self.aware_obj],
187 )
188
189 def test_date_range_contains(self):
190 filter_args = (
191 self.timestamps[1],
192 (self.dates[1], self.dates[2]),
193 Value(self.dates[0], output_field=DateField()),
194 Func(F("timestamps"), function="lower", output_field=DateField()),
195 F("dates_inner"),
196 )
197 for filter_arg in filter_args:
198 with self.subTest(filter_arg=filter_arg):
199 self.assertCountEqual(
200 RangesModel.objects.filter(**{"dates__contains": filter_arg}),
201 [self.obj, self.aware_obj],
202 )
203
204
205class TestQuerying(PostgreSQLTestCase):
206 @classmethod
207 def setUpTestData(cls):
208 cls.objs = RangesModel.objects.bulk_create(
209 [
210 RangesModel(ints=NumericRange(0, 10)),
211 RangesModel(ints=NumericRange(5, 15)),
212 RangesModel(ints=NumericRange(None, 0)),
213 RangesModel(ints=NumericRange(empty=True)),
214 RangesModel(ints=None),
215 ]
216 )
217
218 def test_exact(self):
219 self.assertSequenceEqual(
220 RangesModel.objects.filter(ints__exact=NumericRange(0, 10)),
221 [self.objs[0]],
222 )
223
224 def test_isnull(self):
225 self.assertSequenceEqual(
226 RangesModel.objects.filter(ints__isnull=True),
227 [self.objs[4]],
228 )
229
230 def test_isempty(self):
231 self.assertSequenceEqual(
232 RangesModel.objects.filter(ints__isempty=True),
233 [self.objs[3]],
234 )
235
236 def test_contains(self):
237 self.assertSequenceEqual(
238 RangesModel.objects.filter(ints__contains=8),
239 [self.objs[0], self.objs[1]],
240 )
241
242 def test_contains_range(self):
243 self.assertSequenceEqual(
244 RangesModel.objects.filter(ints__contains=NumericRange(3, 8)),
245 [self.objs[0]],
246 )
247
248 def test_contained_by(self):
249 self.assertSequenceEqual(
250 RangesModel.objects.filter(ints__contained_by=NumericRange(0, 20)),
251 [self.objs[0], self.objs[1], self.objs[3]],
252 )
253
254 def test_overlap(self):
255 self.assertSequenceEqual(
256 RangesModel.objects.filter(ints__overlap=NumericRange(3, 8)),
257 [self.objs[0], self.objs[1]],
258 )
259
260 def test_fully_lt(self):
261 self.assertSequenceEqual(
262 RangesModel.objects.filter(ints__fully_lt=NumericRange(5, 10)),
263 [self.objs[2]],
264 )
265
266 def test_fully_gt(self):
267 self.assertSequenceEqual(
268 RangesModel.objects.filter(ints__fully_gt=NumericRange(5, 10)),
269 [],
270 )
271
272 def test_not_lt(self):
273 self.assertSequenceEqual(
274 RangesModel.objects.filter(ints__not_lt=NumericRange(5, 10)),
275 [self.objs[1]],
276 )
277
278 def test_not_gt(self):
279 self.assertSequenceEqual(
280 RangesModel.objects.filter(ints__not_gt=NumericRange(5, 10)),
281 [self.objs[0], self.objs[2]],
282 )
283
284 def test_adjacent_to(self):
285 self.assertSequenceEqual(
286 RangesModel.objects.filter(ints__adjacent_to=NumericRange(0, 5)),
287 [self.objs[1], self.objs[2]],
288 )
289
290 def test_startswith(self):
291 self.assertSequenceEqual(
292 RangesModel.objects.filter(ints__startswith=0),
293 [self.objs[0]],
294 )
295
296 def test_endswith(self):
297 self.assertSequenceEqual(
298 RangesModel.objects.filter(ints__endswith=0),
299 [self.objs[2]],
300 )
301
302 def test_startswith_chaining(self):
303 self.assertSequenceEqual(
304 RangesModel.objects.filter(ints__startswith__gte=0),
305 [self.objs[0], self.objs[1]],
306 )
307
308 def test_bound_type(self):
309 decimals = RangesModel.objects.bulk_create(
310 [
311 RangesModel(decimals=NumericRange(None, 10)),
312 RangesModel(decimals=NumericRange(10, None)),
313 RangesModel(decimals=NumericRange(5, 15)),
314 RangesModel(decimals=NumericRange(5, 15, "(]")),
315 ]
316 )
317 tests = [
318 ("lower_inc", True, [decimals[1], decimals[2]]),
319 ("lower_inc", False, [decimals[0], decimals[3]]),
320 ("lower_inf", True, [decimals[0]]),
321 ("lower_inf", False, [decimals[1], decimals[2], decimals[3]]),
322 ("upper_inc", True, [decimals[3]]),
323 ("upper_inc", False, [decimals[0], decimals[1], decimals[2]]),
324 ("upper_inf", True, [decimals[1]]),
325 ("upper_inf", False, [decimals[0], decimals[2], decimals[3]]),
326 ]
327 for lookup, filter_arg, excepted_result in tests:
328 with self.subTest(lookup=lookup, filter_arg=filter_arg):
329 self.assertSequenceEqual(
330 RangesModel.objects.filter(**{"decimals__%s" % lookup: filter_arg}),
331 excepted_result,
332 )
333
334
335class TestQueryingWithRanges(PostgreSQLTestCase):
336 def test_date_range(self):
337 objs = [
338 RangeLookupsModel.objects.create(date="2015-01-01"),
339 RangeLookupsModel.objects.create(date="2015-05-05"),
340 ]
341 self.assertSequenceEqual(
342 RangeLookupsModel.objects.filter(
343 date__contained_by=DateRange("2015-01-01", "2015-05-04")
344 ),
345 [objs[0]],
346 )
347
348 def test_date_range_datetime_field(self):
349 objs = [
350 RangeLookupsModel.objects.create(timestamp="2015-01-01"),
351 RangeLookupsModel.objects.create(timestamp="2015-05-05"),
352 ]
353 self.assertSequenceEqual(
354 RangeLookupsModel.objects.filter(
355 timestamp__date__contained_by=DateRange("2015-01-01", "2015-05-04")
356 ),
357 [objs[0]],
358 )
359
360 def test_datetime_range(self):
361 objs = [
362 RangeLookupsModel.objects.create(timestamp="2015-01-01T09:00:00"),
363 RangeLookupsModel.objects.create(timestamp="2015-05-05T17:00:00"),
364 ]
365 self.assertSequenceEqual(
366 RangeLookupsModel.objects.filter(
367 timestamp__contained_by=DateTimeTZRange(
368 "2015-01-01T09:00", "2015-05-04T23:55"
369 )
370 ),
371 [objs[0]],
372 )
373
374 def test_small_integer_field_contained_by(self):
375 objs = [
376 RangeLookupsModel.objects.create(small_integer=8),
377 RangeLookupsModel.objects.create(small_integer=4),
378 RangeLookupsModel.objects.create(small_integer=-1),
379 ]
380 self.assertSequenceEqual(
381 RangeLookupsModel.objects.filter(
382 small_integer__contained_by=NumericRange(4, 6)
383 ),
384 [objs[1]],
385 )
386
387 def test_integer_range(self):
388 objs = [
389 RangeLookupsModel.objects.create(integer=5),
390 RangeLookupsModel.objects.create(integer=99),
391 RangeLookupsModel.objects.create(integer=-1),
392 ]
393 self.assertSequenceEqual(
394 RangeLookupsModel.objects.filter(integer__contained_by=NumericRange(1, 98)),
395 [objs[0]],
396 )
397
398 def test_biginteger_range(self):
399 objs = [
400 RangeLookupsModel.objects.create(big_integer=5),
401 RangeLookupsModel.objects.create(big_integer=99),
402 RangeLookupsModel.objects.create(big_integer=-1),
403 ]
404 self.assertSequenceEqual(
405 RangeLookupsModel.objects.filter(
406 big_integer__contained_by=NumericRange(1, 98)
407 ),
408 [objs[0]],
409 )
410
411 def test_decimal_field_contained_by(self):
412 objs = [
413 RangeLookupsModel.objects.create(decimal_field=Decimal("1.33")),
414 RangeLookupsModel.objects.create(decimal_field=Decimal("2.88")),
415 RangeLookupsModel.objects.create(decimal_field=Decimal("99.17")),
416 ]
417 self.assertSequenceEqual(
418 RangeLookupsModel.objects.filter(
419 decimal_field__contained_by=NumericRange(
420 Decimal("1.89"), Decimal("7.91")
421 ),
422 ),
423 [objs[1]],
424 )
425
426 def test_float_range(self):
427 objs = [
428 RangeLookupsModel.objects.create(float=5),
429 RangeLookupsModel.objects.create(float=99),
430 RangeLookupsModel.objects.create(float=-1),
431 ]
432 self.assertSequenceEqual(
433 RangeLookupsModel.objects.filter(float__contained_by=NumericRange(1, 98)),
434 [objs[0]],
435 )
436
437 def test_small_auto_field_contained_by(self):
438 objs = SmallAutoFieldModel.objects.bulk_create(
439 [SmallAutoFieldModel() for i in range(1, 5)]
440 )
441 self.assertSequenceEqual(
442 SmallAutoFieldModel.objects.filter(
443 id__contained_by=NumericRange(objs[1].pk, objs[3].pk),
444 ),
445 objs[1:3],
446 )
447
448 def test_auto_field_contained_by(self):
449 objs = RangeLookupsModel.objects.bulk_create(
450 [RangeLookupsModel() for i in range(1, 5)]
451 )
452 self.assertSequenceEqual(
453 RangeLookupsModel.objects.filter(
454 id__contained_by=NumericRange(objs[1].pk, objs[3].pk),
455 ),
456 objs[1:3],
457 )
458
459 def test_big_auto_field_contained_by(self):
460 objs = BigAutoFieldModel.objects.bulk_create(
461 [BigAutoFieldModel() for i in range(1, 5)]
462 )
463 self.assertSequenceEqual(
464 BigAutoFieldModel.objects.filter(
465 id__contained_by=NumericRange(objs[1].pk, objs[3].pk),
466 ),
467 objs[1:3],
468 )
469
470 def test_f_ranges(self):
471 parent = RangesModel.objects.create(decimals=NumericRange(0, 10))
472 objs = [
473 RangeLookupsModel.objects.create(float=5, parent=parent),
474 RangeLookupsModel.objects.create(float=99, parent=parent),
475 ]
476 self.assertSequenceEqual(
477 RangeLookupsModel.objects.filter(float__contained_by=F("parent__decimals")),
478 [objs[0]],
479 )
480
481 def test_exclude(self):
482 objs = [
483 RangeLookupsModel.objects.create(float=5),
484 RangeLookupsModel.objects.create(float=99),
485 RangeLookupsModel.objects.create(float=-1),
486 ]
487 self.assertSequenceEqual(
488 RangeLookupsModel.objects.exclude(float__contained_by=NumericRange(0, 100)),
489 [objs[2]],
490 )
491
492
493class TestSerialization(PostgreSQLSimpleTestCase):
494 test_data = (
495 '[{"fields": {"ints": "{\\"upper\\": \\"10\\", \\"lower\\": \\"0\\", '
496 '\\"bounds\\": \\"[)\\"}", "decimals": "{\\"empty\\": true}", '
497 '"bigints": null, "timestamps": "{\\"upper\\": \\"2014-02-02T12:12:12+00:00\\", '
498 '\\"lower\\": \\"2014-01-01T00:00:00+00:00\\", \\"bounds\\": \\"[)\\"}", '
499 '"timestamps_inner": null, '
500 '"dates": "{\\"upper\\": \\"2014-02-02\\", \\"lower\\": \\"2014-01-01\\", \\"bounds\\": \\"[)\\"}", '
501 '"dates_inner": null }, '
502 '"model": "postgres_tests.rangesmodel", "pk": null}]'
503 )
504
505 lower_date = datetime.date(2014, 1, 1)
506 upper_date = datetime.date(2014, 2, 2)
507 lower_dt = datetime.datetime(2014, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
508 upper_dt = datetime.datetime(2014, 2, 2, 12, 12, 12, tzinfo=timezone.utc)
509
510 def test_dumping(self):
511 instance = RangesModel(
512 ints=NumericRange(0, 10),
513 decimals=NumericRange(empty=True),
514 timestamps=DateTimeTZRange(self.lower_dt, self.upper_dt),
515 dates=DateRange(self.lower_date, self.upper_date),
516 )
517 data = serializers.serialize("json", [instance])
518 dumped = json.loads(data)
519 for field in ("ints", "dates", "timestamps"):
520 dumped[0]["fields"][field] = json.loads(dumped[0]["fields"][field])
521 check = json.loads(self.test_data)
522 for field in ("ints", "dates", "timestamps"):
523 check[0]["fields"][field] = json.loads(check[0]["fields"][field])
524 self.assertEqual(dumped, check)
525
526 def test_loading(self):
527 instance = list(serializers.deserialize("json", self.test_data))[0].object
528 self.assertEqual(instance.ints, NumericRange(0, 10))
529 self.assertEqual(instance.decimals, NumericRange(empty=True))
530 self.assertIsNone(instance.bigints)
531 self.assertEqual(instance.dates, DateRange(self.lower_date, self.upper_date))
532 self.assertEqual(
533 instance.timestamps, DateTimeTZRange(self.lower_dt, self.upper_dt)
534 )
535
536 def test_serialize_range_with_null(self):
537 instance = RangesModel(ints=NumericRange(None, 10))
538 data = serializers.serialize("json", [instance])
539 new_instance = list(serializers.deserialize("json", data))[0].object
540 self.assertEqual(new_instance.ints, NumericRange(None, 10))
541
542 instance = RangesModel(ints=NumericRange(10, None))
543 data = serializers.serialize("json", [instance])
544 new_instance = list(serializers.deserialize("json", data))[0].object
545 self.assertEqual(new_instance.ints, NumericRange(10, None))
546
547
548class TestChecks(PostgreSQLSimpleTestCase):
549 def test_choices_tuple_list(self):
550 class Model(PostgreSQLModel):
551 field = pg_fields.IntegerRangeField(
552 choices=[
553 ["1-50", [((1, 25), "1-25"), ([26, 50], "26-50")]],
554 ((51, 100), "51-100"),
555 ],
556 )
557
558 self.assertEqual(Model._meta.get_field("field").check(), [])
559
560
561class TestValidators(PostgreSQLSimpleTestCase):
562 def test_max(self):
563 validator = RangeMaxValueValidator(5)
564 validator(NumericRange(0, 5))
565 msg = "Ensure that this range is completely less than or equal to 5."
566 with self.assertRaises(exceptions.ValidationError) as cm:
567 validator(NumericRange(0, 10))
568 self.assertEqual(cm.exception.messages[0], msg)
569 self.assertEqual(cm.exception.code, "max_value")
570 with self.assertRaisesMessage(exceptions.ValidationError, msg):
571 validator(NumericRange(0, None)) # an unbound range
572
573 def test_min(self):
574 validator = RangeMinValueValidator(5)
575 validator(NumericRange(10, 15))
576 msg = "Ensure that this range is completely greater than or equal to 5."
577 with self.assertRaises(exceptions.ValidationError) as cm:
578 validator(NumericRange(0, 10))
579 self.assertEqual(cm.exception.messages[0], msg)
580 self.assertEqual(cm.exception.code, "min_value")
581 with self.assertRaisesMessage(exceptions.ValidationError, msg):
582 validator(NumericRange(None, 10)) # an unbound range
583
584
585class TestFormField(PostgreSQLSimpleTestCase):
586 def test_valid_integer(self):
587 field = pg_forms.IntegerRangeField()
588 value = field.clean(["1", "2"])
589 self.assertEqual(value, NumericRange(1, 2))
590
591 def test_valid_decimal(self):
592 field = pg_forms.DecimalRangeField()
593 value = field.clean(["1.12345", "2.001"])
594 self.assertEqual(value, NumericRange(Decimal("1.12345"), Decimal("2.001")))
595
596 def test_valid_timestamps(self):
597 field = pg_forms.DateTimeRangeField()
598 value = field.clean(["01/01/2014 00:00:00", "02/02/2014 12:12:12"])
599 lower = datetime.datetime(2014, 1, 1, 0, 0, 0)
600 upper = datetime.datetime(2014, 2, 2, 12, 12, 12)
601 self.assertEqual(value, DateTimeTZRange(lower, upper))
602
603 def test_valid_dates(self):
604 field = pg_forms.DateRangeField()
605 value = field.clean(["01/01/2014", "02/02/2014"])
606 lower = datetime.date(2014, 1, 1)
607 upper = datetime.date(2014, 2, 2)
608 self.assertEqual(value, DateRange(lower, upper))
609
610 def test_using_split_datetime_widget(self):
611 class SplitDateTimeRangeField(pg_forms.DateTimeRangeField):
612 base_field = forms.SplitDateTimeField
613
614 class SplitForm(forms.Form):
615 field = SplitDateTimeRangeField()
616
617 form = SplitForm()
618 self.assertHTMLEqual(
619 str(form),
620 """
621 <tr>
622 <th>
623 <label for="id_field_0">Field:</label>
624 </th>
625 <td>
626 <input id="id_field_0_0" name="field_0_0" type="text">
627 <input id="id_field_0_1" name="field_0_1" type="text">
628 <input id="id_field_1_0" name="field_1_0" type="text">
629 <input id="id_field_1_1" name="field_1_1" type="text">
630 </td>
631 </tr>
632 """,
633 )
634 form = SplitForm(
635 {
636 "field_0_0": "01/01/2014",
637 "field_0_1": "00:00:00",
638 "field_1_0": "02/02/2014",
639 "field_1_1": "12:12:12",
640 }
641 )
642 self.assertTrue(form.is_valid())
643 lower = datetime.datetime(2014, 1, 1, 0, 0, 0)
644 upper = datetime.datetime(2014, 2, 2, 12, 12, 12)
645 self.assertEqual(form.cleaned_data["field"], DateTimeTZRange(lower, upper))
646
647 def test_none(self):
648 field = pg_forms.IntegerRangeField(required=False)
649 value = field.clean(["", ""])
650 self.assertIsNone(value)
651
652 def test_datetime_form_as_table(self):
653 class DateTimeRangeForm(forms.Form):
654 datetime_field = pg_forms.DateTimeRangeField(show_hidden_initial=True)
655
656 form = DateTimeRangeForm()
657 self.assertHTMLEqual(
658 form.as_table(),
659 """
660 <tr><th>
661 <label for="id_datetime_field_0">Datetime field:</label>
662 </th><td>
663 <input type="text" name="datetime_field_0" id="id_datetime_field_0">
664 <input type="text" name="datetime_field_1" id="id_datetime_field_1">
665 <input type="hidden" name="initial-datetime_field_0" id="initial-id_datetime_field_0">
666 <input type="hidden" name="initial-datetime_field_1" id="initial-id_datetime_field_1">
667 </td></tr>
668 """,
669 )
670 form = DateTimeRangeForm(
671 {
672 "datetime_field_0": "2010-01-01 11:13:00",
673 "datetime_field_1": "2020-12-12 16:59:00",
674 }
675 )
676 self.assertHTMLEqual(
677 form.as_table(),
678 """
679 <tr><th>
680 <label for="id_datetime_field_0">Datetime field:</label>
681 </th><td>
682 <input type="text" name="datetime_field_0"
683 value="2010-01-01 11:13:00" id="id_datetime_field_0">
684 <input type="text" name="datetime_field_1"
685 value="2020-12-12 16:59:00" id="id_datetime_field_1">
686 <input type="hidden" name="initial-datetime_field_0" value="2010-01-01 11:13:00"
687 id="initial-id_datetime_field_0">
688 <input type="hidden" name="initial-datetime_field_1" value="2020-12-12 16:59:00"
689 id="initial-id_datetime_field_1"></td></tr>
690 """,
691 )
692
693 def test_datetime_form_initial_data(self):
694 class DateTimeRangeForm(forms.Form):
695 datetime_field = pg_forms.DateTimeRangeField(show_hidden_initial=True)
696
697 data = QueryDict(mutable=True)
698 data.update(
699 {
700 "datetime_field_0": "2010-01-01 11:13:00",
701 "datetime_field_1": "",
702 "initial-datetime_field_0": "2010-01-01 10:12:00",
703 "initial-datetime_field_1": "",
704 }
705 )
706 form = DateTimeRangeForm(data=data)
707 self.assertTrue(form.has_changed())
708
709 data["initial-datetime_field_0"] = "2010-01-01 11:13:00"
710 form = DateTimeRangeForm(data=data)
711 self.assertFalse(form.has_changed())
712
713 def test_rendering(self):
714 class RangeForm(forms.Form):
715 ints = pg_forms.IntegerRangeField()
716
717 self.assertHTMLEqual(
718 str(RangeForm()),
719 """
720 <tr>
721 <th><label for="id_ints_0">Ints:</label></th>
722 <td>
723 <input id="id_ints_0" name="ints_0" type="number">
724 <input id="id_ints_1" name="ints_1" type="number">
725 </td>
726 </tr>
727 """,
728 )
729
730 def test_integer_lower_bound_higher(self):
731 field = pg_forms.IntegerRangeField()
732 with self.assertRaises(exceptions.ValidationError) as cm:
733 field.clean(["10", "2"])
734 self.assertEqual(
735 cm.exception.messages[0],
736 "The start of the range must not exceed the end of the range.",
737 )
738 self.assertEqual(cm.exception.code, "bound_ordering")
739
740 def test_integer_open(self):
741 field = pg_forms.IntegerRangeField()
742 value = field.clean(["", "0"])
743 self.assertEqual(value, NumericRange(None, 0))
744
745 def test_integer_incorrect_data_type(self):
746 field = pg_forms.IntegerRangeField()
747 with self.assertRaises(exceptions.ValidationError) as cm:
748 field.clean("1")
749 self.assertEqual(cm.exception.messages[0], "Enter two whole numbers.")
750 self.assertEqual(cm.exception.code, "invalid")
751
752 def test_integer_invalid_lower(self):
753 field = pg_forms.IntegerRangeField()
754 with self.assertRaises(exceptions.ValidationError) as cm:
755 field.clean(["a", "2"])
756 self.assertEqual(cm.exception.messages[0], "Enter a whole number.")
757
758 def test_integer_invalid_upper(self):
759 field = pg_forms.IntegerRangeField()
760 with self.assertRaises(exceptions.ValidationError) as cm:
761 field.clean(["1", "b"])
762 self.assertEqual(cm.exception.messages[0], "Enter a whole number.")
763
764 def test_integer_required(self):
765 field = pg_forms.IntegerRangeField(required=True)
766 with self.assertRaises(exceptions.ValidationError) as cm:
767 field.clean(["", ""])
768 self.assertEqual(cm.exception.messages[0], "This field is required.")
769 value = field.clean([1, ""])
770 self.assertEqual(value, NumericRange(1, None))
771
772 def test_decimal_lower_bound_higher(self):
773 field = pg_forms.DecimalRangeField()
774 with self.assertRaises(exceptions.ValidationError) as cm:
775 field.clean(["1.8", "1.6"])
776 self.assertEqual(
777 cm.exception.messages[0],
778 "The start of the range must not exceed the end of the range.",
779 )
780 self.assertEqual(cm.exception.code, "bound_ordering")
781
782 def test_decimal_open(self):
783 field = pg_forms.DecimalRangeField()
784 value = field.clean(["", "3.1415926"])
785 self.assertEqual(value, NumericRange(None, Decimal("3.1415926")))
786
787 def test_decimal_incorrect_data_type(self):
788 field = pg_forms.DecimalRangeField()
789 with self.assertRaises(exceptions.ValidationError) as cm:
790 field.clean("1.6")
791 self.assertEqual(cm.exception.messages[0], "Enter two numbers.")
792 self.assertEqual(cm.exception.code, "invalid")
793
794 def test_decimal_invalid_lower(self):
795 field = pg_forms.DecimalRangeField()
796 with self.assertRaises(exceptions.ValidationError) as cm:
797 field.clean(["a", "3.1415926"])
798 self.assertEqual(cm.exception.messages[0], "Enter a number.")
799
800 def test_decimal_invalid_upper(self):
801 field = pg_forms.DecimalRangeField()
802 with self.assertRaises(exceptions.ValidationError) as cm:
803 field.clean(["1.61803399", "b"])
804 self.assertEqual(cm.exception.messages[0], "Enter a number.")
805
806 def test_decimal_required(self):
807 field = pg_forms.DecimalRangeField(required=True)
808 with self.assertRaises(exceptions.ValidationError) as cm:
809 field.clean(["", ""])
810 self.assertEqual(cm.exception.messages[0], "This field is required.")
811 value = field.clean(["1.61803399", ""])
812 self.assertEqual(value, NumericRange(Decimal("1.61803399"), None))
813
814 def test_date_lower_bound_higher(self):
815 field = pg_forms.DateRangeField()
816 with self.assertRaises(exceptions.ValidationError) as cm:
817 field.clean(["2013-04-09", "1976-04-16"])
818 self.assertEqual(
819 cm.exception.messages[0],
820 "The start of the range must not exceed the end of the range.",
821 )
822 self.assertEqual(cm.exception.code, "bound_ordering")
823
824 def test_date_open(self):
825 field = pg_forms.DateRangeField()
826 value = field.clean(["", "2013-04-09"])
827 self.assertEqual(value, DateRange(None, datetime.date(2013, 4, 9)))
828
829 def test_date_incorrect_data_type(self):
830 field = pg_forms.DateRangeField()
831 with self.assertRaises(exceptions.ValidationError) as cm:
832 field.clean("1")
833 self.assertEqual(cm.exception.messages[0], "Enter two valid dates.")
834 self.assertEqual(cm.exception.code, "invalid")
835
836 def test_date_invalid_lower(self):
837 field = pg_forms.DateRangeField()
838 with self.assertRaises(exceptions.ValidationError) as cm:
839 field.clean(["a", "2013-04-09"])
840 self.assertEqual(cm.exception.messages[0], "Enter a valid date.")
841
842 def test_date_invalid_upper(self):
843 field = pg_forms.DateRangeField()
844 with self.assertRaises(exceptions.ValidationError) as cm:
845 field.clean(["2013-04-09", "b"])
846 self.assertEqual(cm.exception.messages[0], "Enter a valid date.")
847
848 def test_date_required(self):
849 field = pg_forms.DateRangeField(required=True)
850 with self.assertRaises(exceptions.ValidationError) as cm:
851 field.clean(["", ""])
852 self.assertEqual(cm.exception.messages[0], "This field is required.")
853 value = field.clean(["1976-04-16", ""])
854 self.assertEqual(value, DateRange(datetime.date(1976, 4, 16), None))
855
856 def test_date_has_changed_first(self):
857 self.assertTrue(
858 pg_forms.DateRangeField().has_changed(
859 ["2010-01-01", "2020-12-12"],
860 ["2010-01-31", "2020-12-12"],
861 )
862 )
863
864 def test_date_has_changed_last(self):
865 self.assertTrue(
866 pg_forms.DateRangeField().has_changed(
867 ["2010-01-01", "2020-12-12"],
868 ["2010-01-01", "2020-12-31"],
869 )
870 )
871
872 def test_datetime_lower_bound_higher(self):
873 field = pg_forms.DateTimeRangeField()
874 with self.assertRaises(exceptions.ValidationError) as cm:
875 field.clean(["2006-10-25 14:59", "2006-10-25 14:58"])
876 self.assertEqual(
877 cm.exception.messages[0],
878 "The start of the range must not exceed the end of the range.",
879 )
880 self.assertEqual(cm.exception.code, "bound_ordering")
881
882 def test_datetime_open(self):
883 field = pg_forms.DateTimeRangeField()
884 value = field.clean(["", "2013-04-09 11:45"])
885 self.assertEqual(
886 value, DateTimeTZRange(None, datetime.datetime(2013, 4, 9, 11, 45))
887 )
888
889 def test_datetime_incorrect_data_type(self):
890 field = pg_forms.DateTimeRangeField()
891 with self.assertRaises(exceptions.ValidationError) as cm:
892 field.clean("2013-04-09 11:45")
893 self.assertEqual(cm.exception.messages[0], "Enter two valid date/times.")
894 self.assertEqual(cm.exception.code, "invalid")
895
896 def test_datetime_invalid_lower(self):
897 field = pg_forms.DateTimeRangeField()
898 with self.assertRaises(exceptions.ValidationError) as cm:
899 field.clean(["45", "2013-04-09 11:45"])
900 self.assertEqual(cm.exception.messages[0], "Enter a valid date/time.")
901
902 def test_datetime_invalid_upper(self):
903 field = pg_forms.DateTimeRangeField()
904 with self.assertRaises(exceptions.ValidationError) as cm:
905 field.clean(["2013-04-09 11:45", "sweet pickles"])
906 self.assertEqual(cm.exception.messages[0], "Enter a valid date/time.")
907
908 def test_datetime_required(self):
909 field = pg_forms.DateTimeRangeField(required=True)
910 with self.assertRaises(exceptions.ValidationError) as cm:
911 field.clean(["", ""])
912 self.assertEqual(cm.exception.messages[0], "This field is required.")
913 value = field.clean(["2013-04-09 11:45", ""])
914 self.assertEqual(
915 value, DateTimeTZRange(datetime.datetime(2013, 4, 9, 11, 45), None)
916 )
917
918 @override_settings(USE_TZ=True, TIME_ZONE="Africa/Johannesburg")
919 def test_datetime_prepare_value(self):
920 field = pg_forms.DateTimeRangeField()
921 value = field.prepare_value(
922 DateTimeTZRange(
923 datetime.datetime(2015, 5, 22, 16, 6, 33, tzinfo=timezone.utc), None
924 )
925 )
926 self.assertEqual(value, [datetime.datetime(2015, 5, 22, 18, 6, 33), None])
927
928 def test_datetime_has_changed_first(self):
929 self.assertTrue(
930 pg_forms.DateTimeRangeField().has_changed(
931 ["2010-01-01 00:00", "2020-12-12 00:00"],
932 ["2010-01-31 23:00", "2020-12-12 00:00"],
933 )
934 )
935
936 def test_datetime_has_changed_last(self):
937 self.assertTrue(
938 pg_forms.DateTimeRangeField().has_changed(
939 ["2010-01-01 00:00", "2020-12-12 00:00"],
940 ["2010-01-01 00:00", "2020-12-31 23:00"],
941 )
942 )
943
944 def test_model_field_formfield_integer(self):
945 model_field = pg_fields.IntegerRangeField()
946 form_field = model_field.formfield()
947 self.assertIsInstance(form_field, pg_forms.IntegerRangeField)
948
949 def test_model_field_formfield_biginteger(self):
950 model_field = pg_fields.BigIntegerRangeField()
951 form_field = model_field.formfield()
952 self.assertIsInstance(form_field, pg_forms.IntegerRangeField)
953
954 def test_model_field_formfield_float(self):
955 model_field = pg_fields.DecimalRangeField()
956 form_field = model_field.formfield()
957 self.assertIsInstance(form_field, pg_forms.DecimalRangeField)
958
959 def test_model_field_formfield_date(self):
960 model_field = pg_fields.DateRangeField()
961 form_field = model_field.formfield()
962 self.assertIsInstance(form_field, pg_forms.DateRangeField)
963
964 def test_model_field_formfield_datetime(self):
965 model_field = pg_fields.DateTimeRangeField()
966 form_field = model_field.formfield()
967 self.assertIsInstance(form_field, pg_forms.DateTimeRangeField)
968
969 def test_has_changed(self):
970 for field, value in (
971 (pg_forms.DateRangeField(), ["2010-01-01", "2020-12-12"]),
972 (pg_forms.DateTimeRangeField(), ["2010-01-01 11:13", "2020-12-12 14:52"]),
973 (pg_forms.IntegerRangeField(), [1, 2]),
974 (pg_forms.DecimalRangeField(), ["1.12345", "2.001"]),
975 ):
976 with self.subTest(field=field.__class__.__name__):
977 self.assertTrue(field.has_changed(None, value))
978 self.assertTrue(field.has_changed([value[0], ""], value))
979 self.assertTrue(field.has_changed(["", value[1]], value))
980 self.assertFalse(field.has_changed(value, value))
981
982
983class TestWidget(PostgreSQLSimpleTestCase):
984 def test_range_widget(self):
985 f = pg_forms.ranges.DateTimeRangeField()
986 self.assertHTMLEqual(
987 f.widget.render("datetimerange", ""),
988 '<input type="text" name="datetimerange_0"><input type="text" name="datetimerange_1">',
989 )
990 self.assertHTMLEqual(
991 f.widget.render("datetimerange", None),
992 '<input type="text" name="datetimerange_0"><input type="text" name="datetimerange_1">',
993 )
994 dt_range = DateTimeTZRange(
995 datetime.datetime(2006, 1, 10, 7, 30), datetime.datetime(2006, 2, 12, 9, 50)
996 )
997 self.assertHTMLEqual(
998 f.widget.render("datetimerange", dt_range),
999 '<input type="text" name="datetimerange_0" value="2006-01-10 07:30:00">'
1000 '<input type="text" name="datetimerange_1" value="2006-02-12 09:50:00">',
1001 )
test_constraints.py ¶
1import datetime
2from unittest import mock
3
4from django.db import connection
5from django.db import IntegrityError
6from django.db import transaction
7from django.db.models import CheckConstraint
8from django.db.models import F
9from django.db.models import Func
10from django.db.models import Q
11from django.utils import timezone
12
13from . import PostgreSQLTestCase
14from .models import HotelReservation
15from .models import RangesModel
16from .models import Room
17
18try:
19 from django.contrib.postgres.constraints import ExclusionConstraint
20 from django.contrib.postgres.fields import (
21 DateTimeRangeField,
22 RangeBoundary,
23 RangeOperators,
24 )
25
26 from psycopg2.extras import DateRange, NumericRange
27except ImportError:
28 pass
29
30
31class SchemaTests(PostgreSQLTestCase):
32 def get_constraints(self, table):
33 """Get the constraints on the table using a new cursor."""
34 with connection.cursor() as cursor:
35 return connection.introspection.get_constraints(cursor, table)
36
37 def test_check_constraint_range_value(self):
38 constraint_name = "ints_between"
39 self.assertNotIn(
40 constraint_name, self.get_constraints(RangesModel._meta.db_table)
41 )
42 constraint = CheckConstraint(
43 check=Q(ints__contained_by=NumericRange(10, 30)),
44 name=constraint_name,
45 )
46 with connection.schema_editor() as editor:
47 editor.add_constraint(RangesModel, constraint)
48 self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
49 with self.assertRaises(IntegrityError), transaction.atomic():
50 RangesModel.objects.create(ints=(20, 50))
51 RangesModel.objects.create(ints=(10, 30))
52
53 def test_check_constraint_daterange_contains(self):
54 constraint_name = "dates_contains"
55 self.assertNotIn(
56 constraint_name, self.get_constraints(RangesModel._meta.db_table)
57 )
58 constraint = CheckConstraint(
59 check=Q(dates__contains=F("dates_inner")),
60 name=constraint_name,
61 )
62 with connection.schema_editor() as editor:
63 editor.add_constraint(RangesModel, constraint)
64 self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
65 date_1 = datetime.date(2016, 1, 1)
66 date_2 = datetime.date(2016, 1, 4)
67 with self.assertRaises(IntegrityError), transaction.atomic():
68 RangesModel.objects.create(
69 dates=(date_1, date_2),
70 dates_inner=(date_1, date_2.replace(day=5)),
71 )
72 RangesModel.objects.create(
73 dates=(date_1, date_2),
74 dates_inner=(date_1, date_2),
75 )
76
77 def test_check_constraint_datetimerange_contains(self):
78 constraint_name = "timestamps_contains"
79 self.assertNotIn(
80 constraint_name, self.get_constraints(RangesModel._meta.db_table)
81 )
82 constraint = CheckConstraint(
83 check=Q(timestamps__contains=F("timestamps_inner")),
84 name=constraint_name,
85 )
86 with connection.schema_editor() as editor:
87 editor.add_constraint(RangesModel, constraint)
88 self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
89 datetime_1 = datetime.datetime(2016, 1, 1)
90 datetime_2 = datetime.datetime(2016, 1, 2, 12)
91 with self.assertRaises(IntegrityError), transaction.atomic():
92 RangesModel.objects.create(
93 timestamps=(datetime_1, datetime_2),
94 timestamps_inner=(datetime_1, datetime_2.replace(hour=13)),
95 )
96 RangesModel.objects.create(
97 timestamps=(datetime_1, datetime_2),
98 timestamps_inner=(datetime_1, datetime_2),
99 )
100
101
102class ExclusionConstraintTests(PostgreSQLTestCase):
103 def get_constraints(self, table):
104 """Get the constraints on the table using a new cursor."""
105 with connection.cursor() as cursor:
106 return connection.introspection.get_constraints(cursor, table)
107
108 def test_invalid_condition(self):
109 msg = "ExclusionConstraint.condition must be a Q instance."
110 with self.assertRaisesMessage(ValueError, msg):
111 ExclusionConstraint(
112 index_type="GIST",
113 name="exclude_invalid_condition",
114 expressions=[(F("datespan"), RangeOperators.OVERLAPS)],
115 condition=F("invalid"),
116 )
117
118 def test_invalid_index_type(self):
119 msg = "Exclusion constraints only support GiST or SP-GiST indexes."
120 with self.assertRaisesMessage(ValueError, msg):
121 ExclusionConstraint(
122 index_type="gin",
123 name="exclude_invalid_index_type",
124 expressions=[(F("datespan"), RangeOperators.OVERLAPS)],
125 )
126
127 def test_invalid_expressions(self):
128 msg = "The expressions must be a list of 2-tuples."
129 for expressions in (["foo"], [("foo")], [("foo_1", "foo_2", "foo_3")]):
130 with self.subTest(expressions), self.assertRaisesMessage(ValueError, msg):
131 ExclusionConstraint(
132 index_type="GIST",
133 name="exclude_invalid_expressions",
134 expressions=expressions,
135 )
136
137 def test_empty_expressions(self):
138 msg = "At least one expression is required to define an exclusion constraint."
139 for empty_expressions in (None, []):
140 with self.subTest(empty_expressions), self.assertRaisesMessage(
141 ValueError, msg
142 ):
143 ExclusionConstraint(
144 index_type="GIST",
145 name="exclude_empty_expressions",
146 expressions=empty_expressions,
147 )
148
149 def test_repr(self):
150 constraint = ExclusionConstraint(
151 name="exclude_overlapping",
152 expressions=[
153 (F("datespan"), RangeOperators.OVERLAPS),
154 (F("room"), RangeOperators.EQUAL),
155 ],
156 )
157 self.assertEqual(
158 repr(constraint),
159 "<ExclusionConstraint: index_type=GIST, expressions=["
160 "(F(datespan), '&&'), (F(room), '=')]>",
161 )
162 constraint = ExclusionConstraint(
163 name="exclude_overlapping",
164 expressions=[(F("datespan"), RangeOperators.ADJACENT_TO)],
165 condition=Q(cancelled=False),
166 index_type="SPGiST",
167 )
168 self.assertEqual(
169 repr(constraint),
170 "<ExclusionConstraint: index_type=SPGiST, expressions=["
171 "(F(datespan), '-|-')], condition=(AND: ('cancelled', False))>",
172 )
173
174 def test_eq(self):
175 constraint_1 = ExclusionConstraint(
176 name="exclude_overlapping",
177 expressions=[
178 (F("datespan"), RangeOperators.OVERLAPS),
179 (F("room"), RangeOperators.EQUAL),
180 ],
181 condition=Q(cancelled=False),
182 )
183 constraint_2 = ExclusionConstraint(
184 name="exclude_overlapping",
185 expressions=[
186 ("datespan", RangeOperators.OVERLAPS),
187 ("room", RangeOperators.EQUAL),
188 ],
189 )
190 constraint_3 = ExclusionConstraint(
191 name="exclude_overlapping",
192 expressions=[("datespan", RangeOperators.OVERLAPS)],
193 condition=Q(cancelled=False),
194 )
195 self.assertEqual(constraint_1, constraint_1)
196 self.assertEqual(constraint_1, mock.ANY)
197 self.assertNotEqual(constraint_1, constraint_2)
198 self.assertNotEqual(constraint_1, constraint_3)
199 self.assertNotEqual(constraint_2, constraint_3)
200 self.assertNotEqual(constraint_1, object())
201
202 def test_deconstruct(self):
203 constraint = ExclusionConstraint(
204 name="exclude_overlapping",
205 expressions=[
206 ("datespan", RangeOperators.OVERLAPS),
207 ("room", RangeOperators.EQUAL),
208 ],
209 )
210 path, args, kwargs = constraint.deconstruct()
211 self.assertEqual(
212 path, "django.contrib.postgres.constraints.ExclusionConstraint"
213 )
214 self.assertEqual(args, ())
215 self.assertEqual(
216 kwargs,
217 {
218 "name": "exclude_overlapping",
219 "expressions": [
220 ("datespan", RangeOperators.OVERLAPS),
221 ("room", RangeOperators.EQUAL),
222 ],
223 },
224 )
225
226 def test_deconstruct_index_type(self):
227 constraint = ExclusionConstraint(
228 name="exclude_overlapping",
229 index_type="SPGIST",
230 expressions=[
231 ("datespan", RangeOperators.OVERLAPS),
232 ("room", RangeOperators.EQUAL),
233 ],
234 )
235 path, args, kwargs = constraint.deconstruct()
236 self.assertEqual(
237 path, "django.contrib.postgres.constraints.ExclusionConstraint"
238 )
239 self.assertEqual(args, ())
240 self.assertEqual(
241 kwargs,
242 {
243 "name": "exclude_overlapping",
244 "index_type": "SPGIST",
245 "expressions": [
246 ("datespan", RangeOperators.OVERLAPS),
247 ("room", RangeOperators.EQUAL),
248 ],
249 },
250 )
251
252 def test_deconstruct_condition(self):
253 constraint = ExclusionConstraint(
254 name="exclude_overlapping",
255 expressions=[
256 ("datespan", RangeOperators.OVERLAPS),
257 ("room", RangeOperators.EQUAL),
258 ],
259 condition=Q(cancelled=False),
260 )
261 path, args, kwargs = constraint.deconstruct()
262 self.assertEqual(
263 path, "django.contrib.postgres.constraints.ExclusionConstraint"
264 )
265 self.assertEqual(args, ())
266 self.assertEqual(
267 kwargs,
268 {
269 "name": "exclude_overlapping",
270 "expressions": [
271 ("datespan", RangeOperators.OVERLAPS),
272 ("room", RangeOperators.EQUAL),
273 ],
274 "condition": Q(cancelled=False),
275 },
276 )
277
278 def _test_range_overlaps(self, constraint):
279 # Create exclusion constraint.
280 self.assertNotIn(
281 constraint.name, self.get_constraints(HotelReservation._meta.db_table)
282 )
283 with connection.schema_editor() as editor:
284 editor.add_constraint(HotelReservation, constraint)
285 self.assertIn(
286 constraint.name, self.get_constraints(HotelReservation._meta.db_table)
287 )
288 # Add initial reservations.
289 room101 = Room.objects.create(number=101)
290 room102 = Room.objects.create(number=102)
291 datetimes = [
292 timezone.datetime(2018, 6, 20),
293 timezone.datetime(2018, 6, 24),
294 timezone.datetime(2018, 6, 26),
295 timezone.datetime(2018, 6, 28),
296 timezone.datetime(2018, 6, 29),
297 ]
298 HotelReservation.objects.create(
299 datespan=DateRange(datetimes[0].date(), datetimes[1].date()),
300 start=datetimes[0],
301 end=datetimes[1],
302 room=room102,
303 )
304 HotelReservation.objects.create(
305 datespan=DateRange(datetimes[1].date(), datetimes[3].date()),
306 start=datetimes[1],
307 end=datetimes[3],
308 room=room102,
309 )
310 # Overlap dates.
311 with self.assertRaises(IntegrityError), transaction.atomic():
312 reservation = HotelReservation(
313 datespan=(datetimes[1].date(), datetimes[2].date()),
314 start=datetimes[1],
315 end=datetimes[2],
316 room=room102,
317 )
318 reservation.save()
319 # Valid range.
320 HotelReservation.objects.bulk_create(
321 [
322 # Other room.
323 HotelReservation(
324 datespan=(datetimes[1].date(), datetimes[2].date()),
325 start=datetimes[1],
326 end=datetimes[2],
327 room=room101,
328 ),
329 # Cancelled reservation.
330 HotelReservation(
331 datespan=(datetimes[1].date(), datetimes[1].date()),
332 start=datetimes[1],
333 end=datetimes[2],
334 room=room102,
335 cancelled=True,
336 ),
337 # Other adjacent dates.
338 HotelReservation(
339 datespan=(datetimes[3].date(), datetimes[4].date()),
340 start=datetimes[3],
341 end=datetimes[4],
342 room=room102,
343 ),
344 ]
345 )
346
347 def test_range_overlaps_custom(self):
348 class TsTzRange(Func):
349 function = "TSTZRANGE"
350 output_field = DateTimeRangeField()
351
352 constraint = ExclusionConstraint(
353 name="exclude_overlapping_reservations_custom",
354 expressions=[
355 (TsTzRange("start", "end", RangeBoundary()), RangeOperators.OVERLAPS),
356 ("room", RangeOperators.EQUAL),
357 ],
358 condition=Q(cancelled=False),
359 )
360 self._test_range_overlaps(constraint)
361
362 def test_range_overlaps(self):
363 constraint = ExclusionConstraint(
364 name="exclude_overlapping_reservations",
365 expressions=[
366 (F("datespan"), RangeOperators.OVERLAPS),
367 ("room", RangeOperators.EQUAL),
368 ],
369 condition=Q(cancelled=False),
370 )
371 self._test_range_overlaps(constraint)
372
373 def test_range_adjacent(self):
374 constraint_name = "ints_adjacent"
375 self.assertNotIn(
376 constraint_name, self.get_constraints(RangesModel._meta.db_table)
377 )
378 constraint = ExclusionConstraint(
379 name=constraint_name,
380 expressions=[("ints", RangeOperators.ADJACENT_TO)],
381 )
382 with connection.schema_editor() as editor:
383 editor.add_constraint(RangesModel, constraint)
384 self.assertIn(constraint_name, self.get_constraints(RangesModel._meta.db_table))
385 RangesModel.objects.create(ints=(20, 50))
386 with self.assertRaises(IntegrityError), transaction.atomic():
387 RangesModel.objects.create(ints=(10, 20))
388 RangesModel.objects.create(ints=(10, 19))
389 RangesModel.objects.create(ints=(51, 60))