Skip to content

Commit 23fd719

Browse files
feat: add clickhouse data sink (#4850)
## Changes Made <!-- Describe what changes were made and why. Include implementation details if necessary. --> Add the clickhouse data sink, and use it to writes the DataFrame to a ClickHouse table ## Related Issues <!-- Link to related GitHub issues, e.g., "Closes #123" --> ## Checklist - [ ] Documented in API Docs (if applicable) - [ ] Documented in User Guide (if applicable) - [ ] If adding a new documentation page, doc is added to `docs/mkdocs.yml` navigation - [ ] Documentation builds and is formatted properly (tag @/ccmao1130 for docs review) --------- Co-authored-by: Desmond Cheong <[email protected]>
1 parent 058c6cd commit 23fd719

File tree

7 files changed

+294
-1
lines changed

7 files changed

+294
-1
lines changed

daft/dataframe/dataframe.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1594,6 +1594,57 @@ def write_turbopuffer(
15941594
)
15951595
return self.write_sink(sink)
15961596

1597+
@DataframePublicAPI
1598+
def write_clickhouse(
1599+
self,
1600+
table: str,
1601+
*,
1602+
host: str,
1603+
port: Optional[int] = None,
1604+
user: Optional[str] = None,
1605+
password: Optional[str] = None,
1606+
database: Optional[str] = None,
1607+
client_kwargs: Optional[dict[str, Any]] = None,
1608+
write_kwargs: Optional[dict[str, Any]] = None,
1609+
) -> "DataFrame":
1610+
"""Writes the DataFrame to a ClickHouse table.
1611+
1612+
Args:
1613+
table: Name of the ClickHouse table to write to.
1614+
host: ClickHouse host.
1615+
port: ClickHouse port.
1616+
user: ClickHouse user.
1617+
password: ClickHouse password.
1618+
database: ClickHouse database.
1619+
client_kwargs: Optional dictionary of arguments to pass to the ClickHouse client constructor.
1620+
write_kwargs: Optional dictionary of arguments to pass to the ClickHouse write() method.
1621+
1622+
Examples:
1623+
>>> import daft
1624+
>>> df = daft.from_pydict({"a": [1, 2, 3, 4]}) # doctest: +SKIP
1625+
>>> df.write_clickhouse(table="", host="", port=8123, user="", password="") # doctest: +SKIP
1626+
╭────────────────────┬─────────────────────╮
1627+
│ total_written_rows ┆ total_written_bytes │
1628+
│ --- ┆ --- │
1629+
│ Int64 ┆ Int64 │
1630+
╞════════════════════╪═════════════════════╡
1631+
│ 4 ┆ 32 │
1632+
╰────────────────────┴─────────────────────╯
1633+
"""
1634+
from daft.io.clickhouse.clickhouse_data_sink import ClickHouseDataSink
1635+
1636+
sink = ClickHouseDataSink(
1637+
table,
1638+
host=host,
1639+
port=port,
1640+
user=user,
1641+
password=password,
1642+
database=database,
1643+
client_kwargs=client_kwargs,
1644+
write_kwargs=write_kwargs,
1645+
)
1646+
return self.write_sink(sink)
1647+
15971648
###
15981649
# DataFrame operations
15991650
###

daft/io/clickhouse/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from __future__ import annotations
2+
3+
from daft.io.clickhouse.clickhouse_data_sink import ClickHouseDataSink
4+
5+
__all__ = ["ClickHouseDataSink"]
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any
4+
5+
from clickhouse_connect import get_client
6+
from clickhouse_connect.driver.summary import QuerySummary
7+
8+
from daft import Schema
9+
from daft.datatype import DataType
10+
from daft.io import DataSink
11+
from daft.io.sink import WriteResult
12+
from daft.recordbatch.micropartition import MicroPartition
13+
14+
if TYPE_CHECKING:
15+
from collections.abc import Iterator
16+
17+
18+
class ClickHouseDataSink(DataSink[QuerySummary]):
19+
def __init__(
20+
self,
21+
table: str,
22+
*,
23+
host: str,
24+
port: int | None = None,
25+
user: str | None = None,
26+
password: str | None = None,
27+
database: str | None = None,
28+
client_kwargs: dict[str, Any] | None = None,
29+
write_kwargs: dict[str, Any] | None = None,
30+
) -> None:
31+
connection_params: dict[str, Any] = {
32+
"host": host,
33+
}
34+
35+
if port is not None:
36+
connection_params["port"] = port
37+
if user is not None:
38+
connection_params["user"] = user
39+
if password is not None:
40+
connection_params["password"] = password
41+
if database is not None:
42+
connection_params["database"] = database
43+
44+
# Merge user-provided client_kwargs with explicit connection parameters
45+
self._client_kwargs = {**(client_kwargs or {}), **connection_params}
46+
self._table = table
47+
48+
self._write_kwargs = write_kwargs or {}
49+
50+
# Define the schema for the result of the write operation
51+
self._result_schema = Schema._from_field_name_and_types(
52+
[("total_written_rows", DataType.int64()), ("total_written_bytes", DataType.int64())]
53+
)
54+
55+
def schema(self) -> Schema:
56+
return self._result_schema
57+
58+
def write(self, micropartitions: Iterator[MicroPartition]) -> Iterator[WriteResult[QuerySummary]]:
59+
"""Writes to ClickHouse from the given micropartitions."""
60+
# socket cannot be serialized, so we need to create a new client in write
61+
ck_client = get_client(**self._client_kwargs)
62+
try:
63+
for micropartition in micropartitions:
64+
df = micropartition.to_pandas()
65+
bytes_written = df.memory_usage().sum()
66+
rows_written = df.shape[0]
67+
68+
query_summary = ck_client.insert_df(self._table, df, **self._write_kwargs)
69+
yield WriteResult(
70+
result=query_summary,
71+
bytes_written=bytes_written,
72+
rows_written=rows_written,
73+
)
74+
finally:
75+
ck_client.close()
76+
77+
def finalize(self, write_results: list[WriteResult[QuerySummary]]) -> MicroPartition:
78+
"""Finish write to ClickHouse dataset. Returns a DataFrame with the stats of the dataset."""
79+
from daft.dependencies import pa
80+
81+
total_written_rows = 0
82+
total_written_bytes = 0
83+
84+
for write_result in write_results:
85+
total_written_rows += write_result.rows_written
86+
total_written_bytes += write_result.bytes_written
87+
88+
tbl = MicroPartition.from_pydict(
89+
{
90+
"total_written_rows": pa.array([total_written_rows], pa.int64()),
91+
"total_written_bytes": pa.array([total_written_bytes], pa.int64()),
92+
}
93+
)
94+
return tbl

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ readme = "README.rst"
2222
requires-python = ">=3.9"
2323

2424
[project.optional-dependencies]
25-
all = ["daft[aws, azure, gcp, ray, pandas, numpy, iceberg, deltalake, spark, sql, unity]"]
25+
all = ["daft[aws, azure, gcp, ray, pandas, numpy, iceberg, deltalake, spark, sql, unity, clickhouse]"]
2626
aws = ["boto3"]
2727
azure = []
28+
clickhouse = ["clickhouse_connect"]
2829
deltalake = ["deltalake", "packaging"]
2930
gcp = []
3031
hudi = ["pyarrow >= 8.0.0"]

requirements-dev.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,6 @@ grpcio-status==1.67.0
116116
# ai
117117
vllm; platform_system == "Linux" and platform_machine == "x86_64" # for other systems, see install instructions: https://docs.vllm.ai/en/latest/getting_started/installation.html
118118
openai
119+
120+
#clickhouse
121+
clickhouse-connect

tests/io/clickhouse/__init__.py

Whitespace-only changes.
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
from __future__ import annotations
2+
3+
import unittest
4+
from unittest.mock import MagicMock, Mock, patch
5+
6+
import pandas as pd
7+
import pytest
8+
from clickhouse_connect.driver.summary import QuerySummary
9+
10+
import daft
11+
from daft.dataframe.dataframe import DataFrame
12+
from daft.io.clickhouse.clickhouse_data_sink import ClickHouseDataSink
13+
14+
15+
class TestWriteClickHouse:
16+
@pytest.fixture
17+
def sample_data(self):
18+
return daft.from_pydict(
19+
{
20+
"vector": [[1.1, 1.2], [0.2, 1.8]],
21+
"lat": [45.5, 40.1],
22+
"long": [-122.7, -74.1],
23+
}
24+
)
25+
26+
@patch.object(ClickHouseDataSink, "__init__", return_value=None)
27+
@patch.object(DataFrame, "write_sink")
28+
def test_write_clickhouse_minimal_params(self, mock_write_sink, mock_clickhouse_sink, sample_data):
29+
"""Test minimal parameters for write_clickhouse."""
30+
sample_data.write_clickhouse(table="test_table", host="localhost")
31+
32+
mock_clickhouse_sink.assert_called_once_with(
33+
"test_table",
34+
host="localhost",
35+
port=None,
36+
user=None,
37+
password=None,
38+
database=None,
39+
client_kwargs=None,
40+
write_kwargs=None,
41+
)
42+
mock_write_sink.assert_called_once()
43+
44+
@patch.object(ClickHouseDataSink, "__init__", return_value=None)
45+
@patch.object(DataFrame, "write_sink")
46+
def test_write_clickhouse_all_params(self, mock_write_sink, mock_clickhouse_sink, sample_data):
47+
"""Test all parameters for write_clickhouse."""
48+
sample_data.write_clickhouse(
49+
table="test_table",
50+
host="localhost",
51+
port=8123,
52+
user="user",
53+
password="pass",
54+
database="db",
55+
client_kwargs={"timeout": 10},
56+
write_kwargs={"batch_size": 1000},
57+
)
58+
59+
mock_clickhouse_sink.assert_called_once_with(
60+
"test_table",
61+
host="localhost",
62+
port=8123,
63+
user="user",
64+
password="pass",
65+
database="db",
66+
client_kwargs={"timeout": 10},
67+
write_kwargs={"batch_size": 1000},
68+
)
69+
mock_write_sink.assert_called_once()
70+
71+
@patch.object(ClickHouseDataSink, "__init__", return_value=None)
72+
@patch.object(DataFrame, "write_sink", side_effect=Exception("Write error"))
73+
def test_write_clickhouse_write_error(self, mock_write_sink, mock_clickhouse_sink, sample_data):
74+
"""Test write error for write_clickhouse."""
75+
with pytest.raises(Exception, match="Write error"):
76+
sample_data.write_clickhouse(table="test_table", host="localhost")
77+
78+
79+
class TestClickHouseDataSink:
80+
"""Tests for ClickHouseDataSink class."""
81+
82+
@pytest.fixture
83+
def clickhouse_sink(self):
84+
return ClickHouseDataSink(
85+
table="test_table", host="localhost", port=8123, user="user", password="pass", database="db"
86+
)
87+
88+
def test_initialization(self, clickhouse_sink):
89+
assert clickhouse_sink._table == "test_table"
90+
assert clickhouse_sink._client_kwargs["host"] == "localhost"
91+
assert clickhouse_sink._client_kwargs["port"] == 8123
92+
assert clickhouse_sink._client_kwargs["user"] == "user"
93+
assert clickhouse_sink._client_kwargs["password"] == "pass"
94+
assert clickhouse_sink._client_kwargs["database"] == "db"
95+
96+
@pytest.fixture
97+
def mock_client(self):
98+
with patch("daft.io.clickhouse.clickhouse_data_sink.get_client") as mock: # 修正mock路径
99+
client = MagicMock()
100+
client.insert_df.return_value = QuerySummary(
101+
summary={"written_rows": 3, "written_bytes": 128, "query_id": "test_query"}
102+
)
103+
client.command.return_value = ("22.8.0", "UTC")
104+
mock.return_value = client
105+
yield client
106+
107+
def test_clickhouse_sink_write(self, mock_client):
108+
sink = ClickHouseDataSink(
109+
table="test_table",
110+
host="localhost",
111+
client_kwargs={"settings": {"async_insert": 1}},
112+
write_kwargs={"column_names": ["col1"]},
113+
)
114+
115+
mp = daft.recordbatch.MicroPartition.from_pydict({"col1": [1, 2, 3]})
116+
results = list(sink.write(iter([mp])))
117+
118+
mock_client.insert_df.assert_called_once_with("test_table", unittest.mock.ANY, column_names=["col1"])
119+
pd.testing.assert_frame_equal(mock_client.insert_df.call_args[0][1], mp.to_pandas())
120+
121+
assert len(results) == 1
122+
assert isinstance(results[0].result, QuerySummary)
123+
124+
def test_finalize_statistics(self):
125+
sink = ClickHouseDataSink(table="test", host="localhost")
126+
mock_results = [Mock(rows_written=10, bytes_written=100), Mock(rows_written=20, bytes_written=200)]
127+
128+
result = sink.finalize(mock_results)
129+
assert result.to_pydict() == {"total_written_rows": [30], "total_written_bytes": [300]}
130+
131+
def test_client_cleanup(self, mock_client):
132+
sink = ClickHouseDataSink(table="test", host="localhost")
133+
mp = daft.recordbatch.MicroPartition.from_pydict({"col1": [1]})
134+
135+
mock_client.insert_df.side_effect = Exception("DB error")
136+
with pytest.raises(Exception):
137+
list(sink.write(iter([mp])))
138+
139+
mock_client.close.assert_called_once()

0 commit comments

Comments
 (0)