diff --git a/awswrangler/opensearch/_read.py b/awswrangler/opensearch/_read.py index d1ead8418..b4356411c 100644 --- a/awswrangler/opensearch/_read.py +++ b/awswrangler/opensearch/_read.py @@ -41,12 +41,24 @@ def _hit_to_row(hit: Mapping[str, Any]) -> Mapping[str, Any]: return row -def _search_response_to_documents(response: Mapping[str, Any]) -> list[Mapping[str, Any]]: - return [_hit_to_row(hit) for hit in response.get("hits", {}).get("hits", [])] - - -def _search_response_to_df(response: Mapping[str, Any] | Any) -> pd.DataFrame: - return pd.DataFrame(_search_response_to_documents(response)) +def _search_response_to_documents( + response: Mapping[str, Any], aggregations: list[str] | None = None +) -> list[Mapping[str, Any]]: + hits = response.get("hits", {}).get("hits", []) + if not hits and aggregations: + hits = [ + dict(aggregation_hit, _aggregation_name=aggregation_name) + for aggregation_name in aggregations + for aggregation_hit in response.get("aggregations", {}) + .get(aggregation_name, {}) + .get("hits", {}) + .get("hits", []) + ] + return [_hit_to_row(hit) for hit in hits] + + +def _search_response_to_df(response: Mapping[str, Any] | Any, aggregations: list[str] | None = None) -> pd.DataFrame: + return pd.DataFrame(_search_response_to_documents(response=response, aggregations=aggregations)) @_utils.check_optional_dependency(opensearchpy, "opensearchpy") @@ -128,8 +140,16 @@ def search( documents = [_hit_to_row(doc) for doc in documents_generator] df = pd.DataFrame(documents) else: + aggregations = ( + list(search_body.get("aggregations", {}).keys() or search_body.get("aggs", {}).keys()) + if search_body + else None + ) response = client.search(index=index, body=search_body, filter_path=filter_path, **kwargs) - df = _search_response_to_df(response) + df = _search_response_to_df( + response=response, + aggregations=aggregations, + ) return df diff --git a/tests/unit/test_opensearch.py b/tests/unit/test_opensearch.py index 3fbc8293a..422ccc8c6 100644 --- a/tests/unit/test_opensearch.py +++ b/tests/unit/test_opensearch.py @@ -424,6 +424,41 @@ def test_search_scroll(client): wr.opensearch.delete_index(client, index) +def test_search_aggregation(client): + index = f"test_search_agg_{_get_unique_suffix()}" + kwargs = {} if _is_serverless(client) else {"refresh": "wait_for"} + try: + wr.opensearch.index_documents( + client, + documents=inspections_documents, + index=index, + id_keys=["inspection_id"], + **kwargs, + ) + if _is_serverless(client): + # The refresh interval for OpenSearch Serverless is between 10 and 30 seconds + # depending on the size of the request. + time.sleep(30) + df = wr.opensearch.search( + client, + index=index, + search_body={ + "aggregations": { + "latest_inspections": {"top_hits": {"sort": [{"inspection_date": {"order": "asc"}}], "size": 1}}, + "lowest_inspection_score": { + "top_hits": {"sort": [{"inspection_score": {"order": "asc"}}], "size": 1} + }, + } + }, + filter_path=["aggregations"], + ) + assert df.shape[0] == 2 + assert len(df.loc[df["_aggregation_name"] == "latest_inspections"]) == 1 + assert len(df.loc[df["_aggregation_name"] == "lowest_inspection_score"]) == 1 + finally: + wr.opensearch.delete_index(client, index) + + @pytest.mark.parametrize("fetch_size", [None, 1000, 10000]) @pytest.mark.parametrize("fetch_size_param_name", ["size", "fetch_size"]) def test_search_sql(client, fetch_size, fetch_size_param_name):