@@ -58,24 +58,28 @@ class PostgreSQLOnlineStore(OnlineStore):
5858 _conn_pool_async : Optional [AsyncConnectionPool ] = None
5959
6060 @contextlib .contextmanager
61- def _get_conn (self , config : RepoConfig ) -> Generator [Connection , Any , Any ]:
61+ def _get_conn (
62+ self , config : RepoConfig , autocommit : bool = False
63+ ) -> Generator [Connection , Any , Any ]:
6264 assert config .online_store .type == "postgres"
6365
6466 if config .online_store .conn_type == ConnectionType .pool :
6567 if not self ._conn_pool :
6668 self ._conn_pool = _get_connection_pool (config .online_store )
6769 self ._conn_pool .open ()
6870 connection = self ._conn_pool .getconn ()
71+ connection .set_autocommit (autocommit )
6972 yield connection
7073 self ._conn_pool .putconn (connection )
7174 else :
7275 if not self ._conn :
7376 self ._conn = _get_conn (config .online_store )
77+ self ._conn .set_autocommit (autocommit )
7478 yield self ._conn
7579
7680 @contextlib .asynccontextmanager
7781 async def _get_conn_async (
78- self , config : RepoConfig
82+ self , config : RepoConfig , autocommit : bool = False
7983 ) -> AsyncGenerator [AsyncConnection , Any ]:
8084 if config .online_store .conn_type == ConnectionType .pool :
8185 if not self ._conn_pool_async :
@@ -84,11 +88,13 @@ async def _get_conn_async(
8488 )
8589 await self ._conn_pool_async .open ()
8690 connection = await self ._conn_pool_async .getconn ()
91+ await connection .set_autocommit (autocommit )
8792 yield connection
8893 await self ._conn_pool_async .putconn (connection )
8994 else :
9095 if not self ._conn_async :
9196 self ._conn_async = await _get_conn_async (config .online_store )
97+ await self ._conn_async .set_autocommit (autocommit )
9298 yield self ._conn_async
9399
94100 def online_write_batch (
@@ -161,7 +167,7 @@ def online_read(
161167 config , table , keys , requested_features
162168 )
163169
164- with self ._get_conn (config ) as conn , conn .cursor () as cur :
170+ with self ._get_conn (config , autocommit = True ) as conn , conn .cursor () as cur :
165171 cur .execute (query , params )
166172 rows = cur .fetchall ()
167173
@@ -179,7 +185,7 @@ async def online_read_async(
179185 config , table , keys , requested_features
180186 )
181187
182- async with self ._get_conn_async (config ) as conn :
188+ async with self ._get_conn_async (config , autocommit = True ) as conn :
183189 async with conn .cursor () as cur :
184190 await cur .execute (query , params )
185191 rows = await cur .fetchall ()
@@ -339,6 +345,7 @@ def teardown(
339345 for table in tables :
340346 table_name = _table_id (project , table )
341347 cur .execute (_drop_table_and_index (table_name ))
348+ conn .commit ()
342349 except Exception :
343350 logging .exception ("Teardown failed" )
344351 raise
@@ -398,7 +405,7 @@ def retrieve_online_documents(
398405 Optional [ValueProto ],
399406 ]
400407 ] = []
401- with self ._get_conn (config ) as conn , conn .cursor () as cur :
408+ with self ._get_conn (config , autocommit = True ) as conn , conn .cursor () as cur :
402409 table_name = _table_id (project , table )
403410
404411 # Search query template to find the top k items that are closest to the given embedding
0 commit comments