4
4
import functools
5
5
import os
6
6
import re
7
- from unittest .mock import patch , MagicMock
7
+ from unittest .mock import MagicMock , patch
8
8
9
9
import pytest
10
10
import sqlalchemy
@@ -78,14 +78,18 @@ def process_result_value(self, value, dialect):
78
78
)
79
79
80
80
81
- @pytest .fixture (autouse = True , scope = "module " )
81
+ @pytest .fixture (autouse = True , scope = "function " )
82
82
def create_test_database ():
83
83
# Create test databases with tables creation
84
84
for url in DATABASE_URLS :
85
85
database_url = DatabaseURL (url )
86
- if database_url .scheme == "mysql" :
86
+ if database_url .scheme in [ "mysql" , "mysql+aiomysql" ] :
87
87
url = str (database_url .replace (driver = "pymysql" ))
88
- elif database_url .scheme == "postgresql+aiopg" :
88
+ elif database_url .scheme in [
89
+ "postgresql+aiopg" ,
90
+ "sqlite+aiosqlite" ,
91
+ "postgresql+asyncpg" ,
92
+ ]:
89
93
url = str (database_url .replace (driver = None ))
90
94
engine = sqlalchemy .create_engine (url )
91
95
metadata .create_all (engine )
@@ -96,9 +100,13 @@ def create_test_database():
96
100
# Drop test databases
97
101
for url in DATABASE_URLS :
98
102
database_url = DatabaseURL (url )
99
- if database_url .scheme == "mysql" :
103
+ if database_url .scheme in [ "mysql" , "mysql+aiomysql" ] :
100
104
url = str (database_url .replace (driver = "pymysql" ))
101
- elif database_url .scheme == "postgresql+aiopg" :
105
+ elif database_url .scheme in [
106
+ "postgresql+aiopg" ,
107
+ "sqlite+aiosqlite" ,
108
+ "postgresql+asyncpg" ,
109
+ ]:
102
110
url = str (database_url .replace (driver = None ))
103
111
engine = sqlalchemy .create_engine (url )
104
112
metadata .drop_all (engine )
@@ -478,9 +486,12 @@ async def test_transaction_commit_serializable(database_url):
478
486
479
487
database_url = DatabaseURL (database_url )
480
488
481
- if database_url .scheme != "postgresql" :
489
+ if database_url .scheme not in [ "postgresql" , "postgresql+asyncpg" ] :
482
490
pytest .skip ("Test (currently) only supports asyncpg" )
483
491
492
+ if database_url .scheme == "postgresql+asyncpg" :
493
+ database_url = database_url .replace (driver = None )
494
+
484
495
def insert_independently ():
485
496
engine = sqlalchemy .create_engine (str (database_url ))
486
497
conn = engine .connect ()
@@ -844,26 +855,34 @@ async def test_queries_with_expose_backend_connection(database_url):
844
855
raw_connection = connection .raw_connection
845
856
846
857
# Insert query
847
- if database .url .scheme in ["mysql" , "postgresql+aiopg" ]:
858
+ if database .url .scheme in [
859
+ "mysql" ,
860
+ "mysql+aiomysql" ,
861
+ "postgresql+aiopg" ,
862
+ ]:
848
863
insert_query = "INSERT INTO notes (text, completed) VALUES (%s, %s)"
849
864
else :
850
865
insert_query = "INSERT INTO notes (text, completed) VALUES ($1, $2)"
851
866
852
867
# execute()
853
868
values = ("example1" , True )
854
869
855
- if database .url .scheme in ["mysql" , "postgresql+aiopg" ]:
870
+ if database .url .scheme in [
871
+ "mysql" ,
872
+ "mysql+aiomysql" ,
873
+ "postgresql+aiopg" ,
874
+ ]:
856
875
cursor = await raw_connection .cursor ()
857
876
await cursor .execute (insert_query , values )
858
- elif database .url .scheme == "postgresql" :
877
+ elif database .url .scheme in [ "postgresql" , "postgresql+asyncpg" ] :
859
878
await raw_connection .execute (insert_query , * values )
860
- elif database .url .scheme == "sqlite" :
879
+ elif database .url .scheme in [ "sqlite" , "sqlite+aiosqlite" ] :
861
880
await raw_connection .execute (insert_query , values )
862
881
863
882
# execute_many()
864
883
values = [("example2" , False ), ("example3" , True )]
865
884
866
- if database .url .scheme == "mysql" :
885
+ if database .url .scheme in [ "mysql" , "mysql+aiomysql" ] :
867
886
cursor = await raw_connection .cursor ()
868
887
await cursor .executemany (insert_query , values )
869
888
elif database .url .scheme == "postgresql+aiopg" :
@@ -878,13 +897,17 @@ async def test_queries_with_expose_backend_connection(database_url):
878
897
select_query = "SELECT notes.id, notes.text, notes.completed FROM notes"
879
898
880
899
# fetch_all()
881
- if database .url .scheme in ["mysql" , "postgresql+aiopg" ]:
900
+ if database .url .scheme in [
901
+ "mysql" ,
902
+ "mysql+aiomysql" ,
903
+ "postgresql+aiopg" ,
904
+ ]:
882
905
cursor = await raw_connection .cursor ()
883
906
await cursor .execute (select_query )
884
907
results = await cursor .fetchall ()
885
- elif database .url .scheme == "postgresql" :
908
+ elif database .url .scheme in [ "postgresql" , "postgresql+asyncpg" ] :
886
909
results = await raw_connection .fetch (select_query )
887
- elif database .url .scheme == "sqlite" :
910
+ elif database .url .scheme in [ "sqlite" , "sqlite+aiosqlite" ] :
888
911
results = await raw_connection .execute_fetchall (select_query )
889
912
890
913
assert len (results ) == 3
@@ -897,7 +920,7 @@ async def test_queries_with_expose_backend_connection(database_url):
897
920
assert results [2 ][2 ] == True
898
921
899
922
# fetch_one()
900
- if database .url .scheme == "postgresql" :
923
+ if database .url .scheme in [ "postgresql" , "postgresql+asyncpg" ] :
901
924
result = await raw_connection .fetchrow (select_query )
902
925
else :
903
926
cursor = await raw_connection .cursor ()
@@ -1065,8 +1088,8 @@ async def test_posgres_interface(database_url):
1065
1088
"""
1066
1089
database_url = DatabaseURL (database_url )
1067
1090
1068
- if database_url .scheme != "postgresql" :
1069
- pytest .skip ("Test is only for postgresql " )
1091
+ if database_url .scheme not in [ "postgresql" , "postgresql+asyncpg" ] :
1092
+ pytest .skip ("Test is only for asyncpg " )
1070
1093
1071
1094
async with Database (database_url ) as database :
1072
1095
async with database .transaction (force_rollback = True ):
0 commit comments