Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 3525ca5

Browse files
mhadamaminalaee
andauthored
support dialect+driver for default drivers (closes #395) (#396)
Co-authored-by: Amin Alaee <[email protected]>
1 parent 612857d commit 3525ca5

File tree

4 files changed

+61
-24
lines changed

4 files changed

+61
-24
lines changed

.github/workflows/test-suite.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,5 @@ jobs:
4747
run: "scripts/install"
4848
- name: "Run tests"
4949
env:
50-
TEST_DATABASE_URLS: "sqlite:///testsuite, mysql://username:password@localhost:3306/testsuite, postgresql://username:password@localhost:5432/testsuite, postgresql+aiopg://username:[email protected]:5432/testsuite"
50+
TEST_DATABASE_URLS: "sqlite:///testsuite, sqlite+aiosqlite:///testsuite, mysql://username:password@localhost:3306/testsuite, mysql+aiomysql://username:password@localhost:3306/testsuite, postgresql://username:password@localhost:5432/testsuite, postgresql+aiopg://username:[email protected]:5432/testsuite, postgresql+asyncpg://username:password@localhost:5432/testsuite"
5151
run: "scripts/test"

databases/core.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(
6262

6363
self._force_rollback = force_rollback
6464

65-
backend_str = self.SUPPORTED_BACKENDS[self.url.scheme]
65+
backend_str = self._get_backend()
6666
backend_cls = import_from_string(backend_str)
6767
assert issubclass(backend_cls, DatabaseBackend)
6868
self._backend = backend_cls(self.url, **self.options)
@@ -220,6 +220,12 @@ def force_rollback(self) -> typing.Iterator[None]:
220220
finally:
221221
self._force_rollback = initial
222222

223+
def _get_backend(self) -> str:
224+
try:
225+
return self.SUPPORTED_BACKENDS[self.url.scheme]
226+
except KeyError:
227+
return self.SUPPORTED_BACKENDS[self.url.dialect]
228+
223229

224230
class Connection:
225231
def __init__(self, backend: DatabaseBackend) -> None:

tests/test_databases.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import functools
55
import os
66
import re
7-
from unittest.mock import patch, MagicMock
7+
from unittest.mock import MagicMock, patch
88

99
import pytest
1010
import sqlalchemy
@@ -78,14 +78,18 @@ def process_result_value(self, value, dialect):
7878
)
7979

8080

81-
@pytest.fixture(autouse=True, scope="module")
81+
@pytest.fixture(autouse=True, scope="function")
8282
def create_test_database():
8383
# Create test databases with tables creation
8484
for url in DATABASE_URLS:
8585
database_url = DatabaseURL(url)
86-
if database_url.scheme == "mysql":
86+
if database_url.scheme in ["mysql", "mysql+aiomysql"]:
8787
url = str(database_url.replace(driver="pymysql"))
88-
elif database_url.scheme == "postgresql+aiopg":
88+
elif database_url.scheme in [
89+
"postgresql+aiopg",
90+
"sqlite+aiosqlite",
91+
"postgresql+asyncpg",
92+
]:
8993
url = str(database_url.replace(driver=None))
9094
engine = sqlalchemy.create_engine(url)
9195
metadata.create_all(engine)
@@ -96,9 +100,13 @@ def create_test_database():
96100
# Drop test databases
97101
for url in DATABASE_URLS:
98102
database_url = DatabaseURL(url)
99-
if database_url.scheme == "mysql":
103+
if database_url.scheme in ["mysql", "mysql+aiomysql"]:
100104
url = str(database_url.replace(driver="pymysql"))
101-
elif database_url.scheme == "postgresql+aiopg":
105+
elif database_url.scheme in [
106+
"postgresql+aiopg",
107+
"sqlite+aiosqlite",
108+
"postgresql+asyncpg",
109+
]:
102110
url = str(database_url.replace(driver=None))
103111
engine = sqlalchemy.create_engine(url)
104112
metadata.drop_all(engine)
@@ -478,9 +486,12 @@ async def test_transaction_commit_serializable(database_url):
478486

479487
database_url = DatabaseURL(database_url)
480488

481-
if database_url.scheme != "postgresql":
489+
if database_url.scheme not in ["postgresql", "postgresql+asyncpg"]:
482490
pytest.skip("Test (currently) only supports asyncpg")
483491

492+
if database_url.scheme == "postgresql+asyncpg":
493+
database_url = database_url.replace(driver=None)
494+
484495
def insert_independently():
485496
engine = sqlalchemy.create_engine(str(database_url))
486497
conn = engine.connect()
@@ -844,26 +855,34 @@ async def test_queries_with_expose_backend_connection(database_url):
844855
raw_connection = connection.raw_connection
845856

846857
# Insert query
847-
if database.url.scheme in ["mysql", "postgresql+aiopg"]:
858+
if database.url.scheme in [
859+
"mysql",
860+
"mysql+aiomysql",
861+
"postgresql+aiopg",
862+
]:
848863
insert_query = "INSERT INTO notes (text, completed) VALUES (%s, %s)"
849864
else:
850865
insert_query = "INSERT INTO notes (text, completed) VALUES ($1, $2)"
851866

852867
# execute()
853868
values = ("example1", True)
854869

855-
if database.url.scheme in ["mysql", "postgresql+aiopg"]:
870+
if database.url.scheme in [
871+
"mysql",
872+
"mysql+aiomysql",
873+
"postgresql+aiopg",
874+
]:
856875
cursor = await raw_connection.cursor()
857876
await cursor.execute(insert_query, values)
858-
elif database.url.scheme == "postgresql":
877+
elif database.url.scheme in ["postgresql", "postgresql+asyncpg"]:
859878
await raw_connection.execute(insert_query, *values)
860-
elif database.url.scheme == "sqlite":
879+
elif database.url.scheme in ["sqlite", "sqlite+aiosqlite"]:
861880
await raw_connection.execute(insert_query, values)
862881

863882
# execute_many()
864883
values = [("example2", False), ("example3", True)]
865884

866-
if database.url.scheme == "mysql":
885+
if database.url.scheme in ["mysql", "mysql+aiomysql"]:
867886
cursor = await raw_connection.cursor()
868887
await cursor.executemany(insert_query, values)
869888
elif database.url.scheme == "postgresql+aiopg":
@@ -878,13 +897,17 @@ async def test_queries_with_expose_backend_connection(database_url):
878897
select_query = "SELECT notes.id, notes.text, notes.completed FROM notes"
879898

880899
# fetch_all()
881-
if database.url.scheme in ["mysql", "postgresql+aiopg"]:
900+
if database.url.scheme in [
901+
"mysql",
902+
"mysql+aiomysql",
903+
"postgresql+aiopg",
904+
]:
882905
cursor = await raw_connection.cursor()
883906
await cursor.execute(select_query)
884907
results = await cursor.fetchall()
885-
elif database.url.scheme == "postgresql":
908+
elif database.url.scheme in ["postgresql", "postgresql+asyncpg"]:
886909
results = await raw_connection.fetch(select_query)
887-
elif database.url.scheme == "sqlite":
910+
elif database.url.scheme in ["sqlite", "sqlite+aiosqlite"]:
888911
results = await raw_connection.execute_fetchall(select_query)
889912

890913
assert len(results) == 3
@@ -897,7 +920,7 @@ async def test_queries_with_expose_backend_connection(database_url):
897920
assert results[2][2] == True
898921

899922
# fetch_one()
900-
if database.url.scheme == "postgresql":
923+
if database.url.scheme in ["postgresql", "postgresql+asyncpg"]:
901924
result = await raw_connection.fetchrow(select_query)
902925
else:
903926
cursor = await raw_connection.cursor()
@@ -1065,8 +1088,8 @@ async def test_posgres_interface(database_url):
10651088
"""
10661089
database_url = DatabaseURL(database_url)
10671090

1068-
if database_url.scheme != "postgresql":
1069-
pytest.skip("Test is only for postgresql")
1091+
if database_url.scheme not in ["postgresql", "postgresql+asyncpg"]:
1092+
pytest.skip("Test is only for asyncpg")
10701093

10711094
async with Database(database_url) as database:
10721095
async with database.transaction(force_rollback=True):

tests/test_integration.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,13 @@ def create_test_database():
2828
# Create test databases
2929
for url in DATABASE_URLS:
3030
database_url = DatabaseURL(url)
31-
if database_url.scheme == "mysql":
31+
if database_url.scheme in ["mysql", "mysql+aiomysql"]:
3232
url = str(database_url.replace(driver="pymysql"))
33-
elif database_url.scheme == "postgresql+aiopg":
33+
elif database_url.scheme in [
34+
"postgresql+aiopg",
35+
"sqlite+aiosqlite",
36+
"postgresql+asyncpg",
37+
]:
3438
url = str(database_url.replace(driver=None))
3539
engine = sqlalchemy.create_engine(url)
3640
metadata.create_all(engine)
@@ -41,9 +45,13 @@ def create_test_database():
4145
# Drop test databases
4246
for url in DATABASE_URLS:
4347
database_url = DatabaseURL(url)
44-
if database_url.scheme == "mysql":
48+
if database_url.scheme in ["mysql", "mysql+aiomysql"]:
4549
url = str(database_url.replace(driver="pymysql"))
46-
elif database_url.scheme == "postgresql+aiopg":
50+
elif database_url.scheme in [
51+
"postgresql+aiopg",
52+
"sqlite+aiosqlite",
53+
"postgresql+asyncpg",
54+
]:
4755
url = str(database_url.replace(driver=None))
4856
engine = sqlalchemy.create_engine(url)
4957
metadata.drop_all(engine)

0 commit comments

Comments
 (0)