Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 6fcb168

Browse files
taybinqweryty
andauthored
Reset counter for failed connections (#385)
Co-authored-by: Sergey Morozov <[email protected]>
1 parent e3e7fa0 commit 6fcb168

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

databases/core.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,12 @@ def __init__(self, backend: DatabaseBackend) -> None:
237237
async def __aenter__(self) -> "Connection":
238238
async with self._connection_lock:
239239
self._connection_counter += 1
240-
if self._connection_counter == 1:
241-
await self._connection.acquire()
240+
try:
241+
if self._connection_counter == 1:
242+
await self._connection.acquire()
243+
except Exception as e:
244+
self._connection_counter -= 1
245+
raise e
242246
return self
243247

244248
async def __aexit__(

tests/test_databases.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import functools
55
import os
66
import re
7+
from unittest.mock import patch, MagicMock
78

89
import pytest
910
import sqlalchemy
@@ -15,6 +16,11 @@
1516
DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")]
1617

1718

19+
class AsyncMock(MagicMock):
20+
async def __call__(self, *args, **kwargs):
21+
return super(AsyncMock, self).__call__(*args, **kwargs)
22+
23+
1824
class MyEpochType(sqlalchemy.types.TypeDecorator):
1925
impl = sqlalchemy.Integer
2026

@@ -267,6 +273,30 @@ async def test_ddl_queries(database_url):
267273
await database.execute(query)
268274

269275

276+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
277+
@async_adapter
278+
async def test_queries_after_error(database_url):
279+
"""
280+
Test that the basic `execute()` works after a previous error.
281+
"""
282+
283+
class DBException(Exception):
284+
pass
285+
286+
async with Database(database_url) as database:
287+
with patch.object(
288+
database.connection()._connection,
289+
"acquire",
290+
new=AsyncMock(side_effect=DBException),
291+
):
292+
with pytest.raises(DBException):
293+
query = notes.select()
294+
await database.fetch_all(query)
295+
296+
query = notes.select()
297+
await database.fetch_all(query)
298+
299+
270300
@pytest.mark.parametrize("database_url", DATABASE_URLS)
271301
@async_adapter
272302
async def test_results_support_mapping_interface(database_url):

0 commit comments

Comments
 (0)