Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest): add support for capability report in snowflake test connection #5472

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions metadata-ingestion/src/datahub/ingestion/api/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class SourceCapability(Enum):
OWNERSHIP = "Extract Ownership"
DELETION_DETECTION = "Detect Deleted Entities"
TAGS = "Extract Tags"
SCHEMA_METADATA = "Schema Metadata"
CONTAINERS = "Asset Containers"


@dataclass
Expand Down
201 changes: 171 additions & 30 deletions metadata-ingestion/src/datahub/ingestion/source/sql/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union

Expand All @@ -9,6 +10,7 @@
# This import verifies that the dependencies are available.
import snowflake.sqlalchemy # noqa: F401
import sqlalchemy.engine
from snowflake import connector
from snowflake.sqlalchemy import custom_types, snowdialect
from sqlalchemy import create_engine, inspect
from sqlalchemy.engine.reflection import Inspector
Expand Down Expand Up @@ -39,10 +41,7 @@
TimeTypeClass,
register_custom_type,
)
from datahub.ingestion.source_config.sql.snowflake import (
APPLICATION_NAME,
SnowflakeConfig,
)
from datahub.ingestion.source_config.sql.snowflake import SnowflakeConfig
from datahub.ingestion.source_report.sql.snowflake import SnowflakeReport
from datahub.metadata.com.linkedin.pegasus2avro.dataset import (
DatasetLineageTypeClass,
Expand All @@ -68,6 +67,8 @@
@support_status(SupportStatus.CERTIFIED)
@capability(SourceCapability.PLATFORM_INSTANCE, "Enabled by default")
@capability(SourceCapability.DOMAINS, "Supported via the `domain` config field")
@capability(SourceCapability.CONTAINERS, "Enabled by default")
@capability(SourceCapability.SCHEMA_METADATA, "Enabled by default")
@capability(SourceCapability.DATA_PROFILING, "Optionally enabled via configuration")
@capability(SourceCapability.DESCRIPTIONS, "Enabled by default")
@capability(SourceCapability.LINEAGE_COARSE, "Optionally enabled via configuration")
Expand All @@ -82,53 +83,193 @@ def __init__(self, config: SnowflakeConfig, ctx: PipelineContext):
self.provision_role_in_progress: bool = False
self.profile_candidates: Dict[str, List[str]] = {}

@staticmethod
def check_capabilities( # noqa: C901
conn: connector.SnowflakeConnection, connection_conf: SnowflakeConfig
) -> Dict[Union[SourceCapability, str], CapabilityReport]: # noqa: C901

# Currently only overall capabilities are reported.
# Resource level variations in capabilities are not considered.

@dataclass
class SnowflakePrivilege:
privilege: str
object_name: str
object_type: str

def query(query):
logger.info("Query : {}".format(query))
resp = conn.cursor().execute(query)
return resp

_report: Dict[Union[SourceCapability, str], CapabilityReport] = dict()
privileges: List[SnowflakePrivilege] = []
capabilities: List[SourceCapability] = [c.capability for c in SnowflakeSource.get_capabilities() if c.capability not in (SourceCapability.PLATFORM_INSTANCE, SourceCapability.DOMAINS, SourceCapability.DELETION_DETECTION)] # type: ignore

cur = query("select current_role()")
current_role = [row[0] for row in cur][0]

cur = query("select current_secondary_roles()")
secondary_roles_str = json.loads([row[0] for row in cur][0])["roles"]
secondary_roles = (
[] if secondary_roles_str == "" else secondary_roles_str.split(",")
)

roles = [current_role] + secondary_roles

# PUBLIC role is automatically granted to every role
if "PUBLIC" not in roles:
roles.append("PUBLIC")
i = 0

while i < len(roles):
role = roles[i]
i = i + 1
try:
cur = query(f"show grants to role {role}")
except connector.errors.ProgrammingError:
# for some roles, quoting is necessary. for example test-role
cur = query(f'show grants to role "{role}"')
shirshanka marked this conversation as resolved.
Show resolved Hide resolved
for row in cur:
privilege = SnowflakePrivilege(
privilege=row[1], object_type=row[2], object_name=row[3]
)
privileges.append(privilege)

if privilege.object_type in (
"DATABASE",
"SCHEMA",
) and privilege.privilege in ("OWNERSHIP", "USAGE"):
_report[SourceCapability.CONTAINERS] = CapabilityReport(
capable=True
)
elif privilege.object_type in (
"TABLE",
"VIEW",
"MATERIALIZED_VIEW",
):
_report[SourceCapability.SCHEMA_METADATA] = CapabilityReport(
capable=True
)
_report[SourceCapability.DESCRIPTIONS] = CapabilityReport(
capable=True
)

if privilege.privilege in ("SELECT", "OWNERSHIP"):
_report[SourceCapability.DATA_PROFILING] = CapabilityReport(
capable=True
)

if privilege.object_name.startswith("SNOWFLAKE.ACCOUNT_USAGE."):
# if access to "snowflake" shared database, access to all account_usage views is automatically granted
# Finer access control is not yet supported for shares
# https://community.snowflake.com/s/article/Error-Granting-individual-privileges-on-imported-database-is-not-allowed-Use-GRANT-IMPORTED-PRIVILEGES-instead
_report[SourceCapability.LINEAGE_COARSE] = CapabilityReport(
capable=True
)
# If all capabilities supported, no need to continue
if set(capabilities) == set(_report.keys()):
break

# Due to this, entire role hierarchy is considered
if (
privilege.object_type == "ROLE"
and privilege.privilege == "USAGE"
and privilege.object_name not in roles
):
roles.append(privilege.object_name)

# If Some capabilities are missing, then mark them as not capable

cur = query("select current_warehouse()")
current_warehouse = [row[0] for row in cur][0]
if current_warehouse is None:
failure_message = (
f"Current role does not have permissions to use warehouse {connection_conf.warehouse}"
if connection_conf.warehouse is not None
else "No default warehouse set for user. Either set default warehouse for user or configure warehouse in recipe"
)

for c in capabilities: # type:ignore

# These capabilities do not work without active warehouse
if current_warehouse is None and c in (
SourceCapability.SCHEMA_METADATA,
SourceCapability.DESCRIPTIONS,
SourceCapability.DATA_PROFILING,
SourceCapability.LINEAGE_COARSE,
):
failure_message = (
f"Current role does not have permissions to use warehouse {connection_conf.warehouse}"
if connection_conf.warehouse is not None
else "No default warehouse set for user. Either set default warehouse for user or configure warehouse in recipe"
)
_report[c] = CapabilityReport(
capable=False,
failure_reason=failure_message,
)
if c in _report.keys():
continue
if c in (SourceCapability.SCHEMA_METADATA, SourceCapability.DESCRIPTIONS):
_report[c] = CapabilityReport(
capable=False,
failure_reason="Either no tables exist or current role does not have permissions to access them",
)
elif c == SourceCapability.DATA_PROFILING:
_report[c] = CapabilityReport(
capable=False,
failure_reason="Either no tables exist or current role does not have permissions to access them",
)
elif c == SourceCapability.CONTAINERS:
_report[c] = CapabilityReport(
capable=False,
failure_reason="Current role does not have permissions to use any database",
)
elif c == SourceCapability.LINEAGE_COARSE:
_report[c] = CapabilityReport(
capable=False,
failure_reason="Current role does not have permissions to snowflake account usage views",
)
return _report

@classmethod
def create(cls, config_dict, ctx):
config = SnowflakeConfig.parse_obj(config_dict)
return cls(config, ctx)

@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
test_report = TestConnectionReport()

try:
SnowflakeConfig.Config.extra = (
pydantic.Extra.allow
) # we are okay with extra fields during this stage
connection_conf = SnowflakeConfig.parse_obj(config_dict)
if connection_conf.authentication_type == "DEFAULT_AUTHENTICATOR":
connection: snowflake.connector.SnowflakeConnection = (
snowflake.connector.connect(
user=connection_conf.username,
password=connection_conf.password.get_secret_value()
if connection_conf.password
else None,
account=connection_conf.account_id,
warehouse=connection_conf.warehouse,
role=connection_conf.role,
application=APPLICATION_NAME,
**connection_conf.connect_args or {},
)
)
assert connection
return TestConnectionReport(
basic_connectivity=CapabilityReport(capable=True)
)
else:
raise NotImplementedError(
"Don't support testing connections for non DEFAULT AUTHENTICATED modes"
)

connection: connector.SnowflakeConnection = connection_conf.get_connection()
assert connection

test_report.basic_connectivity = CapabilityReport(capable=True)

test_report.capability_report = SnowflakeSource.check_capabilities(
connection, connection_conf
)

except Exception as e:
# TODO - do we need sensitive error logging ?
logger.error(f"Failed to test connection due to {e}", exc_info=e)
return TestConnectionReport(
basic_connectivity=CapabilityReport(
if test_report.basic_connectivity is None:
test_report.basic_connectivity = CapabilityReport(
capable=False, failure_reason=f"{e}"
)
)
else:
test_report.internal_failure = True
test_report.internal_failure_reason = f"{e}"
finally:
SnowflakeConfig.Config.extra = (
pydantic.Extra.forbid
) # set config flexibility back to strict
return test_report

def get_metadata_engine(
self, database: Optional[str] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,35 +239,6 @@ def validate_include_view_lineage(cls, v, values):
)
return v

def get_oauth_connection(self):
assert (
self.oauth_config
), "oauth_config should be provided if using oauth based authentication"
generator = OauthTokenGenerator(
self.oauth_config.client_id,
self.oauth_config.authority_url,
self.oauth_config.provider,
)
if self.oauth_config.use_certificate is True:
response = generator.get_token_with_certificate(
private_key_content=str(self.oauth_config.encoded_oauth_public_key),
public_key_content=str(self.oauth_config.encoded_oauth_private_key),
scopes=self.oauth_config.scopes,
)
else:
response = generator.get_token_with_secret(
secret=str(self.oauth_config.client_secret),
scopes=self.oauth_config.scopes,
)
token = response["access_token"]
return snowflake.connector.connect(
user=self.username,
account=self.account_id,
authenticator="oauth",
token=token,
warehouse=self.warehouse,
)

def get_sql_alchemy_url(
self,
database: Optional[str] = None,
Expand Down Expand Up @@ -349,3 +320,79 @@ def get_options(self) -> dict:
options_connect_args.update(self.options.get("connect_args", {}))
self.options["connect_args"] = options_connect_args
return self.options

def get_oauth_connection(self):
assert (
self.oauth_config
), "oauth_config should be provided if using oauth based authentication"
generator = OauthTokenGenerator(
self.oauth_config.client_id,
self.oauth_config.authority_url,
self.oauth_config.provider,
)
if self.oauth_config.use_certificate is True:
response = generator.get_token_with_certificate(
private_key_content=str(self.oauth_config.encoded_oauth_public_key),
public_key_content=str(self.oauth_config.encoded_oauth_private_key),
scopes=self.oauth_config.scopes,
)
else:
response = generator.get_token_with_secret(
secret=str(self.oauth_config.client_secret),
scopes=self.oauth_config.scopes,
)
token = response["access_token"]
connect_args = self.get_options()["connect_args"]
return snowflake.connector.connect(
user=self.username,
account=self.account_id,
token=token,
warehouse=self.warehouse,
authenticator=VALID_AUTH_TYPES.get(self.authentication_type),
application=APPLICATION_NAME,
**connect_args,
)

def get_key_pair_connection(self) -> snowflake.connector.SnowflakeConnection:
connect_args = self.get_options()["connect_args"]

return snowflake.connector.connect(
user=self.username,
account=self.account_id,
warehouse=self.warehouse,
role=self.role,
authenticator=VALID_AUTH_TYPES.get(self.authentication_type),
application=APPLICATION_NAME,
**connect_args,
)

def get_connection(self) -> snowflake.connector.SnowflakeConnection:
connect_args = self.get_options()["connect_args"]
if self.authentication_type == "DEFAULT_AUTHENTICATOR":
return snowflake.connector.connect(
user=self.username,
password=self.password.get_secret_value() if self.password else None,
account=self.account_id,
warehouse=self.warehouse,
role=self.role,
application=APPLICATION_NAME,
**connect_args,
)
elif self.authentication_type == "OAUTH_AUTHENTICATOR":
return self.get_oauth_connection()
elif self.authentication_type == "KEY_PAIR_AUTHENTICATOR":
return self.get_key_pair_connection()
elif self.authentication_type == "EXTERNAL_BROWSER_AUTHENTICATOR":
return snowflake.connector.connect(
user=self.username,
password=self.password.get_secret_value() if self.password else None,
account=self.account_id,
warehouse=self.warehouse,
role=self.role,
authenticator=VALID_AUTH_TYPES.get(self.authentication_type),
application=APPLICATION_NAME,
**connect_args,
)
else:
# not expected to be here
raise Exception("Not expected to be here.")
Loading