Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion databases/backends/aiopg.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ def __init__(self, connection: AiopgConnection):
self._is_root = False
self._savepoint_name = ""

async def start(self, is_root: bool) -> None:
async def start(
self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]
) -> None:
assert self._connection._connection is not None, "Connection is not acquired"
self._is_root = is_root
cursor = await self._connection._connection.cursor()
Expand Down
4 changes: 3 additions & 1 deletion databases/backends/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,9 @@ def __init__(self, connection: MySQLConnection):
self._is_root = False
self._savepoint_name = ""

async def start(self, is_root: bool) -> None:
async def start(
self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]
) -> None:
assert self._connection._connection is not None, "Connection is not acquired"
self._is_root = is_root
if self._is_root:
Expand Down
6 changes: 4 additions & 2 deletions databases/backends/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,11 @@ def __init__(self, connection: PostgresConnection):
None
) # type: typing.Optional[asyncpg.transaction.Transaction]

async def start(self, is_root: bool) -> None:
async def start(
self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]
) -> None:
assert self._connection._connection is not None, "Connection is not acquired"
self._transaction = self._connection._connection.transaction()
self._transaction = self._connection._connection.transaction(**extra_options)
await self._transaction.start()

async def commit(self) -> None:
Expand Down
4 changes: 3 additions & 1 deletion databases/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@ def __init__(self, connection: SQLiteConnection):
self._is_root = False
self._savepoint_name = ""

async def start(self, is_root: bool) -> None:
async def start(
self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]
) -> None:
assert self._connection._connection is not None, "Connection is not acquired"
self._is_root = is_root
if self._is_root:
Expand Down
18 changes: 13 additions & 5 deletions databases/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,10 @@ def connection(self) -> "Connection":
self._connection_context.set(connection)
return connection

def transaction(self, *, force_rollback: bool = False) -> "Transaction":
return Transaction(self.connection, force_rollback=force_rollback)
def transaction(
self, *, force_rollback: bool = False, **kwargs: typing.Any
) -> "Transaction":
return Transaction(self.connection, force_rollback=force_rollback, **kwargs)

@contextlib.contextmanager
def force_rollback(self) -> typing.Iterator[None]:
Expand Down Expand Up @@ -276,11 +278,13 @@ async def iterate(
async for record in self._connection.iterate(built_query):
yield record

def transaction(self, *, force_rollback: bool = False) -> "Transaction":
def transaction(
self, *, force_rollback: bool = False, **kwargs: typing.Any
) -> "Transaction":
def connection_callable() -> Connection:
return self

return Transaction(connection_callable, force_rollback)
return Transaction(connection_callable, force_rollback, **kwargs)

@property
def raw_connection(self) -> typing.Any:
Expand All @@ -305,9 +309,11 @@ def __init__(
self,
connection_callable: typing.Callable[[], Connection],
force_rollback: bool,
**kwargs: typing.Any,
) -> None:
self._connection_callable = connection_callable
self._force_rollback = force_rollback
self._extra_options = kwargs

async def __aenter__(self) -> "Transaction":
"""
Expand Down Expand Up @@ -355,7 +361,9 @@ async def start(self) -> "Transaction":
async with self._connection._transaction_lock:
is_root = not self._connection._transaction_stack
await self._connection.__aenter__()
await self._transaction.start(is_root=is_root)
await self._transaction.start(
is_root=is_root, extra_options=self._extra_options
)
self._connection._transaction_stack.append(self)
return self

Expand Down
4 changes: 3 additions & 1 deletion databases/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def raw_connection(self) -> typing.Any:


class TransactionBackend:
async def start(self, is_root: bool) -> None:
async def start(
self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]
) -> None:
raise NotImplementedError() # pragma: no cover

async def commit(self) -> None:
Expand Down
41 changes: 41 additions & 0 deletions tests/test_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,47 @@ async def test_transaction_commit(database_url):
assert len(results) == 1


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_transaction_commit_serializable(database_url):
"""
Ensure that serializable transaction commit via extra parameters is supported.
"""

database_url = DatabaseURL(database_url)

if database_url.scheme != "postgresql":
pytest.skip("Test (currently) only supports asyncpg")

def insert_independently():
engine = sqlalchemy.create_engine(str(database_url))
conn = engine.connect()

query = notes.insert().values(text="example1", completed=True)
conn.execute(query)

def delete_independently():
engine = sqlalchemy.create_engine(str(database_url))
conn = engine.connect()

query = notes.delete()
conn.execute(query)

async with Database(database_url) as database:
async with database.transaction(force_rollback=True, isolation="serializable"):
query = notes.select()
results = await database.fetch_all(query=query)
assert len(results) == 0

insert_independently()

query = notes.select()
results = await database.fetch_all(query=query)
assert len(results) == 0

delete_independently()


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_transaction_rollback(database_url):
Expand Down