Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mode/ingest): Add support for missing Mode datasets in lineage #11290

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class BIAssetSubTypes(StrEnum):

# Mode
MODE_REPORT = "Report"
MODE_DATASET = "Dataset"
MODE_QUERY = "Query"
MODE_CHART = "Chart"

Expand Down
128 changes: 100 additions & 28 deletions metadata-ingestion/src/datahub/ingestion/source/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
infer_output_schema,
)
from datahub.utilities import config_clean
from datahub.utilities.lossy_collections import LossyDict, LossyList
from datahub.utilities.lossy_collections import LossyList

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -199,10 +199,6 @@ class ModeSourceReport(StaleEntityRemovalSourceReport):
num_query_template_render_failures: int = 0
num_query_template_render_success: int = 0

dropped_imported_datasets: LossyDict[str, LossyList[str]] = dataclasses.field(
default_factory=LossyDict
)

def report_dropped_space(self, ent_name: str) -> None:
self.filtered_spaces.append(ent_name)

Expand Down Expand Up @@ -429,10 +425,25 @@ def construct_dashboard(
# Last refreshed ts.
last_refreshed_ts = self._parse_last_run_at(report_info)

# Datasets
datasets = []
for imported_dataset_name in report_info.get("imported_datasets", {}):
mode_dataset = self._get_request_json(
f"{self.workspace_uri}/reports/{imported_dataset_name.get('token')}"
)
dataset_urn = builder.make_dataset_urn_with_platform_instance(
self.platform,
str(mode_dataset.get("id")),
platform_instance=None,
env=self.config.env,
)
datasets.append(dataset_urn)

dashboard_info_class = DashboardInfoClass(
description=description if description else "",
title=title if title else "",
charts=self._get_chart_urns(report_token),
datasets=datasets if datasets else None,
lastModified=last_modified,
lastRefreshed=last_refreshed_ts,
dashboardUrl=f"{self.config.connect_uri}/{self.config.workspace}/reports/{report_token}",
Expand Down Expand Up @@ -725,6 +736,10 @@ def _get_platform_and_dbname(
data_source.get("adapter", ""), data_source.get("name", "")
)
database = data_source.get("database", "")
# This is hacky but on bigquery we want to change the database if its default
# For lineage we need project_id.db.table
if platform == "bigquery" and database == "default":
database = data_source.get("host", "")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is normally in the "host" key?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

host in this case is BQ Project ID.

"database": "default",
 "host": "harshal-playground-306419",

return platform, database
else:
self.report.report_warning(
Expand Down Expand Up @@ -900,24 +915,36 @@ def normalize_mode_query(self, query: str) -> str:

return rendered_query

def construct_query_from_api_data(
def construct_query_or_dataset(
self,
report_token: str,
query_data: dict,
space_token: str,
report_info: dict,
is_mode_dataset: bool,
) -> Iterable[MetadataWorkUnit]:
query_urn = self.get_dataset_urn_from_query(query_data)
query_urn = (
self.get_dataset_urn_from_query(query_data)
if not is_mode_dataset
else self.get_dataset_urn_from_query(report_info)
sagar-salvi-apptware marked this conversation as resolved.
Show resolved Hide resolved
)

query_token = query_data.get("token")

externalUrl = (
f"{self.config.connect_uri}/{self.config.workspace}/datasets/{report_token}"
if is_mode_dataset
else f"{self.config.connect_uri}/{self.config.workspace}/reports/{report_token}/details/queries/{query_token}"
)

dataset_props = DatasetPropertiesClass(
name=query_data.get("name"),
name=report_info.get("name") if is_mode_dataset else query_data.get("name"),
description=f"""### Source Code
``` sql
{query_data.get("raw_query")}
```
""",
externalUrl=f"{self.config.connect_uri}/{self.config.workspace}/reports/{report_token}/details/queries/{query_token}",
externalUrl=externalUrl,
customProperties=self.get_custom_props_from_dict(
query_data,
[
Expand All @@ -939,7 +966,22 @@ def construct_query_from_api_data(
).as_workunit()
)

subtypes = SubTypesClass(typeNames=([BIAssetSubTypes.MODE_QUERY]))
if is_mode_dataset:
space_container_key = self.gen_space_key(space_token)
yield from add_dataset_to_container(
container_key=space_container_key,
dataset_urn=query_urn,
)

subtypes = SubTypesClass(
typeNames=(
[
BIAssetSubTypes.MODE_DATASET
if is_mode_dataset
else BIAssetSubTypes.MODE_QUERY
]
)
)
yield (
MetadataChangeProposalWrapper(
entityUrn=query_urn,
Expand All @@ -950,15 +992,16 @@ def construct_query_from_api_data(
yield MetadataChangeProposalWrapper(
entityUrn=query_urn,
aspect=BrowsePathsV2Class(
path=self._browse_path_query(space_token, report_info)
path=self._browse_path_dashboard(space_token)
if is_mode_dataset
else self._browse_path_query(space_token, report_info)
),
).as_workunit()

(
upstream_warehouse_platform,
upstream_warehouse_db_name,
) = self._get_platform_and_dbname(query_data.get("data_source_id"))

if upstream_warehouse_platform is None:
# this means we can't infer the platform
return
Expand Down Expand Up @@ -1022,7 +1065,7 @@ def construct_query_from_api_data(
schema_fields = infer_output_schema(parsed_query_object)
if schema_fields:
schema_metadata = SchemaMetadataClass(
schemaName="mode_query",
schemaName="mode_dataset" if is_mode_dataset else "mode_query",
platform=f"urn:li:dataPlatform:{self.platform}",
version=0,
fields=schema_fields,
Expand All @@ -1040,7 +1083,7 @@ def construct_query_from_api_data(
)

yield from self.get_upstream_lineage_for_parsed_sql(
query_data, parsed_query_object
query_urn, query_data, parsed_query_object
)

operation = OperationClass(
Expand Down Expand Up @@ -1089,10 +1132,9 @@ def construct_query_from_api_data(
).as_workunit()

def get_upstream_lineage_for_parsed_sql(
self, query_data: dict, parsed_query_object: SqlParsingResult
self, query_urn: str, query_data: dict, parsed_query_object: SqlParsingResult
) -> List[MetadataWorkUnit]:
wu = []
query_urn = self.get_dataset_urn_from_query(query_data)

if parsed_query_object is None:
logger.info(
Expand Down Expand Up @@ -1350,6 +1392,24 @@ def _get_reports(self, space_token: str) -> List[dict]:
)
return reports

@lru_cache(maxsize=None)
def _get_datasets(self, space_token: str) -> List[dict]:
"""
Retrieves datasets for a given space token.
"""
datasets = []
try:
url = f"{self.workspace_uri}/spaces/{space_token}/datasets"
datasets_json = self._get_request_json(url)
datasets = datasets_json.get("_embedded", {}).get("reports", [])
except HTTPError as http_error:
self.report.report_failure(
title="Failed to Retrieve Datasets for Space",
message=f"Unable to retrieve datasets for space token {space_token}.",
context=f"Error: {str(http_error)}",
)
return datasets

@lru_cache(maxsize=None)
def _get_queries(self, report_token: str) -> list:
queries = []
Expand Down Expand Up @@ -1523,24 +1583,14 @@ def emit_chart_mces(self) -> Iterable[MetadataWorkUnit]:
for report in reports:
report_token = report.get("token", "")

if report.get("imported_datasets"):
# The connector doesn't support imported datasets yet.
# For now, we just keep this in the report to track what we're missing.
imported_datasets = [
imported_dataset.get("name") or str(imported_dataset)
for imported_dataset in report["imported_datasets"]
]
self.report.dropped_imported_datasets.setdefault(
sagar-salvi-apptware marked this conversation as resolved.
Show resolved Hide resolved
report_token, LossyList()
).extend(imported_datasets)

queries = self._get_queries(report_token)
for query in queries:
query_mcps = self.construct_query_from_api_data(
query_mcps = self.construct_query_or_dataset(
report_token,
query,
space_token=space_token,
report_info=report,
is_mode_dataset=False,
)
chart_fields: Dict[str, SchemaFieldClass] = {}
for wu in query_mcps:
Expand All @@ -1566,6 +1616,27 @@ def emit_chart_mces(self) -> Iterable[MetadataWorkUnit]:
query_name=query["name"],
)

def emit_dataset_mces(self):
"""
Emits MetadataChangeEvents (MCEs) for datasets within each space.
"""
for space_token, _ in self.space_tokens.items():
datasets = self._get_datasets(space_token)

for report in datasets:
report_token = report.get("token", "")
queries = self._get_queries(report_token)
for query in queries:
query_mcps = self.construct_query_or_dataset(
report_token,
query,
space_token=space_token,
report_info=report,
is_mode_dataset=True,
)
for wu in query_mcps:
yield wu

@classmethod
def create(cls, config_dict: dict, ctx: PipelineContext) -> "ModeSource":
config: ModeConfig = ModeConfig.parse_obj(config_dict)
Expand All @@ -1581,6 +1652,7 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:

def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
yield from self.emit_dashboard_mces()
yield from self.emit_dataset_mces()
yield from self.emit_chart_mces()

def get_report(self) -> SourceReport:
Expand Down
Loading
Loading