From 90c33da816dcaeaa834871c70b6761b215bb2405 Mon Sep 17 00:00:00 2001 From: Zev Isert Date: Fri, 7 Apr 2023 12:29:40 -0700 Subject: [PATCH 01/17] fix: incorrect concurrent usage of connection and transaction --- databases/core.py | 52 ++++++++++++++++++++++++++++------------------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/databases/core.py b/databases/core.py index 8394ab5c..43ff192e 100644 --- a/databases/core.py +++ b/databases/core.py @@ -11,7 +11,7 @@ from sqlalchemy.sql import ClauseElement from databases.importer import import_from_string -from databases.interfaces import DatabaseBackend, Record +from databases.interfaces import DatabaseBackend, Record, TransactionBackend try: # pragma: no cover import click @@ -344,6 +344,9 @@ def __init__( self._connection_callable = connection_callable self._force_rollback = force_rollback self._extra_options = kwargs + self._transaction_context: ContextVar[TransactionBackend | None] = ContextVar( + "transaction_context" + ) async def __aenter__(self) -> "Transaction": """ @@ -385,31 +388,38 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: return wrapper # type: ignore async def start(self) -> "Transaction": - self._connection = self._connection_callable() - self._transaction = self._connection._connection.transaction() - - async with self._connection._transaction_lock: - is_root = not self._connection._transaction_stack - await self._connection.__aenter__() - await self._transaction.start( - is_root=is_root, extra_options=self._extra_options - ) - self._connection._transaction_stack.append(self) + connection = self._connection_callable() + transaction = connection._connection.transaction() + self._transaction_context.set(transaction) + + async with connection._transaction_lock: + is_root = not connection._transaction_stack + await connection.__aenter__() + await transaction.start(is_root=is_root, extra_options=self._extra_options) + connection._transaction_stack.append(self) return self async def commit(self) -> None: - async with self._connection._transaction_lock: - assert self._connection._transaction_stack[-1] is self - self._connection._transaction_stack.pop() - await self._transaction.commit() - await self._connection.__aexit__() + connection = self._connection_callable() + transaction = self._transaction_context.get() + assert transaction is not None, "Transaction not found in current task" + async with connection._transaction_lock: + assert connection._transaction_stack[-1] is self + connection._transaction_stack.pop() + await transaction.commit() + await connection.__aexit__() + self._transaction_context.set(None) async def rollback(self) -> None: - async with self._connection._transaction_lock: - assert self._connection._transaction_stack[-1] is self - self._connection._transaction_stack.pop() - await self._transaction.rollback() - await self._connection.__aexit__() + connection = self._connection_callable() + transaction = self._transaction_context.get() + assert transaction is not None, "Transaction not found in current task" + async with connection._transaction_lock: + assert connection._transaction_stack[-1] is self + connection._transaction_stack.pop() + await transaction.rollback() + await connection.__aexit__() + self._transaction_context.set(None) class _EmptyNetloc(str): From bea6629187f3771e3bc564899cf359e69bf556f7 Mon Sep 17 00:00:00 2001 From: Zev Isert Date: Sun, 9 Apr 2023 19:57:32 -0700 Subject: [PATCH 02/17] refactor: rename contextvar class attributes, add some explaination comments --- databases/core.py | 50 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/databases/core.py b/databases/core.py index 43ff192e..71fbcae2 100644 --- a/databases/core.py +++ b/databases/core.py @@ -5,6 +5,7 @@ import typing from contextvars import ContextVar from types import TracebackType +from typing import Optional from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit from sqlalchemy import text @@ -63,8 +64,13 @@ def __init__( assert issubclass(backend_cls, DatabaseBackend) self._backend = backend_cls(self.url, **self.options) - # Connections are stored as task-local state. - self._connection_context: ContextVar = ContextVar("connection_context") + # Connections are stored as task-local state, and cannot be garbage collected, + # since the immutable global Context stores a strong reference to each ContextVar + # that is created. We need these local ContextVars since two Database objects + # could run in the same asyncio.Task with connections to different databases. + self._connection_contextvar: ContextVar[Optional["Connection"]] = ContextVar( + f"databases:Database:{id(self)}" + ) # When `force_rollback=True` is used, we use a single global # connection, within a transaction that always rolls back. @@ -113,7 +119,7 @@ async def disconnect(self) -> None: self._global_transaction = None self._global_connection = None else: - self._connection_context = ContextVar("connection_context") + self._connection_contextvar.set(None) await self._backend.disconnect() logger.info( @@ -187,12 +193,12 @@ def connection(self) -> "Connection": if self._global_connection is not None: return self._global_connection - try: - return self._connection_context.get() - except LookupError: + connection = self._connection_contextvar.get(default=None) + if connection is None: connection = Connection(self._backend) - self._connection_context.set(connection) - return connection + self._connection_contextvar.set(connection) + + return connection def transaction( self, *, force_rollback: bool = False, **kwargs: typing.Any @@ -344,9 +350,15 @@ def __init__( self._connection_callable = connection_callable self._force_rollback = force_rollback self._extra_options = kwargs - self._transaction_context: ContextVar[TransactionBackend | None] = ContextVar( - "transaction_context" - ) + + # This ContextVar can never be garbage collected - similar to the ContextVar + # at Database._connection_contextvar - since the current Context has a strong + # reference to every ContextVar that is created. We need local ContextVars since + # there may be multiple (even nested) transactions in a single asyncio.Task, + # which each need their own unique TransactionBackend object. + self._transaction_contextvar: ContextVar[ + Optional[TransactionBackend] + ] = ContextVar(f"databases:Transaction:{id(self)}") async def __aenter__(self) -> "Transaction": """ @@ -390,7 +402,11 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: async def start(self) -> "Transaction": connection = self._connection_callable() transaction = connection._connection.transaction() - self._transaction_context.set(transaction) + + # Cannot store returned reset token anywhere, for the same reason + # we need a ContextVar in the first place - `self` is not + # a safe object on which to store references for concurrent code. + self._transaction_contextvar.set(transaction) async with connection._transaction_lock: is_root = not connection._transaction_stack @@ -401,25 +417,27 @@ async def start(self) -> "Transaction": async def commit(self) -> None: connection = self._connection_callable() - transaction = self._transaction_context.get() + transaction = self._transaction_contextvar.get(default=None) assert transaction is not None, "Transaction not found in current task" async with connection._transaction_lock: assert connection._transaction_stack[-1] is self connection._transaction_stack.pop() await transaction.commit() await connection.__aexit__() - self._transaction_context.set(None) + # Have no reset token, set to None instead + self._transaction_contextvar.set(None) async def rollback(self) -> None: connection = self._connection_callable() - transaction = self._transaction_context.get() + transaction = self._transaction_contextvar.get(default=None) assert transaction is not None, "Transaction not found in current task" async with connection._transaction_lock: assert connection._transaction_stack[-1] is self connection._transaction_stack.pop() await transaction.rollback() await connection.__aexit__() - self._transaction_context.set(None) + # Have no reset token, set to None instead + self._transaction_contextvar.set(None) class _EmptyNetloc(str): From c9e34640b816215035993eaf695258cd0194281c Mon Sep 17 00:00:00 2001 From: Zev Isert Date: Sun, 9 Apr 2023 22:37:18 -0700 Subject: [PATCH 03/17] fix: contextvar.get takes no keyword arguments --- databases/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/databases/core.py b/databases/core.py index 71fbcae2..7575fce5 100644 --- a/databases/core.py +++ b/databases/core.py @@ -193,7 +193,7 @@ def connection(self) -> "Connection": if self._global_connection is not None: return self._global_connection - connection = self._connection_contextvar.get(default=None) + connection = self._connection_contextvar.get(None) if connection is None: connection = Connection(self._backend) self._connection_contextvar.set(connection) @@ -417,7 +417,7 @@ async def start(self) -> "Transaction": async def commit(self) -> None: connection = self._connection_callable() - transaction = self._transaction_contextvar.get(default=None) + transaction = self._transaction_contextvar.get(None) assert transaction is not None, "Transaction not found in current task" async with connection._transaction_lock: assert connection._transaction_stack[-1] is self @@ -429,7 +429,7 @@ async def commit(self) -> None: async def rollback(self) -> None: connection = self._connection_callable() - transaction = self._transaction_contextvar.get(default=None) + transaction = self._transaction_contextvar.get(None) assert transaction is not None, "Transaction not found in current task" async with connection._transaction_lock: assert connection._transaction_stack[-1] is self From f3078aac575b4a2ad64d0fc9d1f41e365d583574 Mon Sep 17 00:00:00 2001 From: Zev Isert Date: Tue, 11 Apr 2023 12:47:14 -0700 Subject: [PATCH 04/17] test: add concurrent task tests --- tests/test_databases.py | 55 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 49 insertions(+), 6 deletions(-) diff --git a/tests/test_databases.py b/tests/test_databases.py index a7545e31..b286a27a 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -961,16 +961,59 @@ async def test_database_url_interface(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_concurrent_access_on_single_connection(database_url): - database_url = DatabaseURL(database_url) - if database_url.dialect != "postgresql": - pytest.skip("Test requires `pg_sleep()`") - async with Database(database_url, force_rollback=True) as database: async def db_lookup(): - await database.fetch_one("SELECT pg_sleep(1)") + await database.fetch_one("SELECT 1 AS value") + + await asyncio.gather( + db_lookup(), + db_lookup(), + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_concurrent_transactions_on_single_connection(database_url: str): + async with Database(database_url) as database: + + @database.transaction() + async def db_lookup(): + await database.fetch_one(query="SELECT 1 AS value") + + await asyncio.gather( + db_lookup(), + db_lookup(), + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_concurrent_tasks_on_single_connection(database_url: str): + async with Database(database_url) as database: + + async def db_lookup(): + await database.fetch_one(query="SELECT 1 AS value") + + await asyncio.gather( + asyncio.create_task(db_lookup()), + asyncio.create_task(db_lookup()), + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_concurrent_task_transactions_on_single_connection(database_url: str): + async with Database(database_url) as database: + + @database.transaction() + async def db_lookup(): + await database.fetch_one(query="SELECT 1 AS value") - await asyncio.gather(db_lookup(), db_lookup()) + await asyncio.gather( + asyncio.create_task(db_lookup()), + asyncio.create_task(db_lookup()), + ) @pytest.mark.parametrize("database_url", DATABASE_URLS) From 75969d343b505a25f6f960c1d22ac4367e45a83f Mon Sep 17 00:00:00 2001 From: Zev Isert Date: Tue, 11 Apr 2023 14:51:36 -0700 Subject: [PATCH 05/17] feat: use ContextVar[dict] to track connections and transactions per task --- databases/core.py | 82 ++++++++++++++++++++++++++--------------------- 1 file changed, 46 insertions(+), 36 deletions(-) diff --git a/databases/core.py b/databases/core.py index 7575fce5..555c25d7 100644 --- a/databases/core.py +++ b/databases/core.py @@ -5,7 +5,7 @@ import typing from contextvars import ContextVar from types import TracebackType -from typing import Optional +from typing import Dict, Optional from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit from sqlalchemy import text @@ -35,6 +35,21 @@ logger = logging.getLogger("databases") +# Connections are stored as task-local state, but care must be taken to ensure +# that two database instances in the same task overwrite each other's connections. +# For this reason, key comprises the database instance and the current task. +_connection_contextmap: ContextVar[ + Dict[tuple["Database", asyncio.Task], "Connection"] +] = ContextVar("databases:Connection") + + +def _get_connection_contextmap() -> Dict[tuple["Database", asyncio.Task], "Connection"]: + connections = _connection_contextmap.get(None) + if connections is None: + connections = {} + _connection_contextmap.set(connections) + return connections + class Database: SUPPORTED_BACKENDS = { @@ -64,14 +79,6 @@ def __init__( assert issubclass(backend_cls, DatabaseBackend) self._backend = backend_cls(self.url, **self.options) - # Connections are stored as task-local state, and cannot be garbage collected, - # since the immutable global Context stores a strong reference to each ContextVar - # that is created. We need these local ContextVars since two Database objects - # could run in the same asyncio.Task with connections to different databases. - self._connection_contextvar: ContextVar[Optional["Connection"]] = ContextVar( - f"databases:Database:{id(self)}" - ) - # When `force_rollback=True` is used, we use a single global # connection, within a transaction that always rolls back. self._global_connection: typing.Optional[Connection] = None @@ -119,7 +126,10 @@ async def disconnect(self) -> None: self._global_transaction = None self._global_connection = None else: - self._connection_contextvar.set(None) + task = asyncio.current_task() + connections = _get_connection_contextmap() + if (self, task) in connections: + del connections[self, task] await self._backend.disconnect() logger.info( @@ -193,12 +203,12 @@ def connection(self) -> "Connection": if self._global_connection is not None: return self._global_connection - connection = self._connection_contextvar.get(None) - if connection is None: - connection = Connection(self._backend) - self._connection_contextvar.set(connection) + task = asyncio.current_task() + connections = _get_connection_contextmap() + if (self, task) not in connections: + connections[self, task] = Connection(self._backend) - return connection + return connections[self, task] def transaction( self, *, force_rollback: bool = False, **kwargs: typing.Any @@ -339,6 +349,19 @@ def _build_query( _CallableType = typing.TypeVar("_CallableType", bound=typing.Callable) +_transaction_contextmap: ContextVar[ + Dict["Transaction", TransactionBackend] +] = ContextVar("databases:Transactions") + + +def _get_transaction_contextmap() -> Dict["Transaction", TransactionBackend]: + transactions = _transaction_contextmap.get(None) + if transactions is None: + transactions = {} + _transaction_contextmap.set(transactions) + + return transactions + class Transaction: def __init__( @@ -351,15 +374,6 @@ def __init__( self._force_rollback = force_rollback self._extra_options = kwargs - # This ContextVar can never be garbage collected - similar to the ContextVar - # at Database._connection_contextvar - since the current Context has a strong - # reference to every ContextVar that is created. We need local ContextVars since - # there may be multiple (even nested) transactions in a single asyncio.Task, - # which each need their own unique TransactionBackend object. - self._transaction_contextvar: ContextVar[ - Optional[TransactionBackend] - ] = ContextVar(f"databases:Transaction:{id(self)}") - async def __aenter__(self) -> "Transaction": """ Called when entering `async with database.transaction()` @@ -402,12 +416,8 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: async def start(self) -> "Transaction": connection = self._connection_callable() transaction = connection._connection.transaction() - - # Cannot store returned reset token anywhere, for the same reason - # we need a ContextVar in the first place - `self` is not - # a safe object on which to store references for concurrent code. - self._transaction_contextvar.set(transaction) - + transactions = _get_transaction_contextmap() + transactions[self] = transaction async with connection._transaction_lock: is_root = not connection._transaction_stack await connection.__aenter__() @@ -417,27 +427,27 @@ async def start(self) -> "Transaction": async def commit(self) -> None: connection = self._connection_callable() - transaction = self._transaction_contextvar.get(None) + transactions = _get_transaction_contextmap() + transaction = transactions.get(self, None) assert transaction is not None, "Transaction not found in current task" async with connection._transaction_lock: assert connection._transaction_stack[-1] is self connection._transaction_stack.pop() await transaction.commit() await connection.__aexit__() - # Have no reset token, set to None instead - self._transaction_contextvar.set(None) + del transactions[self] async def rollback(self) -> None: connection = self._connection_callable() - transaction = self._transaction_contextvar.get(None) + transactions = _get_transaction_contextmap() + transaction = transactions.get(self, None) assert transaction is not None, "Transaction not found in current task" async with connection._transaction_lock: assert connection._transaction_stack[-1] is self connection._transaction_stack.pop() await transaction.rollback() await connection.__aexit__() - # Have no reset token, set to None instead - self._transaction_contextvar.set(None) + del transactions[self] class _EmptyNetloc(str): From 4cd74519e49a69117e6c15cd679a54d2cb00259d Mon Sep 17 00:00:00 2001 From: Zev Isert Date: Tue, 11 Apr 2023 14:53:17 -0700 Subject: [PATCH 06/17] test: check multiple databases in the same task use independant connections --- tests/test_databases.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/tests/test_databases.py b/tests/test_databases.py index b286a27a..10f0b856 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -5,7 +5,7 @@ import os import re from unittest.mock import MagicMock, patch - +import itertools import pytest import sqlalchemy @@ -789,15 +789,16 @@ async def test_connect_and_disconnect(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter -async def test_connection_context(database_url): - """ - Test connection contexts are task-local. - """ +async def test_connection_context_same_task(database_url): async with Database(database_url) as database: async with database.connection() as connection_1: async with database.connection() as connection_2: assert connection_1 is connection_2 + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_context_multiple_tasks(database_url): async with Database(database_url) as database: connection_1 = None connection_2 = None @@ -817,9 +818,8 @@ async def get_connection_2(): connection_2 = connection await test_complete.wait() - loop = asyncio.get_event_loop() - task_1 = loop.create_task(get_connection_1()) - task_2 = loop.create_task(get_connection_2()) + task_1 = asyncio.create_task(get_connection_1()) + task_2 = asyncio.create_task(get_connection_2()) while connection_1 is None or connection_2 is None: await asyncio.sleep(0.000001) assert connection_1 is not connection_2 @@ -828,6 +828,20 @@ async def get_connection_2(): await task_2 +@pytest.mark.parametrize( + "database_url1,database_url2", + ( + pytest.param(db1, db2, id=f"{db1} | {db2}") + for (db1, db2) in itertools.combinations(DATABASE_URLS, 2) + ), +) +@async_adapter +async def test_connection_context_multiple_databases(database_url1, database_url2): + async with Database(database_url1) as database1: + async with Database(database_url2) as database2: + assert database1.connection() is not database2.connection() + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_connection_context_with_raw_connection(database_url): From e4c95a7aedffb7d43721a38356895e8c4cac67b3 Mon Sep 17 00:00:00 2001 From: Zev Isert Date: Tue, 11 Apr 2023 15:13:21 -0700 Subject: [PATCH 07/17] chore: changes for linting and typechecking --- databases/core.py | 4 +++- tests/test_databases.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/databases/core.py b/databases/core.py index 555c25d7..5e0f3bc1 100644 --- a/databases/core.py +++ b/databases/core.py @@ -5,7 +5,7 @@ import typing from contextvars import ContextVar from types import TracebackType -from typing import Dict, Optional +from typing import Dict from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit from sqlalchemy import text @@ -127,6 +127,7 @@ async def disconnect(self) -> None: self._global_connection = None else: task = asyncio.current_task() + assert task is not None, "Not running in an asyncio task" connections = _get_connection_contextmap() if (self, task) in connections: del connections[self, task] @@ -204,6 +205,7 @@ def connection(self) -> "Connection": return self._global_connection task = asyncio.current_task() + assert task is not None, "Not running in an asyncio task" connections = _get_connection_contextmap() if (self, task) not in connections: connections[self, task] = Connection(self._backend) diff --git a/tests/test_databases.py b/tests/test_databases.py index 10f0b856..7f427372 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -2,10 +2,11 @@ import datetime import decimal import functools +import itertools import os import re from unittest.mock import MagicMock, patch -import itertools + import pytest import sqlalchemy From a38e135d541bdaa7e3c9d8ce50af07d61927cb82 Mon Sep 17 00:00:00 2001 From: Zev Isert Date: Tue, 11 Apr 2023 15:16:59 -0700 Subject: [PATCH 08/17] chore: use typing.Tuple for lower python version compatibility --- databases/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/databases/core.py b/databases/core.py index 5e0f3bc1..24fc8028 100644 --- a/databases/core.py +++ b/databases/core.py @@ -5,7 +5,7 @@ import typing from contextvars import ContextVar from types import TracebackType -from typing import Dict +from typing import Dict, Tuple from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit from sqlalchemy import text @@ -39,11 +39,11 @@ # that two database instances in the same task overwrite each other's connections. # For this reason, key comprises the database instance and the current task. _connection_contextmap: ContextVar[ - Dict[tuple["Database", asyncio.Task], "Connection"] + Dict[Tuple["Database", asyncio.Task], "Connection"] ] = ContextVar("databases:Connection") -def _get_connection_contextmap() -> Dict[tuple["Database", asyncio.Task], "Connection"]: +def _get_connection_contextmap() -> Dict[Tuple["Database", asyncio.Task], "Connection"]: connections = _connection_contextmap.get(None) if connections is None: connections = {} From 460f72eb241f0e8c009d1530b67ad874b9f97795 Mon Sep 17 00:00:00 2001 From: Zev Isert Date: Tue, 11 Apr 2023 15:37:02 -0700 Subject: [PATCH 09/17] docs: update comment on _connection_contextmap --- databases/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/databases/core.py b/databases/core.py index 24fc8028..0783819b 100644 --- a/databases/core.py +++ b/databases/core.py @@ -36,8 +36,8 @@ logger = logging.getLogger("databases") # Connections are stored as task-local state, but care must be taken to ensure -# that two database instances in the same task overwrite each other's connections. -# For this reason, key comprises the database instance and the current task. +# that two database instances in the same task do not overwrite each other's connections. +# For this reason, the dict key comprises the database instance and the current task. _connection_contextmap: ContextVar[ Dict[Tuple["Database", asyncio.Task], "Connection"] ] = ContextVar("databases:Connection") From 2d4554d79de684068f9703d4f8435622f8a4b94f Mon Sep 17 00:00:00 2001 From: Zanie Date: Sun, 16 Apr 2023 10:12:44 -0500 Subject: [PATCH 10/17] Update `Connection` and `Transaction` to be robust to concurrent use --- databases/core.py | 59 ++++++++++++++++++++++++++--------------------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/databases/core.py b/databases/core.py index 8394ab5c..a67aa793 100644 --- a/databases/core.py +++ b/databases/core.py @@ -11,7 +11,7 @@ from sqlalchemy.sql import ClauseElement from databases.importer import import_from_string -from databases.interfaces import DatabaseBackend, Record +from databases.interfaces import DatabaseBackend, Record, TransactionBackend try: # pragma: no cover import click @@ -63,8 +63,8 @@ def __init__( assert issubclass(backend_cls, DatabaseBackend) self._backend = backend_cls(self.url, **self.options) - # Connections are stored as task-local state. - self._connection_context: ContextVar = ContextVar("connection_context") + # Connections are stored per asyncio task + self._connections: typing.Dict[typing.Optional[asyncio.Task], Connection]= {} # When `force_rollback=True` is used, we use a single global # connection, within a transaction that always rolls back. @@ -113,7 +113,7 @@ async def disconnect(self) -> None: self._global_transaction = None self._global_connection = None else: - self._connection_context = ContextVar("connection_context") + self._connections.pop(asyncio.current_task(), None) await self._backend.disconnect() logger.info( @@ -187,12 +187,12 @@ def connection(self) -> "Connection": if self._global_connection is not None: return self._global_connection - try: - return self._connection_context.get() - except LookupError: + current_task = asyncio.current_task() + if current_task not in self._connections: connection = Connection(self._backend) - self._connection_context.set(connection) - return connection + self._connections[current_task] = connection + + return self._connections[current_task] def transaction( self, *, force_rollback: bool = False, **kwargs: typing.Any @@ -344,6 +344,9 @@ def __init__( self._connection_callable = connection_callable self._force_rollback = force_rollback self._extra_options = kwargs + + # Transactions are stored per asyncio task + self._transactions: typing.Dict[typing.Optional[asyncio.Task], "TransactionBackend"]= {} async def __aenter__(self) -> "Transaction": """ @@ -385,31 +388,35 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: return wrapper # type: ignore async def start(self) -> "Transaction": - self._connection = self._connection_callable() - self._transaction = self._connection._connection.transaction() + connection = self._connection_callable() + transaction = self._transactions[asyncio.current_task()] = connection._connection.transaction() - async with self._connection._transaction_lock: - is_root = not self._connection._transaction_stack - await self._connection.__aenter__() - await self._transaction.start( + async with connection._transaction_lock: + is_root = not connection._transaction_stack + await connection.__aenter__() + await transaction.start( is_root=is_root, extra_options=self._extra_options ) - self._connection._transaction_stack.append(self) + connection._transaction_stack.append(self) return self async def commit(self) -> None: - async with self._connection._transaction_lock: - assert self._connection._transaction_stack[-1] is self - self._connection._transaction_stack.pop() - await self._transaction.commit() - await self._connection.__aexit__() + connection = self._connection_callable() + transaction = self._transactions[asyncio.current_task()] + async with connection._transaction_lock: + assert connection._transaction_stack[-1] is self + connection._transaction_stack.pop() + await transaction.commit() + await connection.__aexit__() async def rollback(self) -> None: - async with self._connection._transaction_lock: - assert self._connection._transaction_stack[-1] is self - self._connection._transaction_stack.pop() - await self._transaction.rollback() - await self._connection.__aexit__() + connection = self._connection_callable() + transaction = self._transactions[asyncio.current_task()] + async with connection._transaction_lock: + assert connection._transaction_stack[-1] is self + connection._transaction_stack.pop() + await transaction.rollback() + await connection.__aexit__() class _EmptyNetloc(str): From 8370299b7b8bd4ed791fbdcf2805321ea12aac9f Mon Sep 17 00:00:00 2001 From: Zev Isert Date: Tue, 18 Apr 2023 12:16:58 -0700 Subject: [PATCH 11/17] chore: remove optional annotation on asyncio.Task --- databases/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/databases/core.py b/databases/core.py index 17a12e27..40fbaaed 100644 --- a/databases/core.py +++ b/databases/core.py @@ -63,7 +63,7 @@ def __init__( self._backend = backend_cls(self.url, **self.options) # Connections are stored per asyncio task - self._connections: typing.Dict[typing.Optional[asyncio.Task], Connection] = {} + self._connections: typing.Dict[asyncio.Task, Connection] = {} # When `force_rollback=True` is used, we use a single global # connection, within a transaction that always rolls back. From 1d4896fa35cdbe5a94e127bf76b02300186bcc9e Mon Sep 17 00:00:00 2001 From: Zev Isert Date: Tue, 23 May 2023 17:53:54 -0700 Subject: [PATCH 12/17] test: add new tests for upcoming contextvar inheritance/isolation and weakref cleanup --- tests/test_databases.py | 209 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 208 insertions(+), 1 deletion(-) diff --git a/tests/test_databases.py b/tests/test_databases.py index 7f427372..e13c7930 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -5,8 +5,9 @@ import itertools import os import re +import gc from unittest.mock import MagicMock, patch - +from typing import MutableMapping import pytest import sqlalchemy @@ -478,6 +479,212 @@ async def test_transaction_commit(database_url): assert len(results) == 1 +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_child_task_interaction(database_url): + """ + Ensure that child tasks may influence inherited transactions. + """ + # This is an practical example of the next test. + async with Database(database_url) as database: + async with database.transaction(): + # Create a note + await database.execute( + notes.insert().values(id=1, text="setup", completed=True) + ) + + # Change the note from the same task + await database.execute( + notes.update().where(notes.c.id == 1).values(text="prior") + ) + + # Confirm the change + result = await database.fetch_one(notes.select().where(notes.c.id == 1)) + assert result.text == "prior" + + async def run_update_from_child_task(): + # Chage the note from a child task + await database.execute( + notes.update().where(notes.c.id == 1).values(text="test") + ) + + await asyncio.create_task(run_update_from_child_task()) + + # Confirm the child's change + result = await database.fetch_one(notes.select().where(notes.c.id == 1)) + assert result.text == "test" + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_child_task_inheritance(database_url): + """ + Ensure that transactions are inherited by child tasks. + """ + async with Database(database_url) as database: + + async def check_transaction(transaction, active_transaction): + # Should have inherited the same transaction backend from the parent task + assert transaction._transaction is active_transaction + + async with database.transaction() as transaction: + await asyncio.create_task( + check_transaction(transaction, transaction._transaction) + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_sibling_task_isolation(database_url): + """ + Ensure that transactions are isolated between sibling tasks. + """ + start = asyncio.Event() + end = asyncio.Event() + + async with Database(database_url) as database: + + async def check_transaction(transaction): + await start.wait() + # Parent task is now in a transaction, we should not + # see its transaction backend since this task was + # _started_ in a context where no transaction was active. + assert transaction._transaction is None + end.set() + + transaction = database.transaction() + assert transaction._transaction is None + task = asyncio.create_task(check_transaction(transaction)) + + async with transaction: + start.set() + assert transaction._transaction is not None + await end.wait() + + # Cleanup for "Task not awaited" warning + await task + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_context_cleanup_contextmanager(database_url): + """ + Ensure that contextvar connections are not persisted unecessarily. + """ + from databases.core import _ACTIVE_CONNECTIONS + + assert _ACTIVE_CONNECTIONS.get() is None + + async with Database(database_url) as database: + # .connect is lazy, it doesn't create a Connection, but .connection does + connection = database.connection() + + open_connections = _ACTIVE_CONNECTIONS.get() + assert isinstance(open_connections, MutableMapping) + assert open_connections.get(database) is connection + + # Context manager closes, open_connections is cleaned up + open_connections = _ACTIVE_CONNECTIONS.get() + assert isinstance(open_connections, MutableMapping) + assert open_connections.get(database, None) is None + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_context_cleanup_garbagecollector(database_url): + """ + Ensure that contextvar connections are not persisted unecessarily, even + if exit handlers are not called. + """ + from databases.core import _ACTIVE_CONNECTIONS + + assert _ACTIVE_CONNECTIONS.get() is None + + database = Database(database_url) + await database.connect() + connection = database.connection() + + # Should be tracking the connection + open_connections = _ACTIVE_CONNECTIONS.get() + assert isinstance(open_connections, MutableMapping) + assert open_connections.get(database) is connection + + # neither .disconnect nor .__aexit__ are called before deleting the reference + del database + gc.collect() + + # Should have dropped reference to connection, even without proper cleanup + open_connections = _ACTIVE_CONNECTIONS.get() + assert isinstance(open_connections, MutableMapping) + assert len(open_connections) == 0 + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_cleanup_contextmanager(database_url): + """ + Ensure that contextvar transactions are not persisted unecessarily. + """ + from databases.core import _ACTIVE_TRANSACTIONS + + assert _ACTIVE_TRANSACTIONS.get() is None + + async with Database(database_url) as database: + async with database.transaction() as transaction: + + open_transactions = _ACTIVE_TRANSACTIONS.get() + assert isinstance(open_transactions, MutableMapping) + assert open_transactions.get(transaction) is transaction._transaction + + # Context manager closes, open_transactions is cleaned up + open_transactions = _ACTIVE_TRANSACTIONS.get() + assert isinstance(open_transactions, MutableMapping) + assert open_transactions.get(transaction, None) is None + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_cleanup_garbagecollector(database_url): + """ + Ensure that contextvar transactions are not persisted unecessarily, even + if exit handlers are not called. + + This test should be an XFAIL, but cannot be due to the way that is hangs + during teardown. + """ + from databases.core import _ACTIVE_TRANSACTIONS + + assert _ACTIVE_TRANSACTIONS.get() is None + + async with Database(database_url) as database: + transaction = database.transaction() + await transaction.start() + + # Should be tracking the transaction + open_transactions = _ACTIVE_TRANSACTIONS.get() + assert isinstance(open_transactions, MutableMapping) + assert open_transactions.get(transaction) is transaction._transaction + + # neither .commit, .rollback, nor .__aexit__ are called + del transaction + gc.collect() + + # TODO(zevisert,review): Could skip instead of using the logic below + # A strong reference to the transaction is kept alive by the connection's + # ._transaction_stack, so it is still be tracked at this point. + assert len(open_transactions) == 1 + + # If that were magically cleared, the transaction would be cleaned up, + # but as it stands this always causes a hang during teardown at + # `Database(...).disconnect()` if the transaction is not closed. + transaction = database.connection()._transaction_stack[-1] + await transaction.rollback() + del transaction + + # Now with the transaction rolled-back, it should be cleaned up. + assert len(open_transactions) == 0 + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_transaction_commit_serializable(database_url): From 02a9acb6022d10b9483cf65671b0b230b302e19a Mon Sep 17 00:00:00 2001 From: Zev Isert Date: Tue, 23 May 2023 17:55:48 -0700 Subject: [PATCH 13/17] feat: reimplement concurrency system with contextvar and weakmap --- databases/core.py | 126 +++++++++++++++++++++++++++++----------------- 1 file changed, 79 insertions(+), 47 deletions(-) diff --git a/databases/core.py b/databases/core.py index 40fbaaed..7cdf5921 100644 --- a/databases/core.py +++ b/databases/core.py @@ -1,11 +1,12 @@ import asyncio import contextlib +from contextvars import ContextVar import functools import logging import typing from types import TracebackType from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit - +import weakref from sqlalchemy import text from sqlalchemy.sql import ClauseElement @@ -33,6 +34,14 @@ logger = logging.getLogger("databases") +_ACTIVE_CONNECTIONS: ContextVar[ + typing.Optional[weakref.WeakKeyDictionary["Database", "Connection"]] +] = ContextVar("databases:open_connections", default=None) + +_ACTIVE_TRANSACTIONS: ContextVar[ + typing.Optional[weakref.WeakKeyDictionary["Transaction", "TransactionBackend"]] +] = ContextVar("databases:open_transactions", default=None) + class Database: SUPPORTED_BACKENDS = { @@ -62,14 +71,31 @@ def __init__( assert issubclass(backend_cls, DatabaseBackend) self._backend = backend_cls(self.url, **self.options) - # Connections are stored per asyncio task - self._connections: typing.Dict[asyncio.Task, Connection] = {} - # When `force_rollback=True` is used, we use a single global # connection, within a transaction that always rolls back. self._global_connection: typing.Optional[Connection] = None self._global_transaction: typing.Optional[Transaction] = None + @property + def _connection(self) -> typing.Optional["Connection"]: + connections = _ACTIVE_CONNECTIONS.get() + if connections is None: + return None + + return connections.get(self, None) + + @_connection.setter + def _connection( + self, connection: typing.Optional["Connection"] + ) -> typing.Optional["Connection"]: + connections = _ACTIVE_CONNECTIONS.get() + if connections is None: + connections = weakref.WeakKeyDictionary() + _ACTIVE_CONNECTIONS.set(connections) + + connections[self] = connection + return connections[self] + async def connect(self) -> None: """ Establish the connection pool. @@ -112,10 +138,7 @@ async def disconnect(self) -> None: self._global_transaction = None self._global_connection = None else: - current_task = asyncio.current_task() - assert current_task is not None, "No currently running task" - if current_task in self._connections: - del self._connections[current_task] + self._connection = None await self._backend.disconnect() logger.info( @@ -189,12 +212,10 @@ def connection(self) -> "Connection": if self._global_connection is not None: return self._global_connection - current_task = asyncio.current_task() - assert current_task is not None, "No currently running task" - if current_task not in self._connections: - self._connections[current_task] = Connection(self._backend) + if not self._connection: + self._connection = Connection(self._backend) - return self._connections[current_task] + return self._connection def transaction( self, *, force_rollback: bool = False, **kwargs: typing.Any @@ -347,10 +368,30 @@ def __init__( self._force_rollback = force_rollback self._extra_options = kwargs - # Transactions are stored per asyncio task - self._transactions: typing.Dict[ - typing.Optional[asyncio.Task], TransactionBackend - ] = {} + @property + def _connection(self) -> "Connection": + # Returns the same connection if called multiple times + return self._connection_callable() + + @property + def _transaction(self) -> typing.Optional["TransactionBackend"]: + transactions = _ACTIVE_TRANSACTIONS.get() + if transactions is None: + return None + + return transactions.get(self, None) + + @_transaction.setter + def _transaction( + self, transaction: typing.Optional["TransactionBackend"] + ) -> typing.Optional["TransactionBackend"]: + transactions = _ACTIVE_TRANSACTIONS.get() + if transactions is None: + transactions = weakref.WeakKeyDictionary() + _ACTIVE_TRANSACTIONS.set(transactions) + + transactions[self] = transaction + return transactions[self] async def __aenter__(self) -> "Transaction": """ @@ -392,41 +433,32 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: return wrapper # type: ignore async def start(self) -> "Transaction": - connection = self._connection_callable() - current_task = asyncio.current_task() - assert current_task is not None, "No currently running task" - transaction = connection._connection.transaction() - self._transactions[current_task] = transaction - async with connection._transaction_lock: - is_root = not connection._transaction_stack - await connection.__aenter__() - await transaction.start(is_root=is_root, extra_options=self._extra_options) - connection._transaction_stack.append(self) + self._transaction = self._connection._connection.transaction() + + async with self._connection._transaction_lock: + is_root = not self._connection._transaction_stack + await self._connection.__aenter__() + await self._transaction.start( + is_root=is_root, extra_options=self._extra_options + ) + self._connection._transaction_stack.append(self) return self async def commit(self) -> None: - connection = self._connection_callable() - current_task = asyncio.current_task() - transaction = self._transactions.get(current_task, None) - assert transaction is not None, "Transaction not found in current task" - async with connection._transaction_lock: - assert connection._transaction_stack[-1] is self - connection._transaction_stack.pop() - await transaction.commit() - await connection.__aexit__() - del self._transactions[current_task] + async with self._connection._transaction_lock: + assert self._connection._transaction_stack[-1] is self + self._connection._transaction_stack.pop() + await self._transaction.commit() + await self._connection.__aexit__() + self._transaction = None async def rollback(self) -> None: - connection = self._connection_callable() - current_task = asyncio.current_task() - transaction = self._transactions.get(current_task, None) - assert transaction is not None, "Transaction not found in current task" - async with connection._transaction_lock: - assert connection._transaction_stack[-1] is self - connection._transaction_stack.pop() - await transaction.rollback() - await connection.__aexit__() - del self._transactions[current_task] + async with self._connection._transaction_lock: + assert self._connection._transaction_stack[-1] is self + self._connection._transaction_stack.pop() + await self._transaction.rollback() + await self._connection.__aexit__() + self._transaction = None class _EmptyNetloc(str): From 0f938079043e1a3f290f6368e06fbeb30a7e39d8 Mon Sep 17 00:00:00 2001 From: Zev Isert Date: Tue, 23 May 2023 18:09:07 -0700 Subject: [PATCH 14/17] chore: apply corrections from linters --- databases/core.py | 23 +++++++++++++++++------ tests/test_databases.py | 5 +++-- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/databases/core.py b/databases/core.py index 7cdf5921..6fd813f2 100644 --- a/databases/core.py +++ b/databases/core.py @@ -1,12 +1,13 @@ import asyncio import contextlib -from contextvars import ContextVar import functools import logging import typing +import weakref +from contextvars import ContextVar from types import TracebackType from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit -import weakref + from sqlalchemy import text from sqlalchemy.sql import ClauseElement @@ -93,8 +94,12 @@ def _connection( connections = weakref.WeakKeyDictionary() _ACTIVE_CONNECTIONS.set(connections) - connections[self] = connection - return connections[self] + if connection is None: + connections.pop(self, None) + else: + connections[self] = connection + + return connections.get(self, None) async def connect(self) -> None: """ @@ -390,8 +395,12 @@ def _transaction( transactions = weakref.WeakKeyDictionary() _ACTIVE_TRANSACTIONS.set(transactions) - transactions[self] = transaction - return transactions[self] + if transaction is None: + transactions.pop(self, None) + else: + transactions[self] = transaction + + return transactions.get(self, None) async def __aenter__(self) -> "Transaction": """ @@ -448,6 +457,7 @@ async def commit(self) -> None: async with self._connection._transaction_lock: assert self._connection._transaction_stack[-1] is self self._connection._transaction_stack.pop() + assert self._transaction is not None await self._transaction.commit() await self._connection.__aexit__() self._transaction = None @@ -456,6 +466,7 @@ async def rollback(self) -> None: async with self._connection._transaction_lock: assert self._connection._transaction_stack[-1] is self self._connection._transaction_stack.pop() + assert self._transaction is not None await self._transaction.rollback() await self._connection.__aexit__() self._transaction = None diff --git a/tests/test_databases.py b/tests/test_databases.py index e13c7930..c78ce4f3 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -2,12 +2,13 @@ import datetime import decimal import functools +import gc import itertools import os import re -import gc -from unittest.mock import MagicMock, patch from typing import MutableMapping +from unittest.mock import MagicMock, patch + import pytest import sqlalchemy From f091482d250aae97a550b9718fb1074ea1e24318 Mon Sep 17 00:00:00 2001 From: Zev Isert Date: Tue, 23 May 2023 18:37:03 -0700 Subject: [PATCH 15/17] fix: quote WeakKeyDictionary typing for python<=3.7 --- databases/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/databases/core.py b/databases/core.py index 6fd813f2..cf5a7aa0 100644 --- a/databases/core.py +++ b/databases/core.py @@ -35,12 +35,12 @@ logger = logging.getLogger("databases") + _ACTIVE_CONNECTIONS: ContextVar[ - typing.Optional[weakref.WeakKeyDictionary["Database", "Connection"]] + typing.Optional["weakref.WeakKeyDictionary['Database', 'Connection']"] ] = ContextVar("databases:open_connections", default=None) - _ACTIVE_TRANSACTIONS: ContextVar[ - typing.Optional[weakref.WeakKeyDictionary["Transaction", "TransactionBackend"]] + typing.Optional["weakref.WeakKeyDictionary['Transaction', 'TransactionBackend']"] ] = ContextVar("databases:open_transactions", default=None) From 6fb55a5b081dd5fbf04a36f2fe06ec121ddb24e3 Mon Sep 17 00:00:00 2001 From: Zev Isert Date: Thu, 25 May 2023 16:24:49 -0700 Subject: [PATCH 16/17] docs: add examples for async transaction context and nested transactions --- docs/connections_and_transactions.md | 51 ++++++++++++++++++++++++++-- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/docs/connections_and_transactions.md b/docs/connections_and_transactions.md index aa45537d..e52243e3 100644 --- a/docs/connections_and_transactions.md +++ b/docs/connections_and_transactions.md @@ -67,6 +67,7 @@ A transaction can be acquired from the database connection pool: async with database.transaction(): ... ``` + It can also be acquired from a specific database connection: ```python @@ -95,8 +96,54 @@ async def create_users(request): ... ``` -Transaction blocks are managed as task-local state. Nested transactions -are fully supported, and are implemented using database savepoints. +Transaction state is stored in the context of the currently executing asynchronous task. +This state is _inherited_ by tasks that are started from within an active transaction: + +```python +async def add_excitement(database: Database, id: int): + await database.execute( + "UPDATE notes SET text = CONCAT(text, '!!!') WHERE id = :id", + {"id": id} + ) + + +async with Database(database_url) as database: + async with database.transaction(): + # This note won't exist until the transaction closes... + await database.execute( + "INSERT INTO notes(id, text) values (1, 'databases is cool')" + ) + # ...but child tasks inherit transaction state! + await asyncio.create_task(add_excitement(database, id=1)) + + await database.fetch_val("SELECT text FROM notes WHERE id=1") + # ^ returns: "databases is cool!!!" +``` + +!!! note + In python 3.11, you can opt-out of context propagation by providing a new context to + [`asyncio.create_task`](https://docs.python.org/3.11/library/asyncio-task.html#creating-tasks). + +Nested transactions are fully supported, and are implemented using database savepoints: + +```python +async with databases.Database(database_url) as db: + async with db.transaction() as outer: + # Do something in the outer transaction + ... + + # Suppress to prevent influence on the outer transaction + with contextlib.suppress(ValueError): + async with db.transaction(): + # Do something in the inner transaction + ... + + raise ValueError('Abort the inner transaction') + + # Observe the results of the outer transaction, + # without effects from the inner transaction. + await db.fetch_all('SELECT * FROM ...') +``` Transaction isolation-level can be specified if the driver backend supports that: From b94f0971bb2445ff9bd3427a5ab5954eb9fa066d Mon Sep 17 00:00:00 2001 From: Zev Isert Date: Fri, 26 May 2023 15:41:00 -0700 Subject: [PATCH 17/17] fix: remove connection inheritance, add more tests, update docs Connections are once again stored as state on the Database instance, keyed by the current asyncio.Task. Each task acquires it's own connection, and a WeakKeyDictionary allows the connection to be discarded if the owning task is garbage collected. TransactionBackends are still stored as contextvars, and a connection must be explicitly provided to descendant tasks if active transaction state is to be inherited. --- databases/core.py | 44 +++--- docs/connections_and_transactions.md | 23 ++- tests/test_databases.py | 226 ++++++++++++++++++++------- 3 files changed, 203 insertions(+), 90 deletions(-) diff --git a/databases/core.py b/databases/core.py index cf5a7aa0..795609ea 100644 --- a/databases/core.py +++ b/databases/core.py @@ -36,12 +36,9 @@ logger = logging.getLogger("databases") -_ACTIVE_CONNECTIONS: ContextVar[ - typing.Optional["weakref.WeakKeyDictionary['Database', 'Connection']"] -] = ContextVar("databases:open_connections", default=None) _ACTIVE_TRANSACTIONS: ContextVar[ typing.Optional["weakref.WeakKeyDictionary['Transaction', 'TransactionBackend']"] -] = ContextVar("databases:open_transactions", default=None) +] = ContextVar("databases:active_transactions", default=None) class Database: @@ -54,6 +51,8 @@ class Database: "sqlite": "databases.backends.sqlite:SQLiteBackend", } + _connection_map: "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']" + def __init__( self, url: typing.Union[str, "DatabaseURL"], @@ -64,6 +63,7 @@ def __init__( self.url = DatabaseURL(url) self.options = options self.is_connected = False + self._connection_map = weakref.WeakKeyDictionary() self._force_rollback = force_rollback @@ -78,28 +78,28 @@ def __init__( self._global_transaction: typing.Optional[Transaction] = None @property - def _connection(self) -> typing.Optional["Connection"]: - connections = _ACTIVE_CONNECTIONS.get() - if connections is None: - return None + def _current_task(self) -> asyncio.Task: + task = asyncio.current_task() + if not task: + raise RuntimeError("No currently active asyncio.Task found") + return task - return connections.get(self, None) + @property + def _connection(self) -> typing.Optional["Connection"]: + return self._connection_map.get(self._current_task) @_connection.setter def _connection( self, connection: typing.Optional["Connection"] ) -> typing.Optional["Connection"]: - connections = _ACTIVE_CONNECTIONS.get() - if connections is None: - connections = weakref.WeakKeyDictionary() - _ACTIVE_CONNECTIONS.set(connections) + task = self._current_task if connection is None: - connections.pop(self, None) + self._connection_map.pop(task, None) else: - connections[self] = connection + self._connection_map[task] = connection - return connections.get(self, None) + return self._connection async def connect(self) -> None: """ @@ -119,7 +119,7 @@ async def connect(self) -> None: assert self._global_connection is None assert self._global_transaction is None - self._global_connection = Connection(self._backend) + self._global_connection = Connection(self, self._backend) self._global_transaction = self._global_connection.transaction( force_rollback=True ) @@ -218,7 +218,7 @@ def connection(self) -> "Connection": return self._global_connection if not self._connection: - self._connection = Connection(self._backend) + self._connection = Connection(self, self._backend) return self._connection @@ -243,7 +243,8 @@ def _get_backend(self) -> str: class Connection: - def __init__(self, backend: DatabaseBackend) -> None: + def __init__(self, database: Database, backend: DatabaseBackend) -> None: + self._database = database self._backend = backend self._connection_lock = asyncio.Lock() @@ -277,6 +278,7 @@ async def __aexit__( self._connection_counter -= 1 if self._connection_counter == 0: await self._connection.release() + self._database._connection = None async def fetch_all( self, @@ -393,13 +395,15 @@ def _transaction( transactions = _ACTIVE_TRANSACTIONS.get() if transactions is None: transactions = weakref.WeakKeyDictionary() - _ACTIVE_TRANSACTIONS.set(transactions) + else: + transactions = transactions.copy() if transaction is None: transactions.pop(self, None) else: transactions[self] = transaction + _ACTIVE_TRANSACTIONS.set(transactions) return transactions.get(self, None) async def __aenter__(self) -> "Transaction": diff --git a/docs/connections_and_transactions.md b/docs/connections_and_transactions.md index e52243e3..11044655 100644 --- a/docs/connections_and_transactions.md +++ b/docs/connections_and_transactions.md @@ -7,14 +7,14 @@ that transparently handles the use of either transactions or savepoints. ## Connecting and disconnecting -You can control the database connect/disconnect, by using it as a async context manager. +You can control the database connection pool with an async context manager: ```python async with Database(DATABASE_URL) as database: ... ``` -Or by using explicit connection and disconnection: +Or by using the explicit `.connect()` and `.disconnect()` methods: ```python database = Database(DATABASE_URL) @@ -23,6 +23,8 @@ await database.connect() await database.disconnect() ``` +Connections within this connection pool are acquired for each new `asyncio.Task`. + If you're integrating against a web framework, then you'll probably want to hook into framework startup or shutdown events. For example, with [Starlette][starlette] you would use the following: @@ -96,12 +98,13 @@ async def create_users(request): ... ``` -Transaction state is stored in the context of the currently executing asynchronous task. -This state is _inherited_ by tasks that are started from within an active transaction: +Transaction state is tied to the connection used in the currently executing asynchronous task. +If you would like to influence an active transaction from another task, the connection must be +shared. This state is _inherited_ by tasks that are share the same connection: ```python -async def add_excitement(database: Database, id: int): - await database.execute( +async def add_excitement(connnection: databases.core.Connection, id: int): + await connection.execute( "UPDATE notes SET text = CONCAT(text, '!!!') WHERE id = :id", {"id": id} ) @@ -113,17 +116,13 @@ async with Database(database_url) as database: await database.execute( "INSERT INTO notes(id, text) values (1, 'databases is cool')" ) - # ...but child tasks inherit transaction state! - await asyncio.create_task(add_excitement(database, id=1)) + # ...but child tasks can use this connection now! + await asyncio.create_task(add_excitement(database.connection(), id=1)) await database.fetch_val("SELECT text FROM notes WHERE id=1") # ^ returns: "databases is cool!!!" ``` -!!! note - In python 3.11, you can opt-out of context propagation by providing a new context to - [`asyncio.create_task`](https://docs.python.org/3.11/library/asyncio-task.html#creating-tasks). - Nested transactions are fully supported, and are implemented using database savepoints: ```python diff --git a/tests/test_databases.py b/tests/test_databases.py index c78ce4f3..4d737261 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -482,11 +482,29 @@ async def test_transaction_commit(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter -async def test_transaction_context_child_task_interaction(database_url): +async def test_transaction_context_child_task_inheritance(database_url): + """ + Ensure that transactions are inherited by child tasks. + """ + async with Database(database_url) as database: + + async def check_transaction(transaction, active_transaction): + # Should have inherited the same transaction backend from the parent task + assert transaction._transaction is active_transaction + + async with database.transaction() as transaction: + await asyncio.create_task( + check_transaction(transaction, transaction._transaction) + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_child_task_inheritance_example(database_url): """ Ensure that child tasks may influence inherited transactions. """ - # This is an practical example of the next test. + # This is an practical example of the above test. async with Database(database_url) as database: async with database.transaction(): # Create a note @@ -503,37 +521,19 @@ async def test_transaction_context_child_task_interaction(database_url): result = await database.fetch_one(notes.select().where(notes.c.id == 1)) assert result.text == "prior" - async def run_update_from_child_task(): - # Chage the note from a child task - await database.execute( + async def run_update_from_child_task(connection): + # Change the note from a child task + await connection.execute( notes.update().where(notes.c.id == 1).values(text="test") ) - await asyncio.create_task(run_update_from_child_task()) + await asyncio.create_task(run_update_from_child_task(database.connection())) # Confirm the child's change result = await database.fetch_one(notes.select().where(notes.c.id == 1)) assert result.text == "test" -@pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter -async def test_transaction_context_child_task_inheritance(database_url): - """ - Ensure that transactions are inherited by child tasks. - """ - async with Database(database_url) as database: - - async def check_transaction(transaction, active_transaction): - # Should have inherited the same transaction backend from the parent task - assert transaction._transaction is active_transaction - - async with database.transaction() as transaction: - await asyncio.create_task( - check_transaction(transaction, transaction._transaction) - ) - - @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_transaction_context_sibling_task_isolation(database_url): @@ -568,56 +568,99 @@ async def check_transaction(transaction): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter -async def test_connection_context_cleanup_contextmanager(database_url): +async def test_transaction_context_sibling_task_isolation_example(database_url): + """ + Ensure that transactions are running in sibling tasks are isolated from eachother. + """ + # This is an practical example of the above test. + setup = asyncio.Event() + done = asyncio.Event() + + async def tx1(connection): + async with connection.transaction(): + await db.execute( + notes.insert(), values={"id": 1, "text": "tx1", "completed": False} + ) + setup.set() + await done.wait() + + async def tx2(connection): + async with connection.transaction(): + await setup.wait() + result = await db.fetch_all(notes.select()) + assert result == [], result + done.set() + + async with Database(database_url) as db: + await asyncio.gather(tx1(db), tx2(db)) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_cleanup_contextmanager(database_url): """ - Ensure that contextvar connections are not persisted unecessarily. + Ensure that task connections are not persisted unecessarily. """ - from databases.core import _ACTIVE_CONNECTIONS - assert _ACTIVE_CONNECTIONS.get() is None + ready = asyncio.Event() + done = asyncio.Event() + + async def check_child_connection(database: Database): + async with database.connection(): + ready.set() + await done.wait() async with Database(database_url) as database: + # Should have a connection in this task # .connect is lazy, it doesn't create a Connection, but .connection does connection = database.connection() + assert isinstance(database._connection_map, MutableMapping) + assert database._connection_map.get(asyncio.current_task()) is connection - open_connections = _ACTIVE_CONNECTIONS.get() - assert isinstance(open_connections, MutableMapping) - assert open_connections.get(database) is connection + # Create a child task and see if it registers a connection + task = asyncio.create_task(check_child_connection(database)) + await ready.wait() + assert database._connection_map.get(task) is not None + assert database._connection_map.get(task) is not connection - # Context manager closes, open_connections is cleaned up - open_connections = _ACTIVE_CONNECTIONS.get() - assert isinstance(open_connections, MutableMapping) - assert open_connections.get(database, None) is None + # Let the child task finish, and see if it cleaned up + done.set() + await task + # This is normal exit logic cleanup, the WeakKeyDictionary + # shouldn't have cleaned up yet since the task is still referenced + assert task not in database._connection_map + + # Context manager closes, all open connections are removed + assert isinstance(database._connection_map, MutableMapping) + assert len(database._connection_map) == 0 @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter -async def test_connection_context_cleanup_garbagecollector(database_url): +async def test_connection_cleanup_garbagecollector(database_url): """ - Ensure that contextvar connections are not persisted unecessarily, even + Ensure that connections for tasks are not persisted unecessarily, even if exit handlers are not called. """ - from databases.core import _ACTIVE_CONNECTIONS - - assert _ACTIVE_CONNECTIONS.get() is None - database = Database(database_url) await database.connect() - connection = database.connection() - # Should be tracking the connection - open_connections = _ACTIVE_CONNECTIONS.get() - assert isinstance(open_connections, MutableMapping) - assert open_connections.get(database) is connection + created = asyncio.Event() + + async def check_child_connection(database: Database): + # neither .disconnect nor .__aexit__ are called before deleting this task + database.connection() + created.set() - # neither .disconnect nor .__aexit__ are called before deleting the reference - del database + task = asyncio.create_task(check_child_connection(database)) + await created.wait() + assert task in database._connection_map + await task + del task gc.collect() - # Should have dropped reference to connection, even without proper cleanup - open_connections = _ACTIVE_CONNECTIONS.get() - assert isinstance(open_connections, MutableMapping) - assert len(open_connections) == 0 + # Should not have a connection for the task anymore + assert len(database._connection_map) == 0 @pytest.mark.parametrize("database_url", DATABASE_URLS) @@ -632,7 +675,6 @@ async def test_transaction_context_cleanup_contextmanager(database_url): async with Database(database_url) as database: async with database.transaction() as transaction: - open_transactions = _ACTIVE_TRANSACTIONS.get() assert isinstance(open_transactions, MutableMapping) assert open_transactions.get(transaction) is transaction._transaction @@ -818,17 +860,44 @@ async def insert_data(raise_exception): with pytest.raises(RuntimeError): await insert_data(raise_exception=True) - query = notes.select() - results = await database.fetch_all(query=query) + results = await database.fetch_all(query=notes.select()) assert len(results) == 0 await insert_data(raise_exception=False) - query = notes.select() - results = await database.fetch_all(query=query) + results = await database.fetch_all(query=notes.select()) assert len(results) == 1 +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_decorator_concurrent(database_url): + """ + Ensure that @database.transaction() can be called concurrently. + """ + + database = Database(database_url) + + @database.transaction() + async def insert_data(): + await database.execute( + query=notes.insert().values(text="example", completed=True) + ) + + async with database: + await asyncio.gather( + insert_data(), + insert_data(), + insert_data(), + insert_data(), + insert_data(), + insert_data(), + ) + + results = await database.fetch_all(query=notes.select()) + assert len(results) == 6 + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_datetime_field(database_url): @@ -1007,7 +1076,7 @@ async def test_connection_context_same_task(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter -async def test_connection_context_multiple_tasks(database_url): +async def test_connection_context_multiple_sibling_tasks(database_url): async with Database(database_url) as database: connection_1 = None connection_2 = None @@ -1037,6 +1106,47 @@ async def get_connection_2(): await task_2 +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_context_multiple_tasks(database_url): + async with Database(database_url) as database: + parent_connection = database.connection() + connection_1 = None + connection_2 = None + task_1_ready = asyncio.Event() + task_2_ready = asyncio.Event() + test_complete = asyncio.Event() + + async def get_connection_1(): + nonlocal connection_1 + + async with database.connection() as connection: + connection_1 = connection + task_1_ready.set() + await test_complete.wait() + + async def get_connection_2(): + nonlocal connection_2 + + async with database.connection() as connection: + connection_2 = connection + task_2_ready.set() + await test_complete.wait() + + task_1 = asyncio.create_task(get_connection_1()) + task_2 = asyncio.create_task(get_connection_2()) + await task_1_ready.wait() + await task_2_ready.wait() + + assert connection_1 is not parent_connection + assert connection_2 is not parent_connection + assert connection_1 is not connection_2 + + test_complete.set() + await task_1 + await task_2 + + @pytest.mark.parametrize( "database_url1,database_url2", (