Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 0c5e047

Browse files
committed
test: check multiple databases in the same task use independant connections
1 parent 574626a commit 0c5e047

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

tests/test_databases.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
import re
77
from unittest.mock import MagicMock, patch
8-
8+
import itertools
99
import pytest
1010
import sqlalchemy
1111

@@ -789,15 +789,16 @@ async def test_connect_and_disconnect(database_url):
789789

790790
@pytest.mark.parametrize("database_url", DATABASE_URLS)
791791
@async_adapter
792-
async def test_connection_context(database_url):
793-
"""
794-
Test connection contexts are task-local.
795-
"""
792+
async def test_connection_context_same_task(database_url):
796793
async with Database(database_url) as database:
797794
async with database.connection() as connection_1:
798795
async with database.connection() as connection_2:
799796
assert connection_1 is connection_2
800797

798+
799+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
800+
@async_adapter
801+
async def test_connection_context_multiple_tasks(database_url):
801802
async with Database(database_url) as database:
802803
connection_1 = None
803804
connection_2 = None
@@ -817,9 +818,8 @@ async def get_connection_2():
817818
connection_2 = connection
818819
await test_complete.wait()
819820

820-
loop = asyncio.get_event_loop()
821-
task_1 = loop.create_task(get_connection_1())
822-
task_2 = loop.create_task(get_connection_2())
821+
task_1 = asyncio.create_task(get_connection_1())
822+
task_2 = asyncio.create_task(get_connection_2())
823823
while connection_1 is None or connection_2 is None:
824824
await asyncio.sleep(0.000001)
825825
assert connection_1 is not connection_2
@@ -828,6 +828,20 @@ async def get_connection_2():
828828
await task_2
829829

830830

831+
@pytest.mark.parametrize(
832+
"database_url1,database_url2",
833+
(
834+
pytest.param(db1, db2, id=f"{db1} | {db2}")
835+
for (db1, db2) in itertools.combinations(DATABASE_URLS, 2)
836+
),
837+
)
838+
@async_adapter
839+
async def test_connection_context_multiple_databases(database_url1, database_url2):
840+
async with Database(database_url1) as database1:
841+
async with Database(database_url2) as database2:
842+
assert database1.connection() is not database2.connection()
843+
844+
831845
@pytest.mark.parametrize("database_url", DATABASE_URLS)
832846
@async_adapter
833847
async def test_connection_context_with_raw_connection(database_url):

0 commit comments

Comments
 (0)