Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions awswrangler/athena/_formatter.py → awswrangler/_sql_formatter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Formatting logic for Athena parameters."""
"""Formatting logic for SQL parameters."""
import datetime
import decimal
import re
from enum import Enum
from typing import Any, Dict, Generic, Sequence, Type, TypeVar
from typing import Any, Dict, Generic, Optional, Sequence, Type, TypeVar


class _EngineType(Enum):
Expand Down Expand Up @@ -165,3 +166,26 @@ def _format_parameters(params: Dict[str, Any], engine: _EngineType) -> Dict[str,
processed_params[k] = str(abs_type)

return processed_params


_PATTERN = re.compile(r":([A-Za-z0-9_]+)(?![A-Za-z0-9_])")


def _process_sql_params(sql: str, params: Optional[Dict[str, Any]], engine: _EngineType = _EngineType.PRESTO) -> str:
if params is None:
params = {}

processed_params = _format_parameters(params, engine=engine)

def replace(match: re.Match) -> str: # type: ignore
key = match.group(1)

if key not in processed_params:
# do not replace anything if the parameter is not provided
return str(match.group(0))

return str(processed_params[key])

sql = _PATTERN.sub(replace, sql)

return sql
30 changes: 3 additions & 27 deletions awswrangler/athena/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import csv
import logging
import re
import sys
import uuid
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
Expand All @@ -14,7 +13,7 @@
from awswrangler import _utils, catalog, exceptions, s3
from awswrangler._config import apply_configs
from awswrangler._data_types import cast_pandas_with_athena_types
from awswrangler.athena._formatter import _EngineType, _format_parameters
from awswrangler._sql_formatter import _process_sql_params
from awswrangler.athena._utils import (
_apply_query_metadata,
_empty_dataframe_response,
Expand Down Expand Up @@ -561,29 +560,6 @@ def _unload(
return query_metadata


_PATTERN = re.compile(r":([A-Za-z0-9_]+)(?![A-Za-z0-9_])")


def _process_sql_params(sql: str, params: Optional[Dict[str, Any]]) -> str:
if params is None:
params = {}

processed_params = _format_parameters(params, engine=_EngineType.PRESTO)

def replace(match: re.Match) -> str: # type: ignore
key = match.group(1)

if key not in processed_params:
# do not replace anything if the parameter is not provided
return str(match.group(0))

return str(processed_params[key])

sql = _PATTERN.sub(replace, sql)

return sql


@apply_configs
def get_query_results(
query_execution_id: str,
Expand Down Expand Up @@ -915,7 +891,7 @@ def read_sql_query(
>>> import awswrangler as wr
>>> df = wr.athena.read_sql_query(
... sql="SELECT * FROM my_table WHERE name=:name AND city=:city",
... params={"name": "'filtered_name'", "city": "'filtered_city'"}
... params={"name": "filtered_name", "city": "filtered_city"}
... )

"""
Expand Down Expand Up @@ -1296,7 +1272,7 @@ def unload(
>>> import awswrangler as wr
>>> res = wr.athena.unload(
... sql="SELECT * FROM my_table WHERE name=:name AND city=:city",
... params={"name": "'filtered_name'", "city": "'filtered_city'"}
... params={"name": "filtered_name", "city": "filtered_city"}
... )

"""
Expand Down
2 changes: 1 addition & 1 deletion awswrangler/athena/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from awswrangler import _data_types, _utils, catalog, exceptions, s3, sts
from awswrangler._config import apply_configs
from awswrangler.athena._formatter import _EngineType, _format_parameters
from awswrangler._sql_formatter import _EngineType, _format_parameters
from awswrangler.catalog._utils import _catalog_id, _transaction_id

from ._cache import _cache_manager, _CacheInfo, _check_for_cached_results, _LocalMetadataCacheManager
Expand Down
9 changes: 4 additions & 5 deletions awswrangler/lakeformation/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from awswrangler import _data_types, _utils, catalog
from awswrangler._config import apply_configs
from awswrangler._distributed import engine
from awswrangler._sql_formatter import _process_sql_params
from awswrangler._threading import _get_executor
from awswrangler.catalog._utils import _catalog_id, _transaction_id
from awswrangler.distributed.ray import RayLogger
Expand Down Expand Up @@ -159,17 +160,15 @@ def read_sql_query(
... sql="SELECT * FROM my_table WHERE name=:name AND city=:city",
... database="my_db",
... query_as_of_time="1611142914",
... params={"name": "'filtered_name'", "city": "'filtered_city'"}
... params={"name": "filtered_name", "city": "filtered_city"}
... )

"""
session: boto3.Session = _utils.ensure_session(session=boto3_session)
client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session)
commit_trans: bool = False
if params is None:
params = {}
for key, value in params.items():
sql = sql.replace(f":{key};", str(value))

sql = _process_sql_params(sql, params)

if not any([transaction_id, query_as_of_time]):
_logger.debug("Neither `transaction_id` nor `query_as_of_time` were specified, starting transaction")
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_lakeformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def test_lakeformation(path, path2, glue_database, glue_table, glue_table2, use_

# Filter query
df2 = wr.lakeformation.read_sql_query(
sql=f"SELECT * FROM {glue_table} WHERE iint16 = :iint16;",
sql=f'SELECT * FROM {glue_table} WHERE "string" = :city_name',
database=glue_database,
params={"iint16": 1},
params={"city_name": "Washington"},
)
assert len(df2.index) == 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from awswrangler.athena._formatter import _EngineType, _format_parameters
from awswrangler._sql_formatter import _EngineType, _format_parameters


@pytest.mark.parametrize("engine", [_EngineType.HIVE, _EngineType.PRESTO])
Expand Down