Skip to content

Commit

Permalink
fix(ingest/sagemaker): ensure consistent STS token usage with refresh…
Browse files Browse the repository at this point in the history
… mechanism (#11170)

Co-authored-by: Aseem Bansal <asmbansal2@gmail.com>
  • Loading branch information
sagar-salvi-apptware and anshbansal authored Aug 22, 2024
1 parent dc30c0a commit 50ed448
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 13 deletions.
27 changes: 22 additions & 5 deletions metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import boto3
Expand Down Expand Up @@ -73,6 +74,8 @@ class AwsConnectionConfig(ConfigModel):
- dbt source
"""

_credentials_expiration: Optional[datetime] = None

aws_access_key_id: Optional[str] = Field(
default=None,
description=f"AWS access key ID. {AUTODETECT_CREDENTIALS_DOC_LINK}",
Expand Down Expand Up @@ -115,6 +118,11 @@ class AwsConnectionConfig(ConfigModel):
description="Advanced AWS configuration options. These are passed directly to [botocore.config.Config](https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html).",
)

def allowed_cred_refresh(self) -> bool:
if self._normalized_aws_roles():
return True
return False

def _normalized_aws_roles(self) -> List[AwsAssumeRoleConfig]:
if not self.aws_role:
return []
Expand Down Expand Up @@ -153,11 +161,14 @@ def get_session(self) -> Session:
}

for role in self._normalized_aws_roles():
credentials = assume_role(
role,
self.aws_region,
credentials=credentials,
)
if self._should_refresh_credentials():
credentials = assume_role(
role,
self.aws_region,
credentials=credentials,
)
if isinstance(credentials["Expiration"], datetime):
self._credentials_expiration = credentials["Expiration"]

session = Session(
aws_access_key_id=credentials["AccessKeyId"],
Expand All @@ -168,6 +179,12 @@ def get_session(self) -> Session:

return session

def _should_refresh_credentials(self) -> bool:
if self._credentials_expiration is None:
return True
remaining_time = self._credentials_expiration - datetime.now(timezone.utc)
return remaining_time < timedelta(minutes=5)

def get_credentials(self) -> Dict[str, Optional[str]]:
credentials = self.get_session().get_credentials()
if credentials is not None:
Expand Down
20 changes: 18 additions & 2 deletions metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import DefaultDict, Dict, Iterable, List, Optional
from typing import TYPE_CHECKING, DefaultDict, Dict, Iterable, List, Optional

from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.decorators import (
Expand Down Expand Up @@ -33,6 +33,9 @@
StatefulIngestionSourceBase,
)

if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient


@platform_name("SageMaker")
@config_class(SagemakerSourceConfig)
Expand All @@ -56,6 +59,7 @@ def __init__(self, config: SagemakerSourceConfig, ctx: PipelineContext):
self.report = SagemakerSourceReport()
self.sagemaker_client = config.sagemaker_client
self.env = config.env
self.client_factory = ClientFactory(config)

@classmethod
def create(cls, config_dict, ctx):
Expand Down Expand Up @@ -92,7 +96,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
# extract jobs if specified
if self.source_config.extract_jobs is not False:
job_processor = JobProcessor(
sagemaker_client=self.sagemaker_client,
sagemaker_client=self.client_factory.get_client,
env=self.env,
report=self.report,
job_type_filter=self.source_config.extract_jobs,
Expand All @@ -118,3 +122,15 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:

def get_report(self):
return self.report


class ClientFactory:
def __init__(self, config: SagemakerSourceConfig):
self.config = config
self._cached_client = self.config.sagemaker_client

def get_client(self) -> "SageMakerClient":
if self.config.allowed_cred_refresh():
# Always fetch the client dynamically with auto-refresh logic
return self.config.sagemaker_client
return self._cached_client
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
DefaultDict,
Dict,
Iterable,
Expand Down Expand Up @@ -147,7 +148,7 @@ class JobProcessor:
"""

# boto3 SageMaker client
sagemaker_client: "SageMakerClient"
sagemaker_client: Callable[[], "SageMakerClient"]
env: str
report: SagemakerSourceReport
# config filter for specific job types to ingest (see metadata-ingestion README)
Expand All @@ -170,8 +171,7 @@ class JobProcessor:

def get_jobs(self, job_type: JobType, job_spec: JobInfo) -> List[Any]:
jobs = []

paginator = self.sagemaker_client.get_paginator(job_spec.list_command)
paginator = self.sagemaker_client().get_paginator(job_spec.list_command)
for page in paginator.paginate():
page_jobs: List[Any] = page[job_spec.list_key]

Expand Down Expand Up @@ -269,7 +269,7 @@ def get_job_details(self, job_name: str, job_type: JobType) -> Dict[str, Any]:
describe_command = job_type_to_info[job_type].describe_command
describe_name_key = job_type_to_info[job_type].describe_name_key

return getattr(self.sagemaker_client, describe_command)(
return getattr(self.sagemaker_client(), describe_command)(
**{describe_name_key: job_name}
)

Expand Down
15 changes: 13 additions & 2 deletions metadata-ingestion/tests/unit/sagemaker/test_sagemaker_source.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import patch

from botocore.stub import Stubber
from freezegun import freeze_time

Expand Down Expand Up @@ -220,8 +222,17 @@ def test_sagemaker_ingest(tmp_path, pytestconfig):
{"ModelName": "the-second-model"},
)

mce_objects = [wu.metadata for wu in sagemaker_source_instance.get_workunits()]
write_metadata_file(tmp_path / "sagemaker_mces.json", mce_objects)
# Patch the client factory's get_client method to return the stubbed client for jobs
with patch.object(
sagemaker_source_instance.client_factory,
"get_client",
return_value=sagemaker_source_instance.sagemaker_client,
):
# Run the test and generate the MCEs
mce_objects = [
wu.metadata for wu in sagemaker_source_instance.get_workunits()
]
write_metadata_file(tmp_path / "sagemaker_mces.json", mce_objects)

# Verify the output.
test_resources_dir = pytestconfig.rootpath / "tests/unit/sagemaker"
Expand Down

0 comments on commit 50ed448

Please sign in to comment.