Skip to content

Commit 40e23d2

Browse files
author
Jesse Whitehouse
committed
Merge branch 'main' into sqlalchemy-tz
Signed-off-by: Jesse Whitehouse <[email protected]>
2 parents 6719013 + 2d58e5b commit 40e23d2

File tree

7 files changed

+77
-15
lines changed

7 files changed

+77
-15
lines changed

examples/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ To run all of these examples you can clone the entire repository to your disk. O
3333
- **`insert_data.py`** adds a tables called `squares` to your default catalog and inserts one hundred rows of example data. Then it fetches this data and prints it to the screen.
3434
- **`query_cancel.py`** shows how to cancel a query assuming that you can access the `Cursor` executing that query from a different thread. This is necessary because `databricks-sql-connector` does not yet implement an asynchronous API; calling `.execute()` blocks the current thread until execution completes. Therefore, the connector can't cancel queries from the same thread where they began.
3535
- **`interactive_oauth.py`** shows the simplest example of authenticating by OAuth (no need for a PAT generated in the DBSQL UI) while Bring Your Own IDP is in public preview. When you run the script it will open a browser window so you can authenticate. Afterward, the script fetches some sample data from Databricks and prints it to the screen. For this script, the OAuth token is not persisted which means you need to authenticate every time you run the script.
36+
- **`m2m_oauth.py`** shows the simplest example of authenticating by using OAuth M2M (machine-to-machine) for service principal.
3637
- **`persistent_oauth.py`** shows a more advanced example of authenticating by OAuth while Bring Your Own IDP is in public preview. In this case, it shows how to use a sublcass of `OAuthPersistence` to reuse an OAuth token across script executions.
3738
- **`set_user_agent.py`** shows how to customize the user agent header used for Thrift commands. In
3839
this example the string `ExamplePartnerTag` will be added to the the user agent on every request.

examples/m2m_oauth.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import os
2+
3+
from databricks.sdk.core import oauth_service_principal, Config
4+
from databricks import sql
5+
6+
"""
7+
This example shows how to use OAuth M2M (machine-to-machine) for service principal
8+
9+
Pre-requisites:
10+
- Create service principal and OAuth secret in Account Console
11+
- Assign the service principal to the workspace
12+
13+
See more https://docs.databricks.com/en/dev-tools/authentication-oauth.html)
14+
"""
15+
16+
server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME")
17+
18+
19+
def credential_provider():
20+
config = Config(
21+
host=f"https://{server_hostname}",
22+
# Service Principal UUID
23+
client_id=os.getenv("DATABRICKS_CLIENT_ID"),
24+
# Service Principal Secret
25+
client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"))
26+
return oauth_service_principal(config)
27+
28+
29+
with sql.connect(
30+
server_hostname=server_hostname,
31+
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
32+
credentials_provider=credential_provider) as connection:
33+
for x in range(1, 100):
34+
cursor = connection.cursor()
35+
cursor.execute('SELECT 1+1')
36+
result = cursor.fetchall()
37+
for row in result:
38+
print(row)
39+
cursor.close()
40+
41+
connection.close()

src/databricks/sql/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
# PEP 249 module globals
66
apilevel = "2.0"
77
threadsafety = 1 # Threads may share the module, but not connections.
8-
paramstyle = "pyformat" # Python extended format codes, e.g. ...WHERE name=%(name)s
8+
9+
# Python extended format codes, e.g. ...WHERE name=%(name)s
10+
# Note that when we switch to ParameterApproach.NATIVE, paramstyle will be `named`
11+
paramstyle = "pyformat"
912

1013

1114
class DBAPITypeObject(object):

src/databricks/sql/client.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
DEFAULT_RESULT_BUFFER_SIZE_BYTES = 104857600
3737
DEFAULT_ARRAY_SIZE = 100000
3838

39-
NO_NATIVE_PARAMS = []
39+
NO_NATIVE_PARAMS: List = []
4040

4141

4242
class Connection:
@@ -376,9 +376,12 @@ def __iter__(self):
376376
else:
377377
raise Error("There is no active result set")
378378

379-
def _determine_parameter_approach(self) -> ParameterApproach:
379+
def _determine_parameter_approach(
380+
self, params: Optional[Union[List, Dict[str, Any]]] = None
381+
) -> ParameterApproach:
380382
"""Encapsulates the logic for choosing whether to send parameters in native vs inline mode
381383
384+
If params is None then ParameterApproach.NONE is returned.
382385
If self.use_inline_params is True then inline mode is used.
383386
If self.use_inline_params is False, then check if the server supports them and proceed.
384387
Else raise an exception.
@@ -388,6 +391,9 @@ def _determine_parameter_approach(self) -> ParameterApproach:
388391
If inline approach is used when the server supports native approach, a warning is logged
389392
"""
390393

394+
if params is None:
395+
return ParameterApproach.NONE
396+
391397
server_supports_native_approach = (
392398
self.connection.server_parameterized_queries_enabled(
393399
self.connection.protocol_version
@@ -415,7 +421,7 @@ def _determine_parameter_approach(self) -> ParameterApproach:
415421
)
416422

417423
def _prepare_inline_parameters(
418-
self, stmt: str, params: Union[List, Dict[str, Any]]
424+
self, stmt: str, params: Optional[Union[List, Dict[str, Any]]]
419425
) -> Tuple[str, List]:
420426
"""Return a statement and list of native parameters to be passed to thrift_backend for execution
421427
@@ -440,7 +446,7 @@ def _prepare_inline_parameters(
440446
return rendered_statement, NO_NATIVE_PARAMS
441447

442448
def _prepare_native_parameters(
443-
self, stmt: str, params: Union[List[Any], Dict[str, Any]]
449+
self, stmt: str, params: Optional[Union[List[Any], Dict[str, Any]]]
444450
) -> Tuple[str, List[TSparkParameter]]:
445451
"""Return a statement and a list of native parameters to be passed to thrift_backend for execution
446452
@@ -455,12 +461,12 @@ def _prepare_native_parameters(
455461
can be wrapped in a DbsqlParameter class.
456462
457463
Returns a tuple of:
458-
stmt: the passed statement with the param markers replaced by literal rendered values
464+
stmt: the passed statement` with the param markers replaced by literal rendered values
459465
params: a list of TSparkParameters that will be passed in native mode
460466
"""
461467

462468
stmt = stmt
463-
params = named_parameters_to_tsparkparams(params)
469+
params = named_parameters_to_tsparkparams(params) # type: ignore
464470

465471
return stmt, params
466472

@@ -621,7 +627,7 @@ def _handle_staging_remove(self, presigned_url: str, headers: dict = None):
621627
def execute(
622628
self,
623629
operation: str,
624-
parameters: Optional[Union[List[Any], Dict[str, str]]] = None,
630+
parameters: Optional[Union[List[Any], Dict[str, Any]]] = None,
625631
) -> "Cursor":
626632
"""
627633
Execute a query and wait for execution to complete.
@@ -652,18 +658,16 @@ def execute(
652658
:returns self
653659
"""
654660

655-
if parameters:
656-
param_approach = self._determine_parameter_approach()
657-
else:
658-
param_approach = ParameterApproach.NONE
661+
param_approach = self._determine_parameter_approach(parameters)
662+
if param_approach == ParameterApproach.NONE:
659663
prepared_params = NO_NATIVE_PARAMS
660664
prepared_operation = operation
661665

662-
if param_approach == ParameterApproach.INLINE:
666+
elif param_approach == ParameterApproach.INLINE:
663667
prepared_operation, prepared_params = self._prepare_inline_parameters(
664668
operation, parameters
665669
)
666-
if param_approach == ParameterApproach.NATIVE:
670+
elif param_approach == ParameterApproach.NATIVE:
667671
prepared_operation, prepared_params = self._prepare_native_parameters(
668672
operation, parameters
669673
)

src/databricks/sql/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,9 @@ def calculate_decimal_cast_string(input: Decimal) -> str:
633633
return f"DECIMAL({overall},{after})"
634634

635635

636-
def named_parameters_to_tsparkparams(parameters: Union[List[Any], Dict[str, str]]):
636+
def named_parameters_to_tsparkparams(
637+
parameters: Union[List[Any], Dict[str, str]]
638+
) -> List[TSparkParameter]:
637639
tspark_params = []
638640
if isinstance(parameters, dict):
639641
dbsql_params = named_parameters_to_dbsqlparams_v1(parameters)

src/databricks/sqlalchemy/test_local/e2e/test_basic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
except ImportError:
2323
from sqlalchemy.ext.declarative import declarative_base
2424

25+
from databricks.sqlalchemy.test.test_suite import start_protocol_patch
26+
27+
start_protocol_patch()
28+
2529

2630
USER_AGENT_TOKEN = "PySQL e2e Tests"
2731

tests/e2e/test_parameterized_queries.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,13 @@ def conditional_protocol_patch(self, bypass=False):
130130
pass
131131

132132
def _inline_roundtrip(self, params: dict):
133+
"""This INSERT, SELECT, DELETE dance is necessary because simply selecting
134+
```
135+
"SELECT %(param)s"
136+
```
137+
in INLINE mode would always return a str and the nature of the test is to
138+
confirm that types are maintained.
139+
"""
133140
target_column = self.inline_type_map[type(params.get("p"))]
134141
INSERT_QUERY = f"INSERT INTO pysql_e2e_inline_param_test_table (`{target_column}`) VALUES (%(p)s)"
135142
SELECT_QUERY = f"SELECT {target_column} `col` FROM pysql_e2e_inline_param_test_table LIMIT 1"

0 commit comments

Comments
 (0)