Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Upgrade mypy #2243

Merged
merged 6 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Contributors to the Amundsen project.
# SPDX-License-Identifier: Apache-2.0

import abc
import json
import re
from typing import (
Expand Down Expand Up @@ -99,7 +100,7 @@ def execute_query(self) -> Dict[str, Any]:
'Content-Type': 'application/json',
'X-Tableau-Auth': self._auth_token
}
params = {
params: Dict[str, Any] = {
'headers': headers
}
if self._verify_request is not None:
Expand All @@ -108,6 +109,7 @@ def execute_query(self) -> Dict[str, Any]:
response = requests.post(url=self._metadata_url, data=query_payload, **params)
return response.json()['data']

@abc.abstractmethod
def execute(self) -> Iterator[Dict[str, Any]]:
"""
Must be overriden by any extractor using this class. This should parse the result and yield each entity's
Expand Down Expand Up @@ -187,7 +189,7 @@ def _authenticate(self) -> str:
'Content-Type': 'application/json'
}
# verify = False is needed bypass occasional (valid) self-signed cert errors. TODO: actually fix it!!
params = {
params: Dict[str, Any] = {
'headers': headers
}
if self._verify_request is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def init(self, conf: ConfigTree) -> None:
def extract(self) -> Union[TableMetadata, None]:
if not self._extract_iter:
self._extract_iter = self._get_extract_iter()
if self._extract_iter is None:
return None
try:
return next(self._extract_iter)
except StopIteration:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self,
dashboard_group: str,
dashboard_name: str,
description: Union[str, None],
tags: List = None,
tags: Optional[List] = None,
cluster: str = 'gold',
product: Optional[str] = '',
dashboard_group_id: Optional[str] = None,
Expand Down
4 changes: 2 additions & 2 deletions databuilder/databuilder/models/feature/feature_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def __init__(self,
status: Optional[str] = None,
entity: Optional[str] = None,
data_type: Optional[str] = None,
availability: List[str] = None, # list of databases
availability: Optional[List[str]] = None, # list of databases
description: Optional[str] = None,
tags: List[str] = None,
tags: Optional[List[str]] = None,
created_timestamp: Optional[int] = None,
last_updated_timestamp: Optional[int] = None,
**kwargs: Any
Expand Down
6 changes: 3 additions & 3 deletions databuilder/databuilder/models/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import re
from typing import (
Any, Iterator, Union,
Any, Iterator, Optional, Union,
)

from amundsen_common.utils.atlas import AtlasCommonParams, AtlasTableTypes
Expand Down Expand Up @@ -32,8 +32,8 @@ class SchemaModel(GraphSerializable, TableSerializable, AtlasSerializable):
def __init__(self,
schema_key: str,
schema: str,
description: str = None,
description_source: str = None,
description: Optional[str] = None,
description_source: Optional[str] = None,
**kwargs: Any
) -> None:
self._schema_key = schema_key
Expand Down
6 changes: 3 additions & 3 deletions databuilder/databuilder/models/table_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from abc import abstractmethod
from typing import (
Iterator, List, Union,
Iterator, List, Optional, Union,
)

from amundsen_common.utils.atlas import AtlasCommonParams, AtlasTableTypes
Expand Down Expand Up @@ -142,7 +142,7 @@ class TableLineage(BaseLineage):

def __init__(self,
table_key: str,
downstream_deps: List = None, # List of table keys
downstream_deps: Optional[List] = None, # List of table keys
) -> None:
self.table_key = table_key
# a list of downstream dependencies, each of which will follow
Expand Down Expand Up @@ -196,7 +196,7 @@ class ColumnLineage(BaseLineage):

def __init__(self,
column_key: str,
downstream_deps: List = None, # List of column keys
downstream_deps: Optional[List] = None, # List of column keys
) -> None:
self.column_key = column_key
# a list of downstream dependencies, each of which will follow
Expand Down
7 changes: 2 additions & 5 deletions databuilder/databuilder/models/table_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,9 @@ def __init__(self,
schema: str,
name: str,
description: Union[str, None],
columns: Iterable[ColumnMetadata] = None,
columns: Optional[Iterable[ColumnMetadata]] = None,
is_view: bool = False,
tags: Union[List, str] = None,
tags: Union[List, str, None] = None,
description_source: Union[str, None] = None,
**kwargs: Any
) -> None:
Expand Down Expand Up @@ -802,9 +802,6 @@ def _create_atlas_column_entity(self, column_metadata: ColumnMetadata) -> AtlasE

return entity

def _create_next_atlas_relation(self) -> Iterator[AtlasRelationship]:
pass

def _create_atlas_relation_iterator(self) -> Iterator[AtlasRelationship]:
for tag in self.tags:
tag_relation = TagMetadata(tag).create_atlas_tag_relation(self._get_table_key())
Expand Down
17 changes: 5 additions & 12 deletions databuilder/databuilder/models/table_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def __init__(self,
stat_val: str,
is_metric: bool,
db: str = 'hive',
schema: str = None,
schema: Optional[str] = None,
cluster: str = 'gold',
start_epoch: str = None,
end_epoch: str = None
start_epoch: Optional[str] = None,
end_epoch: Optional[str] = None
) -> None:
if schema is None:
self.schema, self.table = table_name.split('.')
Expand All @@ -53,7 +53,6 @@ def __init__(self,
self.is_metric = is_metric
self._node_iter = self._create_node_iterator()
self._relation_iter = self._create_relation_iterator()
self._record_iter = self._create_record_iterator()

def create_next_node(self) -> Optional[GraphNode]:
# return the string representation of the data
Expand All @@ -69,10 +68,7 @@ def create_next_relation(self) -> Optional[GraphRelationship]:
return None

def create_next_record(self) -> Union[RDSModel, None]:
try:
return next(self._record_iter)
except StopIteration:
return None
return None

def get_table_stat_model_key(self) -> str:
return TableStats.KEY_FORMAT.format(db=self.db,
Expand Down Expand Up @@ -123,9 +119,6 @@ def _create_relation_iterator(self) -> Iterator[GraphRelationship]:
)
yield relationship

def _create_record_iterator(self) -> Iterator[RDSModel]:
pass


class TableColumnStats(GraphSerializable, TableSerializable):
"""
Expand All @@ -144,7 +137,7 @@ def __init__(self,
end_epoch: str,
db: str = 'hive',
cluster: str = 'gold',
schema: str = None
schema: Optional[str] = None
) -> None:
if schema is None:
self.schema, self.table = table_name.split('.')
Expand Down
2 changes: 1 addition & 1 deletion databuilder/databuilder/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def create_next_record(self) -> Union[RDSModel, None]:

@classmethod
def get_user_model_key(cls,
email: str = None
email: Optional[str] = None
) -> str:
if not email:
return ''
Expand Down
4 changes: 2 additions & 2 deletions databuilder/databuilder/publisher/neo4j_csv_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from os import listdir
from os.path import isfile, join
from typing import (
Dict, List, Set,
Dict, List, Optional, Set,
)

import neo4j
Expand Down Expand Up @@ -468,7 +468,7 @@ def _create_props_body(self,
def _execute_statement(self,
stmt: str,
tx: Transaction,
params: dict = None,
params: Optional[dict] = None,
expect_result: bool = False) -> Transaction:
"""
Executes statement against Neo4j. If execution fails, it rollsback and raise exception.
Expand Down
4 changes: 2 additions & 2 deletions databuilder/databuilder/publisher/neo4j_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def preprocess_cypher_impl(self,
end_key: str,
relation: str,
reverse_relation: str) -> Tuple[str, Dict[str, str]]:
pass
return '', {}

def is_perform_preprocess(self) -> bool:
return False
Expand Down Expand Up @@ -143,7 +143,7 @@ class DeleteRelationPreprocessor(RelationPreprocessor):
""")

def __init__(self,
label_tuples: List[Tuple[str, str]] = None,
label_tuples: Optional[List[Tuple[str, str]]] = None,
where_clause: str = '') -> None:
super(DeleteRelationPreprocessor, self).__init__()
self._label_tuples = set(label_tuples) if label_tuples else set()
Expand Down
6 changes: 3 additions & 3 deletions databuilder/databuilder/rest_api/rest_api_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import copy
import logging
from typing import (
Any, Callable, Dict, Iterator, List, Union,
Any, Callable, Dict, Iterator, List, Optional, Union,
)

import requests
Expand Down Expand Up @@ -62,8 +62,8 @@ def __init__(self,
fail_no_result: bool = False,
skip_no_result: bool = False,
json_path_contains_or: bool = False,
can_skip_failure: Callable = None,
query_merger: QueryMerger = None,
can_skip_failure: Optional[Callable] = None,
query_merger: Optional[QueryMerger] = None,
**kwargs: Any
) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion databuilder/databuilder/transformer/base_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def transform(self, record: Any) -> Any:
return record

def get_scope(self) -> str:
pass
return ''


class ChainedTransformer(Transformer):
Expand Down
4 changes: 2 additions & 2 deletions databuilder/databuilder/utils/publisher_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from os import listdir
from os.path import isfile, join
from typing import (
Iterator, List, Set,
Iterator, List, Optional, Set,
)

import pandas
Expand Down Expand Up @@ -82,7 +82,7 @@ def create_props_param(record_dict: dict, additional_publisher_metadata_fields:

def execute_neo4j_statement(tx: Transaction,
stmt: str,
params: dict = None) -> None:
params: Optional[dict] = None) -> None:
"""
Executes statement against Neo4j. If execution fails, it rollsback and raises exception.
"""
Expand Down
5 changes: 3 additions & 2 deletions databuilder/example/scripts/sample_tableau_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
import sys
import uuid
from typing import List

from amundsen_common.models.index_map import DASHBOARD_ELASTICSEARCH_INDEX_MAPPING
from elasticsearch import Elasticsearch
Expand Down Expand Up @@ -74,12 +75,12 @@
tableau_site_name = ""
tableau_personal_access_token_name = ""
tableau_personal_access_token_secret = ""
tableau_excluded_projects = []
tableau_excluded_projects: List = []
tableau_dashboard_cluster = ""
tableau_dashboard_database = ""
tableau_external_table_cluster = ""
tableau_external_table_schema = ""
tableau_external_table_types = []
tableau_external_table_types: List = []
tableau_verify_request = None

common_tableau_config = {
Expand Down
2 changes: 1 addition & 1 deletion databuilder/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
'Flask==1.0.2',
'gremlinpython==3.4.3',
'requests-aws4auth==1.1.0',
'typing-extensions==4.0.0',
'typing-extensions==4.1.0',
'overrides==2.5',
'boto3==1.17.23'
]
Expand Down
6 changes: 3 additions & 3 deletions databuilder/tests/unit/rest_api/test_query_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_ensure_record_get_updated(self) -> None:
results[1],
)

def test_exception_rasied_with_duplicate_merge_key(self) -> None:
def test_exception_raised_with_duplicate_merge_key(self) -> None:
"""
Two records in query_to_merge results have {'dashboard_id': 'd2'},
exception should be raised
Expand All @@ -70,7 +70,7 @@ def test_exception_rasied_with_duplicate_merge_key(self) -> None:
query = RestApiQuery(query_to_join=self.query_to_join, url=self.url, params={},
json_path=self.json_path, field_names=self.field_names,
query_merger=query_merger)
self.assertRaises(Exception, query.execute())
self.assertRaises(Exception, query.execute()) # type: ignore

def test_exception_raised_with_missing_merge_key(self) -> None:
"""
Expand All @@ -92,7 +92,7 @@ def test_exception_raised_with_missing_merge_key(self) -> None:
query = RestApiQuery(query_to_join=self.query_to_join, url=self.url, params={},
json_path=self.json_path, field_names=self.field_names,
query_merger=query_merger)
self.assertRaises(Exception, query.execute())
self.assertRaises(Exception, query.execute()) # type: ignore


if __name__ == '__main__':
Expand Down
3 changes: 2 additions & 1 deletion frontend/amundsen_application/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from flask import Blueprint, Flask
from flask_restful import Api
from typing import Optional

from amundsen_application.api import init_routes
from amundsen_application.api.announcements.v0 import announcements_blueprint
Expand Down Expand Up @@ -48,7 +49,7 @@
static_dir = os.path.join(PROJECT_ROOT, STATIC_ROOT)


def create_app(config_module_class: str = None, template_folder: str = None) -> Flask:
def create_app(config_module_class: Optional[str] = None, template_folder: Optional[str] = None) -> Flask:
""" Support for importing arguments for a subclass of flask.Flask """
args = ast.literal_eval(FLASK_APP_KWARGS_DICT_STR) if FLASK_APP_KWARGS_DICT_STR else {}

Expand Down
Loading
Loading