diff --git a/.vscode/settings.json b/.vscode/settings.json index 595652d..a243a50 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,6 @@ { "python.analysis.typeCheckingMode": "standard", + "python.defaultInterpreterPath": "packages/backend/.venv/bin/python", "files.trimTrailingWhitespace": true, "[python]": { "editor.defaultFormatter": "charliermarsh.ruff", diff --git a/packages/backend/embedding_atlas/cli.py b/packages/backend/embedding_atlas/cli.py index d367cd0..3df11b6 100644 --- a/packages/backend/embedding_atlas/cli.py +++ b/packages/backend/embedding_atlas/cli.py @@ -102,69 +102,97 @@ def find_available_port(start_port: int, max_attempts: int = 10, host="localhost @click.command() @click.argument("inputs", nargs=-1, required=True) -@click.option("--text", default=None, help="The column name for text.") -@click.option("--image", default=None, help="The column name for image.") +@click.option("--text", default=None, help="Column containing text data.") +@click.option("--image", default=None, help="Column containing image data.") @click.option( "--split", default=[], multiple=True, - help="The data split name for Hugging Face data. Repeat this command for multiple splits.", + help="Dataset split name(s) to load from Hugging Face datasets. Can be specified multiple times for multiple splits.", +) +@click.option( + "--embedding/--no-embedding", + "enable_embedding", + default=True, + help="Whether to compute embeddings for the data. Disable if embeddings are pre-computed or if you do not want an embedding view.", ) @click.option( "--model", default=None, - help="The model for producing text embeddings.", + help="Model name for generating embeddings (e.g., 'all-MiniLM-L6-v2').", ) @click.option( "--trust-remote-code", is_flag=True, default=False, - help="Trust remote code when loading models.", + help="Allow execution of remote code when loading models from Hugging Face Hub.", +) +@click.option( + "--x", + "x_column", + help="Column containing pre-computed X coordinates for the embedding view.", +) +@click.option( + "--y", + "y_column", + help="Column containing pre-computed Y coordinates for the embedding view.", ) -@click.option("--x", "x_column", help="The column name for x coordinate.") -@click.option("--y", "y_column", help="The column name for y coordinate.") @click.option( "--neighbors", "neighbors_column", - help="""The column name for pre-computed nearest neighbors. The values should be in {"ids": [n1, n2, ...], "distances": [d1, d2, ...]} format.""", + help='Column containing pre-computed nearest neighbors in format: {"ids": [n1, n2, ...], "distances": [d1, d2, ...]}.', ) @click.option( "--sample", default=None, type=int, - help="The number of rows to sample from the original dataset.", + help="Number of random samples to draw from the dataset. Useful for large datasets.", ) @click.option( "--umap-n-neighbors", type=int, - help="The n_neighbors parameter for UMAP.", + help="Number of neighbors to consider for UMAP dimensionality reduction (default: 15).", ) @click.option( "--umap-min-dist", type=float, help="The min_dist parameter for UMAP.", ) -@click.option("--umap-metric", default="cosine", help="The metric for UMAP.") -@click.option("--umap-random-state", type=int, help="The random seed for UMAP.") +@click.option( + "--umap-metric", + default="cosine", + help="Distance metric for UMAP computation (default: 'cosine').", +) +@click.option( + "--umap-random-state", type=int, help="Random seed for reproducible UMAP results." +) @click.option( "--duckdb", type=str, default="wasm", - help="DuckDB server URI (e.g., ws://localhost:3000, http://localhost:3000), or 'wasm' to run DuckDB in browser, or 'server' to run DuckDB in this server. Default to 'wasm'.", + help="DuckDB connection mode: 'wasm' (run in browser), 'server' (run on this server), or URI (e.g., 'ws://localhost:3000').", +) +@click.option( + "--host", + default="localhost", + help="Host address for the web server (default: localhost).", +) +@click.option( + "--port", default=5055, help="Port number for the web server (default: 5055)." ) -@click.option("--host", default="localhost", help="The hostname of the http server.") -@click.option("--port", default=5055, help="The port of the http server.") @click.option( "--auto-port/--no-auto-port", "enable_auto_port", default=True, - help="Enable / disable auto port selection. If disabled, the application crashes if the specified port is already used.", + help="Automatically find an available port if the specified port is in use.", +) +@click.option( + "--static", type=str, help="Custom path to frontend static files directory." ) -@click.option("--static", type=str, help="Path to the static files for frontend.") @click.option( "--export-application", type=str, - help="Export a static Web application to the given zip file and exit.", + help="Export the visualization as a standalone web application to the specified ZIP file and exit.", ) @click.version_option(version=__version__, package_name="embedding_atlas") def main( @@ -172,6 +200,7 @@ def main( text: str | None, image: str | None, split: list[str] | None, + enable_embedding: bool, model: str | None, trust_remote_code: bool, x_column: str | None, @@ -198,6 +227,58 @@ def main( print(df) + if enable_embedding and (x_column is None or y_column is None): + # No x, y column selected, first see if text column is specified, if not, ask for it + if text is None and image is None: + text = prompt_for_column( + df, "Select a column you want to run the embedding on" + ) + umap_args = {} + if umap_min_dist is not None: + umap_args["min_dist"] = umap_min_dist + if umap_n_neighbors is not None: + umap_args["n_neighbors"] = umap_n_neighbors + if umap_random_state is not None: + umap_args["random_state"] = umap_random_state + if umap_metric is not None: + umap_args["metric"] = umap_metric + # Run embedding and projection + if text is not None or image is not None: + from .projection import compute_image_projection, compute_text_projection + + x_column = find_column_name(df.columns, "projection_x") + y_column = find_column_name(df.columns, "projection_y") + if neighbors_column is None: + neighbors_column = find_column_name(df.columns, "__neighbors") + new_neighbors_column = neighbors_column + else: + # If neighbors_column is already specified, don't overwrite it. + new_neighbors_column = None + if text is not None: + compute_text_projection( + df, + text, + x=x_column, + y=y_column, + neighbors=new_neighbors_column, + model=model, + trust_remote_code=trust_remote_code, + umap_args=umap_args, + ) + elif image is not None: + compute_image_projection( + df, + image, + x=x_column, + y=y_column, + neighbors=new_neighbors_column, + model=model, + trust_remote_code=trust_remote_code, + umap_args=umap_args, + ) + else: + raise RuntimeError("unreachable") + id_column = find_column_name(df.columns, "_row_index") df[id_column] = range(df.shape[0]) diff --git a/packages/backend/embedding_atlas/projection.py b/packages/backend/embedding_atlas/projection.py new file mode 100644 index 0000000..3de478c --- /dev/null +++ b/packages/backend/embedding_atlas/projection.py @@ -0,0 +1,297 @@ +# Copyright (c) 2025 Apple Inc. Licensed under MIT License. + +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +import pandas as pd + +from .utils import Hasher, cache_path, logger + + +@dataclass +class Projection: + # Array with shape (N, embedding_dim), the high-dimensional embedding + projection: np.ndarray + + knn_indices: np.ndarray + knn_distances: np.ndarray + + @staticmethod + def exists(path: Path): + return ( + path.with_suffix(".projection.npy").exists() + and path.with_suffix(".knn_indices.npy").exists() + and path.with_suffix(".knn_distances.npy").exists() + ) + + @staticmethod + def save(path: Path, value: "Projection"): + np.save( + path.with_suffix(".projection.npy"), + value.projection, + allow_pickle=False, + ) + np.save( + path.with_suffix(".knn_indices.npy"), + value.knn_indices, + allow_pickle=False, + ) + np.save( + path.with_suffix(".knn_distances.npy"), + value.knn_distances, + allow_pickle=False, + ) + + @staticmethod + def load(path: Path) -> "Projection": + return Projection( + projection=np.load( + path.with_suffix(".projection.npy"), + allow_pickle=False, + ), + knn_indices=np.load( + path.with_suffix(".knn_indices.npy"), + allow_pickle=False, + ), + knn_distances=np.load( + path.with_suffix(".knn_distances.npy"), + allow_pickle=False, + ), + ) + + +def _run_umap( + hidden_vectors: np.ndarray, + umap_args: dict = {}, +) -> Projection: + logger.info("Running UMAP for input with shape %s...", str(hidden_vectors.shape)) # type: ignore + + import umap + from umap.umap_ import nearest_neighbors + + metric = umap_args.get("metric", "cosine") + n_neighbors = umap_args.get("n_neighbors", 15) + + knn = nearest_neighbors( + hidden_vectors, + n_neighbors=n_neighbors, + metric=metric, + metric_kwds=None, + angular=False, + random_state=None, + ) + + proj = umap.UMAP(**umap_args, precomputed_knn=knn) + result: np.ndarray = proj.fit_transform(hidden_vectors) # type: ignore + + return Projection(projection=result, knn_indices=knn[0], knn_distances=knn[1]) + + +def _projection_for_texts( + texts: list[str], + model: str | None = None, + trust_remote_code: bool = False, + umap_args: dict = {}, +) -> Projection: + if model is None: + model = "all-MiniLM-L6-v2" + hasher = Hasher() + hasher.update( + { + "version": 1, + "texts": texts, + "model": model, + "umap_args": umap_args, + } + ) + digest = hasher.hexdigest() + cpath = cache_path("projections") / digest + + if Projection.exists(cpath): + logger.info("Using cached projection from %s", str(cpath)) + return Projection.load(cpath) + + # Import on demand. + from sentence_transformers import SentenceTransformer + + logger.info("Loading model %s...", model) + transformer = SentenceTransformer(model, trust_remote_code=trust_remote_code) + + logger.info("Running embedding for %d texts...", len(texts)) + hidden_vectors = transformer.encode(texts) + + result = _run_umap(hidden_vectors, umap_args) + Projection.save(cpath, result) + return result + + +def _projection_for_images( + images: list, + model: str | None = None, + trust_remote_code: bool = False, + umap_args: dict = {}, +) -> Projection: + if model is None: + model = "google/vit-base-patch16-384" + hasher = Hasher() + hasher.update( + { + "version": 1, + "images": images, + "model": model, + "umap_args": umap_args, + } + ) + digest = hasher.hexdigest() + cpath = cache_path("projections") / (digest + ".npy") + + if Projection.exists(cpath): + logger.info("Using cached projection from %s", str(cpath)) + return Projection.load(cpath) + + # Import on demand. + from io import BytesIO + + import torch + import tqdm + from PIL import Image + from transformers import pipeline + + def load_image(value): + if isinstance(value, bytes): + return Image.open(BytesIO(value)).convert("RGB") + elif isinstance(value, dict) and "bytes" in value: + return Image.open(BytesIO(value["bytes"])).convert("RGB") + else: + raise ValueError("invalid image value") + + logger.info("Loading model %s...", model) + + pipe = pipeline("image-feature-extraction", model=model, device_map="auto") + + logger.info("Running embedding for %d images...", len(images)) + tensors = [] + batch_size = 16 + + current_batch = [] + + @torch.no_grad() + def process_batch(): + rs: torch.Tensor = pipe(current_batch, return_tensors=True) # type: ignore + current_batch.clear() + for r in rs: + if len(r.shape) == 3: + r = r.mean(1) + assert len(r.shape) == 2 + tensors.append(r) + + for image in tqdm.tqdm(images, smoothing=0.1): + current_batch.append(load_image(image)) + if len(current_batch) >= batch_size: + process_batch() + process_batch() + + hidden_vectors = torch.concat(tensors).to(torch.float32).cpu().numpy() + + result = _run_umap(hidden_vectors, umap_args) + Projection.save(cpath, result) + return result + + +def compute_text_projection( + data_frame: pd.DataFrame, + text: str, + x: str = "projection_x", + y: str = "projection_y", + neighbors: str | None = "neighbors", + model: str | None = None, + trust_remote_code: bool = False, + umap_args: dict = {}, +): + """ + Compute text embeddings and generate 2D projections using UMAP. + + This function processes text data by creating embeddings using a SentenceTransformer + model and then reducing the dimensionality to 2D coordinates using UMAP for + visualization purposes. + + Args: + data_frame: pandas DataFrame containing the text data to process. + text: str, column name containing the texts to embed. + x: str, column name where the UMAP X coordinates will be stored. + y: str, column name where the UMAP Y coordinates will be stored. + neighbors: str, column name where the nearest neighbor indices will be stored. + model: str, name or path of the SentenceTransformer model to use for embedding. + trust_remote_code: bool, whether to trust and execute remote code when loading + the model from HuggingFace Hub. Default is False. + umap_args: dict, additional keyword arguments to pass to the UMAP algorithm + (e.g., n_neighbors, min_dist, metric). + + Returns: + The input DataFrame with added columns for X, Y coordinates and nearest neighbors. + """ + + text_series = data_frame[text].astype(str).fillna("") + proj = _projection_for_texts( + list(text_series), + model=model, + trust_remote_code=trust_remote_code, + umap_args=umap_args, + ) + data_frame[x] = proj.projection[:, 0] + data_frame[y] = proj.projection[:, 1] + if neighbors is not None: + data_frame[neighbors] = [ + {"distances": b, "ids": a} # ID is always the same as the row index. + for a, b in zip(proj.knn_indices, proj.knn_distances) + ] + + +def compute_image_projection( + data_frame: pd.DataFrame, + image: str, + x: str = "projection_x", + y: str = "projection_y", + neighbors: str | None = "neighbors", + model: str | None = None, + trust_remote_code: bool = False, + umap_args: dict = {}, +): + """ + Compute image embeddings and generate 2D projections using UMAP. + + This function processes image data by creating embeddings using a model and + then reducing the dimensionality to 2D coordinates using UMAP for + visualization purposes. + + Args: + data_frame: pandas DataFrame containing the image data to process. + image: str, column name containing the images to embed. + x: str, column name where the UMAP X coordinates will be stored. + y: str, column name where the UMAP Y coordinates will be stored. + neighbors: str, column name where the nearest neighbor indices will be stored. + model: str, name or path of the model to use for embedding. + trust_remote_code: bool, whether to trust and execute remote code when loading + the model from HuggingFace Hub. Default is False. + umap_args: dict, additional keyword arguments to pass to the UMAP algorithm + (e.g., n_neighbors, min_dist, metric). + + Returns: + The input DataFrame with added columns for X, Y coordinates and nearest neighbors. + """ + + image_series = data_frame[image] + proj = _projection_for_images( + list(image_series), + model=model, + trust_remote_code=trust_remote_code, + umap_args=umap_args, + ) + data_frame[x] = proj.projection[:, 0] + data_frame[y] = proj.projection[:, 1] + if neighbors is not None: + data_frame[neighbors] = [ + {"distances": b, "ids": a} # ID is always the same as the row index. + for a, b in zip(proj.knn_indices, proj.knn_distances) + ] diff --git a/packages/backend/examples/notebook.ipynb b/packages/backend/examples/notebook.ipynb new file mode 100644 index 0000000..d24f924 --- /dev/null +++ b/packages/backend/examples/notebook.ipynb @@ -0,0 +1,92 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "8ce7f288-2d83-47ee-ac51-741126d4e3bd", + "metadata": {}, + "outputs": [], + "source": [ + "from embedding_atlas.widget import EmbeddingAtlasWidget\n", + "from embedding_atlas.projection import compute_text_projection\n", + "import pandas as pd\n", + "from datasets import load_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "69f21f68-c20e-4055-b073-5ab6c8f89913", + "metadata": {}, + "outputs": [], + "source": [ + "# Load a dataset\n", + "ds = load_dataset(\"james-burton/wine_reviews\", split=\"validation\")\n", + "df = pd.DataFrame(ds)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a927e7d-8dda-4d6b-ae88-60cceb6f9678", + "metadata": {}, + "outputs": [], + "source": [ + "# Compute text embedding and projection of the embedding\n", + "compute_text_projection(df, text=\"description\", x=\"projection_x\", y=\"projection_y\", neighbors=\"neighbors\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a82b7b13-2024-4ff9-b96d-8b21048ca2cf", + "metadata": {}, + "outputs": [], + "source": [ + "# Display the dataset with the Embedding Atlas widget\n", + "w = EmbeddingAtlasWidget(df, text=\"description\", x=\"projection_x\", y=\"projection_y\", neighbors=\"neighbors\")\n", + "w" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e42251b5-0e9f-4191-81fd-64522118f1cd", + "metadata": {}, + "outputs": [], + "source": [ + "# Get the selection from the widget as a dataframe\n", + "w.selection()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d2f6448-3ce9-4529-a81a-9caf3376be57", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/packages/backend/examples/streamlit.py b/packages/backend/examples/streamlit.py new file mode 100644 index 0000000..d8936b0 --- /dev/null +++ b/packages/backend/examples/streamlit.py @@ -0,0 +1,56 @@ +import duckdb +import pandas as pd +import streamlit as st +from datasets import load_dataset +from embedding_atlas.projection import compute_text_projection +from embedding_atlas.streamlit import embedding_atlas + + +@st.cache_data +def load_data(): + ds = load_dataset("james-burton/wine_reviews", split="validation") + return pd.DataFrame(ds) + + +def main(): + # Embedding Atlas looks better in wide mode + st.set_page_config(layout="wide") + + st.title("Embedding Atlas + Streamlit") + + # Load some data + st.write("Load an example dataset") + df = load_data() + + # Compute text embedding and projection of the embedding + compute_text_projection( + df, + text="description", + x="projection_x", + y="projection_y", + neighbors="neighbors", + ) + + # Create the Embedding Atlas widget in Streamlit + value = embedding_atlas( + df, + text="description", + x="projection_x", + y="projection_y", + neighbors="neighbors", + show_table=True, + ) + + # Show selected rows in a Streamlit data frame + st.write("Selected rows:") + if value is not None and value.get("predicate") is not None: + subset = duckdb.query_df( + df, "dataframe", "SELECT * FROM dataframe WHERE " + value.get("predicate") + ) + st.dataframe(subset) + else: + st.write("No selection") + + +if __name__ == "__main__": + main() diff --git a/packages/docs/streamlit.md b/packages/docs/streamlit.md index 9aa55a8..ab48ae6 100644 --- a/packages/docs/streamlit.md +++ b/packages/docs/streamlit.md @@ -12,6 +12,7 @@ pip install embedding-atlas ```python from embedding_atlas.streamlit import embedding_atlas +from embedding_atlas.projection import compute_text_projection # Compute text embedding and projection of the embedding compute_text_projection(df, text="description", diff --git a/packages/docs/tool.md b/packages/docs/tool.md index 299328b..70c0708 100644 --- a/packages/docs/tool.md +++ b/packages/docs/tool.md @@ -46,27 +46,29 @@ You can instead load datasets from Hugging Face: embedding-atlas huggingface_org/dataset_name ``` -## Visualizing Embedding Projections +## Visualizing Embeddings -To visual embedding projections, pre-compute the X and Y coordinates, and specify the column names with `--x` and `--y`, such as: +The script will use [SentenceTransformers](https://sbert.net/) to compute embedding vectors for the specified column containing the text data. The script will then project the high-dimensional embedding vectors to 2D with [UMAP](https://umap-learn.readthedocs.io/en/latest/index.html). + +::: tip +Optionally, if you know what column your text data is in beforehand, you can specify which column to use with the `--text` flag, for example: ```bash -embedding-atlas path_to_dataset.parquet --x projection_x --y projection_y +embedding-atlas path_to_dataset.parquet --text text_column ``` -You may use the [SentenceTransformers](https://sbert.net/) package to compute high-dimensional embeddings from text data, and then use the [UMAP](https://umap-learn.readthedocs.io/en/latest/index.html) package to compute 2D projections. +::: -You may also specify a column for pre-computed nearest neighbors: +If you've already pre-computed the embedding projection (e.g., by running your own embedding model and projecting them with UMAP), you may store them as two columns such as `projection_x` and `projection_y`, and pass them into `embedding-atlas` with the `--x` and `--y` flags: ```bash -embedding-atlas path_to_dataset.parquet --x projection_x --y projection_y --neighbors neighbors +embedding-atlas path_to_dataset.parquet --x projection_x --y projection_y ``` +You may also pass in the `--neighbors` flag to specify the column name for pre-computed nearest neighbors. The `neighbors` column should have values in the following format: `{"ids": [id1, id2, ...], "distances": [d1, d2, ...]}`. If this column is specified, you'll be able to see nearest neighbors for a selected point in the tool. -::: - Once this script completes, it will print out a URL like `http://localhost:5055/`. Open the URL in a web browser to view the embedding. ## Usage @@ -75,36 +77,47 @@ Once this script completes, it will print out a URL like `http://localhost:5055/ Usage: embedding-atlas [OPTIONS] INPUTS... Options: - --text TEXT The column name for text. - --image TEXT The column name for image. - --split TEXT The data split name for Hugging Face data. - Repeat this command for multiple splits. - --model TEXT The model for producing text embeddings. - --trust-remote-code Trust remote code when loading models. - --x TEXT The column name for x coordinate. - --y TEXT The column name for y coordinate. - --neighbors TEXT The column name for pre-computed nearest - neighbors. The values should be in {"ids": - [n1, n2, ...], "distances": [d1, d2, ...]} - format. - --sample INTEGER The number of rows to sample from the original - dataset. - --umap-n-neighbors INTEGER The n_neighbors parameter for UMAP. + --text TEXT Column containing text data. + --image TEXT Column containing image data. + --split TEXT Dataset split name(s) to load from Hugging + Face datasets. Can be specified multiple times + for multiple splits. + --embedding / --no-embedding Whether to compute embeddings for the data. + Disable if embeddings are pre-computed or if + you do not want an embedding view. + --model TEXT Model name for generating embeddings (e.g., + 'all-MiniLM-L6-v2'). + --trust-remote-code Allow execution of remote code when loading + models from Hugging Face Hub. + --x TEXT Column containing pre-computed X coordinates + for the embedding view. + --y TEXT Column containing pre-computed Y coordinates + for the embedding view. + --neighbors TEXT Column containing pre-computed nearest + neighbors in format: {"ids": [n1, n2, ...], + "distances": [d1, d2, ...]}. + --sample INTEGER Number of random samples to draw from the + dataset. Useful for large datasets. + --umap-n-neighbors INTEGER Number of neighbors to consider for UMAP + dimensionality reduction (default: 15). --umap-min-dist FLOAT The min_dist parameter for UMAP. - --umap-metric TEXT The metric for UMAP. - --umap-random-state INTEGER The random seed for UMAP. - --duckdb TEXT DuckDB server URI (e.g., ws://localhost:3000, - http://localhost:3000), or 'wasm' to run - DuckDB in browser, or 'server' to run DuckDB - in this server. Default to 'wasm'. - --host TEXT The hostname of the http server. - --port INTEGER The port of the http server. - --auto-port / --no-auto-port Enable / disable auto port selection. If - disabled, the application crashes if the - specified port is already used. - --static TEXT Path to the static files for frontend. - --export-application TEXT Export a static Web application to the given - zip file and exit. + --umap-metric TEXT Distance metric for UMAP computation (default: + 'cosine'). + --umap-random-state INTEGER Random seed for reproducible UMAP results. + --duckdb TEXT DuckDB connection mode: 'wasm' (run in + browser), 'server' (run on this server), or + URI (e.g., 'ws://localhost:3000'). + --host TEXT Host address for the web server (default: + localhost). + --port INTEGER Port number for the web server (default: + 5055). + --auto-port / --no-auto-port Automatically find an available port if the + specified port is in use. + --static TEXT Custom path to frontend static files + directory. + --export-application TEXT Export the visualization as a standalone web + application to the specified ZIP file and + exit. --version Show the version and exit. --help Show this message and exit. ``` diff --git a/packages/docs/widget.md b/packages/docs/widget.md index 1fcb8e3..ea2e65f 100644 --- a/packages/docs/widget.md +++ b/packages/docs/widget.md @@ -18,6 +18,8 @@ from embedding_atlas.widget import EmbeddingAtlasWidget EmbeddingAtlasWidget(df) # Compute text embedding and projection of the embedding +from embedding_atlas.projection import compute_text_projection + compute_text_projection(df, text="description", x="projection_x", y="projection_y", neighbors="neighbors" ) diff --git a/packages/viewer/src/AdhocViewer.svelte b/packages/viewer/src/AdhocViewer.svelte index 8597b88..3fa1760 100644 --- a/packages/viewer/src/AdhocViewer.svelte +++ b/packages/viewer/src/AdhocViewer.svelte @@ -8,6 +8,7 @@ import FileUpload from "./FileUpload.svelte"; import EmbeddingAtlas from "./lib/EmbeddingAtlas.svelte"; + import { computeEmbedding } from "./embedding/index.js"; import { systemDarkMode } from "./lib/dark_mode_store.js"; import { initializeDatabase } from "./lib/database_utils.js"; import { exportMosaicSelection, type ExportFormat } from "./lib/mosaic_exporter.js"; @@ -67,6 +68,28 @@ if (spec.embedding != null && "precomputed" in spec.embedding) { projectionColumns = { x: spec.embedding.precomputed.x, y: spec.embedding.precomputed.y }; } + + if (spec.embedding != null && "compute" in spec.embedding) { + let input = spec.embedding.compute.column; + let type = spec.embedding.compute.type; + let model = spec.embedding.compute.model; + let x = input + "_proj_x"; + let y = input + "_proj_y"; + await computeEmbedding({ + coordinator: coordinator, + table: "dataset", + idColumn: "__row_index__", + dataColumn: input, + type: type, + xColumn: x, + yColumn: y, + model: model, + callback: (message, progress) => { + log(`Embedding: ${message}`, progress); + }, + }); + projectionColumns = { x, y }; + } } catch (e: any) { logError(e.message); return; diff --git a/packages/viewer/src/ColumnsPicker.svelte b/packages/viewer/src/ColumnsPicker.svelte index d83d74e..9d4c4e6 100644 --- a/packages/viewer/src/ColumnsPicker.svelte +++ b/packages/viewer/src/ColumnsPicker.svelte @@ -3,6 +3,7 @@ import { untrack } from "svelte"; import Button from "./lib/widgets/Button.svelte"; + import ComboBox from "./lib/widgets/ComboBox.svelte"; import Select from "./lib/widgets/Select.svelte"; import { jsTypeFromDBType } from "./lib/database_utils.js"; @@ -116,7 +117,7 @@ embedding and its 2D projection.

- {#each [["precomputed", "Precomputed"], ["none", "Disabled"]] as [mode, label]} + {#each [["precomputed", "Precomputed"], ["from-text", "From Text"], ["from-image", "From Image"], ["none", "Disabled"]] as [mode, label]}
+ {:else if embeddingMode == "from-text"} +
+
Text
+ (embeddingImageColumn = v)} + options={[ + { value: null, label: "(none)" }, + ...columns.map((x) => ({ value: x.column_name, label: `${x.column_name} (${x.column_type})` })), + ]} + /> +
+
+
Model
+ (embeddingImageModel = v)} + options={imageModels} + /> +
+

+ Computing the embedding and 2D projection in browser may take a while. +

{/if}
diff --git a/packages/viewer/src/embedding/embedding.worker.ts b/packages/viewer/src/embedding/embedding.worker.ts new file mode 100644 index 0000000..a64d6de --- /dev/null +++ b/packages/viewer/src/embedding/embedding.worker.ts @@ -0,0 +1,102 @@ +// Copyright (c) 2025 Apple Inc. Licensed under MIT License. + +import { createUMAP } from "@embedding-atlas/umap-wasm"; +import { load_image, pipeline } from "@huggingface/transformers"; + +import { imageToDataUrl } from "../lib/image_utils"; +import { WorkerRPC } from "./worker_helper"; + +let { handler, register } = WorkerRPC.runtime(); + +onmessage = handler; + +interface EmbeddingOptions { + type: "text" | "image"; + model: string; +} + +interface EmbeddingComputer { + batch(data: any[]): Promise; + finalize(): Promise; +} + +let embeddings = new Map(); + +function makeEmbeddingComputer(runBatch: (data: any[]) => Promise): EmbeddingComputer { + let batches: any[] = []; + return { + async batch(data) { + batches.push(await runBatch(data)); + }, + async finalize() { + let count = batches.reduce((a, b) => a + b.dims[0], 0); + let input_dim = batches[0].dims[1]; + let output_dim = 2; + let data = new Float32Array(count * input_dim); + let offset = 0; + for (let i = 0; i < batches.length; i++) { + let length = batches[i].dims[0] * input_dim; + data.set(batches[i].data.subarray(0, length), offset); + offset += length; + } + let umap = await createUMAP(count, input_dim, output_dim, data, { + metric: "cosine", + }); + umap.run(); + let result = new Float32Array(umap.embedding); + umap.destroy(); + return result; + }, + }; +} + +register("embedding.new", async (options: EmbeddingOptions) => { + let instance = new Date().getTime() + "-" + Math.random(); + let pipelineOptions: any = { device: "webgpu" }; + if (options.type == "text") { + let extractor = await pipeline("feature-extraction", options.model, pipelineOptions); + let computer = makeEmbeddingComputer(async (data) => { + let inputs = data.map((x) => x?.toString() ?? ""); + let embedding = await extractor(inputs, { pooling: "mean", normalize: true }); + if (embedding.dims.length == 3) { + embedding = embedding.mean(1); + } + if (embedding.dims.length != 2 || embedding.dims[0] != data.length) { + throw new Error("output embedding dimension mismatch"); + } + return embedding; + }); + embeddings.set(instance, computer); + return instance; + } else if (options.type == "image") { + let extractor = await pipeline("image-feature-extraction", options.model, pipelineOptions); + let computer = makeEmbeddingComputer(async (data) => { + let imgs = data.map((x) => imageToDataUrl(x) ?? ""); + imgs = await Promise.all(imgs.map((x) => load_image(x))); + let embedding = await extractor(imgs); + if (embedding.dims.length == 3) { + embedding = embedding.mean(1); + } + if (embedding.dims.length != 2 || embedding.dims[0] != imgs.length) { + throw new Error("output embedding dimension mismatch"); + } + return embedding; + }); + embeddings.set(instance, computer); + return instance; + } else { + throw new Error("invalid data type"); + } +}); + +register("embedding.batch", async (instance: string, data: any[]) => { + await embeddings.get(instance)?.batch(data); +}); + +register("embedding.finalize", async (instance: string) => { + let obj = embeddings.get(instance); + if (obj) { + embeddings.delete(instance); + return obj.finalize(); + } +}); diff --git a/packages/viewer/src/embedding/index.ts b/packages/viewer/src/embedding/index.ts new file mode 100644 index 0000000..9d6b2d7 --- /dev/null +++ b/packages/viewer/src/embedding/index.ts @@ -0,0 +1,132 @@ +// Copyright (c) 2025 Apple Inc. Licensed under MIT License. + +import { type Coordinator } from "@uwdata/mosaic-core"; +import * as SQL from "@uwdata/mosaic-sql"; + +import { WorkerRPC } from "./worker_helper.js"; + +let _rpc: Promise<(name: string, ...args: any[]) => Promise> | null = null; +function connect() { + if (_rpc == null) { + let worker = new Worker(new URL("./embedding.worker.js", import.meta.url), { type: "module" }); + _rpc = WorkerRPC.connect(worker); + } + return _rpc; +} + +async function* inputBatches( + coordinator: Coordinator, + table: string, + idColumn: string, + valueColumn: string, + batchSize: number, +): AsyncGenerator { + let r = await coordinator.query(SQL.Query.from(table).select({ count: SQL.count() })); + let count = r.get(0).count; + let start = 0; + while (start < count) { + let range0 = start; + let range1 = start + batchSize; + if (range1 > count) { + range1 = count; + } + let data = await coordinator.query( + SQL.Query.from(table) + .select({ id: SQL.column(idColumn), value: SQL.column(valueColumn) }) + .offset(start) + .limit(range1 - range0), + ); + yield { total: count, data: data }; + start = range1; + } +} + +async function setResultColumns( + coordinator: Coordinator, + table: string, + idColumn: string, + xColumn: string, + yColumn: string, + allIDs: any[][], + coordinates: Float32Array, +) { + let offset = 0; + + await coordinator.exec(` + ALTER TABLE ${table} ADD COLUMN IF NOT EXISTS ${SQL.column(xColumn)} DOUBLE DEFAULT 0; + ALTER TABLE ${table} ADD COLUMN IF NOT EXISTS ${SQL.column(yColumn)} DOUBLE DEFAULT 0; + `); + + for (let ids of allIDs) { + let xy = coordinates.subarray(offset, offset + ids.length * 2); + + await coordinator.exec(` + WITH t1 AS ( + SELECT + UNNEST([${ids.map((x) => SQL.literal(x)).join(",")}]) AS id, + UNNEST([${ids.map((_, i) => xy[i * 2]).join(",")}]) AS x, + UNNEST([${ids.map((_, i) => xy[i * 2 + 1]).join(",")}]) AS y + ) + UPDATE ${table} + SET ${SQL.column(xColumn)} = t1.x, ${SQL.column(yColumn)} = t1.y + FROM t1 WHERE ${SQL.column(idColumn, table)} = t1.id + `); + + offset += ids.length * 2; + } +} + +export async function computeEmbedding(options: { + coordinator: Coordinator; + table: string; + idColumn: string; + dataColumn: string; + xColumn: string; + yColumn: string; + type: "text" | "image"; + model: string; + callback?: (message: string, progress?: number) => void; +}) { + function progress(message: string, progress?: number) { + options.callback?.(message, progress); + } + + progress(`Loading ${options.model}...`); + + let rpc = await connect(); + + let instance = await rpc("embedding.new", { type: options.type, model: options.model }); + + let allIDs: any[][] = []; + let idsCount = 0; + + for await (const { total, data } of inputBatches( + options.coordinator, + options.table, + options.idColumn, + options.dataColumn, + options.type == "text" ? 64 : 16, + )) { + progress("Processing Batches...", (idsCount / total) * 100); + + let ids = Array.from(data.getChild("id")); + let values = Array.from(data.getChild("value")); + await rpc("embedding.batch", instance, values); + allIDs.push(ids); + idsCount += ids.length; + } + + progress("UMAP Projection..."); + + let coordinates: Float32Array = await rpc("embedding.finalize", instance); + + await setResultColumns( + options.coordinator, + options.table, + options.idColumn, + options.xColumn, + options.yColumn, + allIDs, + coordinates, + ); +} diff --git a/packages/viewer/src/embedding/worker_helper.ts b/packages/viewer/src/embedding/worker_helper.ts new file mode 100644 index 0000000..fe08da0 --- /dev/null +++ b/packages/viewer/src/embedding/worker_helper.ts @@ -0,0 +1,63 @@ +// Copyright (c) 2025 Apple Inc. Licensed under MIT License. + +export class WorkerRPC { + static connect(worker: Worker): Promise<(name: string, ...args: any[]) => Promise> { + return new Promise((resolve) => { + let callbacks = new Map(); + let rpc = (name: string, ...args: any[]) => { + return new Promise((resolve, reject) => { + let id = new Date().getTime() + "-" + Math.random(); + callbacks.set(id, [resolve, reject]); + worker.postMessage({ rpc: name, id: id, args: args }); + }); + }; + worker.postMessage({ ready: true }); + worker.onmessage = (e) => { + if (e.data.ready) { + resolve(rpc); + return; + } + let cb = callbacks.get(e.data.id); + if (cb != null) { + callbacks.delete(e.data.id); + if (e.data.error) { + cb[1](new Error(e.data.error)); + } else { + cb[0](e.data.result); + } + } + }; + }); + } + + /** Call from worker */ + static runtime() { + let methods = new Map(); + let onmessage = async (ev: MessageEvent) => { + if (ev.data.ready) { + postMessage({ ready: true }); + } + if (ev.data.rpc) { + let id = ev.data.id; + let msg = { + id: id, + result: null, + error: null, + }; + try { + msg.result = await methods.get(ev.data.rpc)?.(...ev.data.args); + } catch (e: any) { + msg.error = e.toString(); + } + postMessage(msg); + } + }; + postMessage({ ready: true }); + return { + handler: onmessage, + register: (name: string, func: any) => { + methods.set(name, func); + }, + }; + } +}