Tests ArrayField ¶
Contents
__init__.py ¶
1import unittest
2
3from django.db import connection
4from django.test import modify_settings
5from django.test import SimpleTestCase
6from django.test import TestCase
7from forms_tests.widget_tests.base import WidgetTest
8
9
10@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
11class PostgreSQLSimpleTestCase(SimpleTestCase):
12 pass
13
14
15@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
16class PostgreSQLTestCase(TestCase):
17 pass
18
19
20@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
21# To locate the widget's template.
22@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
23class PostgreSQLWidgetTestCase(WidgetTest, PostgreSQLSimpleTestCase):
24 pass
fields.py ¶
1"""
2Indirection layer for PostgreSQL-specific fields, so the tests don't fail when
3run with a backend other than PostgreSQL.
4"""
5import enum
6
7from django.db import models
8
9try:
10 from django.contrib.postgres.fields import (
11 ArrayField,
12 BigIntegerRangeField,
13 CICharField,
14 CIEmailField,
15 CITextField,
16 DateRangeField,
17 DateTimeRangeField,
18 DecimalRangeField,
19 HStoreField,
20 IntegerRangeField,
21 JSONField,
22 )
23 from django.contrib.postgres.search import SearchVectorField
24except ImportError:
25
26 class DummyArrayField(models.Field):
27 def __init__(self, base_field, size=None, **kwargs):
28 super().__init__(**kwargs)
29
30 def deconstruct(self):
31 name, path, args, kwargs = super().deconstruct()
32 kwargs.update(
33 {
34 "base_field": "",
35 "size": 1,
36 }
37 )
38 return name, path, args, kwargs
39
40 class DummyJSONField(models.Field):
41 def __init__(self, encoder=None, **kwargs):
42 super().__init__(**kwargs)
43
44 ArrayField = DummyArrayField
45 BigIntegerRangeField = models.Field
46 CICharField = models.Field
47 CIEmailField = models.Field
48 CITextField = models.Field
49 DateRangeField = models.Field
50 DateTimeRangeField = models.Field
51 DecimalRangeField = models.Field
52 HStoreField = models.Field
53 IntegerRangeField = models.Field
54 JSONField = DummyJSONField
55 SearchVectorField = models.Field
56
57
58class EnumField(models.CharField):
59 def get_prep_value(self, value):
60 return value.value if isinstance(value, enum.Enum) else value
models.py ¶
1from django.core.serializers.json import DjangoJSONEncoder
2from django.db import models
3
4from .fields import ArrayField
5from .fields import BigIntegerRangeField
6from .fields import CICharField
7from .fields import CIEmailField
8from .fields import CITextField
9from .fields import DateRangeField
10from .fields import DateTimeRangeField
11from .fields import DecimalRangeField
12from .fields import EnumField
13from .fields import HStoreField
14from .fields import IntegerRangeField
15from .fields import JSONField
16from .fields import SearchVectorField
17
18
19class Tag:
20 def __init__(self, tag_id):
21 self.tag_id = tag_id
22
23 def __eq__(self, other):
24 return isinstance(other, Tag) and self.tag_id == other.tag_id
25
26
27class TagField(models.SmallIntegerField):
28 def from_db_value(self, value, expression, connection):
29 if value is None:
30 return value
31 return Tag(int(value))
32
33 def to_python(self, value):
34 if isinstance(value, Tag):
35 return value
36 if value is None:
37 return value
38 return Tag(int(value))
39
40 def get_prep_value(self, value):
41 return value.tag_id
42
43
44class PostgreSQLModel(models.Model):
45 class Meta:
46 abstract = True
47 required_db_vendor = "postgresql"
48
49
50class IntegerArrayModel(PostgreSQLModel):
51 field = ArrayField(models.IntegerField(), default=list, blank=True)
52
53
54class NullableIntegerArrayModel(PostgreSQLModel):
55 field = ArrayField(models.IntegerField(), blank=True, null=True)
56 field_nested = ArrayField(ArrayField(models.IntegerField(null=True)), null=True)
57
58
59class CharArrayModel(PostgreSQLModel):
60 field = ArrayField(models.CharField(max_length=10))
61
62
63class DateTimeArrayModel(PostgreSQLModel):
64 datetimes = ArrayField(models.DateTimeField())
65 dates = ArrayField(models.DateField())
66 times = ArrayField(models.TimeField())
67
68
69class NestedIntegerArrayModel(PostgreSQLModel):
70 field = ArrayField(ArrayField(models.IntegerField()))
71
72
73class OtherTypesArrayModel(PostgreSQLModel):
74 ips = ArrayField(models.GenericIPAddressField(), default=list)
75 uuids = ArrayField(models.UUIDField(), default=list)
76 decimals = ArrayField(
77 models.DecimalField(max_digits=5, decimal_places=2), default=list
78 )
79 tags = ArrayField(TagField(), blank=True, null=True)
80 json = ArrayField(JSONField(default=dict), default=list)
81 int_ranges = ArrayField(IntegerRangeField(), blank=True, null=True)
82 bigint_ranges = ArrayField(BigIntegerRangeField(), blank=True, null=True)
83
84
85class HStoreModel(PostgreSQLModel):
86 field = HStoreField(blank=True, null=True)
87 array_field = ArrayField(HStoreField(), null=True)
88
89
90class ArrayEnumModel(PostgreSQLModel):
91 array_of_enums = ArrayField(EnumField(max_length=20))
92
93
94class CharFieldModel(models.Model):
95 field = models.CharField(max_length=16)
96
97
98class TextFieldModel(models.Model):
99 field = models.TextField()
100
101 def __str__(self):
102 return self.field
103
104
105class SmallAutoFieldModel(models.Model):
106 id = models.SmallAutoField(primary_key=True)
107
108
109class BigAutoFieldModel(models.Model):
110 id = models.BigAutoField(primary_key=True)
111
112
113# Scene/Character/Line models are used to test full text search. They're
114# populated with content from Monty Python and the Holy Grail.
115class Scene(models.Model):
116 scene = models.CharField(max_length=255)
117 setting = models.CharField(max_length=255)
118
119 def __str__(self):
120 return self.scene
121
122
123class Character(models.Model):
124 name = models.CharField(max_length=255)
125
126 def __str__(self):
127 return self.name
128
129
130class CITestModel(PostgreSQLModel):
131 name = CICharField(primary_key=True, max_length=255)
132 email = CIEmailField()
133 description = CITextField()
134 array_field = ArrayField(CITextField(), null=True)
135
136 def __str__(self):
137 return self.name
138
139
140class Line(PostgreSQLModel):
141 scene = models.ForeignKey("Scene", models.CASCADE)
142 character = models.ForeignKey("Character", models.CASCADE)
143 dialogue = models.TextField(blank=True, null=True)
144 dialogue_search_vector = SearchVectorField(blank=True, null=True)
145 dialogue_config = models.CharField(max_length=100, blank=True, null=True)
146
147 def __str__(self):
148 return self.dialogue or ""
149
150
151class RangesModel(PostgreSQLModel):
152 ints = IntegerRangeField(blank=True, null=True)
153 bigints = BigIntegerRangeField(blank=True, null=True)
154 decimals = DecimalRangeField(blank=True, null=True)
155 timestamps = DateTimeRangeField(blank=True, null=True)
156 timestamps_inner = DateTimeRangeField(blank=True, null=True)
157 dates = DateRangeField(blank=True, null=True)
158 dates_inner = DateRangeField(blank=True, null=True)
159
160
161class RangeLookupsModel(PostgreSQLModel):
162 parent = models.ForeignKey(RangesModel, models.SET_NULL, blank=True, null=True)
163 integer = models.IntegerField(blank=True, null=True)
164 big_integer = models.BigIntegerField(blank=True, null=True)
165 float = models.FloatField(blank=True, null=True)
166 timestamp = models.DateTimeField(blank=True, null=True)
167 date = models.DateField(blank=True, null=True)
168 small_integer = models.SmallIntegerField(blank=True, null=True)
169 decimal_field = models.DecimalField(
170 max_digits=5, decimal_places=2, blank=True, null=True
171 )
172
173
174class JSONModel(PostgreSQLModel):
175 field = JSONField(blank=True, null=True)
176 field_custom = JSONField(blank=True, null=True, encoder=DjangoJSONEncoder)
177
178
179class ArrayFieldSubclass(ArrayField):
180 def __init__(self, *args, **kwargs):
181 super().__init__(models.IntegerField())
182
183
184class AggregateTestModel(models.Model):
185 """
186 To test postgres-specific general aggregation functions
187 """
188
189 char_field = models.CharField(max_length=30, blank=True)
190 integer_field = models.IntegerField(null=True)
191 boolean_field = models.BooleanField(null=True)
192
193
194class StatTestModel(models.Model):
195 """
196 To test postgres-specific aggregation functions for statistics
197 """
198
199 int1 = models.IntegerField()
200 int2 = models.IntegerField()
201 related_field = models.ForeignKey(AggregateTestModel, models.SET_NULL, null=True)
202
203
204class NowTestModel(models.Model):
205 when = models.DateTimeField(null=True, default=None)
206
207
208class UUIDTestModel(models.Model):
209 uuid = models.UUIDField(default=None, null=True)
210
211
212class Room(models.Model):
213 number = models.IntegerField(unique=True)
214
215
216class HotelReservation(PostgreSQLModel):
217 room = models.ForeignKey("Room", on_delete=models.CASCADE)
218 datespan = DateRangeField()
219 start = models.DateTimeField()
220 end = models.DateTimeField()
221 cancelled = models.BooleanField(default=False)
test_array.py ¶
1import decimal
2import enum
3import json
4import unittest
5import uuid
6
7from django import forms
8from django.core import checks
9from django.core import exceptions
10from django.core import serializers
11from django.core import validators
12from django.core.exceptions import FieldError
13from django.core.management import call_command
14from django.db import connection
15from django.db import IntegrityError
16from django.db import models
17from django.db.models.expressions import RawSQL
18from django.db.models.functions import Cast
19from django.test import modify_settings
20from django.test import override_settings
21from django.test import TransactionTestCase
22from django.test.utils import isolate_apps
23from django.utils import timezone
24
25from . import PostgreSQLSimpleTestCase
26from . import PostgreSQLTestCase
27from . import PostgreSQLWidgetTestCase
28from .models import ArrayEnumModel
29from .models import ArrayFieldSubclass
30from .models import CharArrayModel
31from .models import DateTimeArrayModel
32from .models import IntegerArrayModel
33from .models import NestedIntegerArrayModel
34from .models import NullableIntegerArrayModel
35from .models import OtherTypesArrayModel
36from .models import PostgreSQLModel
37from .models import Tag
38
39try:
40 from django.contrib.postgres.aggregates import ArrayAgg
41 from django.contrib.postgres.fields import ArrayField
42 from django.contrib.postgres.fields.array import IndexTransform, SliceTransform
43 from django.contrib.postgres.forms import (
44 SimpleArrayField,
45 SplitArrayField,
46 SplitArrayWidget,
47 )
48 from django.db.backends.postgresql.base import PSYCOPG2_VERSION
49 from psycopg2.extras import NumericRange
50except ImportError:
51 pass
52
53
54@isolate_apps("postgres_tests")
55class BasicTests(PostgreSQLSimpleTestCase):
56 def test_get_field_display(self):
57 class MyModel(PostgreSQLModel):
58 field = ArrayField(
59 models.CharField(max_length=16),
60 choices=[
61 ["Media", [(["vinyl", "cd"], "Audio")]],
62 (("mp3", "mp4"), "Digital"),
63 ],
64 )
65
66 tests = (
67 (["vinyl", "cd"], "Audio"),
68 (("mp3", "mp4"), "Digital"),
69 (("a", "b"), "('a', 'b')"),
70 (["c", "d"], "['c', 'd']"),
71 )
72 for value, display in tests:
73 with self.subTest(value=value, display=display):
74 instance = MyModel(field=value)
75 self.assertEqual(instance.get_field_display(), display)
76
77 def test_get_field_display_nested_array(self):
78 class MyModel(PostgreSQLModel):
79 field = ArrayField(
80 ArrayField(models.CharField(max_length=16)),
81 choices=[
82 [
83 "Media",
84 [([["vinyl", "cd"], ("x",)], "Audio")],
85 ],
86 ((["mp3"], ("mp4",)), "Digital"),
87 ],
88 )
89
90 tests = (
91 ([["vinyl", "cd"], ("x",)], "Audio"),
92 ((["mp3"], ("mp4",)), "Digital"),
93 ((("a", "b"), ("c",)), "(('a', 'b'), ('c',))"),
94 ([["a", "b"], ["c"]], "[['a', 'b'], ['c']]"),
95 )
96 for value, display in tests:
97 with self.subTest(value=value, display=display):
98 instance = MyModel(field=value)
99 self.assertEqual(instance.get_field_display(), display)
100
101
102class TestSaveLoad(PostgreSQLTestCase):
103 def test_integer(self):
104 instance = IntegerArrayModel(field=[1, 2, 3])
105 instance.save()
106 loaded = IntegerArrayModel.objects.get()
107 self.assertEqual(instance.field, loaded.field)
108
109 def test_char(self):
110 instance = CharArrayModel(field=["hello", "goodbye"])
111 instance.save()
112 loaded = CharArrayModel.objects.get()
113 self.assertEqual(instance.field, loaded.field)
114
115 def test_dates(self):
116 instance = DateTimeArrayModel(
117 datetimes=[timezone.now()],
118 dates=[timezone.now().date()],
119 times=[timezone.now().time()],
120 )
121 instance.save()
122 loaded = DateTimeArrayModel.objects.get()
123 self.assertEqual(instance.datetimes, loaded.datetimes)
124 self.assertEqual(instance.dates, loaded.dates)
125 self.assertEqual(instance.times, loaded.times)
126
127 def test_tuples(self):
128 instance = IntegerArrayModel(field=(1,))
129 instance.save()
130 loaded = IntegerArrayModel.objects.get()
131 self.assertSequenceEqual(instance.field, loaded.field)
132
133 def test_integers_passed_as_strings(self):
134 # This checks that get_prep_value is deferred properly
135 instance = IntegerArrayModel(field=["1"])
136 instance.save()
137 loaded = IntegerArrayModel.objects.get()
138 self.assertEqual(loaded.field, [1])
139
140 def test_default_null(self):
141 instance = NullableIntegerArrayModel()
142 instance.save()
143 loaded = NullableIntegerArrayModel.objects.get(pk=instance.pk)
144 self.assertIsNone(loaded.field)
145 self.assertEqual(instance.field, loaded.field)
146
147 def test_null_handling(self):
148 instance = NullableIntegerArrayModel(field=None)
149 instance.save()
150 loaded = NullableIntegerArrayModel.objects.get()
151 self.assertEqual(instance.field, loaded.field)
152
153 instance = IntegerArrayModel(field=None)
154 with self.assertRaises(IntegrityError):
155 instance.save()
156
157 def test_nested(self):
158 instance = NestedIntegerArrayModel(field=[[1, 2], [3, 4]])
159 instance.save()
160 loaded = NestedIntegerArrayModel.objects.get()
161 self.assertEqual(instance.field, loaded.field)
162
163 def test_other_array_types(self):
164 instance = OtherTypesArrayModel(
165 ips=["192.168.0.1", "::1"],
166 uuids=[uuid.uuid4()],
167 decimals=[decimal.Decimal(1.25), 1.75],
168 tags=[Tag(1), Tag(2), Tag(3)],
169 json=[{"a": 1}, {"b": 2}],
170 int_ranges=[NumericRange(10, 20), NumericRange(30, 40)],
171 bigint_ranges=[
172 NumericRange(7000000000, 10000000000),
173 NumericRange(50000000000, 70000000000),
174 ],
175 )
176 instance.save()
177 loaded = OtherTypesArrayModel.objects.get()
178 self.assertEqual(instance.ips, loaded.ips)
179 self.assertEqual(instance.uuids, loaded.uuids)
180 self.assertEqual(instance.decimals, loaded.decimals)
181 self.assertEqual(instance.tags, loaded.tags)
182 self.assertEqual(instance.json, loaded.json)
183 self.assertEqual(instance.int_ranges, loaded.int_ranges)
184 self.assertEqual(instance.bigint_ranges, loaded.bigint_ranges)
185
186 def test_null_from_db_value_handling(self):
187 instance = OtherTypesArrayModel.objects.create(
188 ips=["192.168.0.1", "::1"],
189 uuids=[uuid.uuid4()],
190 decimals=[decimal.Decimal(1.25), 1.75],
191 tags=None,
192 )
193 instance.refresh_from_db()
194 self.assertIsNone(instance.tags)
195 self.assertEqual(instance.json, [])
196 self.assertIsNone(instance.int_ranges)
197 self.assertIsNone(instance.bigint_ranges)
198
199 def test_model_set_on_base_field(self):
200 instance = IntegerArrayModel()
201 field = instance._meta.get_field("field")
202 self.assertEqual(field.model, IntegerArrayModel)
203 self.assertEqual(field.base_field.model, IntegerArrayModel)
204
205 def test_nested_nullable_base_field(self):
206 if PSYCOPG2_VERSION < (2, 7, 5):
207 self.skipTest("See https://github.com/psycopg/psycopg2/issues/325")
208 instance = NullableIntegerArrayModel.objects.create(
209 field_nested=[[None, None], [None, None]],
210 )
211 self.assertEqual(instance.field_nested, [[None, None], [None, None]])
212
213
214class TestQuerying(PostgreSQLTestCase):
215 @classmethod
216 def setUpTestData(cls):
217 cls.objs = NullableIntegerArrayModel.objects.bulk_create(
218 [
219 NullableIntegerArrayModel(field=[1]),
220 NullableIntegerArrayModel(field=[2]),
221 NullableIntegerArrayModel(field=[2, 3]),
222 NullableIntegerArrayModel(field=[20, 30, 40]),
223 NullableIntegerArrayModel(field=None),
224 ]
225 )
226
227 def test_empty_list(self):
228 NullableIntegerArrayModel.objects.create(field=[])
229 obj = (
230 NullableIntegerArrayModel.objects.annotate(
231 empty_array=models.Value(
232 [], output_field=ArrayField(models.IntegerField())
233 ),
234 )
235 .filter(field=models.F("empty_array"))
236 .get()
237 )
238 self.assertEqual(obj.field, [])
239 self.assertEqual(obj.empty_array, [])
240
241 def test_exact(self):
242 self.assertSequenceEqual(
243 NullableIntegerArrayModel.objects.filter(field__exact=[1]), self.objs[:1]
244 )
245
246 def test_exact_charfield(self):
247 instance = CharArrayModel.objects.create(field=["text"])
248 self.assertSequenceEqual(
249 CharArrayModel.objects.filter(field=["text"]), [instance]
250 )
251
252 def test_exact_nested(self):
253 instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
254 self.assertSequenceEqual(
255 NestedIntegerArrayModel.objects.filter(field=[[1, 2], [3, 4]]), [instance]
256 )
257
258 def test_isnull(self):
259 self.assertSequenceEqual(
260 NullableIntegerArrayModel.objects.filter(field__isnull=True), self.objs[-1:]
261 )
262
263 def test_gt(self):
264 self.assertSequenceEqual(
265 NullableIntegerArrayModel.objects.filter(field__gt=[0]), self.objs[:4]
266 )
267
268 def test_lt(self):
269 self.assertSequenceEqual(
270 NullableIntegerArrayModel.objects.filter(field__lt=[2]), self.objs[:1]
271 )
272
273 def test_in(self):
274 self.assertSequenceEqual(
275 NullableIntegerArrayModel.objects.filter(field__in=[[1], [2]]),
276 self.objs[:2],
277 )
278
279 def test_in_subquery(self):
280 IntegerArrayModel.objects.create(field=[2, 3])
281 self.assertSequenceEqual(
282 NullableIntegerArrayModel.objects.filter(
283 field__in=IntegerArrayModel.objects.all().values_list(
284 "field", flat=True
285 )
286 ),
287 self.objs[2:3],
288 )
289
290 @unittest.expectedFailure
291 def test_in_including_F_object(self):
292 # This test asserts that Array objects passed to filters can be
293 # constructed to contain F objects. This currently doesn't work as the
294 # psycopg2 mogrify method that generates the ARRAY() syntax is
295 # expecting literals, not column references (#27095).
296 self.assertSequenceEqual(
297 NullableIntegerArrayModel.objects.filter(field__in=[[models.F("id")]]),
298 self.objs[:2],
299 )
300
301 def test_in_as_F_object(self):
302 self.assertSequenceEqual(
303 NullableIntegerArrayModel.objects.filter(field__in=[models.F("field")]),
304 self.objs[:4],
305 )
306
307 def test_contained_by(self):
308 self.assertSequenceEqual(
309 NullableIntegerArrayModel.objects.filter(field__contained_by=[1, 2]),
310 self.objs[:2],
311 )
312
313 @unittest.expectedFailure
314 def test_contained_by_including_F_object(self):
315 # This test asserts that Array objects passed to filters can be
316 # constructed to contain F objects. This currently doesn't work as the
317 # psycopg2 mogrify method that generates the ARRAY() syntax is
318 # expecting literals, not column references (#27095).
319 self.assertSequenceEqual(
320 NullableIntegerArrayModel.objects.filter(
321 field__contained_by=[models.F("id"), 2]
322 ),
323 self.objs[:2],
324 )
325
326 def test_contains(self):
327 self.assertSequenceEqual(
328 NullableIntegerArrayModel.objects.filter(field__contains=[2]),
329 self.objs[1:3],
330 )
331
332 def test_icontains(self):
333 # Using the __icontains lookup with ArrayField is inefficient.
334 instance = CharArrayModel.objects.create(field=["FoO"])
335 self.assertSequenceEqual(
336 CharArrayModel.objects.filter(field__icontains="foo"), [instance]
337 )
338
339 def test_contains_charfield(self):
340 # Regression for #22907
341 self.assertSequenceEqual(
342 CharArrayModel.objects.filter(field__contains=["text"]), []
343 )
344
345 def test_contained_by_charfield(self):
346 self.assertSequenceEqual(
347 CharArrayModel.objects.filter(field__contained_by=["text"]), []
348 )
349
350 def test_overlap_charfield(self):
351 self.assertSequenceEqual(
352 CharArrayModel.objects.filter(field__overlap=["text"]), []
353 )
354
355 def test_lookups_autofield_array(self):
356 qs = (
357 NullableIntegerArrayModel.objects.filter(
358 field__0__isnull=False,
359 )
360 .values("field__0")
361 .annotate(
362 arrayagg=ArrayAgg("id"),
363 )
364 .order_by("field__0")
365 )
366 tests = (
367 ("contained_by", [self.objs[1].pk, self.objs[2].pk, 0], [2]),
368 ("contains", [self.objs[2].pk], [2]),
369 ("exact", [self.objs[3].pk], [20]),
370 ("overlap", [self.objs[1].pk, self.objs[3].pk], [2, 20]),
371 )
372 for lookup, value, expected in tests:
373 with self.subTest(lookup=lookup):
374 self.assertSequenceEqual(
375 qs.filter(
376 **{"arrayagg__" + lookup: value},
377 ).values_list("field__0", flat=True),
378 expected,
379 )
380
381 def test_index(self):
382 self.assertSequenceEqual(
383 NullableIntegerArrayModel.objects.filter(field__0=2), self.objs[1:3]
384 )
385
386 def test_index_chained(self):
387 self.assertSequenceEqual(
388 NullableIntegerArrayModel.objects.filter(field__0__lt=3), self.objs[0:3]
389 )
390
391 def test_index_nested(self):
392 instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
393 self.assertSequenceEqual(
394 NestedIntegerArrayModel.objects.filter(field__0__0=1), [instance]
395 )
396
397 @unittest.expectedFailure
398 def test_index_used_on_nested_data(self):
399 instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
400 self.assertSequenceEqual(
401 NestedIntegerArrayModel.objects.filter(field__0=[1, 2]), [instance]
402 )
403
404 def test_index_transform_expression(self):
405 expr = RawSQL("string_to_array(%s, ';')", ["1;2"])
406 self.assertSequenceEqual(
407 NullableIntegerArrayModel.objects.filter(
408 field__0=Cast(
409 IndexTransform(1, models.IntegerField, expr),
410 output_field=models.IntegerField(),
411 ),
412 ),
413 self.objs[:1],
414 )
415
416 def test_overlap(self):
417 self.assertSequenceEqual(
418 NullableIntegerArrayModel.objects.filter(field__overlap=[1, 2]),
419 self.objs[0:3],
420 )
421
422 def test_len(self):
423 self.assertSequenceEqual(
424 NullableIntegerArrayModel.objects.filter(field__len__lte=2), self.objs[0:3]
425 )
426
427 def test_len_empty_array(self):
428 obj = NullableIntegerArrayModel.objects.create(field=[])
429 self.assertSequenceEqual(
430 NullableIntegerArrayModel.objects.filter(field__len=0), [obj]
431 )
432
433 def test_slice(self):
434 self.assertSequenceEqual(
435 NullableIntegerArrayModel.objects.filter(field__0_1=[2]), self.objs[1:3]
436 )
437
438 self.assertSequenceEqual(
439 NullableIntegerArrayModel.objects.filter(field__0_2=[2, 3]), self.objs[2:3]
440 )
441
442 def test_order_by_slice(self):
443 more_objs = (
444 NullableIntegerArrayModel.objects.create(field=[1, 637]),
445 NullableIntegerArrayModel.objects.create(field=[2, 1]),
446 NullableIntegerArrayModel.objects.create(field=[3, -98123]),
447 NullableIntegerArrayModel.objects.create(field=[4, 2]),
448 )
449 self.assertSequenceEqual(
450 NullableIntegerArrayModel.objects.order_by("field__1"),
451 [
452 more_objs[2],
453 more_objs[1],
454 more_objs[3],
455 self.objs[2],
456 self.objs[3],
457 more_objs[0],
458 self.objs[4],
459 self.objs[1],
460 self.objs[0],
461 ],
462 )
463
464 @unittest.expectedFailure
465 def test_slice_nested(self):
466 instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
467 self.assertSequenceEqual(
468 NestedIntegerArrayModel.objects.filter(field__0__0_1=[1]), [instance]
469 )
470
471 def test_slice_transform_expression(self):
472 expr = RawSQL("string_to_array(%s, ';')", ["9;2;3"])
473 self.assertSequenceEqual(
474 NullableIntegerArrayModel.objects.filter(
475 field__0_2=SliceTransform(2, 3, expr)
476 ),
477 self.objs[2:3],
478 )
479
480 def test_usage_in_subquery(self):
481 self.assertSequenceEqual(
482 NullableIntegerArrayModel.objects.filter(
483 id__in=NullableIntegerArrayModel.objects.filter(field__len=3)
484 ),
485 [self.objs[3]],
486 )
487
488 def test_enum_lookup(self):
489 class TestEnum(enum.Enum):
490 VALUE_1 = "value_1"
491
492 instance = ArrayEnumModel.objects.create(array_of_enums=[TestEnum.VALUE_1])
493 self.assertSequenceEqual(
494 ArrayEnumModel.objects.filter(array_of_enums__contains=[TestEnum.VALUE_1]),
495 [instance],
496 )
497
498 def test_unsupported_lookup(self):
499 msg = "Unsupported lookup '0_bar' for ArrayField or join on the field not permitted."
500 with self.assertRaisesMessage(FieldError, msg):
501 list(NullableIntegerArrayModel.objects.filter(field__0_bar=[2]))
502
503 msg = "Unsupported lookup '0bar' for ArrayField or join on the field not permitted."
504 with self.assertRaisesMessage(FieldError, msg):
505 list(NullableIntegerArrayModel.objects.filter(field__0bar=[2]))
506
507 def test_grouping_by_annotations_with_array_field_param(self):
508 value = models.Value([1], output_field=ArrayField(models.IntegerField()))
509 self.assertEqual(
510 NullableIntegerArrayModel.objects.annotate(
511 array_length=models.Func(value, 1, function="ARRAY_LENGTH"),
512 )
513 .values("array_length")
514 .annotate(
515 count=models.Count("pk"),
516 )
517 .get()["array_length"],
518 1,
519 )
520
521
522class TestDateTimeExactQuerying(PostgreSQLTestCase):
523 @classmethod
524 def setUpTestData(cls):
525 now = timezone.now()
526 cls.datetimes = [now]
527 cls.dates = [now.date()]
528 cls.times = [now.time()]
529 cls.objs = [
530 DateTimeArrayModel.objects.create(
531 datetimes=cls.datetimes, dates=cls.dates, times=cls.times
532 ),
533 ]
534
535 def test_exact_datetimes(self):
536 self.assertSequenceEqual(
537 DateTimeArrayModel.objects.filter(datetimes=self.datetimes), self.objs
538 )
539
540 def test_exact_dates(self):
541 self.assertSequenceEqual(
542 DateTimeArrayModel.objects.filter(dates=self.dates), self.objs
543 )
544
545 def test_exact_times(self):
546 self.assertSequenceEqual(
547 DateTimeArrayModel.objects.filter(times=self.times), self.objs
548 )
549
550
551class TestOtherTypesExactQuerying(PostgreSQLTestCase):
552 @classmethod
553 def setUpTestData(cls):
554 cls.ips = ["192.168.0.1", "::1"]
555 cls.uuids = [uuid.uuid4()]
556 cls.decimals = [decimal.Decimal(1.25), 1.75]
557 cls.tags = [Tag(1), Tag(2), Tag(3)]
558 cls.objs = [
559 OtherTypesArrayModel.objects.create(
560 ips=cls.ips,
561 uuids=cls.uuids,
562 decimals=cls.decimals,
563 tags=cls.tags,
564 )
565 ]
566
567 def test_exact_ip_addresses(self):
568 self.assertSequenceEqual(
569 OtherTypesArrayModel.objects.filter(ips=self.ips), self.objs
570 )
571
572 def test_exact_uuids(self):
573 self.assertSequenceEqual(
574 OtherTypesArrayModel.objects.filter(uuids=self.uuids), self.objs
575 )
576
577 def test_exact_decimals(self):
578 self.assertSequenceEqual(
579 OtherTypesArrayModel.objects.filter(decimals=self.decimals), self.objs
580 )
581
582 def test_exact_tags(self):
583 self.assertSequenceEqual(
584 OtherTypesArrayModel.objects.filter(tags=self.tags), self.objs
585 )
586
587
588@isolate_apps("postgres_tests")
589class TestChecks(PostgreSQLSimpleTestCase):
590 def test_field_checks(self):
591 class MyModel(PostgreSQLModel):
592 field = ArrayField(models.CharField())
593
594 model = MyModel()
595 errors = model.check()
596 self.assertEqual(len(errors), 1)
597 # The inner CharField is missing a max_length.
598 self.assertEqual(errors[0].id, "postgres.E001")
599 self.assertIn("max_length", errors[0].msg)
600
601 def test_invalid_base_fields(self):
602 class MyModel(PostgreSQLModel):
603 field = ArrayField(
604 models.ManyToManyField("postgres_tests.IntegerArrayModel")
605 )
606
607 model = MyModel()
608 errors = model.check()
609 self.assertEqual(len(errors), 1)
610 self.assertEqual(errors[0].id, "postgres.E002")
611
612 def test_invalid_default(self):
613 class MyModel(PostgreSQLModel):
614 field = ArrayField(models.IntegerField(), default=[])
615
616 model = MyModel()
617 self.assertEqual(
618 model.check(),
619 [
620 checks.Warning(
621 msg=(
622 "ArrayField default should be a callable instead of an "
623 "instance so that it's not shared between all field "
624 "instances."
625 ),
626 hint="Use a callable instead, e.g., use `list` instead of `[]`.",
627 obj=MyModel._meta.get_field("field"),
628 id="fields.E010",
629 )
630 ],
631 )
632
633 def test_valid_default(self):
634 class MyModel(PostgreSQLModel):
635 field = ArrayField(models.IntegerField(), default=list)
636
637 model = MyModel()
638 self.assertEqual(model.check(), [])
639
640 def test_valid_default_none(self):
641 class MyModel(PostgreSQLModel):
642 field = ArrayField(models.IntegerField(), default=None)
643
644 model = MyModel()
645 self.assertEqual(model.check(), [])
646
647 def test_nested_field_checks(self):
648 """
649 Nested ArrayFields are permitted.
650 """
651
652 class MyModel(PostgreSQLModel):
653 field = ArrayField(ArrayField(models.CharField()))
654
655 model = MyModel()
656 errors = model.check()
657 self.assertEqual(len(errors), 1)
658 # The inner CharField is missing a max_length.
659 self.assertEqual(errors[0].id, "postgres.E001")
660 self.assertIn("max_length", errors[0].msg)
661
662 def test_choices_tuple_list(self):
663 class MyModel(PostgreSQLModel):
664 field = ArrayField(
665 models.CharField(max_length=16),
666 choices=[
667 [
668 "Media",
669 [(["vinyl", "cd"], "Audio"), (("vhs", "dvd"), "Video")],
670 ],
671 (["mp3", "mp4"], "Digital"),
672 ],
673 )
674
675 self.assertEqual(MyModel._meta.get_field("field").check(), [])
676
677
678@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
679class TestMigrations(TransactionTestCase):
680
681 available_apps = ["postgres_tests"]
682
683 def test_deconstruct(self):
684 field = ArrayField(models.IntegerField())
685 name, path, args, kwargs = field.deconstruct()
686 new = ArrayField(*args, **kwargs)
687 self.assertEqual(type(new.base_field), type(field.base_field))
688 self.assertIsNot(new.base_field, field.base_field)
689
690 def test_deconstruct_with_size(self):
691 field = ArrayField(models.IntegerField(), size=3)
692 name, path, args, kwargs = field.deconstruct()
693 new = ArrayField(*args, **kwargs)
694 self.assertEqual(new.size, field.size)
695
696 def test_deconstruct_args(self):
697 field = ArrayField(models.CharField(max_length=20))
698 name, path, args, kwargs = field.deconstruct()
699 new = ArrayField(*args, **kwargs)
700 self.assertEqual(new.base_field.max_length, field.base_field.max_length)
701
702 def test_subclass_deconstruct(self):
703 field = ArrayField(models.IntegerField())
704 name, path, args, kwargs = field.deconstruct()
705 self.assertEqual(path, "django.contrib.postgres.fields.ArrayField")
706
707 field = ArrayFieldSubclass()
708 name, path, args, kwargs = field.deconstruct()
709 self.assertEqual(path, "postgres_tests.models.ArrayFieldSubclass")
710
711 @override_settings(
712 MIGRATION_MODULES={
713 "postgres_tests": "postgres_tests.array_default_migrations",
714 }
715 )
716 def test_adding_field_with_default(self):
717 # See #22962
718 table_name = "postgres_tests_integerarraydefaultmodel"
719 with connection.cursor() as cursor:
720 self.assertNotIn(table_name, connection.introspection.table_names(cursor))
721 call_command("migrate", "postgres_tests", verbosity=0)
722 with connection.cursor() as cursor:
723 self.assertIn(table_name, connection.introspection.table_names(cursor))
724 call_command("migrate", "postgres_tests", "zero", verbosity=0)
725 with connection.cursor() as cursor:
726 self.assertNotIn(table_name, connection.introspection.table_names(cursor))
727
728 @override_settings(
729 MIGRATION_MODULES={
730 "postgres_tests": "postgres_tests.array_index_migrations",
731 }
732 )
733 def test_adding_arrayfield_with_index(self):
734 """
735 ArrayField shouldn't have varchar_patterns_ops or text_patterns_ops indexes.
736 """
737 table_name = "postgres_tests_chartextarrayindexmodel"
738 call_command("migrate", "postgres_tests", verbosity=0)
739 with connection.cursor() as cursor:
740 like_constraint_columns_list = [
741 v["columns"]
742 for k, v in list(
743 connection.introspection.get_constraints(cursor, table_name).items()
744 )
745 if k.endswith("_like")
746 ]
747 # Only the CharField should have a LIKE index.
748 self.assertEqual(like_constraint_columns_list, [["char2"]])
749 # All fields should have regular indexes.
750 with connection.cursor() as cursor:
751 indexes = [
752 c["columns"][0]
753 for c in connection.introspection.get_constraints(
754 cursor, table_name
755 ).values()
756 if c["index"] and len(c["columns"]) == 1
757 ]
758 self.assertIn("char", indexes)
759 self.assertIn("char2", indexes)
760 self.assertIn("text", indexes)
761 call_command("migrate", "postgres_tests", "zero", verbosity=0)
762 with connection.cursor() as cursor:
763 self.assertNotIn(table_name, connection.introspection.table_names(cursor))
764
765
766class TestSerialization(PostgreSQLSimpleTestCase):
767 test_data = '[{"fields": {"field": "[\\"1\\", \\"2\\", null]"}, "model": "postgres_tests.integerarraymodel", "pk": null}]'
768
769 def test_dumping(self):
770 instance = IntegerArrayModel(field=[1, 2, None])
771 data = serializers.serialize("json", [instance])
772 self.assertEqual(json.loads(data), json.loads(self.test_data))
773
774 def test_loading(self):
775 instance = list(serializers.deserialize("json", self.test_data))[0].object
776 self.assertEqual(instance.field, [1, 2, None])
777
778
779class TestValidation(PostgreSQLSimpleTestCase):
780 def test_unbounded(self):
781 field = ArrayField(models.IntegerField())
782 with self.assertRaises(exceptions.ValidationError) as cm:
783 field.clean([1, None], None)
784 self.assertEqual(cm.exception.code, "item_invalid")
785 self.assertEqual(
786 cm.exception.message % cm.exception.params,
787 "Item 2 in the array did not validate: This field cannot be null.",
788 )
789
790 def test_blank_true(self):
791 field = ArrayField(models.IntegerField(blank=True, null=True))
792 # This should not raise a validation error
793 field.clean([1, None], None)
794
795 def test_with_size(self):
796 field = ArrayField(models.IntegerField(), size=3)
797 field.clean([1, 2, 3], None)
798 with self.assertRaises(exceptions.ValidationError) as cm:
799 field.clean([1, 2, 3, 4], None)
800 self.assertEqual(
801 cm.exception.messages[0],
802 "List contains 4 items, it should contain no more than 3.",
803 )
804
805 def test_nested_array_mismatch(self):
806 field = ArrayField(ArrayField(models.IntegerField()))
807 field.clean([[1, 2], [3, 4]], None)
808 with self.assertRaises(exceptions.ValidationError) as cm:
809 field.clean([[1, 2], [3, 4, 5]], None)
810 self.assertEqual(cm.exception.code, "nested_array_mismatch")
811 self.assertEqual(
812 cm.exception.messages[0], "Nested arrays must have the same length."
813 )
814
815 def test_with_base_field_error_params(self):
816 field = ArrayField(models.CharField(max_length=2))
817 with self.assertRaises(exceptions.ValidationError) as cm:
818 field.clean(["abc"], None)
819 self.assertEqual(len(cm.exception.error_list), 1)
820 exception = cm.exception.error_list[0]
821 self.assertEqual(
822 exception.message,
823 "Item 1 in the array did not validate: Ensure this value has at most 2 characters (it has 3).",
824 )
825 self.assertEqual(exception.code, "item_invalid")
826 self.assertEqual(
827 exception.params,
828 {"nth": 1, "value": "abc", "limit_value": 2, "show_value": 3},
829 )
830
831 def test_with_validators(self):
832 field = ArrayField(
833 models.IntegerField(validators=[validators.MinValueValidator(1)])
834 )
835 field.clean([1, 2], None)
836 with self.assertRaises(exceptions.ValidationError) as cm:
837 field.clean([0], None)
838 self.assertEqual(len(cm.exception.error_list), 1)
839 exception = cm.exception.error_list[0]
840 self.assertEqual(
841 exception.message,
842 "Item 1 in the array did not validate: Ensure this value is greater than or equal to 1.",
843 )
844 self.assertEqual(exception.code, "item_invalid")
845 self.assertEqual(
846 exception.params, {"nth": 1, "value": 0, "limit_value": 1, "show_value": 0}
847 )
848
849
850class TestSimpleFormField(PostgreSQLSimpleTestCase):
851 def test_valid(self):
852 field = SimpleArrayField(forms.CharField())
853 value = field.clean("a,b,c")
854 self.assertEqual(value, ["a", "b", "c"])
855
856 def test_to_python_fail(self):
857 field = SimpleArrayField(forms.IntegerField())
858 with self.assertRaises(exceptions.ValidationError) as cm:
859 field.clean("a,b,9")
860 self.assertEqual(
861 cm.exception.messages[0],
862 "Item 1 in the array did not validate: Enter a whole number.",
863 )
864
865 def test_validate_fail(self):
866 field = SimpleArrayField(forms.CharField(required=True))
867 with self.assertRaises(exceptions.ValidationError) as cm:
868 field.clean("a,b,")
869 self.assertEqual(
870 cm.exception.messages[0],
871 "Item 3 in the array did not validate: This field is required.",
872 )
873
874 def test_validate_fail_base_field_error_params(self):
875 field = SimpleArrayField(forms.CharField(max_length=2))
876 with self.assertRaises(exceptions.ValidationError) as cm:
877 field.clean("abc,c,defg")
878 errors = cm.exception.error_list
879 self.assertEqual(len(errors), 2)
880 first_error = errors[0]
881 self.assertEqual(
882 first_error.message,
883 "Item 1 in the array did not validate: Ensure this value has at most 2 characters (it has 3).",
884 )
885 self.assertEqual(first_error.code, "item_invalid")
886 self.assertEqual(
887 first_error.params,
888 {"nth": 1, "value": "abc", "limit_value": 2, "show_value": 3},
889 )
890 second_error = errors[1]
891 self.assertEqual(
892 second_error.message,
893 "Item 3 in the array did not validate: Ensure this value has at most 2 characters (it has 4).",
894 )
895 self.assertEqual(second_error.code, "item_invalid")
896 self.assertEqual(
897 second_error.params,
898 {"nth": 3, "value": "defg", "limit_value": 2, "show_value": 4},
899 )
900
901 def test_validators_fail(self):
902 field = SimpleArrayField(forms.RegexField("[a-e]{2}"))
903 with self.assertRaises(exceptions.ValidationError) as cm:
904 field.clean("a,bc,de")
905 self.assertEqual(
906 cm.exception.messages[0],
907 "Item 1 in the array did not validate: Enter a valid value.",
908 )
909
910 def test_delimiter(self):
911 field = SimpleArrayField(forms.CharField(), delimiter="|")
912 value = field.clean("a|b|c")
913 self.assertEqual(value, ["a", "b", "c"])
914
915 def test_delimiter_with_nesting(self):
916 field = SimpleArrayField(SimpleArrayField(forms.CharField()), delimiter="|")
917 value = field.clean("a,b|c,d")
918 self.assertEqual(value, [["a", "b"], ["c", "d"]])
919
920 def test_prepare_value(self):
921 field = SimpleArrayField(forms.CharField())
922 value = field.prepare_value(["a", "b", "c"])
923 self.assertEqual(value, "a,b,c")
924
925 def test_max_length(self):
926 field = SimpleArrayField(forms.CharField(), max_length=2)
927 with self.assertRaises(exceptions.ValidationError) as cm:
928 field.clean("a,b,c")
929 self.assertEqual(
930 cm.exception.messages[0],
931 "List contains 3 items, it should contain no more than 2.",
932 )
933
934 def test_min_length(self):
935 field = SimpleArrayField(forms.CharField(), min_length=4)
936 with self.assertRaises(exceptions.ValidationError) as cm:
937 field.clean("a,b,c")
938 self.assertEqual(
939 cm.exception.messages[0],
940 "List contains 3 items, it should contain no fewer than 4.",
941 )
942
943 def test_required(self):
944 field = SimpleArrayField(forms.CharField(), required=True)
945 with self.assertRaises(exceptions.ValidationError) as cm:
946 field.clean("")
947 self.assertEqual(cm.exception.messages[0], "This field is required.")
948
949 def test_model_field_formfield(self):
950 model_field = ArrayField(models.CharField(max_length=27))
951 form_field = model_field.formfield()
952 self.assertIsInstance(form_field, SimpleArrayField)
953 self.assertIsInstance(form_field.base_field, forms.CharField)
954 self.assertEqual(form_field.base_field.max_length, 27)
955
956 def test_model_field_formfield_size(self):
957 model_field = ArrayField(models.CharField(max_length=27), size=4)
958 form_field = model_field.formfield()
959 self.assertIsInstance(form_field, SimpleArrayField)
960 self.assertEqual(form_field.max_length, 4)
961
962 def test_model_field_choices(self):
963 model_field = ArrayField(models.IntegerField(choices=((1, "A"), (2, "B"))))
964 form_field = model_field.formfield()
965 self.assertEqual(form_field.clean("1,2"), [1, 2])
966
967 def test_already_converted_value(self):
968 field = SimpleArrayField(forms.CharField())
969 vals = ["a", "b", "c"]
970 self.assertEqual(field.clean(vals), vals)
971
972 def test_has_changed(self):
973 field = SimpleArrayField(forms.IntegerField())
974 self.assertIs(field.has_changed([1, 2], [1, 2]), False)
975 self.assertIs(field.has_changed([1, 2], "1,2"), False)
976 self.assertIs(field.has_changed([1, 2], "1,2,3"), True)
977 self.assertIs(field.has_changed([1, 2], "a,b"), True)
978
979 def test_has_changed_empty(self):
980 field = SimpleArrayField(forms.CharField())
981 self.assertIs(field.has_changed(None, None), False)
982 self.assertIs(field.has_changed(None, ""), False)
983 self.assertIs(field.has_changed(None, []), False)
984 self.assertIs(field.has_changed([], None), False)
985 self.assertIs(field.has_changed([], ""), False)
986
987
988class TestSplitFormField(PostgreSQLSimpleTestCase):
989 def test_valid(self):
990 class SplitForm(forms.Form):
991 array = SplitArrayField(forms.CharField(), size=3)
992
993 data = {"array_0": "a", "array_1": "b", "array_2": "c"}
994 form = SplitForm(data)
995 self.assertTrue(form.is_valid())
996 self.assertEqual(form.cleaned_data, {"array": ["a", "b", "c"]})
997
998 def test_required(self):
999 class SplitForm(forms.Form):
1000 array = SplitArrayField(forms.CharField(), required=True, size=3)
1001
1002 data = {"array_0": "", "array_1": "", "array_2": ""}
1003 form = SplitForm(data)
1004 self.assertFalse(form.is_valid())
1005 self.assertEqual(form.errors, {"array": ["This field is required."]})
1006
1007 def test_remove_trailing_nulls(self):
1008 class SplitForm(forms.Form):
1009 array = SplitArrayField(
1010 forms.CharField(required=False), size=5, remove_trailing_nulls=True
1011 )
1012
1013 data = {
1014 "array_0": "a",
1015 "array_1": "",
1016 "array_2": "b",
1017 "array_3": "",
1018 "array_4": "",
1019 }
1020 form = SplitForm(data)
1021 self.assertTrue(form.is_valid(), form.errors)
1022 self.assertEqual(form.cleaned_data, {"array": ["a", "", "b"]})
1023
1024 def test_remove_trailing_nulls_not_required(self):
1025 class SplitForm(forms.Form):
1026 array = SplitArrayField(
1027 forms.CharField(required=False),
1028 size=2,
1029 remove_trailing_nulls=True,
1030 required=False,
1031 )
1032
1033 data = {"array_0": "", "array_1": ""}
1034 form = SplitForm(data)
1035 self.assertTrue(form.is_valid())
1036 self.assertEqual(form.cleaned_data, {"array": []})
1037
1038 def test_required_field(self):
1039 class SplitForm(forms.Form):
1040 array = SplitArrayField(forms.CharField(), size=3)
1041
1042 data = {"array_0": "a", "array_1": "b", "array_2": ""}
1043 form = SplitForm(data)
1044 self.assertFalse(form.is_valid())
1045 self.assertEqual(
1046 form.errors,
1047 {
1048 "array": [
1049 "Item 3 in the array did not validate: This field is required."
1050 ]
1051 },
1052 )
1053
1054 def test_invalid_integer(self):
1055 msg = "Item 2 in the array did not validate: Ensure this value is less than or equal to 100."
1056 with self.assertRaisesMessage(exceptions.ValidationError, msg):
1057 SplitArrayField(forms.IntegerField(max_value=100), size=2).clean([0, 101])
1058
1059 # To locate the widget's template.
1060 @modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
1061 def test_rendering(self):
1062 class SplitForm(forms.Form):
1063 array = SplitArrayField(forms.CharField(), size=3)
1064
1065 self.assertHTMLEqual(
1066 str(SplitForm()),
1067 """
1068 <tr>
1069 <th><label for="id_array_0">Array:</label></th>
1070 <td>
1071 <input id="id_array_0" name="array_0" type="text" required>
1072 <input id="id_array_1" name="array_1" type="text" required>
1073 <input id="id_array_2" name="array_2" type="text" required>
1074 </td>
1075 </tr>
1076 """,
1077 )
1078
1079 def test_invalid_char_length(self):
1080 field = SplitArrayField(forms.CharField(max_length=2), size=3)
1081 with self.assertRaises(exceptions.ValidationError) as cm:
1082 field.clean(["abc", "c", "defg"])
1083 self.assertEqual(
1084 cm.exception.messages,
1085 [
1086 "Item 1 in the array did not validate: Ensure this value has at most 2 characters (it has 3).",
1087 "Item 3 in the array did not validate: Ensure this value has at most 2 characters (it has 4).",
1088 ],
1089 )
1090
1091 def test_splitarraywidget_value_omitted_from_data(self):
1092 class Form(forms.ModelForm):
1093 field = SplitArrayField(forms.IntegerField(), required=False, size=2)
1094
1095 class Meta:
1096 model = IntegerArrayModel
1097 fields = ("field",)
1098
1099 form = Form({"field_0": "1", "field_1": "2"})
1100 self.assertEqual(form.errors, {})
1101 obj = form.save(commit=False)
1102 self.assertEqual(obj.field, [1, 2])
1103
1104 def test_splitarrayfield_has_changed(self):
1105 class Form(forms.ModelForm):
1106 field = SplitArrayField(forms.IntegerField(), required=False, size=2)
1107
1108 class Meta:
1109 model = IntegerArrayModel
1110 fields = ("field",)
1111
1112 tests = [
1113 ({}, {"field_0": "", "field_1": ""}, True),
1114 ({"field": None}, {"field_0": "", "field_1": ""}, True),
1115 ({"field": [1]}, {"field_0": "", "field_1": ""}, True),
1116 ({"field": [1]}, {"field_0": "1", "field_1": "0"}, True),
1117 ({"field": [1, 2]}, {"field_0": "1", "field_1": "2"}, False),
1118 ({"field": [1, 2]}, {"field_0": "a", "field_1": "b"}, True),
1119 ]
1120 for initial, data, expected_result in tests:
1121 with self.subTest(initial=initial, data=data):
1122 obj = IntegerArrayModel(**initial)
1123 form = Form(data, instance=obj)
1124 self.assertIs(form.has_changed(), expected_result)
1125
1126 def test_splitarrayfield_remove_trailing_nulls_has_changed(self):
1127 class Form(forms.ModelForm):
1128 field = SplitArrayField(
1129 forms.IntegerField(), required=False, size=2, remove_trailing_nulls=True
1130 )
1131
1132 class Meta:
1133 model = IntegerArrayModel
1134 fields = ("field",)
1135
1136 tests = [
1137 ({}, {"field_0": "", "field_1": ""}, False),
1138 ({"field": None}, {"field_0": "", "field_1": ""}, False),
1139 ({"field": []}, {"field_0": "", "field_1": ""}, False),
1140 ({"field": [1]}, {"field_0": "1", "field_1": ""}, False),
1141 ]
1142 for initial, data, expected_result in tests:
1143 with self.subTest(initial=initial, data=data):
1144 obj = IntegerArrayModel(**initial)
1145 form = Form(data, instance=obj)
1146 self.assertIs(form.has_changed(), expected_result)
1147
1148
1149class TestSplitFormWidget(PostgreSQLWidgetTestCase):
1150 def test_get_context(self):
1151 self.assertEqual(
1152 SplitArrayWidget(forms.TextInput(), size=2).get_context(
1153 "name", ["val1", "val2"]
1154 ),
1155 {
1156 "widget": {
1157 "name": "name",
1158 "is_hidden": False,
1159 "required": False,
1160 "value": "['val1', 'val2']",
1161 "attrs": {},
1162 "template_name": "postgres/widgets/split_array.html",
1163 "subwidgets": [
1164 {
1165 "name": "name_0",
1166 "is_hidden": False,
1167 "required": False,
1168 "value": "val1",
1169 "attrs": {},
1170 "template_name": "django/forms/widgets/text.html",
1171 "type": "text",
1172 },
1173 {
1174 "name": "name_1",
1175 "is_hidden": False,
1176 "required": False,
1177 "value": "val2",
1178 "attrs": {},
1179 "template_name": "django/forms/widgets/text.html",
1180 "type": "text",
1181 },
1182 ],
1183 }
1184 },
1185 )
1186
1187 def test_checkbox_get_context_attrs(self):
1188 context = SplitArrayWidget(
1189 forms.CheckboxInput(),
1190 size=2,
1191 ).get_context("name", [True, False])
1192 self.assertEqual(context["widget"]["value"], "[True, False]")
1193 self.assertEqual(
1194 [subwidget["attrs"] for subwidget in context["widget"]["subwidgets"]],
1195 [{"checked": True}, {}],
1196 )
1197
1198 def test_render(self):
1199 self.check_html(
1200 SplitArrayWidget(forms.TextInput(), size=2),
1201 "array",
1202 None,
1203 """
1204 <input name="array_0" type="text">
1205 <input name="array_1" type="text">
1206 """,
1207 )
1208
1209 def test_render_attrs(self):
1210 self.check_html(
1211 SplitArrayWidget(forms.TextInput(), size=2),
1212 "array",
1213 ["val1", "val2"],
1214 attrs={"id": "foo"},
1215 html=(
1216 """
1217 <input id="foo_0" name="array_0" type="text" value="val1">
1218 <input id="foo_1" name="array_1" type="text" value="val2">
1219 """
1220 ),
1221 )
1222
1223 def test_value_omitted_from_data(self):
1224 widget = SplitArrayWidget(forms.TextInput(), size=2)
1225 self.assertIs(widget.value_omitted_from_data({}, {}, "field"), True)
1226 self.assertIs(
1227 widget.value_omitted_from_data({"field_0": "value"}, {}, "field"), False
1228 )
1229 self.assertIs(
1230 widget.value_omitted_from_data({"field_1": "value"}, {}, "field"), False
1231 )
1232 self.assertIs(
1233 widget.value_omitted_from_data(
1234 {"field_0": "value", "field_1": "value"}, {}, "field"
1235 ),
1236 False,
1237 )