1+ from __future__ import annotations
2+
13import contextvars
2- import types
34from dataclasses import dataclass
4- from typing import Any , Dict , Generator , List , Mapping , Optional , Sequence , Type , Union
5+ from typing import TYPE_CHECKING , Any , Generator , Mapping , Sequence
56
67from piccolo .engine .base import BaseBatch , Engine , validate_savepoint_name
78from piccolo .engine .exceptions import TransactionError
89from piccolo .query .base import DDL , Query
9- from piccolo .querystring import QueryString
1010from piccolo .utils .sync import run_sync
1111from piccolo .utils .warnings import Level , colored_warning
1212from psqlpy import Connection , ConnectionPool , Cursor , Transaction
1313from psqlpy .exceptions import RustPSQLDriverPyBaseError
1414from typing_extensions import Self
1515
16+ if TYPE_CHECKING :
17+ import types
18+
19+ from piccolo .querystring import QueryString
20+
1621
1722@dataclass
1823class AsyncBatch (BaseBatch ):
@@ -23,8 +28,8 @@ class AsyncBatch(BaseBatch):
2328 batch_size : int
2429
2530 # Set internally
26- _transaction : Optional [ Transaction ] = None
27- _cursor : Optional [ Cursor ] = None
31+ _transaction : Transaction | None = None
32+ _cursor : Cursor | None = None
2833
2934 @property
3035 def cursor (self ) -> Cursor :
@@ -37,19 +42,19 @@ def cursor(self) -> Cursor:
3742 raise ValueError ("_cursor not set" )
3843 return self ._cursor
3944
40- async def next (self ) -> List [ Dict [str , Any ]]:
45+ async def next (self ) -> list [ dict [str , Any ]]:
4146 """Retrieve next batch from the Cursor.
4247
4348 ### Returns:
44- List of dicts of results.
49+ list of dicts of results.
4550 """
4651 data = await self .cursor .fetch (self .batch_size )
4752 return data .result ()
4853
4954 def __aiter__ (self : Self ) -> Self :
5055 return self
5156
52- async def __anext__ (self : Self ) -> List [ Dict [str , Any ]]:
57+ async def __anext__ (self : Self ) -> list [ dict [str , Any ]]:
5358 response = await self .next ()
5459 if response == []:
5560 raise StopAsyncIteration
@@ -70,9 +75,9 @@ async def __aenter__(self: Self) -> Self:
7075
7176 async def __aexit__ (
7277 self : Self ,
73- exception_type : Optional [ Type [ BaseException ]] ,
74- exception : Optional [ BaseException ] ,
75- traceback : Optional [ types .TracebackType ] ,
78+ exception_type : type [ BaseException ] | None ,
79+ exception : BaseException | None ,
80+ traceback : types .TracebackType | None ,
7681 ) -> bool :
7782 if exception :
7883 await self ._transaction .rollback () # type: ignore[union-attr]
@@ -98,19 +103,19 @@ class Atomic:
98103
99104 __slots__ = ("engine" , "queries" )
100105
101- def __init__ (self : Self , engine : " PSQLPyEngine" ) -> None :
106+ def __init__ (self : Self , engine : PSQLPyEngine ) -> None :
102107 """Initialize programmatically configured atomic transaction.
103108
104109 ### Parameters:
105110 - `engine`: engine for query executing.
106111 """
107112 self .engine = engine
108- self .queries : List [ Union [ Query [Any , Any ], DDL ] ] = []
113+ self .queries : list [ Query [Any , Any ] | DDL ] = []
109114
110115 def __await__ (self : Self ) -> Generator [Any , None , None ]:
111116 return self .run ().__await__ ()
112117
113- def add (self : Self , * query : Union [ Query [Any , Any ], DDL ] ) -> None :
118+ def add (self : Self , * query : Query [Any , Any ] | DDL ) -> None :
114119 """Add query to atomic transaction.
115120
116121 ### Params:
@@ -128,7 +133,7 @@ async def run(self: Self) -> None:
128133 if isinstance (query , (Query , DDL , Create , GetOrCreate )):
129134 await query .run ()
130135 else :
131- raise ValueError ("Unrecognised query" )
136+ raise TypeError ("Unrecognised query" ) # noqa: TRY301
132137 self .queries = []
133138 except Exception as exception :
134139 self .queries = []
@@ -142,7 +147,7 @@ def run_sync(self: Self) -> None:
142147class Savepoint :
143148 """PostgreSQL `SAVEPOINT` representation in Python."""
144149
145- def __init__ (self : Self , name : str , transaction : " PostgresTransaction" ) -> None :
150+ def __init__ (self : Self , name : str , transaction : PostgresTransaction ) -> None :
146151 """Initialize new `SAVEPOINT`.
147152
148153 ### Parameters:
@@ -179,7 +184,7 @@ class PostgresTransaction:
179184
180185 """
181186
182- def __init__ (self : Self , engine : " PSQLPyEngine" , allow_nested : bool = True ) -> None :
187+ def __init__ (self : Self , engine : PSQLPyEngine , allow_nested : bool = True ) -> None :
183188 """Initialize new transaction.
184189
185190 ### Parameters:
@@ -204,7 +209,7 @@ def __init__(self: Self, engine: "PSQLPyEngine", allow_nested: bool = True) -> N
204209 "aren't allowed." ,
205210 )
206211
207- async def __aenter__ (self : Self ) -> "PostgresTransaction" :
212+ async def __aenter__ (self : Self ) -> Self :
208213 if self ._parent is not None :
209214 return self ._parent
210215
@@ -218,9 +223,9 @@ async def __aenter__(self: Self) -> "PostgresTransaction":
218223
219224 async def __aexit__ (
220225 self : Self ,
221- exception_type : Optional [ Type [ BaseException ]] ,
222- exception : Optional [ BaseException ] ,
223- traceback : Optional [ types .TracebackType ] ,
226+ exception_type : type [ BaseException ] | None ,
227+ exception : BaseException | None ,
228+ traceback : types .TracebackType | None ,
224229 ) -> bool :
225230 if self ._parent :
226231 return exception is None
@@ -271,7 +276,7 @@ def get_savepoint_id(self: Self) -> int:
271276 self ._savepoint_id += 1
272277 return self ._savepoint_id
273278
274- async def savepoint (self : Self , name : Optional [ str ] = None ) -> Savepoint :
279+ async def savepoint (self : Self , name : str | None = None ) -> Savepoint :
275280 """Create new savepoint.
276281
277282 ### Parameters:
@@ -351,11 +356,11 @@ class PSQLPyEngine(Engine[PostgresTransaction]):
351356
352357 def __init__ (
353358 self : Self ,
354- config : Dict [str , Any ],
359+ config : dict [str , Any ],
355360 extensions : Sequence [str ] = ("uuid-ossp" ,),
356361 log_queries : bool = False ,
357362 log_responses : bool = False ,
358- extra_nodes : Optional [ Mapping [str , " PSQLPyEngine" ]] = None ,
363+ extra_nodes : Mapping [str , PSQLPyEngine ] | None = None ,
359364 ) -> None :
360365 """Initialize `PSQLPyEngine`.
361366
@@ -421,7 +426,7 @@ def __init__(
421426 self .log_queries = log_queries
422427 self .log_responses = log_responses
423428 self .extra_nodes = extra_nodes
424- self .pool : Optional [ ConnectionPool ] = None
429+ self .pool : ConnectionPool | None = None
425430 database_name = config .get ("database" , "Unknown" )
426431 self .current_transaction = contextvars .ContextVar (
427432 f"pg_current_transaction_{ database_name } " ,
@@ -449,7 +454,7 @@ def _parse_raw_version_string(version_string: str) -> float:
449454 async def get_version (self : Self ) -> float :
450455 """Retrieve the version of Postgres being run."""
451456 try :
452- response : Sequence [Dict [str , Any ]] = await self ._run_in_new_connection (
457+ response : Sequence [dict [str , Any ]] = await self ._run_in_new_connection (
453458 "SHOW server_version" ,
454459 )
455460 except ConnectionRefusedError as exception :
@@ -475,7 +480,7 @@ async def prep_database(self: Self) -> None:
475480 await self ._run_in_new_connection (
476481 f'CREATE EXTENSION IF NOT EXISTS "{ extension } "' ,
477482 )
478- except RustPSQLDriverPyBaseError :
483+ except RustPSQLDriverPyBaseError : # noqa: PERF203
479484 colored_warning (
480485 f"=> Unable to create { extension } extension - some "
481486 "functionality may not behave as expected. Make sure "
@@ -487,7 +492,7 @@ async def prep_database(self: Self) -> None:
487492
488493 async def start_connnection_pool (
489494 self : Self ,
490- ** kwargs : Dict [str , Any ],
495+ ** _kwargs : dict [str , Any ],
491496 ) -> None :
492497 """Start new connection pool.
493498
@@ -504,7 +509,7 @@ async def start_connnection_pool(
504509 )
505510 return await self .start_connection_pool ()
506511
507- async def close_connnection_pool (self : Self , ** kwargs : Dict [str , Any ]) -> None :
512+ async def close_connnection_pool (self : Self , ** _kwargs : dict [str , Any ]) -> None :
508513 """Close connection pool."""
509514 colored_warning (
510515 "`close_connnection_pool` is a typo - please change it to "
@@ -513,7 +518,7 @@ async def close_connnection_pool(self: Self, **kwargs: Dict[str, Any]) -> None:
513518 )
514519 return await self .close_connection_pool ()
515520
516- async def start_connection_pool (self : Self , ** kwargs : Dict [str , Any ]) -> None :
521+ async def start_connection_pool (self : Self , ** kwargs : dict [str , Any ]) -> None :
517522 """Start new connection pool.
518523
519524 Create and start new connection pool.
@@ -530,9 +535,6 @@ async def start_connection_pool(self: Self, **kwargs: Dict[str, Any]) -> None:
530535 else :
531536 config = dict (self .config )
532537 config .update (** kwargs )
533- print ("----------------" )
534- print (config )
535- print ("----------------" )
536538 self .pool = ConnectionPool (
537539 db_name = config .pop ("database" , None ),
538540 username = config .pop ("user" , None ),
@@ -549,7 +551,7 @@ async def close_connection_pool(self) -> None:
549551 colored_warning ("No pool is running." )
550552
551553 async def get_new_connection (self ) -> Connection :
552- """Returns a new connection - doesn't retrieve it from the pool."""
554+ """Return a new connection - doesn't retrieve it from the pool."""
553555 if self .pool :
554556 return await self .pool .connection ()
555557
@@ -562,11 +564,21 @@ async def get_new_connection(self) -> Connection:
562564 )
563565 ).connection ()
564566
567+ def transform_response_to_dicts (
568+ self ,
569+ results : list [dict [str , Any ]] | dict [str , Any ],
570+ ) -> list [dict [str , Any ]]:
571+ """Transform result to list of dicts."""
572+ if isinstance (results , list ):
573+ return results
574+
575+ return [results ]
576+
565577 async def batch (
566578 self : Self ,
567579 query : Query [Any , Any ],
568580 batch_size : int = 100 ,
569- node : Optional [ str ] = None ,
581+ node : str | None = None ,
570582 ) -> AsyncBatch :
571583 """Create new `AsyncBatch`.
572584
@@ -588,8 +600,8 @@ async def batch(
588600 async def _run_in_pool (
589601 self : Self ,
590602 query : str ,
591- args : Optional [ Sequence [Any ]] = None ,
592- ) -> List [ Dict [str , Any ]]:
603+ args : Sequence [Any ] | None = None ,
604+ ) -> list [ dict [str , Any ]]:
593605 """Run query in the pool.
594606
595607 ### Parameters:
@@ -613,8 +625,8 @@ async def _run_in_pool(
613625 async def _run_in_new_connection (
614626 self : Self ,
615627 query : str ,
616- args : Optional [ Sequence [Any ]] = None ,
617- ) -> List [ Dict [str , Any ]]:
628+ args : Sequence [Any ] | None = None ,
629+ ) -> list [ dict [str , Any ]]:
618630 """Run query in a new connection.
619631
620632 ### Parameters:
@@ -625,21 +637,19 @@ async def _run_in_new_connection(
625637 Result from the database as a list of dicts.
626638 """
627639 connection = await self .get_new_connection ()
628- try :
629- results = await connection .execute (
630- querystring = query ,
631- parameters = args ,
632- )
633- except RustPSQLDriverPyBaseError as exception :
634- raise exception
640+ results = await connection .execute (
641+ querystring = query ,
642+ parameters = args ,
643+ )
644+ connection .back_to_pool ()
635645
636646 return results .result ()
637647
638648 async def run_querystring (
639649 self : Self ,
640650 querystring : QueryString ,
641651 in_pool : bool = True ,
642- ) -> List [ Dict [str , Any ]]:
652+ ) -> list [ dict [str , Any ]]:
643653 """Run querystring.
644654
645655 ### Parameters:
@@ -649,9 +659,6 @@ async def run_querystring(
649659 ### Returns:
650660 Result from the database as a list of dicts.
651661 """
652- print ("------------------" )
653- print ("RUN" , querystring )
654- print ("------------------" )
655662 query , query_args = querystring .compile_string (engine_type = self .engine_type )
656663
657664 query_id = self .get_query_id ()
@@ -674,14 +681,14 @@ async def run_querystring(
674681
675682 if self .log_responses :
676683 self .print_response (query_id = query_id , response = response )
677- print ( response )
684+
678685 return response
679686
680687 async def run_ddl (
681688 self : Self ,
682689 ddl : str ,
683690 in_pool : bool = True ,
684- ) -> List [ Dict [str , Any ]]:
691+ ) -> list [ dict [str , Any ]]:
685692 """Run ddl query.
686693
687694 ### Parameters:
@@ -697,7 +704,7 @@ async def run_ddl(
697704 current_transaction = self .current_transaction .get ()
698705 if current_transaction :
699706 raw_response = await current_transaction .connection .fetch (ddl )
700- raw_response .result ()
707+ response = raw_response .result ()
701708 elif in_pool and self .pool :
702709 response = await self ._run_in_pool (ddl )
703710 else :
0 commit comments