Skip to content

Commit

Permalink
fix: PR Comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sagar-salvi-apptware committed Aug 16, 2024
1 parent e93e5d1 commit 2050d9e
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,9 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
# extract jobs if specified
if self.source_config.extract_jobs is not False:
job_processor = JobProcessor(
config=self.source_config
sagemaker_client=self.source_config.get_auto_refreshing_sagemaker_client()
if self.source_config.allowed_cred_refresh()
else None,
sagemaker_client=self.sagemaker_client,
else self.sagemaker_client,
env=self.env,
report=self.report,
job_type_filter=self.source_config.extract_jobs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ class SagemakerSourceConfig(
def sagemaker_client(self):
return self.get_sagemaker_client()

def get_auto_refreshing_sagemaker_client(self):
"""
Returns a reference to the SageMaker client function.
This is used to create a fresh client each time it is called.
"""
return self.get_sagemaker_client


@dataclass
class SagemakerSourceReport(StaleEntityRemovalSourceReport):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from types import MethodType
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -21,7 +22,6 @@
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws.s3_util import make_s3_urn
from datahub.ingestion.source.aws.sagemaker_processors.common import (
SagemakerSourceConfig,
SagemakerSourceReport,
)
from datahub.ingestion.source.aws.sagemaker_processors.job_classes import (
Expand Down Expand Up @@ -147,14 +147,8 @@ class JobProcessor:
Job ingestion module, called by top-level SageMaker ingestion handler.
"""

# boto3 SageMaker client using configuration
# Accessing `config.sagemaker_client` within the function ensures that
# the property is re-evaluated each time it is accessed, allowing the
# refresh logic to be triggered if necessary.
config: Optional[SagemakerSourceConfig]

# This is the SageMaker client instance when AWS roles are not used
sagemaker_client: "SageMakerClient"
# boto3 SageMaker client
sagemaker_client: Any

env: str
report: SagemakerSourceReport
Expand All @@ -178,13 +172,7 @@ class JobProcessor:

def get_jobs(self, job_type: JobType, job_spec: JobInfo) -> List[Any]:
jobs = []

paginator = (
self.config.sagemaker_client.get_paginator(job_spec.list_command)
if self.config is not None
else self.sagemaker_client.get_paginator(job_spec.list_command)
)

paginator = self.get_sagemaker_client().get_paginator(job_spec.list_command)
for page in paginator.paginate():
page_jobs: List[Any] = page[job_spec.list_key]

Expand Down Expand Up @@ -282,12 +270,9 @@ def get_job_details(self, job_name: str, job_type: JobType) -> Dict[str, Any]:
describe_command = job_type_to_info[job_type].describe_command
describe_name_key = job_type_to_info[job_type].describe_name_key

client = (
self.config.sagemaker_client
if self.config is not None
else self.sagemaker_client
return getattr(self.get_sagemaker_client(), describe_command)(
**{describe_name_key: job_name}
)
return getattr(client, describe_command)(**{describe_name_key: job_name})

def get_workunits(self) -> Iterable[MetadataWorkUnit]:
jobs = self.get_all_jobs()
Expand Down Expand Up @@ -956,3 +941,8 @@ def process_transform_job(self, job: Dict[str, Any]) -> SageMakerJob:
output_datasets=output_datasets,
input_jobs=input_jobs,
)

def get_sagemaker_client(self) -> SageMakerClient:
if isinstance(self.sagemaker_client, MethodType):
return self.sagemaker_client()
return self.sagemaker_client

0 comments on commit 2050d9e

Please sign in to comment.