Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions awswrangler/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,26 @@ def create_athena_bucket(self):
s3_resource.Bucket(s3_output)
return s3_output

def run_query(self, query, database, s3_output=None):
def run_query(self, query, database, s3_output=None, workgroup=None):
"""
Run a SQL Query against AWS Athena

:param query: SQL query
:param database: AWS Glue/Athena database name
:param s3_output: AWS S3 path
:param workgroup: Athena workgroup (By default uses de Session() workgroup)
:return: Query execution ID
"""
if not s3_output:
if s3_output is None:
s3_output = self.create_athena_bucket()
if workgroup is None:
workgroup = self._session.athena_workgroup
logger.debug(f"Workgroup: {workgroup}")
response = self._client_athena.start_query_execution(
QueryString=query,
QueryExecutionContext={"Database": database},
ResultConfiguration={"OutputLocation": s3_output},
)
WorkGroup=workgroup)
return response["QueryExecutionId"]

def wait_query(self, query_execution_id):
Expand Down Expand Up @@ -109,7 +113,7 @@ def wait_query(self, query_execution_id):
response["QueryExecution"]["Status"].get("StateChangeReason"))
return response

def repair_table(self, database, table, s3_output=None):
def repair_table(self, database, table, s3_output=None, workgroup=None):
"""
Hive's metastore consistency check
"MSCK REPAIR TABLE table;"
Expand All @@ -122,12 +126,14 @@ def repair_table(self, database, table, s3_output=None):
:param database: Glue database name
:param table: Glue table name
:param s3_output: AWS S3 path
:param workgroup: Athena workgroup (By default uses de Session() workgroup)
:return: Query execution ID
"""
query = f"MSCK REPAIR TABLE {table};"
query_id = self.run_query(query=query,
database=database,
s3_output=s3_output)
s3_output=s3_output,
workgroup=workgroup)
self.wait_query(query_execution_id=query_id)
return query_id

Expand Down
48 changes: 29 additions & 19 deletions awswrangler/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
spark_session=None,
procs_cpu_bound=os.cpu_count(),
procs_io_bound=os.cpu_count() * PROCS_IO_BOUND_FACTOR,
athena_workgroup="primary",
):
"""
Most parameters inherit from Boto3 or Pyspark.
Expand All @@ -59,10 +60,9 @@ def __init__(
:param s3_additional_kwargs: Passed on to s3fs (https://s3fs.readthedocs.io/en/latest/#serverside-encryption)
:param spark_context: Spark Context (pyspark.SparkContext)
:param spark_session: Spark Session (pyspark.sql.SparkSession)
:param procs_cpu_bound: number of processes that can be used in single
node applications for CPU bound case (Default: os.cpu_count())
:param procs_io_bound: number of processes that can be used in single
node applications for I/O bound cases (Default: os.cpu_count() * PROCS_IO_BOUND_FACTOR)
:param procs_cpu_bound: number of processes that can be used in single node applications for CPU bound case (Default: os.cpu_count())
:param procs_io_bound: number of processes that can be used in single node applications for I/O bound cases (Default: os.cpu_count() * PROCS_IO_BOUND_FACTOR)
:param athena_workgroup: Default AWS Athena Workgroup (str)
"""
self._profile_name = (boto3_session.profile_name
if boto3_session else profile_name)
Expand All @@ -81,6 +81,7 @@ def __init__(
self._spark_session = spark_session
self._procs_cpu_bound = procs_cpu_bound
self._procs_io_bound = procs_io_bound
self._athena_workgroup = athena_workgroup
self._primitives = None
self._load_new_primitives()
if boto3_session:
Expand Down Expand Up @@ -134,6 +135,7 @@ def _load_new_primitives(self):
botocore_config=self._botocore_config,
procs_cpu_bound=self._procs_cpu_bound,
procs_io_bound=self._procs_io_bound,
athena_workgroup=self._athena_workgroup,
)

@property
Expand Down Expand Up @@ -184,6 +186,10 @@ def procs_cpu_bound(self):
def procs_io_bound(self):
return self._procs_io_bound

@property
def athena_workgroup(self):
return self._athena_workgroup

@property
def boto3_session(self):
return self._boto3_session
Expand Down Expand Up @@ -255,6 +261,7 @@ def __init__(
botocore_config=None,
procs_cpu_bound=None,
procs_io_bound=None,
athena_workgroup=None,
):
"""
Most parameters inherit from Boto3.
Expand All @@ -268,10 +275,9 @@ def __init__(
:param botocore_max_retries: Botocore max retries
:param s3_additional_kwargs: Passed on to s3fs (https://s3fs.readthedocs.io/en/latest/#serverside-encryption)
:param botocore_config: Botocore configurations
:param procs_cpu_bound: number of processes that can be used in single
node applications for CPU bound case (Default: os.cpu_count())
:param procs_io_bound: number of processes that can be used in single
node applications for I/O bound cases (Default: os.cpu_count() * PROCS_IO_BOUND_FACTOR)
:param procs_cpu_bound: number of processes that can be used in single node applications for CPU bound case (Default: os.cpu_count())
:param procs_io_bound: number of processes that can be used in single node applications for I/O bound cases (Default: os.cpu_count() * PROCS_IO_BOUND_FACTOR)
:param athena_workgroup: Default AWS Athena Workgroup (str)
"""
self._profile_name = profile_name
self._aws_access_key_id = aws_access_key_id
Expand All @@ -283,6 +289,7 @@ def __init__(
self._botocore_config = botocore_config
self._procs_cpu_bound = procs_cpu_bound
self._procs_io_bound = procs_io_bound
self._athena_workgroup = athena_workgroup

@property
def profile_name(self):
Expand Down Expand Up @@ -324,20 +331,23 @@ def procs_cpu_bound(self):
def procs_io_bound(self):
return self._procs_io_bound

@property
def athena_workgroup(self):
return self._athena_workgroup

@property
def session(self):
"""
Reconstruct the session from primitives
:return: awswrangler.session.Session
"""
return Session(
profile_name=self._profile_name,
aws_access_key_id=self._aws_access_key_id,
aws_secret_access_key=self._aws_secret_access_key,
aws_session_token=self._aws_session_token,
region_name=self._region_name,
botocore_max_retries=self._botocore_max_retries,
s3_additional_kwargs=self._s3_additional_kwargs,
procs_cpu_bound=self._procs_cpu_bound,
procs_io_bound=self._procs_io_bound,
)
return Session(profile_name=self._profile_name,
aws_access_key_id=self._aws_access_key_id,
aws_secret_access_key=self._aws_secret_access_key,
aws_session_token=self._aws_session_token,
region_name=self._region_name,
botocore_max_retries=self._botocore_max_retries,
s3_additional_kwargs=self._s3_additional_kwargs,
procs_cpu_bound=self._procs_cpu_bound,
procs_io_bound=self._procs_io_bound,
athena_workgroup=self._athena_workgroup)
46 changes: 46 additions & 0 deletions testing/test_awswrangler/test_athena.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from pprint import pprint

import pytest
import boto3
Expand Down Expand Up @@ -37,6 +38,51 @@ def database(cloudformation_outputs):
yield database


@pytest.fixture(scope="module")
def bucket(session, cloudformation_outputs):
if "BucketName" in cloudformation_outputs:
bucket = cloudformation_outputs["BucketName"]
session.s3.delete_objects(path=f"s3://{bucket}/")
else:
raise Exception(
"You must deploy the test infrastructure using Cloudformation!")
yield bucket
session.s3.delete_objects(path=f"s3://{bucket}/")


@pytest.fixture(scope="module")
def workgroup_secondary(bucket):
wkg_name = "awswrangler_test"
client = boto3.client('athena')
wkgs = client.list_work_groups()
wkgs = [x["Name"] for x in wkgs["WorkGroups"]]
if wkg_name not in wkgs:
response = client.create_work_group(
Name=wkg_name,
Configuration={
"ResultConfiguration": {
"OutputLocation":
f"s3://{bucket}/athena_workgroup_secondary/",
"EncryptionConfiguration": {
"EncryptionOption": "SSE_S3",
}
},
"EnforceWorkGroupConfiguration": True,
"PublishCloudWatchMetricsEnabled": True,
"BytesScannedCutoffPerQuery": 100_000_000,
"RequesterPaysEnabled": False
},
Description="AWS Data Wrangler Test WorkGroup")
pprint(response)
yield wkg_name


def test_workgroup_secondary(session, database, workgroup_secondary):
session.athena.run_query(query="SELECT 1",
database=database,
workgroup=workgroup_secondary)


def test_query_cancelled(session, database):
client_athena = boto3.client("athena")
query_execution_id = session.athena.run_query(query="""
Expand Down