Skip to content

Commit

Permalink
Merge pull request #60 from altitudenetworks/task/data_table_filter_r…
Browse files Browse the repository at this point in the history
…ecords_not_equals

Add support for not_equals operand in DataTable.filter_records
  • Loading branch information
vemel authored Sep 9, 2020
2 parents 1d57603 + 8915a9d commit cf3b192
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 15 deletions.
42 changes: 29 additions & 13 deletions dynamo_query/data_table.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import defaultdict
from copy import copy, deepcopy
from enum import Enum, auto
from typing import (
Any,
DefaultDict,
Expand All @@ -24,10 +25,12 @@
_RecordType = TypeVar("_RecordType", bound=RecordType)
_R = TypeVar("_R", bound="DataTable")

__all__ = (
"DataTable",
"DataTableError",
)
__all__ = ("DataTable", "DataTableError", "Filter")


class Filter(Enum):
EQUALS = auto()
NOT_EQUALS = auto()


class DataTableError(BaseException):
Expand Down Expand Up @@ -394,7 +397,7 @@ def get_record(self, record_index: int) -> _RecordType:

return self._convert_record(result)

def filter_records(self: _R, query: Dict[str, Any]) -> _R:
def filter_records(self: _R, query: Dict[str, Any], operand: Filter = Filter.EQUALS) -> _R:
"""
Create a new `DataTable` instance with records that match `query`
Expand All @@ -416,17 +419,30 @@ def filter_records(self: _R, query: Dict[str, Any]) -> _R:
raise DataTableError("Cannot filter not normalized table. Use `normalize` method.")

result = self.__class__({key: [] for key in self.keys()}, record_class=self.record_class)
for record in self.get_records():
record_match = True
for lookup_key, lookup_value in query.items():
if record.get(lookup_key) != lookup_value:
record_match = False
break

if record_match:
def _equals() -> _R:
for record in self.get_records():
if any(
record.get(lookup_key) != lookup_value
for lookup_key, lookup_value in query.items()
):
continue
result.extend({key: [value] for key, value in record.items()})
return result

def _not_equals() -> _R:
for record in self.get_records():
if all(
record.get(lookup_key) == lookup_value
for lookup_key, lookup_value in query.items()
):
continue
result.extend({key: [value] for key, value in record.items()})
return result

return result
job_map = {Filter.EQUALS: _equals, Filter.NOT_EQUALS: _not_equals}

return job_map[operand]()

def _convert_record(self, record: Union[_RecordType, Dict]) -> _RecordType:
# pylint: disable=isinstance-second-argument-not-valid-type
Expand Down
26 changes: 24 additions & 2 deletions tests/test_data_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from typing_extensions import TypedDict

from dynamo_query.data_table import DataTable, DataTableError
from dynamo_query.data_table import DataTable, DataTableError, Filter
from dynamo_query.dictclasses.dictclass import DictClass
from dynamo_query.dictclasses.dynamo_dictclass import DynamoDictClass

Expand Down Expand Up @@ -172,7 +172,7 @@ def test_get_records() -> None:
next(records)

@staticmethod
def test_filter_records() -> None:
def test_filter_records_equals() -> None:
data_table = DataTable({"a": [1, 2, 1], "b": [3, 4, 5]})
assert data_table.filter_records({"a": 1}) == {"a": [1, 1], "b": [3, 5]}
assert data_table.filter_records({"a": 2, "b": 4}) == {"a": [2], "b": [4]}
Expand All @@ -182,6 +182,28 @@ def test_filter_records() -> None:
with pytest.raises(DataTableError):
DataTable({"a": [1, 2, 1], "b": [3, 4]}).filter_records({"a": 1})

@staticmethod
def test_filter_records_not_equals() -> None:
data_table = DataTable({"a": [1, 2, 1], "b": [3, 4, 5]})
assert data_table.filter_records({"a": 1}, operand=Filter.NOT_EQUALS) == {
"a": [2],
"b": [4],
}
assert data_table.filter_records({"a": 2, "b": 4}, operand=Filter.NOT_EQUALS) == {
"a": [1, 1],
"b": [3, 5],
}

assert data_table.filter_records({"a": 1, "b": 4}, operand=Filter.NOT_EQUALS) == {
"a": [1, 2, 1],
"b": [3, 4, 5],
}

with pytest.raises(DataTableError):
DataTable({"a": [1, 2, 1], "b": [3, 4]}).filter_records(
{"a": 1}, operand=Filter.NOT_EQUALS
)

@staticmethod
def test_add_record() -> None:
data_table = DataTable({"a": [1], "b": [3]})
Expand Down

0 comments on commit cf3b192

Please sign in to comment.