Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion awswrangler/athena/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def _resolve_query_without_cache_ctas(
alt_database: Optional[str],
name: Optional[str],
ctas_bucketing_info: Optional[Tuple[List[str], int]],
ctas_write_compression: Optional[str],
use_threads: Union[bool, int],
s3_additional_kwargs: Optional[Dict[str, Any]],
boto3_session: boto3.Session,
Expand All @@ -276,6 +277,7 @@ def _resolve_query_without_cache_ctas(
s3_output=s3_output,
workgroup=workgroup,
encryption=encryption,
write_compression=ctas_write_compression,
kms_key=kms_key,
wait=True,
boto3_session=boto3_session,
Expand Down Expand Up @@ -409,6 +411,7 @@ def _resolve_query_without_cache(
ctas_database_name: Optional[str],
ctas_temp_table_name: Optional[str],
ctas_bucketing_info: Optional[Tuple[List[str], int]],
ctas_write_compression: Optional[str],
use_threads: Union[bool, int],
s3_additional_kwargs: Optional[Dict[str, Any]],
boto3_session: boto3.Session,
Expand Down Expand Up @@ -439,6 +442,7 @@ def _resolve_query_without_cache(
alt_database=ctas_database_name,
name=name,
ctas_bucketing_info=ctas_bucketing_info,
ctas_write_compression=ctas_write_compression,
use_threads=use_threads,
s3_additional_kwargs=s3_additional_kwargs,
boto3_session=boto3_session,
Expand Down Expand Up @@ -656,7 +660,7 @@ def get_query_results(


@apply_configs
def read_sql_query(
def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals
sql: str,
database: str,
ctas_approach: bool = True,
Expand All @@ -672,6 +676,7 @@ def read_sql_query(
ctas_database_name: Optional[str] = None,
ctas_temp_table_name: Optional[str] = None,
ctas_bucketing_info: Optional[Tuple[List[str], int]] = None,
ctas_write_compression: Optional[str] = None,
use_threads: Union[bool, int] = True,
boto3_session: Optional[boto3.Session] = None,
max_cache_seconds: int = 0,
Expand Down Expand Up @@ -838,6 +843,9 @@ def read_sql_query(
Tuple consisting of the column names used for bucketing as the first element and the number of buckets as the
second element.
Only `str`, `int` and `bool` are supported as column data types for bucketing.
ctas_write_compression: str, optional
Write compression for the temporary table where the CTAS result is stored.
Corresponds to the `write_compression` parameters for CREATE TABLE AS statement in Athena.
use_threads : bool, int
True to enable concurrent requests, False to disable multiple threads.
If enabled os.cpu_count() will be used as the max number of threads.
Expand Down Expand Up @@ -963,6 +971,7 @@ def read_sql_query(
ctas_database_name=ctas_database_name,
ctas_temp_table_name=ctas_temp_table_name,
ctas_bucketing_info=ctas_bucketing_info,
ctas_write_compression=ctas_write_compression,
use_threads=use_threads,
s3_additional_kwargs=s3_additional_kwargs,
boto3_session=session,
Expand All @@ -987,6 +996,7 @@ def read_sql_table(
ctas_database_name: Optional[str] = None,
ctas_temp_table_name: Optional[str] = None,
ctas_bucketing_info: Optional[Tuple[List[str], int]] = None,
ctas_write_compression: Optional[str] = None,
use_threads: Union[bool, int] = True,
boto3_session: Optional[boto3.Session] = None,
max_cache_seconds: int = 0,
Expand Down Expand Up @@ -1131,6 +1141,9 @@ def read_sql_table(
Tuple consisting of the column names used for bucketing as the first element and the number of buckets as the
second element.
Only `str`, `int` and `bool` are supported as column data types for bucketing.
ctas_write_compression: str, optional
Write compression for the temporary table where the CTAS result is stored.
Corresponds to the `write_compression` parameters for CREATE TABLE AS statement in Athena.
use_threads : bool, int
True to enable concurrent requests, False to disable multiple threads.
If enabled os.cpu_count() will be used as the max number of threads.
Expand Down Expand Up @@ -1202,6 +1215,7 @@ def read_sql_table(
ctas_database_name=ctas_database_name,
ctas_temp_table_name=ctas_temp_table_name,
ctas_bucketing_info=ctas_bucketing_info,
ctas_write_compression=ctas_write_compression,
use_threads=use_threads,
boto3_session=boto3_session,
max_cache_seconds=max_cache_seconds,
Expand Down
2 changes: 2 additions & 0 deletions awswrangler/athena/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,11 @@ def start_query_execution(
max_cache_query_inspections=max_cache_query_inspections,
max_remote_cache_entries=max_remote_cache_entries,
)
_logger.debug("cache_info:\n%s", cache_info)

if cache_info.has_valid_cache and cache_info.query_execution_id is not None:
query_execution_id = cache_info.query_execution_id
_logger.debug("Valid cache found. Retrieving...")
else:
wg_config: _WorkGroupConfig = _get_workgroup_config(session=session, workgroup=workgroup)
query_execution_id = _start_query_execution(
Expand Down
32 changes: 32 additions & 0 deletions tests/test_athena.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import logging
import string
from unittest.mock import patch

import boto3
import numpy as np
Expand Down Expand Up @@ -1252,3 +1253,34 @@ def test_get_query_execution(workgroup0, workgroup1):
assert isinstance(unprocessed_query_executions_df, pd.DataFrame)
assert set(query_execution_ids).intersection(set(query_executions_df["QueryExecutionId"].values.tolist()))
assert {"aaa", "bbb"}.intersection(set(unprocessed_query_executions_df["QueryExecutionId"].values.tolist()))


@pytest.mark.parametrize("compression", [None, "snappy", "gzip"])
def test_read_sql_query_ctas_write_compression(path, glue_database, glue_table, compression):
wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table)
wr.s3.to_parquet(
df=get_df(),
path=path,
index=True,
use_threads=True,
dataset=True,
mode="overwrite",
database=glue_database,
table=glue_table,
partition_cols=["par0", "par1"],
)

with patch(
"awswrangler.athena._read.create_ctas_table", wraps=wr.athena.create_ctas_table
) as mock_create_ctas_table:
wr.athena.read_sql_query(
sql=f"SELECT * FROM {glue_table}",
database=glue_database,
ctas_approach=True,
ctas_write_compression=compression,
)

mock_create_ctas_table.assert_called_once()

create_ctas_table_args = mock_create_ctas_table.call_args.kwargs
create_ctas_table_args["compression"] = compression