|
4 | 4 | import functools
|
5 | 5 | import os
|
6 | 6 | import re
|
| 7 | +from unittest.mock import patch, MagicMock |
7 | 8 |
|
8 | 9 | import pytest
|
9 | 10 | import sqlalchemy
|
|
15 | 16 | DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")]
|
16 | 17 |
|
17 | 18 |
|
| 19 | +class AsyncMock(MagicMock): |
| 20 | + async def __call__(self, *args, **kwargs): |
| 21 | + return super(AsyncMock, self).__call__(*args, **kwargs) |
| 22 | + |
| 23 | + |
18 | 24 | class MyEpochType(sqlalchemy.types.TypeDecorator):
|
19 | 25 | impl = sqlalchemy.Integer
|
20 | 26 |
|
@@ -267,6 +273,30 @@ async def test_ddl_queries(database_url):
|
267 | 273 | await database.execute(query)
|
268 | 274 |
|
269 | 275 |
|
| 276 | +@pytest.mark.parametrize("database_url", DATABASE_URLS) |
| 277 | +@async_adapter |
| 278 | +async def test_queries_after_error(database_url): |
| 279 | + """ |
| 280 | + Test that the basic `execute()` works after a previous error. |
| 281 | + """ |
| 282 | + |
| 283 | + class DBException(Exception): |
| 284 | + pass |
| 285 | + |
| 286 | + async with Database(database_url) as database: |
| 287 | + with patch.object( |
| 288 | + database.connection()._connection, |
| 289 | + "acquire", |
| 290 | + new=AsyncMock(side_effect=DBException), |
| 291 | + ): |
| 292 | + with pytest.raises(DBException): |
| 293 | + query = notes.select() |
| 294 | + await database.fetch_all(query) |
| 295 | + |
| 296 | + query = notes.select() |
| 297 | + await database.fetch_all(query) |
| 298 | + |
| 299 | + |
270 | 300 | @pytest.mark.parametrize("database_url", DATABASE_URLS)
|
271 | 301 | @async_adapter
|
272 | 302 | async def test_results_support_mapping_interface(database_url):
|
|
0 commit comments