Skip to content

Commit

Permalink
Merge pull request #34 from altitudenetworks/hotfix/immutable-records
Browse files Browse the repository at this point in the history
Immutable records
  • Loading branch information
vemel authored May 29, 2020
2 parents 563cd56 + f6d919d commit 23df695
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 78 deletions.
55 changes: 51 additions & 4 deletions docs/dynamo_query/dynamo_record.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
- [dynamo-query](../README.md#dynamoquery) / [Modules](../MODULES.md#dynamo-query-modules) / [Dynamo Query](index.md#dynamo-query) / DynamoRecord
- [DynamoRecord](#dynamorecord)
- [DynamoRecord().\_\_post\_init\_\_](#dynamorecord__post_init__)
- [DynamoRecord().sanitize](#dynamorecordsanitize)
- [DynamoRecord().sanitize_key](#dynamorecordsanitize_key)
- [NullableDynamoRecord](#nullabledynamorecord)

## DynamoRecord

[[find in source code]](https://github.com/altitudenetworks/dynamoquery/blob/master/dynamo_query/dynamo_record.py#L9)
[[find in source code]](https://github.com/altitudenetworks/dynamoquery/blob/master/dynamo_query/dynamo_record.py#L10)

```python
class DynamoRecord(UserDict):
Expand All @@ -33,27 +35,72 @@ class UserRecord(DynamoRecord):
# do any post-init operations here
self.age = self.age or 35

# add extra computed field
def get_key_min_age(self) -> int:
return 18

# sanitize value on set
def sanitize_key_age(self, value: int) -> int:
return max(self.age, 18)

record = UserRecord(name="Jon")
record["age"] = 30
record.age = 30
record.update({"age": 30})

record.asdict() # {"name": "Jon", "company": "Amazon", "age": 30}
dict(record) # {"name": "Jon", "company": "Amazon", "age": 30, "min_age": 18}
```

### DynamoRecord().\_\_post\_init\_\_

[[find in source code]](https://github.com/altitudenetworks/dynamoquery/blob/master/dynamo_query/dynamo_record.py#L58)
[[find in source code]](https://github.com/altitudenetworks/dynamoquery/blob/master/dynamo_query/dynamo_record.py#L72)

```python
def __post_init__() -> None:
```

Override this method for post-init operations

### DynamoRecord().sanitize

[[find in source code]](https://github.com/altitudenetworks/dynamoquery/blob/master/dynamo_query/dynamo_record.py#L302)

```python
def sanitize(**kwargs: Any) -> None:
```

Sanitize all set fields.

#### Arguments

- `kwargs` - Arguments for sanitize_key_{key}

### DynamoRecord().sanitize_key

[[find in source code]](https://github.com/altitudenetworks/dynamoquery/blob/master/dynamo_query/dynamo_record.py#L268)

```python
def sanitize_key(key: str, value: Any, **kwargs: Any) -> Any:
```

Sanitize value before putting it to dict.

- Converts decimals to int/float
- Calls `sanitize_key_{key}` method if it is defined
- Checks if sanitized value has a proper type

#### Arguments

- `key` - Dictionary key
- `value` - Raw value

#### Returns

A sanitized value

## NullableDynamoRecord

[[find in source code]](https://github.com/altitudenetworks/dynamoquery/blob/master/dynamo_query/dynamo_record.py#L244)
[[find in source code]](https://github.com/altitudenetworks/dynamoquery/blob/master/dynamo_query/dynamo_record.py#L315)

```python
class NullableDynamoRecord(UserDict):
Expand Down
179 changes: 125 additions & 54 deletions dynamo_query/dynamo_record.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
from collections import UserDict
from copy import deepcopy
from decimal import Decimal
from typing import Any, Dict, List, Tuple

Expand All @@ -25,26 +26,39 @@ def __post_init__(self):
# do any post-init operations here
self.age = self.age or 35
# add extra computed field
def get_key_min_age(self) -> int:
return 18
# sanitize value on set
def sanitize_key_age(self, value: int) -> int:
return max(self.age, 18)
record = UserRecord(name="Jon")
record["age"] = 30
record.age = 30
record.update({"age": 30})
record.asdict() # {"name": "Jon", "company": "Amazon", "age": 30}
dict(record) # {"name": "Jon", "company": "Amazon", "age": 30, "min_age": 18}
```
"""

# Marker for optional fields with no initial value, oerride to None if needed
NOT_SET: Any = None

# List of methods that should be updated on field change
COMPUTED_FIELDS: List[str] = []

# KeyError is raised if unknown key provided
SKIP_UNKNOWN_KEYS: bool = True

# Prefix for computed key method names
GET_KEY_PREFIX: str = "get_key_"

# Prefix for sanitize key method names
SANITIZE_KEY_PREFIX: str = "sanitize_key_"

def __init__(self, *args: Dict[str, Any], **kwargs: Any) -> None:
super().__init__()
self._computed_field_names = self._get_computed_field_names()
self._sanitized_field_names = self._get_sanitized_field_names()
self._local_members = self._get_local_members()
self._allowed_types = self._get_allowed_types(self._local_members["__annotations__"])
del self._local_members["__annotations__"]
Expand All @@ -64,17 +78,44 @@ def __post_init__(self) -> None:
def _get_allowed_types(annotations: Dict[str, Any]) -> Dict[str, Tuple[Any, ...]]:
result: Dict[str, Tuple[Any, ...]] = {}
for key, annotation in annotations.items():
if inspect.isclass(annotation):
annotation_str = str(annotation)
if not annotation_str.startswith("typing.") and inspect.isclass(annotation):
result[key] = (annotation,)
continue

if not hasattr(annotation, "__args__"):
result[key] = tuple()
continue
child_types: Tuple[Any, ...] = tuple()

if hasattr(annotation, "__args__"):
child_types = tuple([i for i in annotation.__args__ if inspect.isclass(i)])

child_types = tuple([i for i in annotation.__args__ if inspect.isclass(i)])
if child_types:
if annotation_str.startswith("typing.Dict"):
result[key] = (dict,)
if annotation_str.startswith("typing.List"):
result[key] = (list,)
if annotation_str.startswith("typing.Set"):
result[key] = (set,)
if annotation_str.startswith("typing.Union"):
result[key] = child_types
if annotation_str.startswith("typing.Optional"):
result[key] = (*child_types, None)

return result

@classmethod
def _get_computed_field_names(cls) -> List[str]:
result = []
for name, member in inspect.getmembers(cls):
if name.startswith(cls.GET_KEY_PREFIX) and inspect.isfunction(member):
result.append(name.replace(cls.GET_KEY_PREFIX, "", 1))

return result

@classmethod
def _get_sanitized_field_names(cls) -> List[str]:
result = []
for name, member in inspect.getmembers(cls):
if name.startswith(cls.SANITIZE_KEY_PREFIX) and inspect.isfunction(member):
result.append(name.replace(cls.SANITIZE_KEY_PREFIX, "", 1))

return result

Expand Down Expand Up @@ -123,7 +164,7 @@ def _get_field_names(self) -> List[str]:
result.append(key)

for key in self._local_members:
if key.startswith("_"):
if key.startswith("_") or key.upper() == key:
continue

result.append(key)
Expand All @@ -139,7 +180,7 @@ def _init_data(self, *mappings: Dict[str, Any]) -> None:
):
continue

self.data[member_name] = member
self.data[member_name] = deepcopy(member)

for mapping in mappings:
for key, value in mapping.items():
Expand All @@ -149,85 +190,69 @@ def _init_data(self, *mappings: Dict[str, Any]) -> None:

continue

value = self._fix_decimal(key, value)
allowed_types = self._allowed_types.get(key)
if allowed_types and not isinstance(value, allowed_types):
raise ValueError(
f"{self._class_name}.{key} has type {allowed_types}, got {repr(value)}."
)
self.data[key] = value
self.data[key] = self.sanitize_key(key, value)

for key in self._required_field_names:
if key not in self.data:
raise ValueError(f"{self._class_name}.{key} must be set.")

for key in self.COMPUTED_FIELDS:
self.data[key] = getattr(self, key)()
for key in self._computed_field_names:
self.data[key] = getattr(self, f"{self.GET_KEY_PREFIX}{key}")()

for key, value in list(self.data.items()):
if value is self.NOT_SET:
del self.data[key]

self._update_computed()

def _fix_decimal(self, key: str, value: Any) -> Any:
if not isinstance(value, Decimal):
return value

allowed_types = self._allowed_types.get(key, tuple())
if float in allowed_types:
return float(value)

return int(value)

@property
def _class_name(self) -> str:
return self.__class__.__name__

def _set_item(self, key: str, value: Any) -> None:
if value is self.NOT_SET:
if key in self.data:
del self.data[key]
self._update_computed()
return
def _set_item(
self, key: str, value: Any, is_initial: bool, sanitize_kwargs: Dict[str, Any]
) -> None:
sanitized_value = self.sanitize_key(key, value, **sanitize_kwargs)

value = self._fix_decimal(key, value)
allowed_types = self._allowed_types.get(key)
if allowed_types and not isinstance(value, allowed_types):
raise ValueError(
f"{self._class_name}.{key} has type {allowed_types}, got {repr(value)}."
)
if not is_initial:
if sanitized_value is self.NOT_SET:
if key in self.data:
del self.data[key]
self._update_computed()
return

self.data[key] = value
self._update_computed()
self.data[key] = sanitized_value

if not is_initial:
self._update_computed()

def _update_computed(self) -> None:
for field_name in self.COMPUTED_FIELDS:
value = getattr(self, field_name)()
for key in self._computed_field_names:
value = getattr(self, f"{self.GET_KEY_PREFIX}{key}")()
if value is self.NOT_SET:
if field_name in self.data:
del self.data[field_name]
if key in self.data:
del self.data[key]
else:
self.data[field_name] = value
self.data[key] = value

def __setitem__(self, key: str, value: Any) -> None:
if key in self.COMPUTED_FIELDS:
if key in self._computed_field_names:
return

if key not in self._field_names:
raise KeyError(f"Key {self._class_name}.{key} is incorrect")

self._set_item(key, value)
self._set_item(key, value, is_initial=False, sanitize_kwargs={})

def __setattr__(self, name: str, value: Any) -> None:
if name in self.COMPUTED_FIELDS:
if hasattr(self, "_computed_field_names") and name in self._computed_field_names:
raise KeyError(f"Key {self._class_name}.{name} is computed and cannot be set directly")

if not hasattr(self, "_field_names") or name not in self._field_names:
super().__setattr__(name, value)
return

self._set_item(name, value)
self._set_item(name, value, is_initial=False, sanitize_kwargs={})

def __getattribute__(self, name: str) -> Any:
if name.startswith("_"):
Expand All @@ -240,6 +265,52 @@ def __getattribute__(self, name: str) -> Any:
def __str__(self) -> str:
return f"{self._class_name}({self.data})"

def sanitize_key(self, key: str, value: Any, **kwargs: Any) -> Any:
"""
Sanitize value before putting it to dict.
- Converts decimals to int/float
- Calls `sanitize_key_{key}` method if it is defined
- Checks if sanitized value has a proper type
Arguments:
key -- Dictionary key
value -- Raw value
Returns:
A sanitized value
"""
original_value = value
allowed_types = self._allowed_types.get(key)
if isinstance(value, Decimal):
if allowed_types and float in allowed_types:
value = float(value)
else:
value = int(value)

if key in self._sanitized_field_names:
sanitize_method = getattr(self, f"{self.SANITIZE_KEY_PREFIX}{key}")
value = sanitize_method(value, **kwargs)

if allowed_types and not isinstance(value, allowed_types):
raise ValueError(
f"{self._class_name}.{key} has type {allowed_types}, got {repr(value)} (raw {repr(original_value)})."
)

return value

def sanitize(self, **kwargs: Any) -> None:
"""
Sanitize all set fields.
Arguments:
kwargs -- Arguments for sanitize_key_{key}
"""
for key in self._sanitized_field_names:
self._set_item(
key, self.get(key, self.NOT_SET), is_initial=False, sanitize_kwargs=kwargs
)


class NullableDynamoRecord(UserDict):
"""
Expand Down
Loading

0 comments on commit 23df695

Please sign in to comment.