Tests ArrayField

__init__.py

 1import unittest
 2
 3from django.db import connection
 4from django.test import modify_settings
 5from django.test import SimpleTestCase
 6from django.test import TestCase
 7from forms_tests.widget_tests.base import WidgetTest
 8
 9
10@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
11class PostgreSQLSimpleTestCase(SimpleTestCase):
12    pass
13
14
15@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
16class PostgreSQLTestCase(TestCase):
17    pass
18
19
20@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
21# To locate the widget's template.
22@modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
23class PostgreSQLWidgetTestCase(WidgetTest, PostgreSQLSimpleTestCase):
24    pass

fields.py

 1"""
 2Indirection layer for PostgreSQL-specific fields, so the tests don't fail when
 3run with a backend other than PostgreSQL.
 4"""
 5import enum
 6
 7from django.db import models
 8
 9try:
10    from django.contrib.postgres.fields import (
11        ArrayField,
12        BigIntegerRangeField,
13        CICharField,
14        CIEmailField,
15        CITextField,
16        DateRangeField,
17        DateTimeRangeField,
18        DecimalRangeField,
19        HStoreField,
20        IntegerRangeField,
21        JSONField,
22    )
23    from django.contrib.postgres.search import SearchVectorField
24except ImportError:
25
26    class DummyArrayField(models.Field):
27        def __init__(self, base_field, size=None, **kwargs):
28            super().__init__(**kwargs)
29
30        def deconstruct(self):
31            name, path, args, kwargs = super().deconstruct()
32            kwargs.update(
33                {
34                    "base_field": "",
35                    "size": 1,
36                }
37            )
38            return name, path, args, kwargs
39
40    class DummyJSONField(models.Field):
41        def __init__(self, encoder=None, **kwargs):
42            super().__init__(**kwargs)
43
44    ArrayField = DummyArrayField
45    BigIntegerRangeField = models.Field
46    CICharField = models.Field
47    CIEmailField = models.Field
48    CITextField = models.Field
49    DateRangeField = models.Field
50    DateTimeRangeField = models.Field
51    DecimalRangeField = models.Field
52    HStoreField = models.Field
53    IntegerRangeField = models.Field
54    JSONField = DummyJSONField
55    SearchVectorField = models.Field
56
57
58class EnumField(models.CharField):
59    def get_prep_value(self, value):
60        return value.value if isinstance(value, enum.Enum) else value

models.py

  1from django.core.serializers.json import DjangoJSONEncoder
  2from django.db import models
  3
  4from .fields import ArrayField
  5from .fields import BigIntegerRangeField
  6from .fields import CICharField
  7from .fields import CIEmailField
  8from .fields import CITextField
  9from .fields import DateRangeField
 10from .fields import DateTimeRangeField
 11from .fields import DecimalRangeField
 12from .fields import EnumField
 13from .fields import HStoreField
 14from .fields import IntegerRangeField
 15from .fields import JSONField
 16from .fields import SearchVectorField
 17
 18
 19class Tag:
 20    def __init__(self, tag_id):
 21        self.tag_id = tag_id
 22
 23    def __eq__(self, other):
 24        return isinstance(other, Tag) and self.tag_id == other.tag_id
 25
 26
 27class TagField(models.SmallIntegerField):
 28    def from_db_value(self, value, expression, connection):
 29        if value is None:
 30            return value
 31        return Tag(int(value))
 32
 33    def to_python(self, value):
 34        if isinstance(value, Tag):
 35            return value
 36        if value is None:
 37            return value
 38        return Tag(int(value))
 39
 40    def get_prep_value(self, value):
 41        return value.tag_id
 42
 43
 44class PostgreSQLModel(models.Model):
 45    class Meta:
 46        abstract = True
 47        required_db_vendor = "postgresql"
 48
 49
 50class IntegerArrayModel(PostgreSQLModel):
 51    field = ArrayField(models.IntegerField(), default=list, blank=True)
 52
 53
 54class NullableIntegerArrayModel(PostgreSQLModel):
 55    field = ArrayField(models.IntegerField(), blank=True, null=True)
 56    field_nested = ArrayField(ArrayField(models.IntegerField(null=True)), null=True)
 57
 58
 59class CharArrayModel(PostgreSQLModel):
 60    field = ArrayField(models.CharField(max_length=10))
 61
 62
 63class DateTimeArrayModel(PostgreSQLModel):
 64    datetimes = ArrayField(models.DateTimeField())
 65    dates = ArrayField(models.DateField())
 66    times = ArrayField(models.TimeField())
 67
 68
 69class NestedIntegerArrayModel(PostgreSQLModel):
 70    field = ArrayField(ArrayField(models.IntegerField()))
 71
 72
 73class OtherTypesArrayModel(PostgreSQLModel):
 74    ips = ArrayField(models.GenericIPAddressField(), default=list)
 75    uuids = ArrayField(models.UUIDField(), default=list)
 76    decimals = ArrayField(
 77        models.DecimalField(max_digits=5, decimal_places=2), default=list
 78    )
 79    tags = ArrayField(TagField(), blank=True, null=True)
 80    json = ArrayField(JSONField(default=dict), default=list)
 81    int_ranges = ArrayField(IntegerRangeField(), blank=True, null=True)
 82    bigint_ranges = ArrayField(BigIntegerRangeField(), blank=True, null=True)
 83
 84
 85class HStoreModel(PostgreSQLModel):
 86    field = HStoreField(blank=True, null=True)
 87    array_field = ArrayField(HStoreField(), null=True)
 88
 89
 90class ArrayEnumModel(PostgreSQLModel):
 91    array_of_enums = ArrayField(EnumField(max_length=20))
 92
 93
 94class CharFieldModel(models.Model):
 95    field = models.CharField(max_length=16)
 96
 97
 98class TextFieldModel(models.Model):
 99    field = models.TextField()
100
101    def __str__(self):
102        return self.field
103
104
105class SmallAutoFieldModel(models.Model):
106    id = models.SmallAutoField(primary_key=True)
107
108
109class BigAutoFieldModel(models.Model):
110    id = models.BigAutoField(primary_key=True)
111
112
113# Scene/Character/Line models are used to test full text search. They're
114# populated with content from Monty Python and the Holy Grail.
115class Scene(models.Model):
116    scene = models.CharField(max_length=255)
117    setting = models.CharField(max_length=255)
118
119    def __str__(self):
120        return self.scene
121
122
123class Character(models.Model):
124    name = models.CharField(max_length=255)
125
126    def __str__(self):
127        return self.name
128
129
130class CITestModel(PostgreSQLModel):
131    name = CICharField(primary_key=True, max_length=255)
132    email = CIEmailField()
133    description = CITextField()
134    array_field = ArrayField(CITextField(), null=True)
135
136    def __str__(self):
137        return self.name
138
139
140class Line(PostgreSQLModel):
141    scene = models.ForeignKey("Scene", models.CASCADE)
142    character = models.ForeignKey("Character", models.CASCADE)
143    dialogue = models.TextField(blank=True, null=True)
144    dialogue_search_vector = SearchVectorField(blank=True, null=True)
145    dialogue_config = models.CharField(max_length=100, blank=True, null=True)
146
147    def __str__(self):
148        return self.dialogue or ""
149
150
151class RangesModel(PostgreSQLModel):
152    ints = IntegerRangeField(blank=True, null=True)
153    bigints = BigIntegerRangeField(blank=True, null=True)
154    decimals = DecimalRangeField(blank=True, null=True)
155    timestamps = DateTimeRangeField(blank=True, null=True)
156    timestamps_inner = DateTimeRangeField(blank=True, null=True)
157    dates = DateRangeField(blank=True, null=True)
158    dates_inner = DateRangeField(blank=True, null=True)
159
160
161class RangeLookupsModel(PostgreSQLModel):
162    parent = models.ForeignKey(RangesModel, models.SET_NULL, blank=True, null=True)
163    integer = models.IntegerField(blank=True, null=True)
164    big_integer = models.BigIntegerField(blank=True, null=True)
165    float = models.FloatField(blank=True, null=True)
166    timestamp = models.DateTimeField(blank=True, null=True)
167    date = models.DateField(blank=True, null=True)
168    small_integer = models.SmallIntegerField(blank=True, null=True)
169    decimal_field = models.DecimalField(
170        max_digits=5, decimal_places=2, blank=True, null=True
171    )
172
173
174class JSONModel(PostgreSQLModel):
175    field = JSONField(blank=True, null=True)
176    field_custom = JSONField(blank=True, null=True, encoder=DjangoJSONEncoder)
177
178
179class ArrayFieldSubclass(ArrayField):
180    def __init__(self, *args, **kwargs):
181        super().__init__(models.IntegerField())
182
183
184class AggregateTestModel(models.Model):
185    """
186    To test postgres-specific general aggregation functions
187    """
188
189    char_field = models.CharField(max_length=30, blank=True)
190    integer_field = models.IntegerField(null=True)
191    boolean_field = models.BooleanField(null=True)
192
193
194class StatTestModel(models.Model):
195    """
196    To test postgres-specific aggregation functions for statistics
197    """
198
199    int1 = models.IntegerField()
200    int2 = models.IntegerField()
201    related_field = models.ForeignKey(AggregateTestModel, models.SET_NULL, null=True)
202
203
204class NowTestModel(models.Model):
205    when = models.DateTimeField(null=True, default=None)
206
207
208class UUIDTestModel(models.Model):
209    uuid = models.UUIDField(default=None, null=True)
210
211
212class Room(models.Model):
213    number = models.IntegerField(unique=True)
214
215
216class HotelReservation(PostgreSQLModel):
217    room = models.ForeignKey("Room", on_delete=models.CASCADE)
218    datespan = DateRangeField()
219    start = models.DateTimeField()
220    end = models.DateTimeField()
221    cancelled = models.BooleanField(default=False)

test_array.py

   1import decimal
   2import enum
   3import json
   4import unittest
   5import uuid
   6
   7from django import forms
   8from django.core import checks
   9from django.core import exceptions
  10from django.core import serializers
  11from django.core import validators
  12from django.core.exceptions import FieldError
  13from django.core.management import call_command
  14from django.db import connection
  15from django.db import IntegrityError
  16from django.db import models
  17from django.db.models.expressions import RawSQL
  18from django.db.models.functions import Cast
  19from django.test import modify_settings
  20from django.test import override_settings
  21from django.test import TransactionTestCase
  22from django.test.utils import isolate_apps
  23from django.utils import timezone
  24
  25from . import PostgreSQLSimpleTestCase
  26from . import PostgreSQLTestCase
  27from . import PostgreSQLWidgetTestCase
  28from .models import ArrayEnumModel
  29from .models import ArrayFieldSubclass
  30from .models import CharArrayModel
  31from .models import DateTimeArrayModel
  32from .models import IntegerArrayModel
  33from .models import NestedIntegerArrayModel
  34from .models import NullableIntegerArrayModel
  35from .models import OtherTypesArrayModel
  36from .models import PostgreSQLModel
  37from .models import Tag
  38
  39try:
  40    from django.contrib.postgres.aggregates import ArrayAgg
  41    from django.contrib.postgres.fields import ArrayField
  42    from django.contrib.postgres.fields.array import IndexTransform, SliceTransform
  43    from django.contrib.postgres.forms import (
  44        SimpleArrayField,
  45        SplitArrayField,
  46        SplitArrayWidget,
  47    )
  48    from django.db.backends.postgresql.base import PSYCOPG2_VERSION
  49    from psycopg2.extras import NumericRange
  50except ImportError:
  51    pass
  52
  53
  54@isolate_apps("postgres_tests")
  55class BasicTests(PostgreSQLSimpleTestCase):
  56    def test_get_field_display(self):
  57        class MyModel(PostgreSQLModel):
  58            field = ArrayField(
  59                models.CharField(max_length=16),
  60                choices=[
  61                    ["Media", [(["vinyl", "cd"], "Audio")]],
  62                    (("mp3", "mp4"), "Digital"),
  63                ],
  64            )
  65
  66        tests = (
  67            (["vinyl", "cd"], "Audio"),
  68            (("mp3", "mp4"), "Digital"),
  69            (("a", "b"), "('a', 'b')"),
  70            (["c", "d"], "['c', 'd']"),
  71        )
  72        for value, display in tests:
  73            with self.subTest(value=value, display=display):
  74                instance = MyModel(field=value)
  75                self.assertEqual(instance.get_field_display(), display)
  76
  77    def test_get_field_display_nested_array(self):
  78        class MyModel(PostgreSQLModel):
  79            field = ArrayField(
  80                ArrayField(models.CharField(max_length=16)),
  81                choices=[
  82                    [
  83                        "Media",
  84                        [([["vinyl", "cd"], ("x",)], "Audio")],
  85                    ],
  86                    ((["mp3"], ("mp4",)), "Digital"),
  87                ],
  88            )
  89
  90        tests = (
  91            ([["vinyl", "cd"], ("x",)], "Audio"),
  92            ((["mp3"], ("mp4",)), "Digital"),
  93            ((("a", "b"), ("c",)), "(('a', 'b'), ('c',))"),
  94            ([["a", "b"], ["c"]], "[['a', 'b'], ['c']]"),
  95        )
  96        for value, display in tests:
  97            with self.subTest(value=value, display=display):
  98                instance = MyModel(field=value)
  99                self.assertEqual(instance.get_field_display(), display)
 100
 101
 102class TestSaveLoad(PostgreSQLTestCase):
 103    def test_integer(self):
 104        instance = IntegerArrayModel(field=[1, 2, 3])
 105        instance.save()
 106        loaded = IntegerArrayModel.objects.get()
 107        self.assertEqual(instance.field, loaded.field)
 108
 109    def test_char(self):
 110        instance = CharArrayModel(field=["hello", "goodbye"])
 111        instance.save()
 112        loaded = CharArrayModel.objects.get()
 113        self.assertEqual(instance.field, loaded.field)
 114
 115    def test_dates(self):
 116        instance = DateTimeArrayModel(
 117            datetimes=[timezone.now()],
 118            dates=[timezone.now().date()],
 119            times=[timezone.now().time()],
 120        )
 121        instance.save()
 122        loaded = DateTimeArrayModel.objects.get()
 123        self.assertEqual(instance.datetimes, loaded.datetimes)
 124        self.assertEqual(instance.dates, loaded.dates)
 125        self.assertEqual(instance.times, loaded.times)
 126
 127    def test_tuples(self):
 128        instance = IntegerArrayModel(field=(1,))
 129        instance.save()
 130        loaded = IntegerArrayModel.objects.get()
 131        self.assertSequenceEqual(instance.field, loaded.field)
 132
 133    def test_integers_passed_as_strings(self):
 134        # This checks that get_prep_value is deferred properly
 135        instance = IntegerArrayModel(field=["1"])
 136        instance.save()
 137        loaded = IntegerArrayModel.objects.get()
 138        self.assertEqual(loaded.field, [1])
 139
 140    def test_default_null(self):
 141        instance = NullableIntegerArrayModel()
 142        instance.save()
 143        loaded = NullableIntegerArrayModel.objects.get(pk=instance.pk)
 144        self.assertIsNone(loaded.field)
 145        self.assertEqual(instance.field, loaded.field)
 146
 147    def test_null_handling(self):
 148        instance = NullableIntegerArrayModel(field=None)
 149        instance.save()
 150        loaded = NullableIntegerArrayModel.objects.get()
 151        self.assertEqual(instance.field, loaded.field)
 152
 153        instance = IntegerArrayModel(field=None)
 154        with self.assertRaises(IntegrityError):
 155            instance.save()
 156
 157    def test_nested(self):
 158        instance = NestedIntegerArrayModel(field=[[1, 2], [3, 4]])
 159        instance.save()
 160        loaded = NestedIntegerArrayModel.objects.get()
 161        self.assertEqual(instance.field, loaded.field)
 162
 163    def test_other_array_types(self):
 164        instance = OtherTypesArrayModel(
 165            ips=["192.168.0.1", "::1"],
 166            uuids=[uuid.uuid4()],
 167            decimals=[decimal.Decimal(1.25), 1.75],
 168            tags=[Tag(1), Tag(2), Tag(3)],
 169            json=[{"a": 1}, {"b": 2}],
 170            int_ranges=[NumericRange(10, 20), NumericRange(30, 40)],
 171            bigint_ranges=[
 172                NumericRange(7000000000, 10000000000),
 173                NumericRange(50000000000, 70000000000),
 174            ],
 175        )
 176        instance.save()
 177        loaded = OtherTypesArrayModel.objects.get()
 178        self.assertEqual(instance.ips, loaded.ips)
 179        self.assertEqual(instance.uuids, loaded.uuids)
 180        self.assertEqual(instance.decimals, loaded.decimals)
 181        self.assertEqual(instance.tags, loaded.tags)
 182        self.assertEqual(instance.json, loaded.json)
 183        self.assertEqual(instance.int_ranges, loaded.int_ranges)
 184        self.assertEqual(instance.bigint_ranges, loaded.bigint_ranges)
 185
 186    def test_null_from_db_value_handling(self):
 187        instance = OtherTypesArrayModel.objects.create(
 188            ips=["192.168.0.1", "::1"],
 189            uuids=[uuid.uuid4()],
 190            decimals=[decimal.Decimal(1.25), 1.75],
 191            tags=None,
 192        )
 193        instance.refresh_from_db()
 194        self.assertIsNone(instance.tags)
 195        self.assertEqual(instance.json, [])
 196        self.assertIsNone(instance.int_ranges)
 197        self.assertIsNone(instance.bigint_ranges)
 198
 199    def test_model_set_on_base_field(self):
 200        instance = IntegerArrayModel()
 201        field = instance._meta.get_field("field")
 202        self.assertEqual(field.model, IntegerArrayModel)
 203        self.assertEqual(field.base_field.model, IntegerArrayModel)
 204
 205    def test_nested_nullable_base_field(self):
 206        if PSYCOPG2_VERSION < (2, 7, 5):
 207            self.skipTest("See https://github.com/psycopg/psycopg2/issues/325")
 208        instance = NullableIntegerArrayModel.objects.create(
 209            field_nested=[[None, None], [None, None]],
 210        )
 211        self.assertEqual(instance.field_nested, [[None, None], [None, None]])
 212
 213
 214class TestQuerying(PostgreSQLTestCase):
 215    @classmethod
 216    def setUpTestData(cls):
 217        cls.objs = NullableIntegerArrayModel.objects.bulk_create(
 218            [
 219                NullableIntegerArrayModel(field=[1]),
 220                NullableIntegerArrayModel(field=[2]),
 221                NullableIntegerArrayModel(field=[2, 3]),
 222                NullableIntegerArrayModel(field=[20, 30, 40]),
 223                NullableIntegerArrayModel(field=None),
 224            ]
 225        )
 226
 227    def test_empty_list(self):
 228        NullableIntegerArrayModel.objects.create(field=[])
 229        obj = (
 230            NullableIntegerArrayModel.objects.annotate(
 231                empty_array=models.Value(
 232                    [], output_field=ArrayField(models.IntegerField())
 233                ),
 234            )
 235            .filter(field=models.F("empty_array"))
 236            .get()
 237        )
 238        self.assertEqual(obj.field, [])
 239        self.assertEqual(obj.empty_array, [])
 240
 241    def test_exact(self):
 242        self.assertSequenceEqual(
 243            NullableIntegerArrayModel.objects.filter(field__exact=[1]), self.objs[:1]
 244        )
 245
 246    def test_exact_charfield(self):
 247        instance = CharArrayModel.objects.create(field=["text"])
 248        self.assertSequenceEqual(
 249            CharArrayModel.objects.filter(field=["text"]), [instance]
 250        )
 251
 252    def test_exact_nested(self):
 253        instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
 254        self.assertSequenceEqual(
 255            NestedIntegerArrayModel.objects.filter(field=[[1, 2], [3, 4]]), [instance]
 256        )
 257
 258    def test_isnull(self):
 259        self.assertSequenceEqual(
 260            NullableIntegerArrayModel.objects.filter(field__isnull=True), self.objs[-1:]
 261        )
 262
 263    def test_gt(self):
 264        self.assertSequenceEqual(
 265            NullableIntegerArrayModel.objects.filter(field__gt=[0]), self.objs[:4]
 266        )
 267
 268    def test_lt(self):
 269        self.assertSequenceEqual(
 270            NullableIntegerArrayModel.objects.filter(field__lt=[2]), self.objs[:1]
 271        )
 272
 273    def test_in(self):
 274        self.assertSequenceEqual(
 275            NullableIntegerArrayModel.objects.filter(field__in=[[1], [2]]),
 276            self.objs[:2],
 277        )
 278
 279    def test_in_subquery(self):
 280        IntegerArrayModel.objects.create(field=[2, 3])
 281        self.assertSequenceEqual(
 282            NullableIntegerArrayModel.objects.filter(
 283                field__in=IntegerArrayModel.objects.all().values_list(
 284                    "field", flat=True
 285                )
 286            ),
 287            self.objs[2:3],
 288        )
 289
 290    @unittest.expectedFailure
 291    def test_in_including_F_object(self):
 292        # This test asserts that Array objects passed to filters can be
 293        # constructed to contain F objects. This currently doesn't work as the
 294        # psycopg2 mogrify method that generates the ARRAY() syntax is
 295        # expecting literals, not column references (#27095).
 296        self.assertSequenceEqual(
 297            NullableIntegerArrayModel.objects.filter(field__in=[[models.F("id")]]),
 298            self.objs[:2],
 299        )
 300
 301    def test_in_as_F_object(self):
 302        self.assertSequenceEqual(
 303            NullableIntegerArrayModel.objects.filter(field__in=[models.F("field")]),
 304            self.objs[:4],
 305        )
 306
 307    def test_contained_by(self):
 308        self.assertSequenceEqual(
 309            NullableIntegerArrayModel.objects.filter(field__contained_by=[1, 2]),
 310            self.objs[:2],
 311        )
 312
 313    @unittest.expectedFailure
 314    def test_contained_by_including_F_object(self):
 315        # This test asserts that Array objects passed to filters can be
 316        # constructed to contain F objects. This currently doesn't work as the
 317        # psycopg2 mogrify method that generates the ARRAY() syntax is
 318        # expecting literals, not column references (#27095).
 319        self.assertSequenceEqual(
 320            NullableIntegerArrayModel.objects.filter(
 321                field__contained_by=[models.F("id"), 2]
 322            ),
 323            self.objs[:2],
 324        )
 325
 326    def test_contains(self):
 327        self.assertSequenceEqual(
 328            NullableIntegerArrayModel.objects.filter(field__contains=[2]),
 329            self.objs[1:3],
 330        )
 331
 332    def test_icontains(self):
 333        # Using the __icontains lookup with ArrayField is inefficient.
 334        instance = CharArrayModel.objects.create(field=["FoO"])
 335        self.assertSequenceEqual(
 336            CharArrayModel.objects.filter(field__icontains="foo"), [instance]
 337        )
 338
 339    def test_contains_charfield(self):
 340        # Regression for #22907
 341        self.assertSequenceEqual(
 342            CharArrayModel.objects.filter(field__contains=["text"]), []
 343        )
 344
 345    def test_contained_by_charfield(self):
 346        self.assertSequenceEqual(
 347            CharArrayModel.objects.filter(field__contained_by=["text"]), []
 348        )
 349
 350    def test_overlap_charfield(self):
 351        self.assertSequenceEqual(
 352            CharArrayModel.objects.filter(field__overlap=["text"]), []
 353        )
 354
 355    def test_lookups_autofield_array(self):
 356        qs = (
 357            NullableIntegerArrayModel.objects.filter(
 358                field__0__isnull=False,
 359            )
 360            .values("field__0")
 361            .annotate(
 362                arrayagg=ArrayAgg("id"),
 363            )
 364            .order_by("field__0")
 365        )
 366        tests = (
 367            ("contained_by", [self.objs[1].pk, self.objs[2].pk, 0], [2]),
 368            ("contains", [self.objs[2].pk], [2]),
 369            ("exact", [self.objs[3].pk], [20]),
 370            ("overlap", [self.objs[1].pk, self.objs[3].pk], [2, 20]),
 371        )
 372        for lookup, value, expected in tests:
 373            with self.subTest(lookup=lookup):
 374                self.assertSequenceEqual(
 375                    qs.filter(
 376                        **{"arrayagg__" + lookup: value},
 377                    ).values_list("field__0", flat=True),
 378                    expected,
 379                )
 380
 381    def test_index(self):
 382        self.assertSequenceEqual(
 383            NullableIntegerArrayModel.objects.filter(field__0=2), self.objs[1:3]
 384        )
 385
 386    def test_index_chained(self):
 387        self.assertSequenceEqual(
 388            NullableIntegerArrayModel.objects.filter(field__0__lt=3), self.objs[0:3]
 389        )
 390
 391    def test_index_nested(self):
 392        instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
 393        self.assertSequenceEqual(
 394            NestedIntegerArrayModel.objects.filter(field__0__0=1), [instance]
 395        )
 396
 397    @unittest.expectedFailure
 398    def test_index_used_on_nested_data(self):
 399        instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
 400        self.assertSequenceEqual(
 401            NestedIntegerArrayModel.objects.filter(field__0=[1, 2]), [instance]
 402        )
 403
 404    def test_index_transform_expression(self):
 405        expr = RawSQL("string_to_array(%s, ';')", ["1;2"])
 406        self.assertSequenceEqual(
 407            NullableIntegerArrayModel.objects.filter(
 408                field__0=Cast(
 409                    IndexTransform(1, models.IntegerField, expr),
 410                    output_field=models.IntegerField(),
 411                ),
 412            ),
 413            self.objs[:1],
 414        )
 415
 416    def test_overlap(self):
 417        self.assertSequenceEqual(
 418            NullableIntegerArrayModel.objects.filter(field__overlap=[1, 2]),
 419            self.objs[0:3],
 420        )
 421
 422    def test_len(self):
 423        self.assertSequenceEqual(
 424            NullableIntegerArrayModel.objects.filter(field__len__lte=2), self.objs[0:3]
 425        )
 426
 427    def test_len_empty_array(self):
 428        obj = NullableIntegerArrayModel.objects.create(field=[])
 429        self.assertSequenceEqual(
 430            NullableIntegerArrayModel.objects.filter(field__len=0), [obj]
 431        )
 432
 433    def test_slice(self):
 434        self.assertSequenceEqual(
 435            NullableIntegerArrayModel.objects.filter(field__0_1=[2]), self.objs[1:3]
 436        )
 437
 438        self.assertSequenceEqual(
 439            NullableIntegerArrayModel.objects.filter(field__0_2=[2, 3]), self.objs[2:3]
 440        )
 441
 442    def test_order_by_slice(self):
 443        more_objs = (
 444            NullableIntegerArrayModel.objects.create(field=[1, 637]),
 445            NullableIntegerArrayModel.objects.create(field=[2, 1]),
 446            NullableIntegerArrayModel.objects.create(field=[3, -98123]),
 447            NullableIntegerArrayModel.objects.create(field=[4, 2]),
 448        )
 449        self.assertSequenceEqual(
 450            NullableIntegerArrayModel.objects.order_by("field__1"),
 451            [
 452                more_objs[2],
 453                more_objs[1],
 454                more_objs[3],
 455                self.objs[2],
 456                self.objs[3],
 457                more_objs[0],
 458                self.objs[4],
 459                self.objs[1],
 460                self.objs[0],
 461            ],
 462        )
 463
 464    @unittest.expectedFailure
 465    def test_slice_nested(self):
 466        instance = NestedIntegerArrayModel.objects.create(field=[[1, 2], [3, 4]])
 467        self.assertSequenceEqual(
 468            NestedIntegerArrayModel.objects.filter(field__0__0_1=[1]), [instance]
 469        )
 470
 471    def test_slice_transform_expression(self):
 472        expr = RawSQL("string_to_array(%s, ';')", ["9;2;3"])
 473        self.assertSequenceEqual(
 474            NullableIntegerArrayModel.objects.filter(
 475                field__0_2=SliceTransform(2, 3, expr)
 476            ),
 477            self.objs[2:3],
 478        )
 479
 480    def test_usage_in_subquery(self):
 481        self.assertSequenceEqual(
 482            NullableIntegerArrayModel.objects.filter(
 483                id__in=NullableIntegerArrayModel.objects.filter(field__len=3)
 484            ),
 485            [self.objs[3]],
 486        )
 487
 488    def test_enum_lookup(self):
 489        class TestEnum(enum.Enum):
 490            VALUE_1 = "value_1"
 491
 492        instance = ArrayEnumModel.objects.create(array_of_enums=[TestEnum.VALUE_1])
 493        self.assertSequenceEqual(
 494            ArrayEnumModel.objects.filter(array_of_enums__contains=[TestEnum.VALUE_1]),
 495            [instance],
 496        )
 497
 498    def test_unsupported_lookup(self):
 499        msg = "Unsupported lookup '0_bar' for ArrayField or join on the field not permitted."
 500        with self.assertRaisesMessage(FieldError, msg):
 501            list(NullableIntegerArrayModel.objects.filter(field__0_bar=[2]))
 502
 503        msg = "Unsupported lookup '0bar' for ArrayField or join on the field not permitted."
 504        with self.assertRaisesMessage(FieldError, msg):
 505            list(NullableIntegerArrayModel.objects.filter(field__0bar=[2]))
 506
 507    def test_grouping_by_annotations_with_array_field_param(self):
 508        value = models.Value([1], output_field=ArrayField(models.IntegerField()))
 509        self.assertEqual(
 510            NullableIntegerArrayModel.objects.annotate(
 511                array_length=models.Func(value, 1, function="ARRAY_LENGTH"),
 512            )
 513            .values("array_length")
 514            .annotate(
 515                count=models.Count("pk"),
 516            )
 517            .get()["array_length"],
 518            1,
 519        )
 520
 521
 522class TestDateTimeExactQuerying(PostgreSQLTestCase):
 523    @classmethod
 524    def setUpTestData(cls):
 525        now = timezone.now()
 526        cls.datetimes = [now]
 527        cls.dates = [now.date()]
 528        cls.times = [now.time()]
 529        cls.objs = [
 530            DateTimeArrayModel.objects.create(
 531                datetimes=cls.datetimes, dates=cls.dates, times=cls.times
 532            ),
 533        ]
 534
 535    def test_exact_datetimes(self):
 536        self.assertSequenceEqual(
 537            DateTimeArrayModel.objects.filter(datetimes=self.datetimes), self.objs
 538        )
 539
 540    def test_exact_dates(self):
 541        self.assertSequenceEqual(
 542            DateTimeArrayModel.objects.filter(dates=self.dates), self.objs
 543        )
 544
 545    def test_exact_times(self):
 546        self.assertSequenceEqual(
 547            DateTimeArrayModel.objects.filter(times=self.times), self.objs
 548        )
 549
 550
 551class TestOtherTypesExactQuerying(PostgreSQLTestCase):
 552    @classmethod
 553    def setUpTestData(cls):
 554        cls.ips = ["192.168.0.1", "::1"]
 555        cls.uuids = [uuid.uuid4()]
 556        cls.decimals = [decimal.Decimal(1.25), 1.75]
 557        cls.tags = [Tag(1), Tag(2), Tag(3)]
 558        cls.objs = [
 559            OtherTypesArrayModel.objects.create(
 560                ips=cls.ips,
 561                uuids=cls.uuids,
 562                decimals=cls.decimals,
 563                tags=cls.tags,
 564            )
 565        ]
 566
 567    def test_exact_ip_addresses(self):
 568        self.assertSequenceEqual(
 569            OtherTypesArrayModel.objects.filter(ips=self.ips), self.objs
 570        )
 571
 572    def test_exact_uuids(self):
 573        self.assertSequenceEqual(
 574            OtherTypesArrayModel.objects.filter(uuids=self.uuids), self.objs
 575        )
 576
 577    def test_exact_decimals(self):
 578        self.assertSequenceEqual(
 579            OtherTypesArrayModel.objects.filter(decimals=self.decimals), self.objs
 580        )
 581
 582    def test_exact_tags(self):
 583        self.assertSequenceEqual(
 584            OtherTypesArrayModel.objects.filter(tags=self.tags), self.objs
 585        )
 586
 587
 588@isolate_apps("postgres_tests")
 589class TestChecks(PostgreSQLSimpleTestCase):
 590    def test_field_checks(self):
 591        class MyModel(PostgreSQLModel):
 592            field = ArrayField(models.CharField())
 593
 594        model = MyModel()
 595        errors = model.check()
 596        self.assertEqual(len(errors), 1)
 597        # The inner CharField is missing a max_length.
 598        self.assertEqual(errors[0].id, "postgres.E001")
 599        self.assertIn("max_length", errors[0].msg)
 600
 601    def test_invalid_base_fields(self):
 602        class MyModel(PostgreSQLModel):
 603            field = ArrayField(
 604                models.ManyToManyField("postgres_tests.IntegerArrayModel")
 605            )
 606
 607        model = MyModel()
 608        errors = model.check()
 609        self.assertEqual(len(errors), 1)
 610        self.assertEqual(errors[0].id, "postgres.E002")
 611
 612    def test_invalid_default(self):
 613        class MyModel(PostgreSQLModel):
 614            field = ArrayField(models.IntegerField(), default=[])
 615
 616        model = MyModel()
 617        self.assertEqual(
 618            model.check(),
 619            [
 620                checks.Warning(
 621                    msg=(
 622                        "ArrayField default should be a callable instead of an "
 623                        "instance so that it's not shared between all field "
 624                        "instances."
 625                    ),
 626                    hint="Use a callable instead, e.g., use `list` instead of `[]`.",
 627                    obj=MyModel._meta.get_field("field"),
 628                    id="fields.E010",
 629                )
 630            ],
 631        )
 632
 633    def test_valid_default(self):
 634        class MyModel(PostgreSQLModel):
 635            field = ArrayField(models.IntegerField(), default=list)
 636
 637        model = MyModel()
 638        self.assertEqual(model.check(), [])
 639
 640    def test_valid_default_none(self):
 641        class MyModel(PostgreSQLModel):
 642            field = ArrayField(models.IntegerField(), default=None)
 643
 644        model = MyModel()
 645        self.assertEqual(model.check(), [])
 646
 647    def test_nested_field_checks(self):
 648        """
 649        Nested ArrayFields are permitted.
 650        """
 651
 652        class MyModel(PostgreSQLModel):
 653            field = ArrayField(ArrayField(models.CharField()))
 654
 655        model = MyModel()
 656        errors = model.check()
 657        self.assertEqual(len(errors), 1)
 658        # The inner CharField is missing a max_length.
 659        self.assertEqual(errors[0].id, "postgres.E001")
 660        self.assertIn("max_length", errors[0].msg)
 661
 662    def test_choices_tuple_list(self):
 663        class MyModel(PostgreSQLModel):
 664            field = ArrayField(
 665                models.CharField(max_length=16),
 666                choices=[
 667                    [
 668                        "Media",
 669                        [(["vinyl", "cd"], "Audio"), (("vhs", "dvd"), "Video")],
 670                    ],
 671                    (["mp3", "mp4"], "Digital"),
 672                ],
 673            )
 674
 675        self.assertEqual(MyModel._meta.get_field("field").check(), [])
 676
 677
 678@unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific tests")
 679class TestMigrations(TransactionTestCase):
 680
 681    available_apps = ["postgres_tests"]
 682
 683    def test_deconstruct(self):
 684        field = ArrayField(models.IntegerField())
 685        name, path, args, kwargs = field.deconstruct()
 686        new = ArrayField(*args, **kwargs)
 687        self.assertEqual(type(new.base_field), type(field.base_field))
 688        self.assertIsNot(new.base_field, field.base_field)
 689
 690    def test_deconstruct_with_size(self):
 691        field = ArrayField(models.IntegerField(), size=3)
 692        name, path, args, kwargs = field.deconstruct()
 693        new = ArrayField(*args, **kwargs)
 694        self.assertEqual(new.size, field.size)
 695
 696    def test_deconstruct_args(self):
 697        field = ArrayField(models.CharField(max_length=20))
 698        name, path, args, kwargs = field.deconstruct()
 699        new = ArrayField(*args, **kwargs)
 700        self.assertEqual(new.base_field.max_length, field.base_field.max_length)
 701
 702    def test_subclass_deconstruct(self):
 703        field = ArrayField(models.IntegerField())
 704        name, path, args, kwargs = field.deconstruct()
 705        self.assertEqual(path, "django.contrib.postgres.fields.ArrayField")
 706
 707        field = ArrayFieldSubclass()
 708        name, path, args, kwargs = field.deconstruct()
 709        self.assertEqual(path, "postgres_tests.models.ArrayFieldSubclass")
 710
 711    @override_settings(
 712        MIGRATION_MODULES={
 713            "postgres_tests": "postgres_tests.array_default_migrations",
 714        }
 715    )
 716    def test_adding_field_with_default(self):
 717        # See #22962
 718        table_name = "postgres_tests_integerarraydefaultmodel"
 719        with connection.cursor() as cursor:
 720            self.assertNotIn(table_name, connection.introspection.table_names(cursor))
 721        call_command("migrate", "postgres_tests", verbosity=0)
 722        with connection.cursor() as cursor:
 723            self.assertIn(table_name, connection.introspection.table_names(cursor))
 724        call_command("migrate", "postgres_tests", "zero", verbosity=0)
 725        with connection.cursor() as cursor:
 726            self.assertNotIn(table_name, connection.introspection.table_names(cursor))
 727
 728    @override_settings(
 729        MIGRATION_MODULES={
 730            "postgres_tests": "postgres_tests.array_index_migrations",
 731        }
 732    )
 733    def test_adding_arrayfield_with_index(self):
 734        """
 735        ArrayField shouldn't have varchar_patterns_ops or text_patterns_ops indexes.
 736        """
 737        table_name = "postgres_tests_chartextarrayindexmodel"
 738        call_command("migrate", "postgres_tests", verbosity=0)
 739        with connection.cursor() as cursor:
 740            like_constraint_columns_list = [
 741                v["columns"]
 742                for k, v in list(
 743                    connection.introspection.get_constraints(cursor, table_name).items()
 744                )
 745                if k.endswith("_like")
 746            ]
 747        # Only the CharField should have a LIKE index.
 748        self.assertEqual(like_constraint_columns_list, [["char2"]])
 749        # All fields should have regular indexes.
 750        with connection.cursor() as cursor:
 751            indexes = [
 752                c["columns"][0]
 753                for c in connection.introspection.get_constraints(
 754                    cursor, table_name
 755                ).values()
 756                if c["index"] and len(c["columns"]) == 1
 757            ]
 758        self.assertIn("char", indexes)
 759        self.assertIn("char2", indexes)
 760        self.assertIn("text", indexes)
 761        call_command("migrate", "postgres_tests", "zero", verbosity=0)
 762        with connection.cursor() as cursor:
 763            self.assertNotIn(table_name, connection.introspection.table_names(cursor))
 764
 765
 766class TestSerialization(PostgreSQLSimpleTestCase):
 767    test_data = '[{"fields": {"field": "[\\"1\\", \\"2\\", null]"}, "model": "postgres_tests.integerarraymodel", "pk": null}]'
 768
 769    def test_dumping(self):
 770        instance = IntegerArrayModel(field=[1, 2, None])
 771        data = serializers.serialize("json", [instance])
 772        self.assertEqual(json.loads(data), json.loads(self.test_data))
 773
 774    def test_loading(self):
 775        instance = list(serializers.deserialize("json", self.test_data))[0].object
 776        self.assertEqual(instance.field, [1, 2, None])
 777
 778
 779class TestValidation(PostgreSQLSimpleTestCase):
 780    def test_unbounded(self):
 781        field = ArrayField(models.IntegerField())
 782        with self.assertRaises(exceptions.ValidationError) as cm:
 783            field.clean([1, None], None)
 784        self.assertEqual(cm.exception.code, "item_invalid")
 785        self.assertEqual(
 786            cm.exception.message % cm.exception.params,
 787            "Item 2 in the array did not validate: This field cannot be null.",
 788        )
 789
 790    def test_blank_true(self):
 791        field = ArrayField(models.IntegerField(blank=True, null=True))
 792        # This should not raise a validation error
 793        field.clean([1, None], None)
 794
 795    def test_with_size(self):
 796        field = ArrayField(models.IntegerField(), size=3)
 797        field.clean([1, 2, 3], None)
 798        with self.assertRaises(exceptions.ValidationError) as cm:
 799            field.clean([1, 2, 3, 4], None)
 800        self.assertEqual(
 801            cm.exception.messages[0],
 802            "List contains 4 items, it should contain no more than 3.",
 803        )
 804
 805    def test_nested_array_mismatch(self):
 806        field = ArrayField(ArrayField(models.IntegerField()))
 807        field.clean([[1, 2], [3, 4]], None)
 808        with self.assertRaises(exceptions.ValidationError) as cm:
 809            field.clean([[1, 2], [3, 4, 5]], None)
 810        self.assertEqual(cm.exception.code, "nested_array_mismatch")
 811        self.assertEqual(
 812            cm.exception.messages[0], "Nested arrays must have the same length."
 813        )
 814
 815    def test_with_base_field_error_params(self):
 816        field = ArrayField(models.CharField(max_length=2))
 817        with self.assertRaises(exceptions.ValidationError) as cm:
 818            field.clean(["abc"], None)
 819        self.assertEqual(len(cm.exception.error_list), 1)
 820        exception = cm.exception.error_list[0]
 821        self.assertEqual(
 822            exception.message,
 823            "Item 1 in the array did not validate: Ensure this value has at most 2 characters (it has 3).",
 824        )
 825        self.assertEqual(exception.code, "item_invalid")
 826        self.assertEqual(
 827            exception.params,
 828            {"nth": 1, "value": "abc", "limit_value": 2, "show_value": 3},
 829        )
 830
 831    def test_with_validators(self):
 832        field = ArrayField(
 833            models.IntegerField(validators=[validators.MinValueValidator(1)])
 834        )
 835        field.clean([1, 2], None)
 836        with self.assertRaises(exceptions.ValidationError) as cm:
 837            field.clean([0], None)
 838        self.assertEqual(len(cm.exception.error_list), 1)
 839        exception = cm.exception.error_list[0]
 840        self.assertEqual(
 841            exception.message,
 842            "Item 1 in the array did not validate: Ensure this value is greater than or equal to 1.",
 843        )
 844        self.assertEqual(exception.code, "item_invalid")
 845        self.assertEqual(
 846            exception.params, {"nth": 1, "value": 0, "limit_value": 1, "show_value": 0}
 847        )
 848
 849
 850class TestSimpleFormField(PostgreSQLSimpleTestCase):
 851    def test_valid(self):
 852        field = SimpleArrayField(forms.CharField())
 853        value = field.clean("a,b,c")
 854        self.assertEqual(value, ["a", "b", "c"])
 855
 856    def test_to_python_fail(self):
 857        field = SimpleArrayField(forms.IntegerField())
 858        with self.assertRaises(exceptions.ValidationError) as cm:
 859            field.clean("a,b,9")
 860        self.assertEqual(
 861            cm.exception.messages[0],
 862            "Item 1 in the array did not validate: Enter a whole number.",
 863        )
 864
 865    def test_validate_fail(self):
 866        field = SimpleArrayField(forms.CharField(required=True))
 867        with self.assertRaises(exceptions.ValidationError) as cm:
 868            field.clean("a,b,")
 869        self.assertEqual(
 870            cm.exception.messages[0],
 871            "Item 3 in the array did not validate: This field is required.",
 872        )
 873
 874    def test_validate_fail_base_field_error_params(self):
 875        field = SimpleArrayField(forms.CharField(max_length=2))
 876        with self.assertRaises(exceptions.ValidationError) as cm:
 877            field.clean("abc,c,defg")
 878        errors = cm.exception.error_list
 879        self.assertEqual(len(errors), 2)
 880        first_error = errors[0]
 881        self.assertEqual(
 882            first_error.message,
 883            "Item 1 in the array did not validate: Ensure this value has at most 2 characters (it has 3).",
 884        )
 885        self.assertEqual(first_error.code, "item_invalid")
 886        self.assertEqual(
 887            first_error.params,
 888            {"nth": 1, "value": "abc", "limit_value": 2, "show_value": 3},
 889        )
 890        second_error = errors[1]
 891        self.assertEqual(
 892            second_error.message,
 893            "Item 3 in the array did not validate: Ensure this value has at most 2 characters (it has 4).",
 894        )
 895        self.assertEqual(second_error.code, "item_invalid")
 896        self.assertEqual(
 897            second_error.params,
 898            {"nth": 3, "value": "defg", "limit_value": 2, "show_value": 4},
 899        )
 900
 901    def test_validators_fail(self):
 902        field = SimpleArrayField(forms.RegexField("[a-e]{2}"))
 903        with self.assertRaises(exceptions.ValidationError) as cm:
 904            field.clean("a,bc,de")
 905        self.assertEqual(
 906            cm.exception.messages[0],
 907            "Item 1 in the array did not validate: Enter a valid value.",
 908        )
 909
 910    def test_delimiter(self):
 911        field = SimpleArrayField(forms.CharField(), delimiter="|")
 912        value = field.clean("a|b|c")
 913        self.assertEqual(value, ["a", "b", "c"])
 914
 915    def test_delimiter_with_nesting(self):
 916        field = SimpleArrayField(SimpleArrayField(forms.CharField()), delimiter="|")
 917        value = field.clean("a,b|c,d")
 918        self.assertEqual(value, [["a", "b"], ["c", "d"]])
 919
 920    def test_prepare_value(self):
 921        field = SimpleArrayField(forms.CharField())
 922        value = field.prepare_value(["a", "b", "c"])
 923        self.assertEqual(value, "a,b,c")
 924
 925    def test_max_length(self):
 926        field = SimpleArrayField(forms.CharField(), max_length=2)
 927        with self.assertRaises(exceptions.ValidationError) as cm:
 928            field.clean("a,b,c")
 929        self.assertEqual(
 930            cm.exception.messages[0],
 931            "List contains 3 items, it should contain no more than 2.",
 932        )
 933
 934    def test_min_length(self):
 935        field = SimpleArrayField(forms.CharField(), min_length=4)
 936        with self.assertRaises(exceptions.ValidationError) as cm:
 937            field.clean("a,b,c")
 938        self.assertEqual(
 939            cm.exception.messages[0],
 940            "List contains 3 items, it should contain no fewer than 4.",
 941        )
 942
 943    def test_required(self):
 944        field = SimpleArrayField(forms.CharField(), required=True)
 945        with self.assertRaises(exceptions.ValidationError) as cm:
 946            field.clean("")
 947        self.assertEqual(cm.exception.messages[0], "This field is required.")
 948
 949    def test_model_field_formfield(self):
 950        model_field = ArrayField(models.CharField(max_length=27))
 951        form_field = model_field.formfield()
 952        self.assertIsInstance(form_field, SimpleArrayField)
 953        self.assertIsInstance(form_field.base_field, forms.CharField)
 954        self.assertEqual(form_field.base_field.max_length, 27)
 955
 956    def test_model_field_formfield_size(self):
 957        model_field = ArrayField(models.CharField(max_length=27), size=4)
 958        form_field = model_field.formfield()
 959        self.assertIsInstance(form_field, SimpleArrayField)
 960        self.assertEqual(form_field.max_length, 4)
 961
 962    def test_model_field_choices(self):
 963        model_field = ArrayField(models.IntegerField(choices=((1, "A"), (2, "B"))))
 964        form_field = model_field.formfield()
 965        self.assertEqual(form_field.clean("1,2"), [1, 2])
 966
 967    def test_already_converted_value(self):
 968        field = SimpleArrayField(forms.CharField())
 969        vals = ["a", "b", "c"]
 970        self.assertEqual(field.clean(vals), vals)
 971
 972    def test_has_changed(self):
 973        field = SimpleArrayField(forms.IntegerField())
 974        self.assertIs(field.has_changed([1, 2], [1, 2]), False)
 975        self.assertIs(field.has_changed([1, 2], "1,2"), False)
 976        self.assertIs(field.has_changed([1, 2], "1,2,3"), True)
 977        self.assertIs(field.has_changed([1, 2], "a,b"), True)
 978
 979    def test_has_changed_empty(self):
 980        field = SimpleArrayField(forms.CharField())
 981        self.assertIs(field.has_changed(None, None), False)
 982        self.assertIs(field.has_changed(None, ""), False)
 983        self.assertIs(field.has_changed(None, []), False)
 984        self.assertIs(field.has_changed([], None), False)
 985        self.assertIs(field.has_changed([], ""), False)
 986
 987
 988class TestSplitFormField(PostgreSQLSimpleTestCase):
 989    def test_valid(self):
 990        class SplitForm(forms.Form):
 991            array = SplitArrayField(forms.CharField(), size=3)
 992
 993        data = {"array_0": "a", "array_1": "b", "array_2": "c"}
 994        form = SplitForm(data)
 995        self.assertTrue(form.is_valid())
 996        self.assertEqual(form.cleaned_data, {"array": ["a", "b", "c"]})
 997
 998    def test_required(self):
 999        class SplitForm(forms.Form):
1000            array = SplitArrayField(forms.CharField(), required=True, size=3)
1001
1002        data = {"array_0": "", "array_1": "", "array_2": ""}
1003        form = SplitForm(data)
1004        self.assertFalse(form.is_valid())
1005        self.assertEqual(form.errors, {"array": ["This field is required."]})
1006
1007    def test_remove_trailing_nulls(self):
1008        class SplitForm(forms.Form):
1009            array = SplitArrayField(
1010                forms.CharField(required=False), size=5, remove_trailing_nulls=True
1011            )
1012
1013        data = {
1014            "array_0": "a",
1015            "array_1": "",
1016            "array_2": "b",
1017            "array_3": "",
1018            "array_4": "",
1019        }
1020        form = SplitForm(data)
1021        self.assertTrue(form.is_valid(), form.errors)
1022        self.assertEqual(form.cleaned_data, {"array": ["a", "", "b"]})
1023
1024    def test_remove_trailing_nulls_not_required(self):
1025        class SplitForm(forms.Form):
1026            array = SplitArrayField(
1027                forms.CharField(required=False),
1028                size=2,
1029                remove_trailing_nulls=True,
1030                required=False,
1031            )
1032
1033        data = {"array_0": "", "array_1": ""}
1034        form = SplitForm(data)
1035        self.assertTrue(form.is_valid())
1036        self.assertEqual(form.cleaned_data, {"array": []})
1037
1038    def test_required_field(self):
1039        class SplitForm(forms.Form):
1040            array = SplitArrayField(forms.CharField(), size=3)
1041
1042        data = {"array_0": "a", "array_1": "b", "array_2": ""}
1043        form = SplitForm(data)
1044        self.assertFalse(form.is_valid())
1045        self.assertEqual(
1046            form.errors,
1047            {
1048                "array": [
1049                    "Item 3 in the array did not validate: This field is required."
1050                ]
1051            },
1052        )
1053
1054    def test_invalid_integer(self):
1055        msg = "Item 2 in the array did not validate: Ensure this value is less than or equal to 100."
1056        with self.assertRaisesMessage(exceptions.ValidationError, msg):
1057            SplitArrayField(forms.IntegerField(max_value=100), size=2).clean([0, 101])
1058
1059    # To locate the widget's template.
1060    @modify_settings(INSTALLED_APPS={"append": "django.contrib.postgres"})
1061    def test_rendering(self):
1062        class SplitForm(forms.Form):
1063            array = SplitArrayField(forms.CharField(), size=3)
1064
1065        self.assertHTMLEqual(
1066            str(SplitForm()),
1067            """
1068            <tr>
1069                <th><label for="id_array_0">Array:</label></th>
1070                <td>
1071                    <input id="id_array_0" name="array_0" type="text" required>
1072                    <input id="id_array_1" name="array_1" type="text" required>
1073                    <input id="id_array_2" name="array_2" type="text" required>
1074                </td>
1075            </tr>
1076        """,
1077        )
1078
1079    def test_invalid_char_length(self):
1080        field = SplitArrayField(forms.CharField(max_length=2), size=3)
1081        with self.assertRaises(exceptions.ValidationError) as cm:
1082            field.clean(["abc", "c", "defg"])
1083        self.assertEqual(
1084            cm.exception.messages,
1085            [
1086                "Item 1 in the array did not validate: Ensure this value has at most 2 characters (it has 3).",
1087                "Item 3 in the array did not validate: Ensure this value has at most 2 characters (it has 4).",
1088            ],
1089        )
1090
1091    def test_splitarraywidget_value_omitted_from_data(self):
1092        class Form(forms.ModelForm):
1093            field = SplitArrayField(forms.IntegerField(), required=False, size=2)
1094
1095            class Meta:
1096                model = IntegerArrayModel
1097                fields = ("field",)
1098
1099        form = Form({"field_0": "1", "field_1": "2"})
1100        self.assertEqual(form.errors, {})
1101        obj = form.save(commit=False)
1102        self.assertEqual(obj.field, [1, 2])
1103
1104    def test_splitarrayfield_has_changed(self):
1105        class Form(forms.ModelForm):
1106            field = SplitArrayField(forms.IntegerField(), required=False, size=2)
1107
1108            class Meta:
1109                model = IntegerArrayModel
1110                fields = ("field",)
1111
1112        tests = [
1113            ({}, {"field_0": "", "field_1": ""}, True),
1114            ({"field": None}, {"field_0": "", "field_1": ""}, True),
1115            ({"field": [1]}, {"field_0": "", "field_1": ""}, True),
1116            ({"field": [1]}, {"field_0": "1", "field_1": "0"}, True),
1117            ({"field": [1, 2]}, {"field_0": "1", "field_1": "2"}, False),
1118            ({"field": [1, 2]}, {"field_0": "a", "field_1": "b"}, True),
1119        ]
1120        for initial, data, expected_result in tests:
1121            with self.subTest(initial=initial, data=data):
1122                obj = IntegerArrayModel(**initial)
1123                form = Form(data, instance=obj)
1124                self.assertIs(form.has_changed(), expected_result)
1125
1126    def test_splitarrayfield_remove_trailing_nulls_has_changed(self):
1127        class Form(forms.ModelForm):
1128            field = SplitArrayField(
1129                forms.IntegerField(), required=False, size=2, remove_trailing_nulls=True
1130            )
1131
1132            class Meta:
1133                model = IntegerArrayModel
1134                fields = ("field",)
1135
1136        tests = [
1137            ({}, {"field_0": "", "field_1": ""}, False),
1138            ({"field": None}, {"field_0": "", "field_1": ""}, False),
1139            ({"field": []}, {"field_0": "", "field_1": ""}, False),
1140            ({"field": [1]}, {"field_0": "1", "field_1": ""}, False),
1141        ]
1142        for initial, data, expected_result in tests:
1143            with self.subTest(initial=initial, data=data):
1144                obj = IntegerArrayModel(**initial)
1145                form = Form(data, instance=obj)
1146                self.assertIs(form.has_changed(), expected_result)
1147
1148
1149class TestSplitFormWidget(PostgreSQLWidgetTestCase):
1150    def test_get_context(self):
1151        self.assertEqual(
1152            SplitArrayWidget(forms.TextInput(), size=2).get_context(
1153                "name", ["val1", "val2"]
1154            ),
1155            {
1156                "widget": {
1157                    "name": "name",
1158                    "is_hidden": False,
1159                    "required": False,
1160                    "value": "['val1', 'val2']",
1161                    "attrs": {},
1162                    "template_name": "postgres/widgets/split_array.html",
1163                    "subwidgets": [
1164                        {
1165                            "name": "name_0",
1166                            "is_hidden": False,
1167                            "required": False,
1168                            "value": "val1",
1169                            "attrs": {},
1170                            "template_name": "django/forms/widgets/text.html",
1171                            "type": "text",
1172                        },
1173                        {
1174                            "name": "name_1",
1175                            "is_hidden": False,
1176                            "required": False,
1177                            "value": "val2",
1178                            "attrs": {},
1179                            "template_name": "django/forms/widgets/text.html",
1180                            "type": "text",
1181                        },
1182                    ],
1183                }
1184            },
1185        )
1186
1187    def test_checkbox_get_context_attrs(self):
1188        context = SplitArrayWidget(
1189            forms.CheckboxInput(),
1190            size=2,
1191        ).get_context("name", [True, False])
1192        self.assertEqual(context["widget"]["value"], "[True, False]")
1193        self.assertEqual(
1194            [subwidget["attrs"] for subwidget in context["widget"]["subwidgets"]],
1195            [{"checked": True}, {}],
1196        )
1197
1198    def test_render(self):
1199        self.check_html(
1200            SplitArrayWidget(forms.TextInput(), size=2),
1201            "array",
1202            None,
1203            """
1204            <input name="array_0" type="text">
1205            <input name="array_1" type="text">
1206            """,
1207        )
1208
1209    def test_render_attrs(self):
1210        self.check_html(
1211            SplitArrayWidget(forms.TextInput(), size=2),
1212            "array",
1213            ["val1", "val2"],
1214            attrs={"id": "foo"},
1215            html=(
1216                """
1217                <input id="foo_0" name="array_0" type="text" value="val1">
1218                <input id="foo_1" name="array_1" type="text" value="val2">
1219                """
1220            ),
1221        )
1222
1223    def test_value_omitted_from_data(self):
1224        widget = SplitArrayWidget(forms.TextInput(), size=2)
1225        self.assertIs(widget.value_omitted_from_data({}, {}, "field"), True)
1226        self.assertIs(
1227            widget.value_omitted_from_data({"field_0": "value"}, {}, "field"), False
1228        )
1229        self.assertIs(
1230            widget.value_omitted_from_data({"field_1": "value"}, {}, "field"), False
1231        )
1232        self.assertIs(
1233            widget.value_omitted_from_data(
1234                {"field_0": "value", "field_1": "value"}, {}, "field"
1235            ),
1236            False,
1237        )