Skip to content

Commit 44ae3fb

Browse files
feat: add columns parameters support (#2814)
* feat: add columns parameters support * tests: expand test routine * docs: fix minor docs edits --------- Co-authored-by: Leon Luttenberger <[email protected]>
1 parent 1686539 commit 44ae3fb

File tree

10 files changed

+156
-22
lines changed

10 files changed

+156
-22
lines changed

awswrangler/catalog/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
_get_table_input,
3131
databases,
3232
get_columns_comments,
33+
get_columns_parameters,
3334
get_connection,
3435
get_csv_partitions,
3536
get_databases,
@@ -83,6 +84,7 @@
8384
"_get_table_input",
8485
"databases",
8586
"get_columns_comments",
87+
"get_columns_parameters",
8688
"get_connection",
8789
"get_csv_partitions",
8890
"get_databases",

awswrangler/catalog/_create.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
_logger: logging.Logger = logging.getLogger(__name__)
2626

2727

28-
def _update_if_necessary(dic: dict[str, str], key: str, value: str | None, mode: str) -> str:
28+
def _update_if_necessary(
29+
dic: dict[str, str | dict[str, str]], key: str, value: str | dict[str, str] | None, mode: str
30+
) -> str:
2931
if value is not None:
3032
if key not in dic or dic[key] != value:
3133
dic[key] = value
@@ -46,6 +48,7 @@ def _create_table( # noqa: PLR0912,PLR0915
4648
table_exist: bool,
4749
partitions_types: dict[str, str] | None,
4850
columns_comments: dict[str, str] | None,
51+
columns_parameters: dict[str, dict[str, str]] | None,
4952
athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None,
5053
catalog_id: str | None,
5154
) -> None:
@@ -130,6 +133,19 @@ def _create_table( # noqa: PLR0912,PLR0915
130133
if name in columns_comments:
131134
mode = _update_if_necessary(dic=par, key="Comment", value=columns_comments[name], mode=mode)
132135

136+
# Column parameters
137+
columns_parameters = columns_parameters if columns_parameters else {}
138+
columns_parameters = {sanitize_column_name(k): v for k, v in columns_parameters.items()}
139+
if columns_parameters:
140+
for col in table_input["StorageDescriptor"]["Columns"]:
141+
name: str = col["Name"] # type: ignore[no-redef]
142+
if name in columns_parameters:
143+
mode = _update_if_necessary(dic=col, key="Parameters", value=columns_parameters[name], mode=mode)
144+
for par in table_input["PartitionKeys"]:
145+
name = par["Name"]
146+
if name in columns_parameters:
147+
mode = _update_if_necessary(dic=par, key="Parameters", value=columns_parameters[name], mode=mode)
148+
133149
_logger.debug("table_input: %s", table_input)
134150

135151
client_glue = _utils.client(service_name="glue", session=boto3_session)
@@ -275,6 +291,7 @@ def _create_parquet_table(
275291
description: str | None,
276292
parameters: dict[str, str] | None,
277293
columns_comments: dict[str, str] | None,
294+
columns_parameters: dict[str, dict[str, str]] | None,
278295
mode: str,
279296
catalog_versioning: bool,
280297
athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None,
@@ -311,6 +328,7 @@ def _create_parquet_table(
311328
description=description,
312329
parameters=parameters,
313330
columns_comments=columns_comments,
331+
columns_parameters=columns_parameters,
314332
mode=mode,
315333
catalog_versioning=catalog_versioning,
316334
boto3_session=boto3_session,
@@ -335,6 +353,7 @@ def _create_orc_table(
335353
description: str | None,
336354
parameters: dict[str, str] | None,
337355
columns_comments: dict[str, str] | None,
356+
columns_parameters: dict[str, dict[str, str]] | None,
338357
mode: str,
339358
catalog_versioning: bool,
340359
athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None,
@@ -371,6 +390,7 @@ def _create_orc_table(
371390
description=description,
372391
parameters=parameters,
373392
columns_comments=columns_comments,
393+
columns_parameters=columns_parameters,
374394
mode=mode,
375395
catalog_versioning=catalog_versioning,
376396
boto3_session=boto3_session,
@@ -394,6 +414,7 @@ def _create_csv_table(
394414
compression: str | None,
395415
parameters: dict[str, str] | None,
396416
columns_comments: dict[str, str] | None,
417+
columns_parameters: dict[str, dict[str, str]] | None,
397418
mode: str,
398419
catalog_versioning: bool,
399420
schema_evolution: bool,
@@ -444,6 +465,7 @@ def _create_csv_table(
444465
description=description,
445466
parameters=parameters,
446467
columns_comments=columns_comments,
468+
columns_parameters=columns_parameters,
447469
mode=mode,
448470
catalog_versioning=catalog_versioning,
449471
boto3_session=boto3_session,
@@ -467,6 +489,7 @@ def _create_json_table(
467489
compression: str | None,
468490
parameters: dict[str, str] | None,
469491
columns_comments: dict[str, str] | None,
492+
columns_parameters: dict[str, dict[str, str]] | None,
470493
mode: str,
471494
catalog_versioning: bool,
472495
schema_evolution: bool,
@@ -512,6 +535,7 @@ def _create_json_table(
512535
description=description,
513536
parameters=parameters,
514537
columns_comments=columns_comments,
538+
columns_parameters=columns_parameters,
515539
mode=mode,
516540
catalog_versioning=catalog_versioning,
517541
boto3_session=boto3_session,
@@ -713,6 +737,7 @@ def create_parquet_table(
713737
description: str | None = None,
714738
parameters: dict[str, str] | None = None,
715739
columns_comments: dict[str, str] | None = None,
740+
columns_parameters: dict[str, dict[str, str]] | None = None,
716741
mode: Literal["overwrite", "append"] = "overwrite",
717742
catalog_versioning: bool = False,
718743
athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None = None,
@@ -751,6 +776,8 @@ def create_parquet_table(
751776
Key/value pairs to tag the table.
752777
columns_comments: Dict[str, str], optional
753778
Columns names and the related comments (e.g. {'col0': 'Column 0.', 'col1': 'Column 1.', 'col2': 'Partition.'}).
779+
columns_parameters: Dict[str, Dict[str, str]], optional
780+
Columns names and the related parameters (e.g. {'col0': {'par0': 'Param 0', 'par1': 'Param 1'}}).
754781
mode: str
755782
'overwrite' to recreate any possible existing table or 'append' to keep any possible existing table.
756783
catalog_versioning : bool
@@ -848,6 +875,7 @@ def create_parquet_table(
848875
description=description,
849876
parameters=parameters,
850877
columns_comments=columns_comments,
878+
columns_parameters=columns_parameters,
851879
mode=mode,
852880
catalog_versioning=catalog_versioning,
853881
athena_partition_projection_settings=athena_partition_projection_settings,
@@ -870,6 +898,7 @@ def create_orc_table(
870898
description: str | None = None,
871899
parameters: dict[str, str] | None = None,
872900
columns_comments: dict[str, str] | None = None,
901+
columns_parameters: dict[str, dict[str, str]] | None = None,
873902
mode: Literal["overwrite", "append"] = "overwrite",
874903
catalog_versioning: bool = False,
875904
athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None = None,
@@ -908,6 +937,8 @@ def create_orc_table(
908937
Key/value pairs to tag the table.
909938
columns_comments: Dict[str, str], optional
910939
Columns names and the related comments (e.g. {'col0': 'Column 0.', 'col1': 'Column 1.', 'col2': 'Partition.'}).
940+
columns_parameters: Dict[str, Dict[str, str]], optional
941+
Columns names and the related parameters (e.g. {'col0': {'par0': 'Param 0', 'par1': 'Param 1'}}).
911942
mode: str
912943
'overwrite' to recreate any possible existing table or 'append' to keep any possible existing table.
913944
catalog_versioning : bool
@@ -1005,6 +1036,7 @@ def create_orc_table(
10051036
description=description,
10061037
parameters=parameters,
10071038
columns_comments=columns_comments,
1039+
columns_parameters=columns_parameters,
10081040
mode=mode,
10091041
catalog_versioning=catalog_versioning,
10101042
athena_partition_projection_settings=athena_partition_projection_settings,
@@ -1026,6 +1058,7 @@ def create_csv_table(
10261058
description: str | None = None,
10271059
parameters: dict[str, str] | None = None,
10281060
columns_comments: dict[str, str] | None = None,
1061+
columns_parameters: dict[str, dict[str, str]] | None = None,
10291062
mode: Literal["overwrite", "append"] = "overwrite",
10301063
catalog_versioning: bool = False,
10311064
schema_evolution: bool = False,
@@ -1072,6 +1105,8 @@ def create_csv_table(
10721105
Key/value pairs to tag the table.
10731106
columns_comments: Dict[str, str], optional
10741107
Columns names and the related comments (e.g. {'col0': 'Column 0.', 'col1': 'Column 1.', 'col2': 'Partition.'}).
1108+
columns_parameters: Dict[str, Dict[str, str]], optional
1109+
Columns names and the related parameters (e.g. {'col0': {'par0': 'Param 0', 'par1': 'Param 1'}}).
10751110
mode : str
10761111
'overwrite' to recreate any possible existing table or 'append' to keep any possible existing table.
10771112
catalog_versioning : bool
@@ -1188,6 +1223,7 @@ def create_csv_table(
11881223
description=description,
11891224
parameters=parameters,
11901225
columns_comments=columns_comments,
1226+
columns_parameters=columns_parameters,
11911227
mode=mode,
11921228
catalog_versioning=catalog_versioning,
11931229
schema_evolution=schema_evolution,
@@ -1214,6 +1250,7 @@ def create_json_table(
12141250
description: str | None = None,
12151251
parameters: dict[str, str] | None = None,
12161252
columns_comments: dict[str, str] | None = None,
1253+
columns_parameters: dict[str, dict[str, str]] | None = None,
12171254
mode: Literal["overwrite", "append"] = "overwrite",
12181255
catalog_versioning: bool = False,
12191256
schema_evolution: bool = False,
@@ -1253,6 +1290,8 @@ def create_json_table(
12531290
Key/value pairs to tag the table.
12541291
columns_comments: Dict[str, str], optional
12551292
Columns names and the related comments (e.g. {'col0': 'Column 0.', 'col1': 'Column 1.', 'col2': 'Partition.'}).
1293+
columns_parameters: Dict[str, Dict[str, str]], optional
1294+
Columns names and the related parameters (e.g. {'col0': {'par0': 'Param 0', 'par1': 'Param 1'}}).
12561295
mode : str
12571296
'overwrite' to recreate any possible existing table or 'append' to keep any possible existing table.
12581297
catalog_versioning : bool
@@ -1361,6 +1400,7 @@ def create_json_table(
13611400
description=description,
13621401
parameters=parameters,
13631402
columns_comments=columns_comments,
1403+
columns_parameters=columns_parameters,
13641404
mode=mode,
13651405
catalog_versioning=catalog_versioning,
13661406
schema_evolution=schema_evolution,

awswrangler/catalog/_get.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import base64
66
import itertools
77
import logging
8-
from typing import TYPE_CHECKING, Any, Dict, Iterator, cast
8+
from typing import TYPE_CHECKING, Any, Dict, Iterator, Mapping, cast
99

1010
import boto3
1111
import botocore.exceptions
@@ -887,6 +887,49 @@ def get_columns_comments(
887887
return comments
888888

889889

890+
@apply_configs
891+
def get_columns_parameters(
892+
database: str,
893+
table: str,
894+
catalog_id: str | None = None,
895+
boto3_session: boto3.Session | None = None,
896+
) -> dict[str, Mapping[str, str] | None]:
897+
"""Get all columns parameters.
898+
899+
Parameters
900+
----------
901+
database : str
902+
Database name.
903+
table : str
904+
Table name.
905+
catalog_id : str, optional
906+
The ID of the Data Catalog from which to retrieve Databases.
907+
If none is provided, the AWS account ID is used by default.
908+
boto3_session : boto3.Session(), optional
909+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
910+
911+
Returns
912+
-------
913+
Dict[str, Optional[Dict[str, str]]]
914+
Columns parameters.
915+
916+
Examples
917+
--------
918+
>>> import awswrangler as wr
919+
>>> pars = wr.catalog.get_columns_parameters(database="...", table="...")
920+
921+
"""
922+
client_glue = _utils.client("glue", session=boto3_session)
923+
response = client_glue.get_table(**_catalog_id(catalog_id=catalog_id, DatabaseName=database, Name=table))
924+
parameters = {}
925+
for c in response["Table"]["StorageDescriptor"]["Columns"]:
926+
parameters[c["Name"]] = c.get("Parameters")
927+
if "PartitionKeys" in response["Table"]:
928+
for p in response["Table"]["PartitionKeys"]:
929+
parameters[p["Name"]] = p.get("Parameters")
930+
return parameters
931+
932+
890933
@apply_configs
891934
def get_table_versions(
892935
database: str, table: str, catalog_id: str | None = None, boto3_session: boto3.Session | None = None

awswrangler/s3/_write.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def _validate_args(
7070
description: str | None,
7171
parameters: dict[str, str] | None,
7272
columns_comments: dict[str, str] | None,
73+
columns_parameters: dict[str, dict[str, str]] | None,
7374
execution_engine: Enum,
7475
) -> None:
7576
if df.empty is True:
@@ -87,11 +88,11 @@ def _validate_args(
8788
raise exceptions.InvalidArgumentCombination("Please, pass dataset=True to be able to use bucketing_info.")
8889
if mode is not None:
8990
raise exceptions.InvalidArgumentCombination("Please pass dataset=True to be able to use mode.")
90-
if any(arg is not None for arg in (table, description, parameters, columns_comments)):
91+
if any(arg is not None for arg in (table, description, parameters, columns_comments, columns_parameters)):
9192
raise exceptions.InvalidArgumentCombination(
9293
"Please pass dataset=True to be able to use any one of these "
9394
"arguments: database, table, description, parameters, "
94-
"columns_comments."
95+
"columns_comments, columns_parameters."
9596
)
9697
elif (database is None) != (table is None):
9798
raise exceptions.InvalidArgumentCombination(
@@ -214,6 +215,7 @@ def _create_glue_table(
214215
description: str | None,
215216
parameters: dict[str, str] | None,
216217
columns_comments: dict[str, str] | None,
218+
columns_parameters: dict[str, dict[str, str]] | None,
217219
mode: str,
218220
catalog_versioning: bool,
219221
athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None,
@@ -262,6 +264,7 @@ def write( # noqa: PLR0912,PLR0913
262264
description: str | None,
263265
parameters: dict[str, str] | None,
264266
columns_comments: dict[str, str] | None,
267+
columns_parameters: dict[str, dict[str, str]] | None,
265268
regular_partitions: bool,
266269
table_type: str | None,
267270
dtype: dict[str, str] | None,
@@ -361,6 +364,7 @@ def write( # noqa: PLR0912,PLR0913
361364
"description": description,
362365
"parameters": parameters,
363366
"columns_comments": columns_comments,
367+
"columns_parameters": columns_parameters,
364368
"boto3_session": boto3_session,
365369
"mode": mode,
366370
"catalog_versioning": catalog_versioning,

awswrangler/s3/_write_orc.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def _create_glue_table(
253253
description: str | None = None,
254254
parameters: dict[str, str] | None = None,
255255
columns_comments: dict[str, str] | None = None,
256+
columns_parameters: dict[str, dict[str, str]] | None = None,
256257
mode: str = "overwrite",
257258
catalog_versioning: bool = False,
258259
athena_partition_projection_settings: AthenaPartitionProjectionSettings | None = None,
@@ -272,6 +273,7 @@ def _create_glue_table(
272273
description=description,
273274
parameters=parameters,
274275
columns_comments=columns_comments,
276+
columns_parameters=columns_parameters,
275277
mode=mode,
276278
catalog_versioning=catalog_versioning,
277279
athena_partition_projection_settings=athena_partition_projection_settings,
@@ -629,6 +631,7 @@ def to_orc(
629631
description = glue_table_settings.get("description")
630632
parameters = glue_table_settings.get("parameters")
631633
columns_comments = glue_table_settings.get("columns_comments")
634+
columns_parameters = glue_table_settings.get("columns_parameters")
632635
regular_partitions = glue_table_settings.get("regular_partitions", True)
633636

634637
_validate_args(
@@ -643,6 +646,7 @@ def to_orc(
643646
description=description,
644647
parameters=parameters,
645648
columns_comments=columns_comments,
649+
columns_parameters=columns_parameters,
646650
execution_engine=engine.get(),
647651
)
648652

@@ -682,6 +686,7 @@ def to_orc(
682686
description=description,
683687
parameters=parameters,
684688
columns_comments=columns_comments,
689+
columns_parameters=columns_parameters,
685690
table_type=table_type,
686691
regular_partitions=regular_partitions,
687692
dtype=dtype,

0 commit comments

Comments
 (0)