Skip to content

Commit

Permalink
resolve conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
sid-acryl committed Dec 9, 2024
2 parents ab5a98f + 4811de1 commit b7ca573
Show file tree
Hide file tree
Showing 36 changed files with 116,085 additions and 292 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union

import boto3
from boto3.session import Session
Expand Down Expand Up @@ -107,6 +107,14 @@ class AwsConnectionConfig(ConfigModel):
default=None,
description="A set of proxy configs to use with AWS. See the [botocore.config](https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html) docs for details.",
)
aws_retry_num: int = Field(
default=5,
description="Number of times to retry failed AWS requests. See the [botocore.retry](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html) docs for details.",
)
aws_retry_mode: Literal["legacy", "standard", "adaptive"] = Field(
default="standard",
description="Retry mode to use for failed AWS requests. See the [botocore.retry](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html) docs for details.",
)

read_timeout: float = Field(
default=DEFAULT_TIMEOUT,
Expand Down Expand Up @@ -199,6 +207,10 @@ def _aws_config(self) -> Config:
return Config(
proxies=self.aws_proxy,
read_timeout=self.read_timeout,
retries={
"max_attempts": self.aws_retry_num,
"mode": self.aws_retry_mode,
},
**self.aws_advanced_config,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, DefaultDict, Dict, Iterable, List, Optional

Expand Down Expand Up @@ -36,6 +37,8 @@
if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient

logger = logging.getLogger(__name__)


@platform_name("SageMaker")
@config_class(SagemakerSourceConfig)
Expand Down Expand Up @@ -75,6 +78,7 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:
]

def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
logger.info("Starting SageMaker ingestion...")
# get common lineage graph
lineage_processor = LineageProcessor(
sagemaker_client=self.sagemaker_client, env=self.env, report=self.report
Expand All @@ -83,6 +87,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:

# extract feature groups if specified
if self.source_config.extract_feature_groups:
logger.info("Extracting feature groups...")
feature_group_processor = FeatureGroupProcessor(
sagemaker_client=self.sagemaker_client, env=self.env, report=self.report
)
Expand All @@ -95,6 +100,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:

# extract jobs if specified
if self.source_config.extract_jobs is not False:
logger.info("Extracting jobs...")
job_processor = JobProcessor(
sagemaker_client=self.client_factory.get_client,
env=self.env,
Expand All @@ -109,6 +115,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:

# extract models if specified
if self.source_config.extract_models:
logger.info("Extracting models...")

model_processor = ModelProcessor(
sagemaker_client=self.sagemaker_client,
env=self.env,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@ class SagemakerSourceReport(StaleEntityRemovalSourceReport):
groups_scanned = 0
models_scanned = 0
jobs_scanned = 0
jobs_processed = 0
datasets_scanned = 0
filtered: List[str] = field(default_factory=list)
model_endpoint_lineage = 0
model_group_lineage = 0

def report_feature_group_scanned(self) -> None:
self.feature_groups_scanned += 1
Expand All @@ -58,6 +61,9 @@ def report_group_scanned(self) -> None:
def report_model_scanned(self) -> None:
self.models_scanned += 1

def report_job_processed(self) -> None:
self.jobs_processed += 1

def report_job_scanned(self) -> None:
self.jobs_scanned += 1

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
Expand Down Expand Up @@ -49,6 +50,8 @@
if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient

logger = logging.getLogger(__name__)

JobInfo = TypeVar(
"JobInfo",
AutoMlJobInfo,
Expand Down Expand Up @@ -274,15 +277,18 @@ def get_job_details(self, job_name: str, job_type: JobType) -> Dict[str, Any]:
)

def get_workunits(self) -> Iterable[MetadataWorkUnit]:
logger.info("Getting all SageMaker jobs")
jobs = self.get_all_jobs()

processed_jobs: Dict[str, SageMakerJob] = {}

logger.info("Processing SageMaker jobs")
# first pass: process jobs and collect datasets used
logger.info("first pass: process jobs and collect datasets used")
for job in jobs:
job_type = job_type_to_info[job["type"]]
job_name = job[job_type.list_name_key]

logger.debug(f"Processing job {job_name} with type {job_type}")
job_details = self.get_job_details(job_name, job["type"])

processed_job = getattr(self, job_type.processor)(job_details)
Expand All @@ -293,6 +299,9 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]:
# second pass:
# - move output jobs to inputs
# - aggregate i/o datasets
logger.info(
"second pass: move output jobs to inputs and aggregate i/o datasets"
)
for job_urn in sorted(processed_jobs):
processed_job = processed_jobs[job_urn]

Expand All @@ -301,6 +310,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]:

all_datasets.update(processed_job.input_datasets)
all_datasets.update(processed_job.output_datasets)
self.report.report_job_processed()

# yield datasets
for dataset_urn, dataset in all_datasets.items():
Expand All @@ -322,6 +332,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]:
self.report.report_dataset_scanned()

# third pass: construct and yield MCEs
logger.info("third pass: construct and yield MCEs")
for job_urn in sorted(processed_jobs):
processed_job = processed_jobs[job_urn]
job_snapshot = processed_job.job_snapshot
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Set
Expand All @@ -6,6 +7,8 @@
SagemakerSourceReport,
)

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker.type_defs import (
Expand Down Expand Up @@ -88,7 +91,6 @@ def get_all_contexts(self) -> List["ContextSummaryTypeDef"]:
paginator = self.sagemaker_client.get_paginator("list_contexts")
for page in paginator.paginate():
contexts += page["ContextSummaries"]

return contexts

def get_incoming_edges(self, node_arn: str) -> List["AssociationSummaryTypeDef"]:
Expand Down Expand Up @@ -225,27 +227,32 @@ def get_lineage(self) -> LineageInfo:
"""
Get the lineage of all artifacts in SageMaker.
"""

logger.info("Getting lineage for SageMaker artifacts...")
logger.info("Getting all actions")
for action in self.get_all_actions():
self.nodes[action["ActionArn"]] = {**action, "node_type": "action"}
logger.info("Getting all artifacts")
for artifact in self.get_all_artifacts():
self.nodes[artifact["ArtifactArn"]] = {**artifact, "node_type": "artifact"}
logger.info("Getting all contexts")
for context in self.get_all_contexts():
self.nodes[context["ContextArn"]] = {**context, "node_type": "context"}

logger.info("Getting lineage for model deployments and model groups")
for node_arn, node in self.nodes.items():
logger.debug(f"Getting lineage for node {node_arn}")
# get model-endpoint lineage
if (
node["node_type"] == "action"
and node.get("ActionType") == "ModelDeployment"
):
self.get_model_deployment_lineage(node_arn)

self.report.model_endpoint_lineage += 1
# get model-group lineage
if (
node["node_type"] == "context"
and node.get("ContextType") == "ModelGroup"
):
self.get_model_group_lineage(node_arn, node)

self.report.model_group_lineage += 1
return self.lineage_info
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ def fetch_dpis(self, job_urn: str, batch_size: int) -> List[dict]:
assert self.ctx.graph
dpis = []
start = 0
# This graphql endpoint doesn't support scrolling and therefore after 10k DPIs it causes performance issues on ES
# Therefore, we are limiting the max DPIs to 9000
max_item = 9000
while True:
try:
job_query_result = self.ctx.graph.execute_graphql(
Expand All @@ -226,10 +229,12 @@ def fetch_dpis(self, job_urn: str, batch_size: int) -> List[dict]:
runs = runs_data.get("runs")
dpis.extend(runs)
start += batch_size
if len(runs) < batch_size:
if len(runs) < batch_size or start >= max_item:
break
except Exception as e:
logger.error(f"Exception while fetching DPIs for job {job_urn}: {e}")
self.report.failure(
f"Exception while fetching DPIs for job {job_urn}:", exc=e
)
break
return dpis

Expand All @@ -254,8 +259,9 @@ def keep_last_n_dpi(
deleted_count_last_n += 1
futures[future]["deleted"] = True
except Exception as e:
logger.error(f"Exception while deleting DPI: {e}")

self.report.report_failure(
f"Exception while deleting DPI: {e}", exc=e
)
if deleted_count_last_n % self.config.batch_size == 0:
logger.info(f"Deleted {deleted_count_last_n} DPIs from {job.urn}")
if self.config.delay:
Expand Down Expand Up @@ -289,7 +295,7 @@ def delete_dpi_from_datajobs(self, job: DataJobEntity) -> None:
dpis = self.fetch_dpis(job.urn, self.config.batch_size)
dpis.sort(
key=lambda x: x["created"]["time"]
if "created" in x and "time" in x["created"]
if x.get("created") and x["created"].get("time")
else 0,
reverse=True,
)
Expand Down Expand Up @@ -325,8 +331,8 @@ def remove_old_dpis(
continue

if (
"created" not in dpi
or "time" not in dpi["created"]
not dpi.get("created")
or not dpi["created"].get("time")
or dpi["created"]["time"] < retention_time * 1000
):
future = executor.submit(
Expand All @@ -340,7 +346,7 @@ def remove_old_dpis(
deleted_count_retention += 1
futures[future]["deleted"] = True
except Exception as e:
logger.error(f"Exception while deleting DPI: {e}")
self.report.report_failure(f"Exception while deleting DPI: {e}", exc=e)

if deleted_count_retention % self.config.batch_size == 0:
logger.info(
Expand All @@ -351,9 +357,12 @@ def remove_old_dpis(
logger.info(f"Sleeping for {self.config.delay} seconds")
time.sleep(self.config.delay)

logger.info(
f"Deleted {deleted_count_retention} DPIs from {job.urn} due to retention"
)
if deleted_count_retention > 0:
logger.info(
f"Deleted {deleted_count_retention} DPIs from {job.urn} due to retention"
)
else:
logger.debug(f"No DPIs to delete from {job.urn} due to retention")

def get_data_flows(self) -> Iterable[DataFlowEntity]:
assert self.ctx.graph
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from abc import ABC
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional
Expand All @@ -12,18 +11,8 @@
TRACE_POWERBI_MQUERY_PARSER = os.getenv("DATAHUB_TRACE_POWERBI_MQUERY_PARSER", False)


class AbstractIdentifierAccessor(ABC): # To pass lint
pass


# @dataclass
# class ItemSelector:
# items: Dict[str, Any]
# next: Optional[AbstractIdentifierAccessor]


@dataclass
class IdentifierAccessor(AbstractIdentifierAccessor):
class IdentifierAccessor:
"""
statement
public_order_date = Source{[Schema="public",Item="order_date"]}[Data]
Expand All @@ -40,7 +29,7 @@ class IdentifierAccessor(AbstractIdentifierAccessor):

identifier: str
items: Dict[str, Any]
next: Optional[AbstractIdentifierAccessor]
next: Optional["IdentifierAccessor"]


@dataclass
Expand Down
Loading

0 comments on commit b7ca573

Please sign in to comment.