Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from functools import reduce
from typing import List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union

import boto3
from boto3.session import Session
from mypy_boto3_glue import GlueClient
from mypy_boto3_s3 import S3Client
from mypy_boto3_sagemaker import SageMakerClient

from datahub.configuration import ConfigModel
from datahub.configuration.common import AllowDenyPattern
from datahub.emitter.mce_builder import DEFAULT_ENV

if TYPE_CHECKING:

from mypy_boto3_glue import GlueClient
from mypy_boto3_s3 import S3Client
from mypy_boto3_sagemaker import SageMakerClient


def assume_role(
role_arn: str, aws_region: str, credentials: Optional[dict] = None
Expand Down Expand Up @@ -88,13 +91,13 @@ def get_session(self) -> Session:
else:
return Session(region_name=self.aws_region)

def get_s3_client(self) -> S3Client:
def get_s3_client(self) -> "S3Client":
return self.get_session().client("s3")

def get_glue_client(self) -> GlueClient:
def get_glue_client(self) -> "GlueClient":
return self.get_session().client("glue")

def get_sagemaker_client(self) -> SageMakerClient:
def get_sagemaker_client(self) -> "SageMakerClient":
return self.get_session().client("sagemaker")


Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
from dataclasses import dataclass
from typing import Iterable, List

from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.type_defs import (
DescribeFeatureGroupResponseTypeDef,
FeatureDefinitionTypeDef,
FeatureGroupSummaryTypeDef,
)
from typing import TYPE_CHECKING, Iterable, List

import datahub.emitter.mce_builder as builder
from datahub.ingestion.api.workunit import MetadataWorkUnit
Expand All @@ -27,14 +20,23 @@
MLPrimaryKeyPropertiesClass,
)

if TYPE_CHECKING:

from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.type_defs import (
DescribeFeatureGroupResponseTypeDef,
FeatureDefinitionTypeDef,
FeatureGroupSummaryTypeDef,
)


@dataclass
class FeatureGroupProcessor:
sagemaker_client: SageMakerClient
sagemaker_client: "SageMakerClient"
env: str
report: SagemakerSourceReport

def get_all_feature_groups(self) -> List[FeatureGroupSummaryTypeDef]:
def get_all_feature_groups(self) -> List["FeatureGroupSummaryTypeDef"]:
"""
List all feature groups in SageMaker.
"""
Expand All @@ -50,7 +52,7 @@ def get_all_feature_groups(self) -> List[FeatureGroupSummaryTypeDef]:

def get_feature_group_details(
self, feature_group_name: str
) -> DescribeFeatureGroupResponseTypeDef:
) -> "DescribeFeatureGroupResponseTypeDef":
"""
Get details of a feature group (including list of component features).
"""
Expand All @@ -74,7 +76,7 @@ def get_feature_group_details(
return feature_group

def get_feature_group_wu(
self, feature_group_details: DescribeFeatureGroupResponseTypeDef
self, feature_group_details: "DescribeFeatureGroupResponseTypeDef"
) -> MetadataWorkUnit:
"""
Generate an MLFeatureTable workunit for a SageMaker feature group.
Expand Down Expand Up @@ -146,8 +148,8 @@ def get_feature_type(self, aws_type: str, feature_name: str) -> str:

def get_feature_wu(
self,
feature_group_details: DescribeFeatureGroupResponseTypeDef,
feature: FeatureDefinitionTypeDef,
feature_group_details: "DescribeFeatureGroupResponseTypeDef",
feature: "FeatureDefinitionTypeDef",
) -> MetadataWorkUnit:
"""
Generate an MLFeature workunit for a SageMaker feature.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
DefaultDict,
Dict,
Expand All @@ -16,8 +17,6 @@
Union,
)

from mypy_boto3_sagemaker import SageMakerClient

from datahub.emitter import mce_builder
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws.aws_common import make_s3_urn
Expand Down Expand Up @@ -47,6 +46,9 @@
JobStatusClass,
)

if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient

JobInfo = TypeVar(
"JobInfo",
AutoMlJobInfo,
Expand Down Expand Up @@ -151,7 +153,7 @@ class JobProcessor:
"""

# boto3 SageMaker client
sagemaker_client: SageMakerClient
sagemaker_client: "SageMakerClient"
env: str
report: SagemakerSourceReport
# config filter for specific job types to ingest (see metadata-ingestion README)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, DefaultDict, Dict, List, Set

from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.type_defs import (
ActionSummaryTypeDef,
ArtifactSummaryTypeDef,
AssociationSummaryTypeDef,
ContextSummaryTypeDef,
)
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Set

from datahub.ingestion.source.aws.sagemaker_processors.common import (
SagemakerSourceReport,
)

if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.type_defs import (
ActionSummaryTypeDef,
ArtifactSummaryTypeDef,
AssociationSummaryTypeDef,
ContextSummaryTypeDef,
)


@dataclass
class LineageInfo:
Expand Down Expand Up @@ -42,13 +43,13 @@ class LineageInfo:

@dataclass
class LineageProcessor:
sagemaker_client: SageMakerClient
sagemaker_client: "SageMakerClient"
env: str
report: SagemakerSourceReport
nodes: Dict[str, Dict[str, Any]] = field(default_factory=dict)
lineage_info: LineageInfo = field(default_factory=LineageInfo)

def get_all_actions(self) -> List[ActionSummaryTypeDef]:
def get_all_actions(self) -> List["ActionSummaryTypeDef"]:
"""
List all actions in SageMaker.
"""
Expand All @@ -62,7 +63,7 @@ def get_all_actions(self) -> List[ActionSummaryTypeDef]:

return actions

def get_all_artifacts(self) -> List[ArtifactSummaryTypeDef]:
def get_all_artifacts(self) -> List["ArtifactSummaryTypeDef"]:
"""
List all artifacts in SageMaker.
"""
Expand All @@ -76,7 +77,7 @@ def get_all_artifacts(self) -> List[ArtifactSummaryTypeDef]:

return artifacts

def get_all_contexts(self) -> List[ContextSummaryTypeDef]:
def get_all_contexts(self) -> List["ContextSummaryTypeDef"]:
"""
List all contexts in SageMaker.
"""
Expand All @@ -90,7 +91,7 @@ def get_all_contexts(self) -> List[ContextSummaryTypeDef]:

return contexts

def get_incoming_edges(self, node_arn: str) -> List[AssociationSummaryTypeDef]:
def get_incoming_edges(self, node_arn: str) -> List["AssociationSummaryTypeDef"]:
"""
Get all incoming edges for a node in the lineage graph.
"""
Expand All @@ -104,7 +105,7 @@ def get_incoming_edges(self, node_arn: str) -> List[AssociationSummaryTypeDef]:

return edges

def get_outgoing_edges(self, node_arn: str) -> List[AssociationSummaryTypeDef]:
def get_outgoing_edges(self, node_arn: str) -> List["AssociationSummaryTypeDef"]:
"""
Get all outgoing edges for a node in the lineage graph.
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from typing import DefaultDict, Dict, Iterable, List, Optional, Set, Tuple

from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.type_defs import (
DescribeEndpointOutputTypeDef,
DescribeModelOutputTypeDef,
DescribeModelPackageGroupOutputTypeDef,
EndpointSummaryTypeDef,
ModelPackageGroupSummaryTypeDef,
ModelSummaryTypeDef,
from typing import (
TYPE_CHECKING,
DefaultDict,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
)

import datahub.emitter.mce_builder as builder
Expand Down Expand Up @@ -43,6 +42,17 @@
OwnershipTypeClass,
)

if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.type_defs import (
DescribeEndpointOutputTypeDef,
DescribeModelOutputTypeDef,
DescribeModelPackageGroupOutputTypeDef,
EndpointSummaryTypeDef,
ModelPackageGroupSummaryTypeDef,
ModelSummaryTypeDef,
)

ENDPOINT_STATUS_MAP: Dict[str, str] = {
"OutOfService": DeploymentStatusClass.OUT_OF_SERVICE,
"Creating": DeploymentStatusClass.CREATING,
Expand All @@ -58,7 +68,7 @@

@dataclass
class ModelProcessor:
sagemaker_client: SageMakerClient
sagemaker_client: "SageMakerClient"
env: str
report: SagemakerSourceReport
lineage: LineageInfo
Expand All @@ -81,7 +91,7 @@ class ModelProcessor:

group_arn_to_name: Dict[str, str] = field(default_factory=dict)

def get_all_models(self) -> List[ModelSummaryTypeDef]:
def get_all_models(self) -> List["ModelSummaryTypeDef"]:
"""
List all models in SageMaker.
"""
Expand All @@ -95,15 +105,15 @@ def get_all_models(self) -> List[ModelSummaryTypeDef]:

return models

def get_model_details(self, model_name: str) -> DescribeModelOutputTypeDef:
def get_model_details(self, model_name: str) -> "DescribeModelOutputTypeDef":
"""
Get details of a model.
"""

# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_model
return self.sagemaker_client.describe_model(ModelName=model_name)

def get_all_groups(self) -> List[ModelPackageGroupSummaryTypeDef]:
def get_all_groups(self) -> List["ModelPackageGroupSummaryTypeDef"]:
"""
List all model groups in SageMaker.
"""
Expand All @@ -118,7 +128,7 @@ def get_all_groups(self) -> List[ModelPackageGroupSummaryTypeDef]:

def get_group_details(
self, group_name: str
) -> DescribeModelPackageGroupOutputTypeDef:
) -> "DescribeModelPackageGroupOutputTypeDef":
"""
Get details of a model group.
"""
Expand All @@ -128,7 +138,7 @@ def get_group_details(
ModelPackageGroupName=group_name
)

def get_all_endpoints(self) -> List[EndpointSummaryTypeDef]:
def get_all_endpoints(self) -> List["EndpointSummaryTypeDef"]:

endpoints = []

Expand All @@ -140,7 +150,9 @@ def get_all_endpoints(self) -> List[EndpointSummaryTypeDef]:

return endpoints

def get_endpoint_details(self, endpoint_name: str) -> DescribeEndpointOutputTypeDef:
def get_endpoint_details(
self, endpoint_name: str
) -> "DescribeEndpointOutputTypeDef":

# see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_endpoint
return self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
Expand All @@ -162,7 +174,7 @@ def get_endpoint_status(
return endpoint_status

def get_endpoint_wu(
self, endpoint_details: DescribeEndpointOutputTypeDef
self, endpoint_details: "DescribeEndpointOutputTypeDef"
) -> MetadataWorkUnit:
"""a
Get a workunit for an endpoint.
Expand Down Expand Up @@ -206,7 +218,7 @@ def get_endpoint_wu(

def get_model_endpoints(
self,
model_details: DescribeModelOutputTypeDef,
model_details: "DescribeModelOutputTypeDef",
endpoint_arn_to_name: Dict[str, str],
model_image: Optional[str],
model_uri: Optional[str],
Expand Down Expand Up @@ -235,7 +247,7 @@ def get_model_endpoints(
return model_endpoints_sorted

def get_group_wu(
self, group_details: DescribeModelPackageGroupOutputTypeDef
self, group_details: "DescribeModelPackageGroupOutputTypeDef"
) -> MetadataWorkUnit:
"""
Get a workunit for a model group.
Expand Down Expand Up @@ -285,7 +297,7 @@ def get_group_wu(
return MetadataWorkUnit(id=group_name, mce=mce)

def match_model_jobs(
self, model_details: DescribeModelOutputTypeDef
self, model_details: "DescribeModelOutputTypeDef"
) -> Tuple[Set[str], Set[str], List[MLHyperParamClass], List[MLMetricClass]]:

model_training_jobs: Set[str] = set()
Expand Down Expand Up @@ -380,7 +392,7 @@ def strip_quotes(string: str) -> str:

def get_model_wu(
self,
model_details: DescribeModelOutputTypeDef,
model_details: "DescribeModelOutputTypeDef",
endpoint_arn_to_name: Dict[str, str],
) -> MetadataWorkUnit:
"""
Expand Down