1313# limitations under the License.
1414
1515"""DB-API Connection for the Google Cloud Spanner."""
16- import time
1716import warnings
1817
1918from google .api_core .exceptions import Aborted
2322from google .cloud .spanner_dbapi .batch_dml_executor import BatchMode , BatchDmlExecutor
2423from google .cloud .spanner_dbapi .parse_utils import _get_statement_type
2524from google .cloud .spanner_dbapi .parsed_statement import (
26- ParsedStatement ,
27- Statement ,
2825 StatementType ,
2926)
3027from google .cloud .spanner_dbapi .partition_helper import PartitionId
28+ from google .cloud .spanner_dbapi .parsed_statement import ParsedStatement , Statement
29+ from google .cloud .spanner_dbapi .transaction_helper import TransactionRetryHelper
30+ from google .cloud .spanner_dbapi .cursor import Cursor
3131from google .cloud .spanner_v1 import RequestOptions
32- from google .cloud .spanner_v1 .session import _get_retry_delay
3332from google .cloud .spanner_v1 .snapshot import Snapshot
3433from deprecated import deprecated
3534
36- from google .cloud .spanner_dbapi .checksum import _compare_checksums
37- from google .cloud .spanner_dbapi .checksum import ResultsChecksum
38- from google .cloud .spanner_dbapi .cursor import Cursor
3935from google .cloud .spanner_dbapi .exceptions import (
4036 InterfaceError ,
4137 OperationalError ,
4440from google .cloud .spanner_dbapi .version import DEFAULT_USER_AGENT
4541from google .cloud .spanner_dbapi .version import PY_VERSION
4642
47- from google .rpc .code_pb2 import ABORTED
48-
4943
5044CLIENT_TRANSACTION_NOT_STARTED_WARNING = (
5145 "This method is non-operational as a transaction has not been started."
5246)
53- MAX_INTERNAL_RETRIES = 50
5447
5548
5649def check_not_closed (function ):
@@ -106,9 +99,6 @@ def __init__(self, instance, database=None, read_only=False):
10699 self ._transaction = None
107100 self ._session = None
108101 self ._snapshot = None
109- # SQL statements, which were executed
110- # within the current transaction
111- self ._statements = []
112102
113103 self .is_closed = False
114104 self ._autocommit = False
@@ -125,6 +115,7 @@ def __init__(self, instance, database=None, read_only=False):
125115 self ._spanner_transaction_started = False
126116 self ._batch_mode = BatchMode .NONE
127117 self ._batch_dml_executor : BatchDmlExecutor = None
118+ self ._transaction_helper = TransactionRetryHelper (self )
128119
129120 @property
130121 def autocommit (self ):
@@ -288,76 +279,6 @@ def _release_session(self):
288279 self .database ._pool .put (self ._session )
289280 self ._session = None
290281
291- def retry_transaction (self ):
292- """Retry the aborted transaction.
293-
294- All the statements executed in the original transaction
295- will be re-executed in new one. Results checksums of the
296- original statements and the retried ones will be compared.
297-
298- :raises: :class:`google.cloud.spanner_dbapi.exceptions.RetryAborted`
299- If results checksum of the retried statement is
300- not equal to the checksum of the original one.
301- """
302- attempt = 0
303- while True :
304- self ._spanner_transaction_started = False
305- attempt += 1
306- if attempt > MAX_INTERNAL_RETRIES :
307- raise
308-
309- try :
310- self ._rerun_previous_statements ()
311- break
312- except Aborted as exc :
313- delay = _get_retry_delay (exc .errors [0 ], attempt )
314- if delay :
315- time .sleep (delay )
316-
317- def _rerun_previous_statements (self ):
318- """
319- Helper to run all the remembered statements
320- from the last transaction.
321- """
322- for statement in self ._statements :
323- if isinstance (statement , list ):
324- statements , checksum = statement
325-
326- transaction = self .transaction_checkout ()
327- statements_tuple = []
328- for single_statement in statements :
329- statements_tuple .append (single_statement .get_tuple ())
330- status , res = transaction .batch_update (statements_tuple )
331-
332- if status .code == ABORTED :
333- raise Aborted (status .details )
334-
335- retried_checksum = ResultsChecksum ()
336- retried_checksum .consume_result (res )
337- retried_checksum .consume_result (status .code )
338-
339- _compare_checksums (checksum , retried_checksum )
340- else :
341- res_iter , retried_checksum = self .run_statement (statement , retried = True )
342- # executing all the completed statements
343- if statement != self ._statements [- 1 ]:
344- for res in res_iter :
345- retried_checksum .consume_result (res )
346-
347- _compare_checksums (statement .checksum , retried_checksum )
348- # executing the failed statement
349- else :
350- # streaming up to the failed result or
351- # to the end of the streaming iterator
352- while len (retried_checksum ) < len (statement .checksum ):
353- try :
354- res = next (iter (res_iter ))
355- retried_checksum .consume_result (res )
356- except StopIteration :
357- break
358-
359- _compare_checksums (statement .checksum , retried_checksum )
360-
361282 def transaction_checkout (self ):
362283 """Get a Cloud Spanner transaction.
363284
@@ -433,12 +354,10 @@ def begin(self):
433354
434355 def commit (self ):
435356 """Commits any pending transaction to the database.
436-
437357 This is a no-op if there is no active client transaction.
438358 """
439359 if self .database is None :
440360 raise ValueError ("Database needs to be passed for this operation" )
441-
442361 if not self ._client_transaction_started :
443362 warnings .warn (
444363 CLIENT_TRANSACTION_NOT_STARTED_WARNING , UserWarning , stacklevel = 2
@@ -450,33 +369,31 @@ def commit(self):
450369 if self ._spanner_transaction_started and not self ._read_only :
451370 self ._transaction .commit ()
452371 except Aborted :
453- self .retry_transaction ()
372+ self ._transaction_helper . retry_transaction ()
454373 self .commit ()
455374 finally :
456- self ._release_session ()
457- self ._statements = []
458- self ._transaction_begin_marked = False
459- self ._spanner_transaction_started = False
375+ self ._reset_post_commit_or_rollback ()
460376
461377 def rollback (self ):
462378 """Rolls back any pending transaction.
463-
464379 This is a no-op if there is no active client transaction.
465380 """
466381 if not self ._client_transaction_started :
467382 warnings .warn (
468383 CLIENT_TRANSACTION_NOT_STARTED_WARNING , UserWarning , stacklevel = 2
469384 )
470385 return
471-
472386 try :
473387 if self ._spanner_transaction_started and not self ._read_only :
474388 self ._transaction .rollback ()
475389 finally :
476- self ._release_session ()
477- self ._statements = []
478- self ._transaction_begin_marked = False
479- self ._spanner_transaction_started = False
390+ self ._reset_post_commit_or_rollback ()
391+
392+ def _reset_post_commit_or_rollback (self ):
393+ self ._release_session ()
394+ self ._transaction_helper .reset ()
395+ self ._transaction_begin_marked = False
396+ self ._spanner_transaction_started = False
480397
481398 @check_not_closed
482399 def cursor (self ):
@@ -493,7 +410,7 @@ def run_prior_DDL_statements(self):
493410
494411 return self .database .update_ddl (ddl_statements ).result ()
495412
496- def run_statement (self , statement : Statement , retried = False ):
413+ def run_statement (self , statement : Statement ):
497414 """Run single SQL statement in begun transaction.
498415
499416 This method is never used in autocommit mode. In
@@ -513,17 +430,11 @@ def run_statement(self, statement: Statement, retried=False):
513430 checksum of this statement results.
514431 """
515432 transaction = self .transaction_checkout ()
516- if not retried :
517- self ._statements .append (statement )
518-
519- return (
520- transaction .execute_sql (
521- statement .sql ,
522- statement .params ,
523- param_types = statement .param_types ,
524- request_options = self .request_options ,
525- ),
526- ResultsChecksum () if retried else statement .checksum ,
433+ return transaction .execute_sql (
434+ statement .sql ,
435+ statement .params ,
436+ param_types = statement .param_types ,
437+ request_options = self .request_options ,
527438 )
528439
529440 @check_not_closed
0 commit comments