From a12a40e3bb50a53235f466bfcbe3d40aa8a633cc Mon Sep 17 00:00:00 2001 From: Phil Demetriou Date: Sat, 12 Sep 2020 15:35:31 +0300 Subject: [PATCH 1/7] Allowing extra transaction options --- databases/backends/aiopg.py | 2 +- databases/backends/mysql.py | 2 +- databases/backends/postgres.py | 4 ++-- databases/backends/sqlite.py | 2 +- databases/core.py | 12 ++++++----- databases/interfaces.py | 2 +- tests/test_databases.py | 37 ++++++++++++++++++++++++++++++++++ 7 files changed, 50 insertions(+), 11 deletions(-) diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py index bf6d9b11..5992640b 100644 --- a/databases/backends/aiopg.py +++ b/databases/backends/aiopg.py @@ -215,7 +215,7 @@ def __init__(self, connection: AiopgConnection): self._is_root = False self._savepoint_name = "" - async def start(self, is_root: bool) -> None: + async def start(self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]) -> None: assert self._connection._connection is not None, "Connection is not acquired" self._is_root = is_root cursor = await self._connection._connection.cursor() diff --git a/databases/backends/mysql.py b/databases/backends/mysql.py index 7ba74a9a..9125ecfb 100644 --- a/databases/backends/mysql.py +++ b/databases/backends/mysql.py @@ -206,7 +206,7 @@ def __init__(self, connection: MySQLConnection): self._is_root = False self._savepoint_name = "" - async def start(self, is_root: bool) -> None: + async def start(self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]) -> None: assert self._connection._connection is not None, "Connection is not acquired" self._is_root = is_root if self._is_root: diff --git a/databases/backends/postgres.py b/databases/backends/postgres.py index dcedcdf8..faa28d90 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/postgres.py @@ -278,9 +278,9 @@ def __init__(self, connection: PostgresConnection): None ) # type: typing.Optional[asyncpg.transaction.Transaction] - async def start(self, is_root: bool) -> None: + async def start(self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]) -> None: assert self._connection._connection is not None, "Connection is not acquired" - self._transaction = self._connection._connection.transaction() + self._transaction = self._connection._connection.transaction(**extra_options) await self._transaction.start() async def commit(self) -> None: diff --git a/databases/backends/sqlite.py b/databases/backends/sqlite.py index 264c97a8..d6ef6eb8 100644 --- a/databases/backends/sqlite.py +++ b/databases/backends/sqlite.py @@ -178,7 +178,7 @@ def __init__(self, connection: SQLiteConnection): self._is_root = False self._savepoint_name = "" - async def start(self, is_root: bool) -> None: + async def start(self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]) -> None: assert self._connection._connection is not None, "Connection is not acquired" self._is_root = is_root if self._is_root: diff --git a/databases/core.py b/databases/core.py index cb201be3..587e3f90 100644 --- a/databases/core.py +++ b/databases/core.py @@ -184,8 +184,8 @@ def connection(self) -> "Connection": self._connection_context.set(connection) return connection - def transaction(self, *, force_rollback: bool = False) -> "Transaction": - return Transaction(self.connection, force_rollback=force_rollback) + def transaction(self, *, force_rollback: bool = False, **kwargs: typing.Any) -> "Transaction": + return Transaction(self.connection, force_rollback=force_rollback, **kwargs) @contextlib.contextmanager def force_rollback(self) -> typing.Iterator[None]: @@ -276,11 +276,11 @@ async def iterate( async for record in self._connection.iterate(built_query): yield record - def transaction(self, *, force_rollback: bool = False) -> "Transaction": + def transaction(self, *, force_rollback: bool = False, **kwargs: typing.Any) -> "Transaction": def connection_callable() -> Connection: return self - return Transaction(connection_callable, force_rollback) + return Transaction(connection_callable, force_rollback, **kwargs) @property def raw_connection(self) -> typing.Any: @@ -305,9 +305,11 @@ def __init__( self, connection_callable: typing.Callable[[], Connection], force_rollback: bool, + **kwargs ) -> None: self._connection_callable = connection_callable self._force_rollback = force_rollback + self._extra_options = kwargs async def __aenter__(self) -> "Transaction": """ @@ -355,7 +357,7 @@ async def start(self) -> "Transaction": async with self._connection._transaction_lock: is_root = not self._connection._transaction_stack await self._connection.__aenter__() - await self._transaction.start(is_root=is_root) + await self._transaction.start(is_root=is_root, extra_options=self._extra_options) self._connection._transaction_stack.append(self) return self diff --git a/databases/interfaces.py b/databases/interfaces.py index 9882ef3d..39289328 100644 --- a/databases/interfaces.py +++ b/databases/interfaces.py @@ -56,7 +56,7 @@ def raw_connection(self) -> typing.Any: class TransactionBackend: - async def start(self, is_root: bool) -> None: + async def start(self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]) -> None: raise NotImplementedError() # pragma: no cover async def commit(self) -> None: diff --git a/tests/test_databases.py b/tests/test_databases.py index 99de7ae6..1ebc86e8 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -426,6 +426,43 @@ async def test_transaction_commit(database_url): assert len(results) == 1 +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_commit_serializable(database_url): + """ + Ensure that serializable transaction commit via extra parameters is supported. + """ + + database_url = DatabaseURL(database_url) + + if database_url.dialect != "postgresql": + pytest.skip("Test (currently) requires asyncpg") + + async def insert_concurrently(event): + async with Database(database_url) as database: + async with database.transaction(): + query = notes.insert().values(text="example1", completed=True) + await database.execute(query) + + event.set() + + async with Database(database_url) as database: + async with database.transaction(force_rollback=True, isolation="serializable"): + query = notes.select() + results = await database.fetch_all(query=query) + assert len(results) == 0 + + event = asyncio.Event() + asyncio.create_task(insert_concurrently(event)) + await event.wait() + + query = notes.select() + results = await database.fetch_all(query=query) + assert len(results) == 0 + + await database.execute(notes.delete()) + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_transaction_rollback(database_url): From f43ab544cbadf700136343e740fc08ffb54072cf Mon Sep 17 00:00:00 2001 From: Phil Demetriou Date: Sat, 12 Sep 2020 15:45:13 +0300 Subject: [PATCH 2/7] Switching to Python 3.6-compatible asyncio primitives --- tests/test_databases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_databases.py b/tests/test_databases.py index 1ebc86e8..43e04d79 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -453,7 +453,7 @@ async def insert_concurrently(event): assert len(results) == 0 event = asyncio.Event() - asyncio.create_task(insert_concurrently(event)) + asyncio.ensure_future(insert_concurrently(event)) await event.wait() query = notes.select() From 76c90c7e7ea5e953d7fad58f45941bd5ee48e133 Mon Sep 17 00:00:00 2001 From: Phil Demetriou Date: Sat, 12 Sep 2020 16:20:35 +0300 Subject: [PATCH 3/7] Using native SQLAlchemy engine for independent queries in tests --- tests/test_databases.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/tests/test_databases.py b/tests/test_databases.py index 43e04d79..6395169f 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -438,13 +438,19 @@ async def test_transaction_commit_serializable(database_url): if database_url.dialect != "postgresql": pytest.skip("Test (currently) requires asyncpg") - async def insert_concurrently(event): - async with Database(database_url) as database: - async with database.transaction(): - query = notes.insert().values(text="example1", completed=True) - await database.execute(query) + def insert_independently(): + engine = sqlalchemy.create_engine(str(database_url)) + conn = engine.connect() + + query = notes.insert().values(text="example1", completed=True) + conn.execute(query) + + def delete_independently(): + engine = sqlalchemy.create_engine(str(database_url)) + conn = engine.connect() - event.set() + query = notes.delete() + conn.execute(query) async with Database(database_url) as database: async with database.transaction(force_rollback=True, isolation="serializable"): @@ -452,15 +458,13 @@ async def insert_concurrently(event): results = await database.fetch_all(query=query) assert len(results) == 0 - event = asyncio.Event() - asyncio.ensure_future(insert_concurrently(event)) - await event.wait() + insert_independently() query = notes.select() results = await database.fetch_all(query=query) assert len(results) == 0 - await database.execute(notes.delete()) + delete_independently() @pytest.mark.parametrize("database_url", DATABASE_URLS) From d73647efe3d9f9eb4e68fcd8edf8e21a4af23c32 Mon Sep 17 00:00:00 2001 From: Phil Demetriou Date: Sat, 12 Sep 2020 16:26:58 +0300 Subject: [PATCH 4/7] Excluding postgresql+aiopg in parameterized transaction test --- tests/test_databases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_databases.py b/tests/test_databases.py index 6395169f..35967327 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -435,7 +435,7 @@ async def test_transaction_commit_serializable(database_url): database_url = DatabaseURL(database_url) - if database_url.dialect != "postgresql": + if database_url.scheme != "postgresql": pytest.skip("Test (currently) requires asyncpg") def insert_independently(): From d4a76c20acf9616a9286df1b75c6cb188ce1e4f6 Mon Sep 17 00:00:00 2001 From: Phil Demetriou Date: Sat, 12 Sep 2020 16:28:25 +0300 Subject: [PATCH 5/7] Clarifying test skip comment in parameterized transaction test --- tests/test_databases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_databases.py b/tests/test_databases.py index 35967327..d95ebcd7 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -436,7 +436,7 @@ async def test_transaction_commit_serializable(database_url): database_url = DatabaseURL(database_url) if database_url.scheme != "postgresql": - pytest.skip("Test (currently) requires asyncpg") + pytest.skip("Test (currently) only supports asyncpg") def insert_independently(): engine = sqlalchemy.create_engine(str(database_url)) From 916d03e7619e43317acec87f239cb6ef9bdda615 Mon Sep 17 00:00:00 2001 From: Phil Demetriou Date: Sat, 12 Sep 2020 16:33:30 +0300 Subject: [PATCH 6/7] Adding missing type annotation --- databases/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/databases/core.py b/databases/core.py index 587e3f90..55edd1e1 100644 --- a/databases/core.py +++ b/databases/core.py @@ -305,7 +305,7 @@ def __init__( self, connection_callable: typing.Callable[[], Connection], force_rollback: bool, - **kwargs + **kwargs: typing.Any ) -> None: self._connection_callable = connection_callable self._force_rollback = force_rollback From f646820e89c35ba285c0106092c4c3e6d204bb33 Mon Sep 17 00:00:00 2001 From: Phil Demetriou Date: Sat, 12 Sep 2020 16:37:23 +0300 Subject: [PATCH 7/7] Formatting with black --- databases/backends/aiopg.py | 4 +++- databases/backends/mysql.py | 4 +++- databases/backends/postgres.py | 4 +++- databases/backends/sqlite.py | 4 +++- databases/core.py | 14 ++++++++++---- databases/interfaces.py | 4 +++- 6 files changed, 25 insertions(+), 9 deletions(-) diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py index 5992640b..2fecb1b5 100644 --- a/databases/backends/aiopg.py +++ b/databases/backends/aiopg.py @@ -215,7 +215,9 @@ def __init__(self, connection: AiopgConnection): self._is_root = False self._savepoint_name = "" - async def start(self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]) -> None: + async def start( + self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any] + ) -> None: assert self._connection._connection is not None, "Connection is not acquired" self._is_root = is_root cursor = await self._connection._connection.cursor() diff --git a/databases/backends/mysql.py b/databases/backends/mysql.py index 9125ecfb..b6476add 100644 --- a/databases/backends/mysql.py +++ b/databases/backends/mysql.py @@ -206,7 +206,9 @@ def __init__(self, connection: MySQLConnection): self._is_root = False self._savepoint_name = "" - async def start(self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]) -> None: + async def start( + self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any] + ) -> None: assert self._connection._connection is not None, "Connection is not acquired" self._is_root = is_root if self._is_root: diff --git a/databases/backends/postgres.py b/databases/backends/postgres.py index faa28d90..63fc40b2 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/postgres.py @@ -278,7 +278,9 @@ def __init__(self, connection: PostgresConnection): None ) # type: typing.Optional[asyncpg.transaction.Transaction] - async def start(self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]) -> None: + async def start( + self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any] + ) -> None: assert self._connection._connection is not None, "Connection is not acquired" self._transaction = self._connection._connection.transaction(**extra_options) await self._transaction.start() diff --git a/databases/backends/sqlite.py b/databases/backends/sqlite.py index d6ef6eb8..28ceb6fb 100644 --- a/databases/backends/sqlite.py +++ b/databases/backends/sqlite.py @@ -178,7 +178,9 @@ def __init__(self, connection: SQLiteConnection): self._is_root = False self._savepoint_name = "" - async def start(self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]) -> None: + async def start( + self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any] + ) -> None: assert self._connection._connection is not None, "Connection is not acquired" self._is_root = is_root if self._is_root: diff --git a/databases/core.py b/databases/core.py index 55edd1e1..71e4a129 100644 --- a/databases/core.py +++ b/databases/core.py @@ -184,7 +184,9 @@ def connection(self) -> "Connection": self._connection_context.set(connection) return connection - def transaction(self, *, force_rollback: bool = False, **kwargs: typing.Any) -> "Transaction": + def transaction( + self, *, force_rollback: bool = False, **kwargs: typing.Any + ) -> "Transaction": return Transaction(self.connection, force_rollback=force_rollback, **kwargs) @contextlib.contextmanager @@ -276,7 +278,9 @@ async def iterate( async for record in self._connection.iterate(built_query): yield record - def transaction(self, *, force_rollback: bool = False, **kwargs: typing.Any) -> "Transaction": + def transaction( + self, *, force_rollback: bool = False, **kwargs: typing.Any + ) -> "Transaction": def connection_callable() -> Connection: return self @@ -305,7 +309,7 @@ def __init__( self, connection_callable: typing.Callable[[], Connection], force_rollback: bool, - **kwargs: typing.Any + **kwargs: typing.Any, ) -> None: self._connection_callable = connection_callable self._force_rollback = force_rollback @@ -357,7 +361,9 @@ async def start(self) -> "Transaction": async with self._connection._transaction_lock: is_root = not self._connection._transaction_stack await self._connection.__aenter__() - await self._transaction.start(is_root=is_root, extra_options=self._extra_options) + await self._transaction.start( + is_root=is_root, extra_options=self._extra_options + ) self._connection._transaction_stack.append(self) return self diff --git a/databases/interfaces.py b/databases/interfaces.py index 39289328..c786f42d 100644 --- a/databases/interfaces.py +++ b/databases/interfaces.py @@ -56,7 +56,9 @@ def raw_connection(self) -> typing.Any: class TransactionBackend: - async def start(self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]) -> None: + async def start( + self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any] + ) -> None: raise NotImplementedError() # pragma: no cover async def commit(self) -> None: