Skip to content

Commit a6bc908

Browse files
authored
docs: sort data by Hilbert curve, use pathlib (#6)
1 parent 9c407f4 commit a6bc908

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

packages/docs/generate_demo_data.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
# /// script
22
# requires-python = ">=3.11"
3-
# dependencies = ["click", "datasets", "pandas", "sentence-transformers", "umap-learn"]
3+
# dependencies = ["click", "datasets", "pandas", "sentence-transformers", "umap-learn", "duckdb"]
44
# ///
55

66
import json
77
import os
88
import shutil
9+
from pathlib import Path
910

1011
import click
12+
import duckdb
1113
import numpy as np
1214
import pandas as pd
1315
from datasets import load_dataset
@@ -45,8 +47,9 @@ def add_embedding_projection(df: pd.DataFrame, text: str):
4547
@click.command()
4648
@click.option("--output", default="demo-data")
4749
def main(output: str):
48-
shutil.rmtree(output, ignore_errors=True)
49-
os.makedirs(output, exist_ok=True)
50+
output_path = Path(output)
51+
shutil.rmtree(output_path, ignore_errors=True)
52+
output_path.mkdir(parents=True, exist_ok=True)
5053

5154
name = "spawn99/wine-reviews"
5255
columns = [
@@ -66,7 +69,26 @@ def main(output: str):
6669

6770
add_embedding_projection(df, text="description")
6871

69-
df.to_parquet(os.path.join(output, "dataset.parquet"), index=False)
72+
# Setup DuckDB with Hilbert support
73+
# See https://duckdb.org/2025/06/06/advanced-sorting-for-fast-selective-queries.html
74+
conn = duckdb.connect()
75+
76+
conn.execute("INSTALL lindel FROM community;")
77+
conn.execute("LOAD lindel;")
78+
79+
conn.register("wine_data", df)
80+
81+
# Sort data using Hilbert curve encoding of the projection.
82+
conn.execute(f"""
83+
COPY (
84+
SELECT *
85+
FROM wine_data
86+
ORDER BY hilbert_encode([
87+
projection_x,
88+
projection_y
89+
]::FLOAT[2])
90+
) TO '{output_path / "dataset.parquet"}' (FORMAT PARQUET)
91+
""")
7092

7193
metadata = {
7294
"columns": {
@@ -79,7 +101,7 @@ def main(output: str):
79101
"database": {"type": "wasm", "load": True},
80102
}
81103

82-
with open(os.path.join(output, "metadata.json"), "wb") as f:
104+
with open(output_path / "metadata.json", "wb") as f:
83105
f.write(json.dumps(metadata).encode("utf-8"))
84106

85107

0 commit comments

Comments
 (0)