5
5
import os
6
6
import re
7
7
from unittest .mock import MagicMock , patch
8
-
8
+ import itertools
9
9
import pytest
10
10
import sqlalchemy
11
11
@@ -789,15 +789,16 @@ async def test_connect_and_disconnect(database_url):
789
789
790
790
@pytest .mark .parametrize ("database_url" , DATABASE_URLS )
791
791
@async_adapter
792
- async def test_connection_context (database_url ):
793
- """
794
- Test connection contexts are task-local.
795
- """
792
+ async def test_connection_context_same_task (database_url ):
796
793
async with Database (database_url ) as database :
797
794
async with database .connection () as connection_1 :
798
795
async with database .connection () as connection_2 :
799
796
assert connection_1 is connection_2
800
797
798
+
799
+ @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
800
+ @async_adapter
801
+ async def test_connection_context_multiple_tasks (database_url ):
801
802
async with Database (database_url ) as database :
802
803
connection_1 = None
803
804
connection_2 = None
@@ -817,9 +818,8 @@ async def get_connection_2():
817
818
connection_2 = connection
818
819
await test_complete .wait ()
819
820
820
- loop = asyncio .get_event_loop ()
821
- task_1 = loop .create_task (get_connection_1 ())
822
- task_2 = loop .create_task (get_connection_2 ())
821
+ task_1 = asyncio .create_task (get_connection_1 ())
822
+ task_2 = asyncio .create_task (get_connection_2 ())
823
823
while connection_1 is None or connection_2 is None :
824
824
await asyncio .sleep (0.000001 )
825
825
assert connection_1 is not connection_2
@@ -828,6 +828,20 @@ async def get_connection_2():
828
828
await task_2
829
829
830
830
831
+ @pytest .mark .parametrize (
832
+ "database_url1,database_url2" ,
833
+ (
834
+ pytest .param (db1 , db2 , id = f"{ db1 } | { db2 } " )
835
+ for (db1 , db2 ) in itertools .combinations (DATABASE_URLS , 2 )
836
+ ),
837
+ )
838
+ @async_adapter
839
+ async def test_connection_context_multiple_databases (database_url1 , database_url2 ):
840
+ async with Database (database_url1 ) as database1 :
841
+ async with Database (database_url2 ) as database2 :
842
+ assert database1 .connection () is not database2 .connection ()
843
+
844
+
831
845
@pytest .mark .parametrize ("database_url" , DATABASE_URLS )
832
846
@async_adapter
833
847
async def test_connection_context_with_raw_connection (database_url ):
0 commit comments