Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"python.analysis.typeCheckingMode": "standard",
"python.defaultInterpreterPath": "packages/backend/.venv/bin/python",
"files.trimTrailingWhitespace": true,
"[python]": {
"editor.defaultFormatter": "charliermarsh.ruff",
Expand Down
117 changes: 99 additions & 18 deletions packages/backend/embedding_atlas/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,76 +102,105 @@ def find_available_port(start_port: int, max_attempts: int = 10, host="localhost

@click.command()
@click.argument("inputs", nargs=-1, required=True)
@click.option("--text", default=None, help="The column name for text.")
@click.option("--image", default=None, help="The column name for image.")
@click.option("--text", default=None, help="Column containing text data.")
@click.option("--image", default=None, help="Column containing image data.")
@click.option(
"--split",
default=[],
multiple=True,
help="The data split name for Hugging Face data. Repeat this command for multiple splits.",
help="Dataset split name(s) to load from Hugging Face datasets. Can be specified multiple times for multiple splits.",
)
@click.option(
"--embedding/--no-embedding",
"enable_embedding",
default=True,
help="Whether to compute embeddings for the data. Disable if embeddings are pre-computed or if you do not want an embedding view.",
)
@click.option(
"--model",
default=None,
help="The model for producing text embeddings.",
help="Model name for generating embeddings (e.g., 'all-MiniLM-L6-v2').",
)
@click.option(
"--trust-remote-code",
is_flag=True,
default=False,
help="Trust remote code when loading models.",
help="Allow execution of remote code when loading models from Hugging Face Hub.",
)
@click.option(
"--x",
"x_column",
help="Column containing pre-computed X coordinates for the embedding view.",
)
@click.option(
"--y",
"y_column",
help="Column containing pre-computed Y coordinates for the embedding view.",
)
@click.option("--x", "x_column", help="The column name for x coordinate.")
@click.option("--y", "y_column", help="The column name for y coordinate.")
@click.option(
"--neighbors",
"neighbors_column",
help="""The column name for pre-computed nearest neighbors. The values should be in {"ids": [n1, n2, ...], "distances": [d1, d2, ...]} format.""",
help='Column containing pre-computed nearest neighbors in format: {"ids": [n1, n2, ...], "distances": [d1, d2, ...]}.',
)
@click.option(
"--sample",
default=None,
type=int,
help="The number of rows to sample from the original dataset.",
help="Number of random samples to draw from the dataset. Useful for large datasets.",
)
@click.option(
"--umap-n-neighbors",
type=int,
help="The n_neighbors parameter for UMAP.",
help="Number of neighbors to consider for UMAP dimensionality reduction (default: 15).",
)
@click.option(
"--umap-min-dist",
type=float,
help="The min_dist parameter for UMAP.",
)
@click.option("--umap-metric", default="cosine", help="The metric for UMAP.")
@click.option("--umap-random-state", type=int, help="The random seed for UMAP.")
@click.option(
"--umap-metric",
default="cosine",
help="Distance metric for UMAP computation (default: 'cosine').",
)
@click.option(
"--umap-random-state", type=int, help="Random seed for reproducible UMAP results."
)
@click.option(
"--duckdb",
type=str,
default="wasm",
help="DuckDB server URI (e.g., ws://localhost:3000, http://localhost:3000), or 'wasm' to run DuckDB in browser, or 'server' to run DuckDB in this server. Default to 'wasm'.",
help="DuckDB connection mode: 'wasm' (run in browser), 'server' (run on this server), or URI (e.g., 'ws://localhost:3000').",
)
@click.option(
"--host",
default="localhost",
help="Host address for the web server (default: localhost).",
)
@click.option(
"--port", default=5055, help="Port number for the web server (default: 5055)."
)
@click.option("--host", default="localhost", help="The hostname of the http server.")
@click.option("--port", default=5055, help="The port of the http server.")
@click.option(
"--auto-port/--no-auto-port",
"enable_auto_port",
default=True,
help="Enable / disable auto port selection. If disabled, the application crashes if the specified port is already used.",
help="Automatically find an available port if the specified port is in use.",
)
@click.option(
"--static", type=str, help="Custom path to frontend static files directory."
)
@click.option("--static", type=str, help="Path to the static files for frontend.")
@click.option(
"--export-application",
type=str,
help="Export a static Web application to the given zip file and exit.",
help="Export the visualization as a standalone web application to the specified ZIP file and exit.",
)
@click.version_option(version=__version__, package_name="embedding_atlas")
def main(
inputs,
text: str | None,
image: str | None,
split: list[str] | None,
enable_embedding: bool,
model: str | None,
trust_remote_code: bool,
x_column: str | None,
Expand All @@ -198,6 +227,58 @@ def main(

print(df)

if enable_embedding and (x_column is None or y_column is None):
# No x, y column selected, first see if text column is specified, if not, ask for it
if text is None and image is None:
text = prompt_for_column(
df, "Select a column you want to run the embedding on"
)
umap_args = {}
if umap_min_dist is not None:
umap_args["min_dist"] = umap_min_dist
if umap_n_neighbors is not None:
umap_args["n_neighbors"] = umap_n_neighbors
if umap_random_state is not None:
umap_args["random_state"] = umap_random_state
if umap_metric is not None:
umap_args["metric"] = umap_metric
# Run embedding and projection
if text is not None or image is not None:
from .projection import compute_image_projection, compute_text_projection

x_column = find_column_name(df.columns, "projection_x")
y_column = find_column_name(df.columns, "projection_y")
if neighbors_column is None:
neighbors_column = find_column_name(df.columns, "__neighbors")
new_neighbors_column = neighbors_column
else:
# If neighbors_column is already specified, don't overwrite it.
new_neighbors_column = None
if text is not None:
compute_text_projection(
df,
text,
x=x_column,
y=y_column,
neighbors=new_neighbors_column,
model=model,
trust_remote_code=trust_remote_code,
umap_args=umap_args,
)
elif image is not None:
compute_image_projection(
df,
image,
x=x_column,
y=y_column,
neighbors=new_neighbors_column,
model=model,
trust_remote_code=trust_remote_code,
umap_args=umap_args,
)
else:
raise RuntimeError("unreachable")

id_column = find_column_name(df.columns, "_row_index")
df[id_column] = range(df.shape[0])

Expand Down
Loading