From 49d7ca3cf006722885e83845bd67ef5f2870d165 Mon Sep 17 00:00:00 2001 From: Kara Wang Date: Thu, 28 Oct 2021 17:35:29 -0400 Subject: [PATCH] Support reverse relation filter by in comparison --- django_mock_queries/comparisons.py | 3 +++ tests/test_query.py | 12 ++++++++++++ tests/test_utils.py | 6 ++++++ 3 files changed, 21 insertions(+) diff --git a/django_mock_queries/comparisons.py b/django_mock_queries/comparisons.py index 17bccd3..dbccfc3 100644 --- a/django_mock_queries/comparisons.py +++ b/django_mock_queries/comparisons.py @@ -34,6 +34,9 @@ def lte_comparison(first, second): def in_comparison(first, second): + if isinstance(first, list): + return bool(set(first).intersection(set(second))) + return first in second if first is not None else False diff --git a/tests/test_query.py b/tests/test_query.py index ebcd308..06536c6 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -202,6 +202,18 @@ def test_query_filters_model_objects_by_bad_field(self): r"Choices are 'id', 'make', 'make_id', 'model', 'passengers', 'sedan', 'speed', 'variations'\."): self.mock_set.filter(bad_field='bogus') + def test_query_filters_reverse_relationship_by_in_comparison(self): + with mocked_relations(Manufacturer): + cars = [Car(speed=1)] + + make = Manufacturer() + make.car_set = MockSet(*cars) + + self.mock_set.add(make) + + result = self.mock_set.filter(car__speed__in=[1, 2]) + assert result.count() == 1 + def test_query_exclude(self): item_1 = MockModel(foo=1, bar='a') item_2 = MockModel(foo=1, bar='b') diff --git a/tests/test_utils.py b/tests/test_utils.py index 21739a2..3487a07 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -227,6 +227,12 @@ def test_is_match_in_value_check(self): result = utils.is_match(1, [1, 3], constants.COMPARISON_IN) assert result is True + result = utils.is_match([3], [1, 2], constants.COMPARISON_IN) + assert result is False + + result = utils.is_match([1, 3], [1, 2], constants.COMPARISON_IN) + assert result is True + @patch('django_mock_queries.utils.get_attribute') @patch('django_mock_queries.utils.is_match', MagicMock(return_value=True)) def test_matches_includes_object_in_results_when_match(self, get_attr_mock):