From 54daef38ccf62d5354fd48bbff2b1529756655a3 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Wed, 23 Nov 2022 09:30:01 -0600 Subject: [PATCH 1/3] Add ctas_write_compression argument to athena.read_sql_query --- awswrangler/athena/_read.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index 96da9f83c..af75279c7 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -261,6 +261,7 @@ def _resolve_query_without_cache_ctas( alt_database: Optional[str], name: Optional[str], ctas_bucketing_info: Optional[Tuple[List[str], int]], + ctas_write_compression: Optional[str], use_threads: Union[bool, int], s3_additional_kwargs: Optional[Dict[str, Any]], boto3_session: boto3.Session, @@ -276,6 +277,7 @@ def _resolve_query_without_cache_ctas( s3_output=s3_output, workgroup=workgroup, encryption=encryption, + write_compression=ctas_write_compression, kms_key=kms_key, wait=True, boto3_session=boto3_session, @@ -409,6 +411,7 @@ def _resolve_query_without_cache( ctas_database_name: Optional[str], ctas_temp_table_name: Optional[str], ctas_bucketing_info: Optional[Tuple[List[str], int]], + ctas_write_compression: Optional[str], use_threads: Union[bool, int], s3_additional_kwargs: Optional[Dict[str, Any]], boto3_session: boto3.Session, @@ -439,6 +442,7 @@ def _resolve_query_without_cache( alt_database=ctas_database_name, name=name, ctas_bucketing_info=ctas_bucketing_info, + ctas_write_compression=ctas_write_compression, use_threads=use_threads, s3_additional_kwargs=s3_additional_kwargs, boto3_session=boto3_session, @@ -656,7 +660,7 @@ def get_query_results( @apply_configs -def read_sql_query( +def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals sql: str, database: str, ctas_approach: bool = True, @@ -672,6 +676,7 @@ def read_sql_query( ctas_database_name: Optional[str] = None, ctas_temp_table_name: Optional[str] = None, ctas_bucketing_info: Optional[Tuple[List[str], int]] = None, + ctas_write_compression: Optional[str] = None, use_threads: Union[bool, int] = True, boto3_session: Optional[boto3.Session] = None, max_cache_seconds: int = 0, @@ -838,6 +843,9 @@ def read_sql_query( Tuple consisting of the column names used for bucketing as the first element and the number of buckets as the second element. Only `str`, `int` and `bool` are supported as column data types for bucketing. + ctas_write_compression: str, optional + Write compression for the temporary table where the CTAS result is stored. + Corresponds to the `write_compression` parameters for CREATE TABLE AS statement in Athena. use_threads : bool, int True to enable concurrent requests, False to disable multiple threads. If enabled os.cpu_count() will be used as the max number of threads. @@ -963,6 +971,7 @@ def read_sql_query( ctas_database_name=ctas_database_name, ctas_temp_table_name=ctas_temp_table_name, ctas_bucketing_info=ctas_bucketing_info, + ctas_write_compression=ctas_write_compression, use_threads=use_threads, s3_additional_kwargs=s3_additional_kwargs, boto3_session=session, @@ -987,6 +996,7 @@ def read_sql_table( ctas_database_name: Optional[str] = None, ctas_temp_table_name: Optional[str] = None, ctas_bucketing_info: Optional[Tuple[List[str], int]] = None, + ctas_write_compression: Optional[str] = None, use_threads: Union[bool, int] = True, boto3_session: Optional[boto3.Session] = None, max_cache_seconds: int = 0, @@ -1131,6 +1141,9 @@ def read_sql_table( Tuple consisting of the column names used for bucketing as the first element and the number of buckets as the second element. Only `str`, `int` and `bool` are supported as column data types for bucketing. + ctas_write_compression: str, optional + Write compression for the temporary table where the CTAS result is stored. + Corresponds to the `write_compression` parameters for CREATE TABLE AS statement in Athena. use_threads : bool, int True to enable concurrent requests, False to disable multiple threads. If enabled os.cpu_count() will be used as the max number of threads. @@ -1202,6 +1215,7 @@ def read_sql_table( ctas_database_name=ctas_database_name, ctas_temp_table_name=ctas_temp_table_name, ctas_bucketing_info=ctas_bucketing_info, + ctas_write_compression=ctas_write_compression, use_threads=use_threads, boto3_session=boto3_session, max_cache_seconds=max_cache_seconds, From 7524ec807f3a5737071ce76e38382a9ee418fc74 Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Wed, 23 Nov 2022 11:15:20 -0600 Subject: [PATCH 2/3] Add unit test --- tests/test_athena.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_athena.py b/tests/test_athena.py index 6b2d6175a..73ee77848 100644 --- a/tests/test_athena.py +++ b/tests/test_athena.py @@ -1,6 +1,7 @@ import datetime import logging import string +from unittest.mock import patch import boto3 import numpy as np @@ -1252,3 +1253,34 @@ def test_get_query_execution(workgroup0, workgroup1): assert isinstance(unprocessed_query_executions_df, pd.DataFrame) assert set(query_execution_ids).intersection(set(query_executions_df["QueryExecutionId"].values.tolist())) assert {"aaa", "bbb"}.intersection(set(unprocessed_query_executions_df["QueryExecutionId"].values.tolist())) + + +@pytest.mark.parametrize("compression", [None, "snappy", "gzip"]) +def test_read_sql_query_ctas_write_compression(path, glue_database, glue_table, compression): + wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table) + wr.s3.to_parquet( + df=get_df(), + path=path, + index=True, + use_threads=True, + dataset=True, + mode="overwrite", + database=glue_database, + table=glue_table, + partition_cols=["par0", "par1"], + ) + + with patch( + "awswrangler.athena._read.create_ctas_table", wraps=wr.athena.create_ctas_table + ) as mock_create_ctas_table: + wr.athena.read_sql_query( + sql=f"SELECT * FROM {glue_table}", + database=glue_database, + ctas_approach=True, + ctas_write_compression=compression, + ) + + mock_create_ctas_table.assert_called_once() + + create_ctas_table_args = mock_create_ctas_table.call_args.kwargs + create_ctas_table_args["compression"] = compression From 6fc93314e4c1f78449cc7895de2ae7476501bafc Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Wed, 23 Nov 2022 14:40:40 -0600 Subject: [PATCH 3/3] Add more debug loging --- awswrangler/athena/_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/awswrangler/athena/_utils.py b/awswrangler/athena/_utils.py index 26bd5526f..c247213d7 100644 --- a/awswrangler/athena/_utils.py +++ b/awswrangler/athena/_utils.py @@ -472,9 +472,11 @@ def start_query_execution( max_cache_query_inspections=max_cache_query_inspections, max_remote_cache_entries=max_remote_cache_entries, ) + _logger.debug("cache_info:\n%s", cache_info) if cache_info.has_valid_cache and cache_info.query_execution_id is not None: query_execution_id = cache_info.query_execution_id + _logger.debug("Valid cache found. Retrieving...") else: wg_config: _WorkGroupConfig = _get_workgroup_config(session=session, workgroup=workgroup) query_execution_id = _start_query_execution(