Skip to content

Commit fc7cefb

Browse files
authored
Merge branch 'main' into feature/modify-refresh-interval-flag
2 parents f66c358 + cc77561 commit fc7cefb

File tree

8 files changed

+87
-5
lines changed

8 files changed

+87
-5
lines changed

awswrangler/_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,9 @@ def _apply_type(name: str, value: Any, dtype: type[_ConfigValueType], nullable:
224224
raise exceptions.InvalidArgumentValue(
225225
f"{name} configuration does not accept a null value. Please pass {dtype}."
226226
)
227+
# Handle case where string is empty, "False" or "0". Anything else is True
228+
if isinstance(value, str) and dtype is bool:
229+
return value.lower() not in ("false", "0", "")
227230
try:
228231
return dtype(value) if isinstance(value, dtype) is False else value
229232
except ValueError as ex:

awswrangler/_data_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ def pyarrow2postgresql( # noqa: PLR0911
213213
return pyarrow2postgresql(dtype=dtype.value_type, string_type=string_type)
214214
if pa.types.is_binary(dtype):
215215
return "BYTEA"
216+
if pa.types.is_list(dtype):
217+
return pyarrow2postgresql(dtype=dtype.value_type, string_type=string_type) + "[]"
216218
raise exceptions.UnsupportedType(f"Unsupported PostgreSQL type: {dtype}")
217219

218220

awswrangler/_databases.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,8 @@ def generate_placeholder_parameter_pairs(
359359
"""Extract Placeholder and Parameter pairs."""
360360

361361
def convert_value_to_native_python_type(value: Any) -> Any:
362+
if isinstance(value, list):
363+
return value
362364
if pd.isna(value):
363365
return None
364366
if hasattr(value, "to_pydatetime"):

awswrangler/athena/_write_iceberg.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,13 @@ def _determine_differences(
115115

116116
catalog_column_types = typing.cast(
117117
Dict[str, str],
118-
catalog.get_table_types(database=database, table=table, catalog_id=catalog_id, boto3_session=boto3_session),
118+
catalog.get_table_types(
119+
database=database,
120+
table=table,
121+
catalog_id=catalog_id,
122+
filter_iceberg_current=True,
123+
boto3_session=boto3_session,
124+
),
119125
)
120126

121127
original_column_names = set(catalog_column_types)

awswrangler/catalog/_get.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def get_table_types(
107107
database: str,
108108
table: str,
109109
catalog_id: str | None = None,
110+
filter_iceberg_current: bool = False,
110111
boto3_session: boto3.Session | None = None,
111112
) -> dict[str, str] | None:
112113
"""Get all columns and types from a table.
@@ -120,6 +121,9 @@ def get_table_types(
120121
catalog_id
121122
The ID of the Data Catalog from which to retrieve Databases.
122123
If ``None`` is provided, the AWS account ID is used by default.
124+
filter_iceberg_current
125+
If True, returns only current iceberg fields (fields marked with iceberg.field.current: true).
126+
Otherwise, returns the all fields. False by default (return all fields).
123127
boto3_session
124128
The default boto3 session will be used if **boto3_session** receive ``None``.
125129
@@ -139,7 +143,10 @@ def get_table_types(
139143
response = client_glue.get_table(**_catalog_id(catalog_id=catalog_id, DatabaseName=database, Name=table))
140144
except client_glue.exceptions.EntityNotFoundException:
141145
return None
142-
return _extract_dtypes_from_table_details(response=response)
146+
return _extract_dtypes_from_table_details(
147+
response=response,
148+
filter_iceberg_current=filter_iceberg_current,
149+
)
143150

144151

145152
def get_databases(

awswrangler/catalog/_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,16 @@ def _sanitize_name(name: str) -> str:
3131
return re.sub("[^A-Za-z0-9_]+", "_", name).lower() # Replacing non alphanumeric characters by underscore
3232

3333

34-
def _extract_dtypes_from_table_details(response: "GetTableResponseTypeDef") -> dict[str, str]:
34+
def _extract_dtypes_from_table_details(
35+
response: "GetTableResponseTypeDef",
36+
filter_iceberg_current: bool = False,
37+
) -> dict[str, str]:
3538
dtypes: dict[str, str] = {}
3639
for col in response["Table"]["StorageDescriptor"]["Columns"]:
37-
dtypes[col["Name"]] = col["Type"]
40+
# Only return current fields if flag is enabled
41+
if not filter_iceberg_current or col.get("Parameters", {}).get("iceberg.field.current") == "true":
42+
dtypes[col["Name"]] = col["Type"]
43+
# Add partition keys as columns
3844
if "PartitionKeys" in response["Table"]:
3945
for par in response["Table"]["PartitionKeys"]:
4046
dtypes[par["Name"]] = par["Type"]

tests/unit/test_athena_iceberg.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,3 +1159,55 @@ def test_to_iceberg_fill_missing_columns_with_complex_types(
11591159
schema_evolution=True,
11601160
fill_missing_columns_in_df=True,
11611161
)
1162+
1163+
1164+
def test_athena_to_iceberg_alter_schema(
1165+
path: str,
1166+
path2: str,
1167+
glue_database: str,
1168+
glue_table: str,
1169+
) -> None:
1170+
df = pd.DataFrame(
1171+
{
1172+
"id": pd.Series([1, 2, 3, 4, 5], dtype="Int64"),
1173+
"name": pd.Series(["a", "b", "c", "d", "e"], dtype="string"),
1174+
},
1175+
).reset_index(drop=True)
1176+
1177+
split_index = 3
1178+
1179+
wr.athena.to_iceberg(
1180+
df=df[:split_index],
1181+
database=glue_database,
1182+
table=glue_table,
1183+
table_location=path,
1184+
temp_path=path2,
1185+
schema_evolution=True,
1186+
keep_files=False,
1187+
)
1188+
1189+
wr.athena.start_query_execution(
1190+
sql=f"ALTER TABLE {glue_table} CHANGE COLUMN id new_id bigint",
1191+
database=glue_database,
1192+
wait=True,
1193+
)
1194+
1195+
df = df.rename(columns={"id": "new_id"})
1196+
1197+
wr.athena.to_iceberg(
1198+
df=df[split_index:],
1199+
database=glue_database,
1200+
table=glue_table,
1201+
table_location=path,
1202+
temp_path=path2,
1203+
schema_evolution=True,
1204+
keep_files=False,
1205+
)
1206+
1207+
df_actual = wr.athena.read_sql_query(
1208+
sql=f"SELECT new_id, name FROM {glue_table} ORDER BY new_id",
1209+
database=glue_database,
1210+
ctas_approach=False,
1211+
)
1212+
1213+
assert_pandas_equals(df, df_actual)

tests/unit/test_postgresql.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ def test_unknown_overwrite_method_error(postgresql_table, postgresql_con):
100100
def test_sql_types(postgresql_table, postgresql_con):
101101
table = postgresql_table
102102
df = get_df()
103+
df["arrint"] = pd.Series([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
104+
df["arrstr"] = pd.Series([["a", "b", "c"], ["d", "e", "f"], ["g", "h", "i"]])
103105
df.drop(["binary"], axis=1, inplace=True)
104106
wr.postgresql.to_sql(
105107
df=df,
@@ -108,7 +110,7 @@ def test_sql_types(postgresql_table, postgresql_con):
108110
schema="public",
109111
mode="overwrite",
110112
index=True,
111-
dtype={"iint32": "INTEGER"},
113+
dtype={"iint32": "INTEGER", "arrint": "INTEGER[]", "arrstr": "VARCHAR[]"},
112114
)
113115
df = wr.postgresql.read_sql_query(f"SELECT * FROM public.{table}", postgresql_con)
114116
ensure_data_types(df, has_list=False)
@@ -130,6 +132,8 @@ def test_sql_types(postgresql_table, postgresql_con):
130132
"timestamp": pa.timestamp(unit="ns"),
131133
"binary": pa.binary(),
132134
"category": pa.float64(),
135+
"arrint": pa.list_(pa.int64()),
136+
"arrstr": pa.list_(pa.string()),
133137
},
134138
)
135139
for df in dfs:

0 commit comments

Comments
 (0)