Skip to content

Commit

Permalink
feat: Set unique table suffix to allow parallel incremental executions (
Browse files Browse the repository at this point in the history
dbt-labs#650)

Co-authored-by: nicor88 <[email protected]>
  • Loading branch information
pierrebzl and nicor88 authored May 23, 2024
1 parent 20f039d commit 8813f9e
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

## Getting started

* Clone of fork the repo
* Clone or fork the repo
* Run `make setup`, it will:
1. Install all dependencies
2. Install pre-commit hooks
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ athena:
- Skip creating the table as CTAS and run the operation directly in batch insert mode
- This is particularly useful when the standard table creation process fails due to partition limitations,
allowing you to work with temporary tables and persist the dataset more efficiently
- `unique_tmp_table_suffix` (`default=false`)
- For incremental models using insert overwrite strategy on hive table
- Replace the __dbt_tmp suffix used as temporary table name suffix by a unique uuid
- Useful if you are looking to run multiple dbt build inserting in the same table in parallel
- `lf_tags_config` (`default=none`)
- [AWS Lake Formation](#aws-lake-formation-integration) tags to associate with the table and columns
- `enabled` (`default=False`) whether LF tags management is enabled for a model
Expand Down
6 changes: 6 additions & 0 deletions dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class AthenaConfig(AdapterConfig):
seed_s3_upload_args: Dictionary containing boto3 ExtraArgs when uploading to S3.
partitions_limit: Maximum numbers of partitions when batching.
force_batch: Skip creating the table as ctas and run the operation directly in batch insert mode.
unique_tmp_table_suffix: Enforce the use of a unique id as tmp table suffix instead of __dbt_tmp.
"""

work_group: Optional[str] = None
Expand All @@ -119,6 +120,7 @@ class AthenaConfig(AdapterConfig):
seed_s3_upload_args: Optional[Dict[str, Any]] = None
partitions_limit: Optional[int] = None
force_batch: bool = False
unique_tmp_table_suffix: bool = False


class AthenaAdapter(SQLAdapter):
Expand Down Expand Up @@ -419,6 +421,10 @@ def clean_up_table(self, relation: AthenaRelation) -> None:
if table_location := self.get_glue_table_location(relation):
self.delete_from_s3(table_location)

@available
def generate_unique_temporary_table_suffix(self, suffix_initial: str = "__dbt_tmp") -> str:
return f"{suffix_initial}_{str(uuid4())}"

def quote(self, identifier: str) -> str:
return f"{self.quote_character}{identifier}{self.quote_character}"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,20 @@
{% set lf_grants = config.get('lf_grants') %}
{% set partitioned_by = config.get('partitioned_by') %}
{% set force_batch = config.get('force_batch', False) | as_bool -%}
{% set unique_tmp_table_suffix = config.get('unique_tmp_table_suffix', False) | as_bool -%}
{% set target_relation = this.incorporate(type='table') %}
{% set existing_relation = load_relation(this) %}
{% set old_tmp_relation = adapter.get_relation(identifier=target_relation.identifier ~ '__dbt_tmp',
-- If using insert_overwrite on Hive table, allow to set a unique tmp table suffix
{% if unique_tmp_table_suffix == True and strategy == 'insert_overwrite' and table_type == 'hive' %}
{% set tmp_table_suffix = adapter.generate_unique_temporary_table_suffix() %}
{% else %}
{% set tmp_table_suffix = '__dbt_tmp' %}
{% endif %}

{% set old_tmp_relation = adapter.get_relation(identifier=target_relation.identifier ~ tmp_table_suffix,
schema=schema,
database=database) %}
{% set tmp_relation = make_temp_relation(target_relation, '__dbt_tmp') %}
{% set tmp_relation = make_temp_relation(target_relation, suffix=tmp_table_suffix) %}

-- If no partitions are used with insert_overwrite, we fall back to append mode.
{% if partitioned_by is none and strategy == 'insert_overwrite' %}
Expand Down
161 changes: 161 additions & 0 deletions tests/functional/adapter/test_unique_tmp_table_suffix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import json
import re
from typing import List

import pytest

from dbt.contracts.results import RunStatus
from dbt.tests.util import run_dbt

models__unique_tmp_table_suffix_sql = """
{{ config(
materialized='incremental',
incremental_strategy='insert_overwrite',
partitioned_by=['date_column'],
unique_tmp_table_suffix=True
)
}}
select
random() as rnd,
cast(from_iso8601_date('{{ var('logical_date') }}') as date) as date_column
"""


def extract_running_create_statements(dbt_run_capsys_output: str) -> List[str]:
sql_create_statements = []
# Skipping "Invoking dbt with ['run', '--select', 'unique_tmp_table_suffix'..."
for events_msg in dbt_run_capsys_output.split("\n")[1:]:
base_msg_data = None
# Best effort solution to avoid invalid records and blank lines
try:
base_msg_data = json.loads(events_msg).get("data")
except json.JSONDecodeError:
pass
"""First run will not produce data.sql object in the execution logs, only data.base_msg
containing the "Running Athena query:" initial create statement.
Subsequent incremental runs will only contain the insert from the tmp table into the model
table destination.
Since we want to compare both run create statements, we need to handle both cases"""
if base_msg_data:
base_msg = base_msg_data.get("base_msg")
if "Running Athena query:" in str(base_msg):
if "create table" in base_msg:
sql_create_statements.append(base_msg)

if base_msg_data.get("conn_name") == "model.test.unique_tmp_table_suffix" and "sql" in base_msg_data:
if "create table" in base_msg_data.get("sql"):
sql_create_statements.append(base_msg_data.get("sql"))

return sql_create_statements


def extract_create_statement_table_names(sql_create_statement: str) -> List[str]:
table_names = re.findall(r"(?s)(?<=create table ).*?(?=with)", sql_create_statement)
return [table_name.rstrip() for table_name in table_names]


class TestUniqueTmpTableSuffix:
@pytest.fixture(scope="class")
def models(self):
return {"unique_tmp_table_suffix.sql": models__unique_tmp_table_suffix_sql}

def test__unique_tmp_table_suffix(self, project, capsys):
relation_name = "unique_tmp_table_suffix"
model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}"
expected_unique_table_name_re = (
r"unique_tmp_table_suffix__dbt_tmp_"
r"[0-9a-fA-F]{8}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{12}"
)

first_model_run = run_dbt(
[
"run",
"--select",
relation_name,
"--vars",
'{"logical_date": "2024-01-01"}',
"--log-level",
"debug",
"--log-format",
"json",
]
)

first_model_run_result = first_model_run.results[0]

assert first_model_run_result.status == RunStatus.Success

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_create_statements(out)

assert len(athena_running_create_statements) == 1

first_model_run_result_table_name = extract_create_statement_table_names(athena_running_create_statements[0])[0]

# Run statements logged output should not contain unique table suffix after first run
assert not bool(re.search(expected_unique_table_name_re, first_model_run_result_table_name))

records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0]

assert records_count_first_run == 1

incremental_model_run = run_dbt(
[
"run",
"--select",
relation_name,
"--vars",
'{"logical_date": "2024-01-02"}',
"--log-level",
"debug",
"--log-format",
"json",
]
)

incremental_model_run_result = incremental_model_run.results[0]

assert incremental_model_run_result.status == RunStatus.Success

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_create_statements(out)

assert len(athena_running_create_statements) == 1

incremental_model_run_result_table_name = extract_create_statement_table_names(
athena_running_create_statements[0]
)[0]

# Run statements logged for subsequent incremental model runs should use unique table suffix
assert bool(re.search(expected_unique_table_name_re, incremental_model_run_result_table_name))

assert first_model_run_result_table_name != incremental_model_run_result_table_name

incremental_model_run_2 = run_dbt(
[
"run",
"--select",
relation_name,
"--vars",
'{"logical_date": "2024-01-03"}',
"--log-level",
"debug",
"--log-format",
"json",
]
)

incremental_model_run_result = incremental_model_run_2.results[0]

assert incremental_model_run_result.status == RunStatus.Success

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_create_statements(out)

incremental_model_run_result_table_name_2 = extract_create_statement_table_names(
athena_running_create_statements[0]
)[0]

assert incremental_model_run_result_table_name != incremental_model_run_result_table_name_2

assert first_model_run_result_table_name != incremental_model_run_result_table_name_2

0 comments on commit 8813f9e

Please sign in to comment.