diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..df573b28e8a840f0092cd0e97609ace2c73500a4 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/trackio_logo_old.png filter=lfs diff=lfs merge=lfs -text
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a2d95e45505954e32011da6f18f3f47f844beb2
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,264 @@
+import hashlib
+import os
+import warnings
+import webbrowser
+from pathlib import Path
+from typing import Any
+
+from gradio.blocks import BUILT_IN_THEMES
+from gradio.themes import Default as DefaultTheme
+from gradio.themes import ThemeClass
+from gradio_client import Client
+from huggingface_hub import SpaceStorage
+
+from trackio import context_vars, deploy, utils
+from trackio.imports import import_csv, import_tf_events
+from trackio.media import TrackioImage, TrackioVideo
+from trackio.run import Run
+from trackio.sqlite_storage import SQLiteStorage
+from trackio.table import Table
+from trackio.ui import demo
+from trackio.utils import TRACKIO_DIR, TRACKIO_LOGO_DIR
+
+__version__ = Path(__file__).parent.joinpath("version.txt").read_text().strip()
+
+__all__ = [
+ "init",
+ "log",
+ "finish",
+ "show",
+ "import_csv",
+ "import_tf_events",
+ "Image",
+ "Video",
+ "Table",
+]
+
+Image = TrackioImage
+Video = TrackioVideo
+
+
+config = {}
+
+DEFAULT_THEME = "citrus"
+
+
+def init(
+ project: str,
+ name: str | None = None,
+ space_id: str | None = None,
+ space_storage: SpaceStorage | None = None,
+ dataset_id: str | None = None,
+ config: dict | None = None,
+ resume: str = "never",
+ settings: Any = None,
+) -> Run:
+ """
+ Creates a new Trackio project and returns a [`Run`] object.
+
+ Args:
+ project (`str`):
+ The name of the project (can be an existing project to continue tracking or
+ a new project to start tracking from scratch).
+ name (`str` or `None`, *optional*, defaults to `None`):
+ The name of the run (if not provided, a default name will be generated).
+ space_id (`str` or `None`, *optional*, defaults to `None`):
+ If provided, the project will be logged to a Hugging Face Space instead of
+ a local directory. Should be a complete Space name like
+ `"username/reponame"` or `"orgname/reponame"`, or just `"reponame"` in which
+ case the Space will be created in the currently-logged-in Hugging Face
+ user's namespace. If the Space does not exist, it will be created. If the
+ Space already exists, the project will be logged to it.
+ space_storage ([`~huggingface_hub.SpaceStorage`] or `None`, *optional*, defaults to `None`):
+ Choice of persistent storage tier.
+ dataset_id (`str` or `None`, *optional*, defaults to `None`):
+ If a `space_id` is provided, a persistent Hugging Face Dataset will be
+ created and the metrics will be synced to it every 5 minutes. Specify a
+ Dataset with name like `"username/datasetname"` or `"orgname/datasetname"`,
+ or `"datasetname"` (uses currently-logged-in Hugging Face user's namespace),
+ or `None` (uses the same name as the Space but with the `"_dataset"`
+ suffix). If the Dataset does not exist, it will be created. If the Dataset
+ already exists, the project will be appended to it.
+ config (`dict` or `None`, *optional*, defaults to `None`):
+ A dictionary of configuration options. Provided for compatibility with
+ `wandb.init()`.
+ resume (`str`, *optional*, defaults to `"never"`):
+ Controls how to handle resuming a run. Can be one of:
+
+ - `"must"`: Must resume the run with the given name, raises error if run
+ doesn't exist
+ - `"allow"`: Resume the run if it exists, otherwise create a new run
+ - `"never"`: Never resume a run, always create a new one
+ settings (`Any`, *optional*, defaults to `None`):
+ Not used. Provided for compatibility with `wandb.init()`.
+
+ Returns:
+ `Run`: A [`Run`] object that can be used to log metrics and finish the run.
+ """
+ if settings is not None:
+ warnings.warn(
+ "* Warning: settings is not used. Provided for compatibility with wandb.init(). Please create an issue at: https://github.com/gradio-app/trackio/issues if you need a specific feature implemented."
+ )
+
+ if space_id is None and dataset_id is not None:
+ raise ValueError("Must provide a `space_id` when `dataset_id` is provided.")
+ space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
+ url = context_vars.current_server.get()
+
+ if url is None:
+ if space_id is None:
+ _, url, _ = demo.launch(
+ show_api=False,
+ inline=False,
+ quiet=True,
+ prevent_thread_lock=True,
+ show_error=True,
+ )
+ else:
+ url = space_id
+ context_vars.current_server.set(url)
+
+ if (
+ context_vars.current_project.get() is None
+ or context_vars.current_project.get() != project
+ ):
+ print(f"* Trackio project initialized: {project}")
+
+ if dataset_id is not None:
+ os.environ["TRACKIO_DATASET_ID"] = dataset_id
+ print(
+ f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}"
+ )
+ if space_id is None:
+ print(f"* Trackio metrics logged to: {TRACKIO_DIR}")
+ utils.print_dashboard_instructions(project)
+ else:
+ deploy.create_space_if_not_exists(space_id, space_storage, dataset_id)
+ print(
+ f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}"
+ )
+ context_vars.current_project.set(project)
+
+ client = None
+ if not space_id:
+ client = Client(url, verbose=False)
+
+ if resume == "must":
+ if name is None:
+ raise ValueError("Must provide a run name when resume='must'")
+ if name not in SQLiteStorage.get_runs(project):
+ raise ValueError(f"Run '{name}' does not exist in project '{project}'")
+ resumed = True
+ elif resume == "allow":
+ resumed = name is not None and name in SQLiteStorage.get_runs(project)
+ elif resume == "never":
+ if name is not None and name in SQLiteStorage.get_runs(project):
+ warnings.warn(
+ f"* Warning: resume='never' but a run '{name}' already exists in "
+ f"project '{project}'. Generating a new name and instead. If you want "
+ "to resume this run, call init() with resume='must' or resume='allow'."
+ )
+ name = None
+ resumed = False
+ else:
+ raise ValueError("resume must be one of: 'must', 'allow', or 'never'")
+
+ run = Run(
+ url=url,
+ project=project,
+ client=client,
+ name=name,
+ config=config,
+ space_id=space_id,
+ )
+
+ if resumed:
+ print(f"* Resumed existing run: {run.name}")
+ else:
+ print(f"* Created new run: {run.name}")
+
+ context_vars.current_run.set(run)
+ globals()["config"] = run.config
+ return run
+
+
+def log(metrics: dict, step: int | None = None) -> None:
+ """
+ Logs metrics to the current run.
+
+ Args:
+ metrics (`dict`):
+ A dictionary of metrics to log.
+ step (`int` or `None`, *optional*, defaults to `None`):
+ The step number. If not provided, the step will be incremented
+ automatically.
+ """
+ run = context_vars.current_run.get()
+ if run is None:
+ raise RuntimeError("Call trackio.init() before trackio.log().")
+ run.log(
+ metrics=metrics,
+ step=step,
+ )
+
+
+def finish():
+ """
+ Finishes the current run.
+ """
+ run = context_vars.current_run.get()
+ if run is None:
+ raise RuntimeError("Call trackio.init() before trackio.finish().")
+ run.finish()
+
+
+def show(project: str | None = None, theme: str | ThemeClass = DEFAULT_THEME):
+ """
+ Launches the Trackio dashboard.
+
+ Args:
+ project (`str` or `None`, *optional*, defaults to `None`):
+ The name of the project whose runs to show. If not provided, all projects
+ will be shown and the user can select one.
+ theme (`str` or `ThemeClass`, *optional*, defaults to `"citrus"`):
+ A Gradio Theme to use for the dashboard instead of the default `"citrus"`,
+ can be a built-in theme (e.g. `'soft'`, `'default'`), a theme from the Hub
+ (e.g. `"gstaff/xkcd"`), or a custom Theme class.
+ """
+ if theme != DEFAULT_THEME:
+ # TODO: It's a little hacky to reproduce this theme-setting logic from Gradio Blocks,
+ # but in Gradio 6.0, the theme will be set in `launch()` instead, which means that we
+ # will be able to remove this code.
+ if isinstance(theme, str):
+ if theme.lower() in BUILT_IN_THEMES:
+ theme = BUILT_IN_THEMES[theme.lower()]
+ else:
+ try:
+ theme = ThemeClass.from_hub(theme)
+ except Exception as e:
+ warnings.warn(f"Cannot load {theme}. Caught Exception: {str(e)}")
+ theme = DefaultTheme()
+ if not isinstance(theme, ThemeClass):
+ warnings.warn("Theme should be a class loaded from gradio.themes")
+ theme = DefaultTheme()
+ demo.theme: ThemeClass = theme
+ demo.theme_css = theme._get_theme_css()
+ demo.stylesheets = theme._stylesheets
+ theme_hasher = hashlib.sha256()
+ theme_hasher.update(demo.theme_css.encode("utf-8"))
+ demo.theme_hash = theme_hasher.hexdigest()
+
+ _, url, share_url = demo.launch(
+ show_api=False,
+ quiet=True,
+ inline=False,
+ prevent_thread_lock=True,
+ favicon_path=TRACKIO_LOGO_DIR / "trackio_logo_light.png",
+ allowed_paths=[TRACKIO_LOGO_DIR],
+ )
+
+ base_url = share_url + "/" if share_url else url
+ dashboard_url = base_url + f"?project={project}" if project else base_url
+ print(f"* Trackio UI launched at: {dashboard_url}")
+ webbrowser.open(dashboard_url)
+ utils.block_except_in_notebook()
diff --git a/__pycache__/__init__.cpython-310.pyc b/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0fb771bfb289af7c5ed3fc544b23ec3cd4695688
Binary files /dev/null and b/__pycache__/__init__.cpython-310.pyc differ
diff --git a/__pycache__/__init__.cpython-311.pyc b/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b0686ad1db33906c9e968b929f69d1eaf367e2e0
Binary files /dev/null and b/__pycache__/__init__.cpython-311.pyc differ
diff --git a/__pycache__/__init__.cpython-312.pyc b/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bad3227e0f9a57894354f6fdf86b63f5321d8d1c
Binary files /dev/null and b/__pycache__/__init__.cpython-312.pyc differ
diff --git a/__pycache__/api.cpython-312.pyc b/__pycache__/api.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db895ead0b2b19e5f124e694a00d930d2b875ef4
Binary files /dev/null and b/__pycache__/api.cpython-312.pyc differ
diff --git a/__pycache__/cli.cpython-311.pyc b/__pycache__/cli.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a11d442256198bc11f985299c4245b42138cc6f5
Binary files /dev/null and b/__pycache__/cli.cpython-311.pyc differ
diff --git a/__pycache__/cli.cpython-312.pyc b/__pycache__/cli.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e81c2e2075bda7d412fbb5b472655a6b41e5830e
Binary files /dev/null and b/__pycache__/cli.cpython-312.pyc differ
diff --git a/__pycache__/commit_scheduler.cpython-311.pyc b/__pycache__/commit_scheduler.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d3b72648ed7a92505efe3c5c594e6e2c057b219
Binary files /dev/null and b/__pycache__/commit_scheduler.cpython-311.pyc differ
diff --git a/__pycache__/commit_scheduler.cpython-312.pyc b/__pycache__/commit_scheduler.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5c496df12a82ee73ba1b8614d32eddb694babb74
Binary files /dev/null and b/__pycache__/commit_scheduler.cpython-312.pyc differ
diff --git a/__pycache__/context.cpython-312.pyc b/__pycache__/context.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7a215aca5b91bb762b70c64bc5cc419f6c610e05
Binary files /dev/null and b/__pycache__/context.cpython-312.pyc differ
diff --git a/__pycache__/context_vars.cpython-311.pyc b/__pycache__/context_vars.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d9f0e81ee10adda3c17c222bc516c1e79d5badcc
Binary files /dev/null and b/__pycache__/context_vars.cpython-311.pyc differ
diff --git a/__pycache__/context_vars.cpython-312.pyc b/__pycache__/context_vars.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96702ea9a765374170b36ac83b928a6b5e8a7309
Binary files /dev/null and b/__pycache__/context_vars.cpython-312.pyc differ
diff --git a/__pycache__/deploy.cpython-310.pyc b/__pycache__/deploy.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..23e835fb9694a9735f3bf34bc7a454ee4f1a9dc3
Binary files /dev/null and b/__pycache__/deploy.cpython-310.pyc differ
diff --git a/__pycache__/deploy.cpython-311.pyc b/__pycache__/deploy.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fb048416cd4f4f7cf33724e1459d63fb88e51141
Binary files /dev/null and b/__pycache__/deploy.cpython-311.pyc differ
diff --git a/__pycache__/deploy.cpython-312.pyc b/__pycache__/deploy.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e25ac39ee03ff567b86e7e1b4ca2500ca48f41de
Binary files /dev/null and b/__pycache__/deploy.cpython-312.pyc differ
diff --git a/__pycache__/dummy_commit_scheduler.cpython-310.pyc b/__pycache__/dummy_commit_scheduler.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7465aad3c0d39db51f75ad0762999e10744537d3
Binary files /dev/null and b/__pycache__/dummy_commit_scheduler.cpython-310.pyc differ
diff --git a/__pycache__/dummy_commit_scheduler.cpython-311.pyc b/__pycache__/dummy_commit_scheduler.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..517f6e2f4d1dfcfcc6a695b26cfb5b116e503b21
Binary files /dev/null and b/__pycache__/dummy_commit_scheduler.cpython-311.pyc differ
diff --git a/__pycache__/dummy_commit_scheduler.cpython-312.pyc b/__pycache__/dummy_commit_scheduler.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..40ae2a8fff8f37a435732b9f6964ebac3b90ee41
Binary files /dev/null and b/__pycache__/dummy_commit_scheduler.cpython-312.pyc differ
diff --git a/__pycache__/file_storage.cpython-311.pyc b/__pycache__/file_storage.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6c42edf83e3861ea81f56c71c37959907e07051d
Binary files /dev/null and b/__pycache__/file_storage.cpython-311.pyc differ
diff --git a/__pycache__/file_storage.cpython-312.pyc b/__pycache__/file_storage.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f26e69fe8517b1efafe61df875f0b6994e79e04d
Binary files /dev/null and b/__pycache__/file_storage.cpython-312.pyc differ
diff --git a/__pycache__/imports.cpython-311.pyc b/__pycache__/imports.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..51de8c063b2d01f3f69ddc5dd0338a91def55a0a
Binary files /dev/null and b/__pycache__/imports.cpython-311.pyc differ
diff --git a/__pycache__/imports.cpython-312.pyc b/__pycache__/imports.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f67cb1f5665ef3b833ee4738e2310063f248888b
Binary files /dev/null and b/__pycache__/imports.cpython-312.pyc differ
diff --git a/__pycache__/media.cpython-311.pyc b/__pycache__/media.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f57dfc7bf64e11ec7b6a71863da1c61a4d6dd885
Binary files /dev/null and b/__pycache__/media.cpython-311.pyc differ
diff --git a/__pycache__/media.cpython-312.pyc b/__pycache__/media.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5278976200978981f4f7e5a6d840c24e6dd4486b
Binary files /dev/null and b/__pycache__/media.cpython-312.pyc differ
diff --git a/__pycache__/run.cpython-310.pyc b/__pycache__/run.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce986079a4e824bd78d2c5fcdf47bc5a030b4193
Binary files /dev/null and b/__pycache__/run.cpython-310.pyc differ
diff --git a/__pycache__/run.cpython-311.pyc b/__pycache__/run.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4a62b4b743094cf51c093228519c668b5628aeea
Binary files /dev/null and b/__pycache__/run.cpython-311.pyc differ
diff --git a/__pycache__/run.cpython-312.pyc b/__pycache__/run.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4cec2e7ea8b7c22d1f1d3ea35a3fb76e050dbb3a
Binary files /dev/null and b/__pycache__/run.cpython-312.pyc differ
diff --git a/__pycache__/sqlite_storage.cpython-310.pyc b/__pycache__/sqlite_storage.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..111bbf14ce395fe05202b818ae4de3587ed92e6e
Binary files /dev/null and b/__pycache__/sqlite_storage.cpython-310.pyc differ
diff --git a/__pycache__/sqlite_storage.cpython-311.pyc b/__pycache__/sqlite_storage.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dcc6e812de7d24a968d7e6739cc477dcf595fafa
Binary files /dev/null and b/__pycache__/sqlite_storage.cpython-311.pyc differ
diff --git a/__pycache__/sqlite_storage.cpython-312.pyc b/__pycache__/sqlite_storage.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6584e0712a58c6538cc804719ee00f30f3d4ae9f
Binary files /dev/null and b/__pycache__/sqlite_storage.cpython-312.pyc differ
diff --git a/__pycache__/storage.cpython-312.pyc b/__pycache__/storage.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..86681eb1df8341239c0e0bb8e03d836873807911
Binary files /dev/null and b/__pycache__/storage.cpython-312.pyc differ
diff --git a/__pycache__/table.cpython-311.pyc b/__pycache__/table.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..77ceec71354cae777b968c220a24c5cee957473c
Binary files /dev/null and b/__pycache__/table.cpython-311.pyc differ
diff --git a/__pycache__/table.cpython-312.pyc b/__pycache__/table.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2fa772de941751e68a5198af82db3dd2eda79875
Binary files /dev/null and b/__pycache__/table.cpython-312.pyc differ
diff --git a/__pycache__/typehints.cpython-311.pyc b/__pycache__/typehints.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9073d269f19ed295d2f3d67556b91b5ae42bd34d
Binary files /dev/null and b/__pycache__/typehints.cpython-311.pyc differ
diff --git a/__pycache__/typehints.cpython-312.pyc b/__pycache__/typehints.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7893460bed4c2b55f36bc2b509420a4026d337f1
Binary files /dev/null and b/__pycache__/typehints.cpython-312.pyc differ
diff --git a/__pycache__/types.cpython-312.pyc b/__pycache__/types.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b17d58d91e97665a9cbe3e275eb5467dca73707
Binary files /dev/null and b/__pycache__/types.cpython-312.pyc differ
diff --git a/__pycache__/ui.cpython-310.pyc b/__pycache__/ui.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..08551e84ddeb0c06d5869cc556e40b2be889a702
Binary files /dev/null and b/__pycache__/ui.cpython-310.pyc differ
diff --git a/__pycache__/ui.cpython-311.pyc b/__pycache__/ui.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c56bc716811da81a97acca76395f204044c64c88
Binary files /dev/null and b/__pycache__/ui.cpython-311.pyc differ
diff --git a/__pycache__/ui.cpython-312.pyc b/__pycache__/ui.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..05196127a497adffeb50c2673b4eeb5275d24b62
Binary files /dev/null and b/__pycache__/ui.cpython-312.pyc differ
diff --git a/__pycache__/utils.cpython-310.pyc b/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9f54a74128e47f630e8b216c14e355f571fec93d
Binary files /dev/null and b/__pycache__/utils.cpython-310.pyc differ
diff --git a/__pycache__/utils.cpython-311.pyc b/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e8c887cd370c241ff00d7c6fec7f41724a3591dc
Binary files /dev/null and b/__pycache__/utils.cpython-311.pyc differ
diff --git a/__pycache__/utils.cpython-312.pyc b/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a6c274b4d5e1b3d79bfbe46853acf0216567b501
Binary files /dev/null and b/__pycache__/utils.cpython-312.pyc differ
diff --git a/__pycache__/video_writer.cpython-311.pyc b/__pycache__/video_writer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49e7f763939485f74379c2804bb9c9323a024f12
Binary files /dev/null and b/__pycache__/video_writer.cpython-311.pyc differ
diff --git a/__pycache__/video_writer.cpython-312.pyc b/__pycache__/video_writer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..91a7849164ff42355a2ea004b793dfa51e4e9390
Binary files /dev/null and b/__pycache__/video_writer.cpython-312.pyc differ
diff --git a/assets/trackio_logo_dark.png b/assets/trackio_logo_dark.png
new file mode 100644
index 0000000000000000000000000000000000000000..5c5c08b2387d23599f177477ef7482ff6a601df3
Binary files /dev/null and b/assets/trackio_logo_dark.png differ
diff --git a/assets/trackio_logo_light.png b/assets/trackio_logo_light.png
new file mode 100644
index 0000000000000000000000000000000000000000..b3438262c61989e6c6d16df4801a8935136115b3
Binary files /dev/null and b/assets/trackio_logo_light.png differ
diff --git a/assets/trackio_logo_old.png b/assets/trackio_logo_old.png
new file mode 100644
index 0000000000000000000000000000000000000000..48a07d40b83e23c9cc9dc0cb6544a3c6ad62b57f
--- /dev/null
+++ b/assets/trackio_logo_old.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3922c4d1e465270ad4d8abb12023f3beed5d9f7f338528a4c0ac21dcf358a1c8
+size 487101
diff --git a/assets/trackio_logo_type_dark.png b/assets/trackio_logo_type_dark.png
new file mode 100644
index 0000000000000000000000000000000000000000..6f80a3191e514a8a0beaa6ab2011e5baf8df5eda
Binary files /dev/null and b/assets/trackio_logo_type_dark.png differ
diff --git a/assets/trackio_logo_type_dark_transparent.png b/assets/trackio_logo_type_dark_transparent.png
new file mode 100644
index 0000000000000000000000000000000000000000..95b2c1f3499c502a81f2ec1094c0e09f827fb1fa
Binary files /dev/null and b/assets/trackio_logo_type_dark_transparent.png differ
diff --git a/assets/trackio_logo_type_light.png b/assets/trackio_logo_type_light.png
new file mode 100644
index 0000000000000000000000000000000000000000..f07866d245ea897b9aba417b29403f7f91cc8bbd
Binary files /dev/null and b/assets/trackio_logo_type_light.png differ
diff --git a/assets/trackio_logo_type_light_transparent.png b/assets/trackio_logo_type_light_transparent.png
new file mode 100644
index 0000000000000000000000000000000000000000..a20b4d5e64c61c91546577645310593fe3493508
Binary files /dev/null and b/assets/trackio_logo_type_light_transparent.png differ
diff --git a/cli.py b/cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..76be544f3d03671ce37cae802fc068f70d2c61e7
--- /dev/null
+++ b/cli.py
@@ -0,0 +1,32 @@
+import argparse
+
+from trackio import show
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Trackio CLI")
+ subparsers = parser.add_subparsers(dest="command")
+
+ ui_parser = subparsers.add_parser(
+ "show", help="Show the Trackio dashboard UI for a project"
+ )
+ ui_parser.add_argument(
+ "--project", required=False, help="Project name to show in the dashboard"
+ )
+ ui_parser.add_argument(
+ "--theme",
+ required=False,
+ default="citrus",
+ help="A Gradio Theme to use for the dashboard instead of the default 'citrus', can be a built-in theme (e.g. 'soft', 'default'), a theme from the Hub (e.g. 'gstaff/xkcd').",
+ )
+
+ args = parser.parse_args()
+
+ if args.command == "show":
+ show(args.project, args.theme)
+ else:
+ parser.print_help()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/commit_scheduler.py b/commit_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffa9dfab146ef594450e0ab0e5a25169abf1c07b
--- /dev/null
+++ b/commit_scheduler.py
@@ -0,0 +1,391 @@
+# Originally copied from https://github.com/huggingface/huggingface_hub/blob/d0a948fc2a32ed6e557042a95ef3e4af97ec4a7c/src/huggingface_hub/_commit_scheduler.py
+
+import atexit
+import logging
+import os
+import time
+from concurrent.futures import Future
+from dataclasses import dataclass
+from io import SEEK_END, SEEK_SET, BytesIO
+from pathlib import Path
+from threading import Lock, Thread
+from typing import Callable, Dict, List, Optional, Union
+
+from huggingface_hub.hf_api import (
+ DEFAULT_IGNORE_PATTERNS,
+ CommitInfo,
+ CommitOperationAdd,
+ HfApi,
+)
+from huggingface_hub.utils import filter_repo_objects
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass(frozen=True)
+class _FileToUpload:
+ """Temporary dataclass to store info about files to upload. Not meant to be used directly."""
+
+ local_path: Path
+ path_in_repo: str
+ size_limit: int
+ last_modified: float
+
+
+class CommitScheduler:
+ """
+ Scheduler to upload a local folder to the Hub at regular intervals (e.g. push to hub every 5 minutes).
+
+ The recommended way to use the scheduler is to use it as a context manager. This ensures that the scheduler is
+ properly stopped and the last commit is triggered when the script ends. The scheduler can also be stopped manually
+ with the `stop` method. Checkout the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#scheduled-uploads)
+ to learn more about how to use it.
+
+ Args:
+ repo_id (`str`):
+ The id of the repo to commit to.
+ folder_path (`str` or `Path`):
+ Path to the local folder to upload regularly.
+ every (`int` or `float`, *optional*):
+ The number of minutes between each commit. Defaults to 5 minutes.
+ path_in_repo (`str`, *optional*):
+ Relative path of the directory in the repo, for example: `"checkpoints/"`. Defaults to the root folder
+ of the repository.
+ repo_type (`str`, *optional*):
+ The type of the repo to commit to. Defaults to `model`.
+ revision (`str`, *optional*):
+ The revision of the repo to commit to. Defaults to `main`.
+ private (`bool`, *optional*):
+ Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
+ token (`str`, *optional*):
+ The token to use to commit to the repo. Defaults to the token saved on the machine.
+ allow_patterns (`List[str]` or `str`, *optional*):
+ If provided, only files matching at least one pattern are uploaded.
+ ignore_patterns (`List[str]` or `str`, *optional*):
+ If provided, files matching any of the patterns are not uploaded.
+ squash_history (`bool`, *optional*):
+ Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is
+ useful to avoid degraded performances on the repo when it grows too large.
+ hf_api (`HfApi`, *optional*):
+ The [`HfApi`] client to use to commit to the Hub. Can be set with custom settings (user agent, token,...).
+ on_before_commit (`Callable[[], None]`, *optional*):
+ If specified, a function that will be called before the CommitScheduler lists files to create a commit.
+
+ Example:
+ ```py
+ >>> from pathlib import Path
+ >>> from huggingface_hub import CommitScheduler
+
+ # Scheduler uploads every 10 minutes
+ >>> csv_path = Path("watched_folder/data.csv")
+ >>> CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path=csv_path.parent, every=10)
+
+ >>> with csv_path.open("a") as f:
+ ... f.write("first line")
+
+ # Some time later (...)
+ >>> with csv_path.open("a") as f:
+ ... f.write("second line")
+ ```
+
+ Example using a context manager:
+ ```py
+ >>> from pathlib import Path
+ >>> from huggingface_hub import CommitScheduler
+
+ >>> with CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path="watched_folder", every=10) as scheduler:
+ ... csv_path = Path("watched_folder/data.csv")
+ ... with csv_path.open("a") as f:
+ ... f.write("first line")
+ ... (...)
+ ... with csv_path.open("a") as f:
+ ... f.write("second line")
+
+ # Scheduler is now stopped and last commit have been triggered
+ ```
+ """
+
+ def __init__(
+ self,
+ *,
+ repo_id: str,
+ folder_path: Union[str, Path],
+ every: Union[int, float] = 5,
+ path_in_repo: Optional[str] = None,
+ repo_type: Optional[str] = None,
+ revision: Optional[str] = None,
+ private: Optional[bool] = None,
+ token: Optional[str] = None,
+ allow_patterns: Optional[Union[List[str], str]] = None,
+ ignore_patterns: Optional[Union[List[str], str]] = None,
+ squash_history: bool = False,
+ hf_api: Optional["HfApi"] = None,
+ on_before_commit: Optional[Callable[[], None]] = None,
+ ) -> None:
+ self.api = hf_api or HfApi(token=token)
+ self.on_before_commit = on_before_commit
+
+ # Folder
+ self.folder_path = Path(folder_path).expanduser().resolve()
+ self.path_in_repo = path_in_repo or ""
+ self.allow_patterns = allow_patterns
+
+ if ignore_patterns is None:
+ ignore_patterns = []
+ elif isinstance(ignore_patterns, str):
+ ignore_patterns = [ignore_patterns]
+ self.ignore_patterns = ignore_patterns + DEFAULT_IGNORE_PATTERNS
+
+ if self.folder_path.is_file():
+ raise ValueError(
+ f"'folder_path' must be a directory, not a file: '{self.folder_path}'."
+ )
+ self.folder_path.mkdir(parents=True, exist_ok=True)
+
+ # Repository
+ repo_url = self.api.create_repo(
+ repo_id=repo_id, private=private, repo_type=repo_type, exist_ok=True
+ )
+ self.repo_id = repo_url.repo_id
+ self.repo_type = repo_type
+ self.revision = revision
+ self.token = token
+
+ self.last_uploaded: Dict[Path, float] = {}
+ self.last_push_time: float | None = None
+
+ if not every > 0:
+ raise ValueError(f"'every' must be a positive integer, not '{every}'.")
+ self.lock = Lock()
+ self.every = every
+ self.squash_history = squash_history
+
+ logger.info(
+ f"Scheduled job to push '{self.folder_path}' to '{self.repo_id}' every {self.every} minutes."
+ )
+ self._scheduler_thread = Thread(target=self._run_scheduler, daemon=True)
+ self._scheduler_thread.start()
+ atexit.register(self._push_to_hub)
+
+ self.__stopped = False
+
+ def stop(self) -> None:
+ """Stop the scheduler.
+
+ A stopped scheduler cannot be restarted. Mostly for tests purposes.
+ """
+ self.__stopped = True
+
+ def __enter__(self) -> "CommitScheduler":
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback) -> None:
+ # Upload last changes before exiting
+ self.trigger().result()
+ self.stop()
+ return
+
+ def _run_scheduler(self) -> None:
+ """Dumb thread waiting between each scheduled push to Hub."""
+ while True:
+ self.last_future = self.trigger()
+ time.sleep(self.every * 60)
+ if self.__stopped:
+ break
+
+ def trigger(self) -> Future:
+ """Trigger a `push_to_hub` and return a future.
+
+ This method is automatically called every `every` minutes. You can also call it manually to trigger a commit
+ immediately, without waiting for the next scheduled commit.
+ """
+ return self.api.run_as_future(self._push_to_hub)
+
+ def _push_to_hub(self) -> Optional[CommitInfo]:
+ if self.__stopped: # If stopped, already scheduled commits are ignored
+ return None
+
+ logger.info("(Background) scheduled commit triggered.")
+ try:
+ value = self.push_to_hub()
+ if self.squash_history:
+ logger.info("(Background) squashing repo history.")
+ self.api.super_squash_history(
+ repo_id=self.repo_id, repo_type=self.repo_type, branch=self.revision
+ )
+ return value
+ except Exception as e:
+ logger.error(
+ f"Error while pushing to Hub: {e}"
+ ) # Depending on the setup, error might be silenced
+ raise
+
+ def push_to_hub(self) -> Optional[CommitInfo]:
+ """
+ Push folder to the Hub and return the commit info.
+
+
+
+ This method is not meant to be called directly. It is run in the background by the scheduler, respecting a
+ queue mechanism to avoid concurrent commits. Making a direct call to the method might lead to concurrency
+ issues.
+
+
+
+ The default behavior of `push_to_hub` is to assume an append-only folder. It lists all files in the folder and
+ uploads only changed files. If no changes are found, the method returns without committing anything. If you want
+ to change this behavior, you can inherit from [`CommitScheduler`] and override this method. This can be useful
+ for example to compress data together in a single file before committing. For more details and examples, check
+ out our [integration guide](https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#scheduled-uploads).
+ """
+ # Check files to upload (with lock)
+ with self.lock:
+ if self.on_before_commit is not None:
+ self.on_before_commit()
+
+ logger.debug("Listing files to upload for scheduled commit.")
+
+ # List files from folder (taken from `_prepare_upload_folder_additions`)
+ relpath_to_abspath = {
+ path.relative_to(self.folder_path).as_posix(): path
+ for path in sorted(
+ self.folder_path.glob("**/*")
+ ) # sorted to be deterministic
+ if path.is_file()
+ }
+ prefix = f"{self.path_in_repo.strip('/')}/" if self.path_in_repo else ""
+
+ # Filter with pattern + filter out unchanged files + retrieve current file size
+ files_to_upload: List[_FileToUpload] = []
+ for relpath in filter_repo_objects(
+ relpath_to_abspath.keys(),
+ allow_patterns=self.allow_patterns,
+ ignore_patterns=self.ignore_patterns,
+ ):
+ local_path = relpath_to_abspath[relpath]
+ stat = local_path.stat()
+ if (
+ self.last_uploaded.get(local_path) is None
+ or self.last_uploaded[local_path] != stat.st_mtime
+ ):
+ files_to_upload.append(
+ _FileToUpload(
+ local_path=local_path,
+ path_in_repo=prefix + relpath,
+ size_limit=stat.st_size,
+ last_modified=stat.st_mtime,
+ )
+ )
+
+ # Return if nothing to upload
+ if len(files_to_upload) == 0:
+ logger.debug("Dropping schedule commit: no changed file to upload.")
+ return None
+
+ # Convert `_FileToUpload` as `CommitOperationAdd` (=> compute file shas + limit to file size)
+ logger.debug("Removing unchanged files since previous scheduled commit.")
+ add_operations = [
+ CommitOperationAdd(
+ # TODO: Cap the file to its current size, even if the user append data to it while a scheduled commit is happening
+ # (requires an upstream fix for XET-535: `hf_xet` should support `BinaryIO` for upload)
+ path_or_fileobj=file_to_upload.local_path,
+ path_in_repo=file_to_upload.path_in_repo,
+ )
+ for file_to_upload in files_to_upload
+ ]
+
+ # Upload files (append mode expected - no need for lock)
+ logger.debug("Uploading files for scheduled commit.")
+ commit_info = self.api.create_commit(
+ repo_id=self.repo_id,
+ repo_type=self.repo_type,
+ operations=add_operations,
+ commit_message="Scheduled Commit",
+ revision=self.revision,
+ )
+
+ for file in files_to_upload:
+ self.last_uploaded[file.local_path] = file.last_modified
+
+ self.last_push_time = time.time()
+
+ return commit_info
+
+
+class PartialFileIO(BytesIO):
+ """A file-like object that reads only the first part of a file.
+
+ Useful to upload a file to the Hub when the user might still be appending data to it. Only the first part of the
+ file is uploaded (i.e. the part that was available when the filesystem was first scanned).
+
+ In practice, only used internally by the CommitScheduler to regularly push a folder to the Hub with minimal
+ disturbance for the user. The object is passed to `CommitOperationAdd`.
+
+ Only supports `read`, `tell` and `seek` methods.
+
+ Args:
+ file_path (`str` or `Path`):
+ Path to the file to read.
+ size_limit (`int`):
+ The maximum number of bytes to read from the file. If the file is larger than this, only the first part
+ will be read (and uploaded).
+ """
+
+ def __init__(self, file_path: Union[str, Path], size_limit: int) -> None:
+ self._file_path = Path(file_path)
+ self._file = self._file_path.open("rb")
+ self._size_limit = min(size_limit, os.fstat(self._file.fileno()).st_size)
+
+ def __del__(self) -> None:
+ self._file.close()
+ return super().__del__()
+
+ def __repr__(self) -> str:
+ return (
+ f""
+ )
+
+ def __len__(self) -> int:
+ return self._size_limit
+
+ def __getattribute__(self, name: str):
+ if name.startswith("_") or name in (
+ "read",
+ "tell",
+ "seek",
+ ): # only 3 public methods supported
+ return super().__getattribute__(name)
+ raise NotImplementedError(f"PartialFileIO does not support '{name}'.")
+
+ def tell(self) -> int:
+ """Return the current file position."""
+ return self._file.tell()
+
+ def seek(self, __offset: int, __whence: int = SEEK_SET) -> int:
+ """Change the stream position to the given offset.
+
+ Behavior is the same as a regular file, except that the position is capped to the size limit.
+ """
+ if __whence == SEEK_END:
+ # SEEK_END => set from the truncated end
+ __offset = len(self) + __offset
+ __whence = SEEK_SET
+
+ pos = self._file.seek(__offset, __whence)
+ if pos > self._size_limit:
+ return self._file.seek(self._size_limit)
+ return pos
+
+ def read(self, __size: Optional[int] = -1) -> bytes:
+ """Read at most `__size` bytes from the file.
+
+ Behavior is the same as a regular file, except that it is capped to the size limit.
+ """
+ current = self._file.tell()
+ if __size is None or __size < 0:
+ # Read until file limit
+ truncated_size = self._size_limit - current
+ else:
+ # Read until file limit or __size
+ truncated_size = min(__size, self._size_limit - current)
+ return self._file.read(truncated_size)
diff --git a/context_vars.py b/context_vars.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dab719f76cdcdf5e173bebea47666ca1b016694
--- /dev/null
+++ b/context_vars.py
@@ -0,0 +1,15 @@
+import contextvars
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from trackio.run import Run
+
+current_run: contextvars.ContextVar["Run | None"] = contextvars.ContextVar(
+ "current_run", default=None
+)
+current_project: contextvars.ContextVar[str | None] = contextvars.ContextVar(
+ "current_project", default=None
+)
+current_server: contextvars.ContextVar[str | None] = contextvars.ContextVar(
+ "current_server", default=None
+)
diff --git a/data_types/__pycache__/__init__.cpython-311.pyc b/data_types/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..79597e766b34941f993cfb0bc47e86f4f12abe1a
Binary files /dev/null and b/data_types/__pycache__/__init__.cpython-311.pyc differ
diff --git a/data_types/__pycache__/__init__.cpython-312.pyc b/data_types/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..38a9cb512c28ad442010e27a653093f71e40baa6
Binary files /dev/null and b/data_types/__pycache__/__init__.cpython-312.pyc differ
diff --git a/data_types/__pycache__/table.cpython-311.pyc b/data_types/__pycache__/table.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce6bea364a885f347b02772a9b5d722f412ebcfa
Binary files /dev/null and b/data_types/__pycache__/table.cpython-311.pyc differ
diff --git a/data_types/__pycache__/table.cpython-312.pyc b/data_types/__pycache__/table.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..16aca75496bd22d60e7b213fb799873e9744c1f8
Binary files /dev/null and b/data_types/__pycache__/table.cpython-312.pyc differ
diff --git a/deploy.py b/deploy.py
new file mode 100644
index 0000000000000000000000000000000000000000..2df76b131abe0b0fbf76f1743f859644f118ea5c
--- /dev/null
+++ b/deploy.py
@@ -0,0 +1,217 @@
+import importlib.metadata
+import io
+import os
+import time
+from importlib.resources import files
+from pathlib import Path
+
+import gradio
+import huggingface_hub
+from gradio_client import Client, handle_file
+from httpx import ReadTimeout
+from huggingface_hub.errors import RepositoryNotFoundError
+from requests import HTTPError
+
+import trackio
+from trackio.sqlite_storage import SQLiteStorage
+
+SPACE_URL = "https://huggingface.co/spaces/{space_id}"
+
+
+def _is_trackio_installed_from_source() -> bool:
+ """Check if trackio is installed from source/editable install vs PyPI."""
+ try:
+ trackio_file = trackio.__file__
+ if "site-packages" not in trackio_file:
+ return True
+
+ dist = importlib.metadata.distribution("trackio")
+ if dist.files:
+ files = list(dist.files)
+ has_pth = any(".pth" in str(f) for f in files)
+ if has_pth:
+ return True
+
+ return False
+ except (
+ AttributeError,
+ importlib.metadata.PackageNotFoundError,
+ importlib.metadata.MetadataError,
+ ValueError,
+ TypeError,
+ ):
+ return True
+
+
+def deploy_as_space(
+ space_id: str,
+ space_storage: huggingface_hub.SpaceStorage | None = None,
+ dataset_id: str | None = None,
+):
+ if (
+ os.getenv("SYSTEM") == "spaces"
+ ): # in case a repo with this function is uploaded to spaces
+ return
+
+ trackio_path = files("trackio")
+
+ hf_api = huggingface_hub.HfApi()
+
+ try:
+ huggingface_hub.create_repo(
+ space_id,
+ space_sdk="gradio",
+ space_storage=space_storage,
+ repo_type="space",
+ exist_ok=True,
+ )
+ except HTTPError as e:
+ if e.response.status_code in [401, 403]: # unauthorized or forbidden
+ print("Need 'write' access token to create a Spaces repo.")
+ huggingface_hub.login(add_to_git_credential=False)
+ huggingface_hub.create_repo(
+ space_id,
+ space_sdk="gradio",
+ space_storage=space_storage,
+ repo_type="space",
+ exist_ok=True,
+ )
+ else:
+ raise ValueError(f"Failed to create Space: {e}")
+
+ with open(Path(trackio_path, "README.md"), "r") as f:
+ readme_content = f.read()
+ readme_content = readme_content.replace("{GRADIO_VERSION}", gradio.__version__)
+ readme_buffer = io.BytesIO(readme_content.encode("utf-8"))
+ hf_api.upload_file(
+ path_or_fileobj=readme_buffer,
+ path_in_repo="README.md",
+ repo_id=space_id,
+ repo_type="space",
+ )
+
+ # We can assume pandas, gradio, and huggingface-hub are already installed in a Gradio Space.
+ # Make sure necessary dependencies are installed by creating a requirements.txt.
+ is_source_install = _is_trackio_installed_from_source()
+
+ if is_source_install:
+ requirements_content = """pyarrow>=21.0"""
+ else:
+ requirements_content = f"""pyarrow>=21.0
+trackio=={trackio.__version__}"""
+
+ requirements_buffer = io.BytesIO(requirements_content.encode("utf-8"))
+ hf_api.upload_file(
+ path_or_fileobj=requirements_buffer,
+ path_in_repo="requirements.txt",
+ repo_id=space_id,
+ repo_type="space",
+ )
+
+ huggingface_hub.utils.disable_progress_bars()
+
+ if is_source_install:
+ hf_api.upload_folder(
+ repo_id=space_id,
+ repo_type="space",
+ folder_path=trackio_path,
+ ignore_patterns=["README.md"],
+ )
+ else:
+ app_file_content = """import trackio
+trackio.show()"""
+ app_file_buffer = io.BytesIO(app_file_content.encode("utf-8"))
+ hf_api.upload_file(
+ path_or_fileobj=app_file_buffer,
+ path_in_repo="ui.py",
+ repo_id=space_id,
+ repo_type="space",
+ )
+
+ if hf_token := huggingface_hub.utils.get_token():
+ huggingface_hub.add_space_secret(space_id, "HF_TOKEN", hf_token)
+ if dataset_id is not None:
+ huggingface_hub.add_space_variable(space_id, "TRACKIO_DATASET_ID", dataset_id)
+
+
+def create_space_if_not_exists(
+ space_id: str,
+ space_storage: huggingface_hub.SpaceStorage | None = None,
+ dataset_id: str | None = None,
+) -> None:
+ """
+ Creates a new Hugging Face Space if it does not exist. If a dataset_id is provided, it will be added as a space variable.
+
+ Args:
+ space_id: The ID of the Space to create.
+ dataset_id: The ID of the Dataset to add to the Space.
+ """
+ if "/" not in space_id:
+ raise ValueError(
+ f"Invalid space ID: {space_id}. Must be in the format: username/reponame or orgname/reponame."
+ )
+ if dataset_id is not None and "/" not in dataset_id:
+ raise ValueError(
+ f"Invalid dataset ID: {dataset_id}. Must be in the format: username/datasetname or orgname/datasetname."
+ )
+ try:
+ huggingface_hub.repo_info(space_id, repo_type="space")
+ print(f"* Found existing space: {SPACE_URL.format(space_id=space_id)}")
+ if dataset_id is not None:
+ huggingface_hub.add_space_variable(
+ space_id, "TRACKIO_DATASET_ID", dataset_id
+ )
+ return
+ except RepositoryNotFoundError:
+ pass
+ except HTTPError as e:
+ if e.response.status_code in [401, 403]: # unauthorized or forbidden
+ print("Need 'write' access token to create a Spaces repo.")
+ huggingface_hub.login(add_to_git_credential=False)
+ huggingface_hub.add_space_variable(
+ space_id, "TRACKIO_DATASET_ID", dataset_id
+ )
+ else:
+ raise ValueError(f"Failed to create Space: {e}")
+
+ print(f"* Creating new space: {SPACE_URL.format(space_id=space_id)}")
+ deploy_as_space(space_id, space_storage, dataset_id)
+
+
+def wait_until_space_exists(
+ space_id: str,
+) -> None:
+ """
+ Blocks the current thread until the space exists.
+ May raise a TimeoutError if this takes quite a while.
+
+ Args:
+ space_id: The ID of the Space to wait for.
+ """
+ delay = 1
+ for _ in range(10):
+ try:
+ Client(space_id, verbose=False)
+ return
+ except (ReadTimeout, ValueError):
+ time.sleep(delay)
+ delay = min(delay * 2, 30)
+ raise TimeoutError("Waiting for space to exist took longer than expected")
+
+
+def upload_db_to_space(project: str, space_id: str) -> None:
+ """
+ Uploads the database of a local Trackio project to a Hugging Face Space.
+
+ Args:
+ project: The name of the project to upload.
+ space_id: The ID of the Space to upload to.
+ """
+ db_path = SQLiteStorage.get_project_db_path(project)
+ client = Client(space_id, verbose=False)
+ client.predict(
+ api_name="/upload_db_to_space",
+ project=project,
+ uploaded_db=handle_file(db_path),
+ hf_token=huggingface_hub.utils.get_token(),
+ )
diff --git a/dummy_commit_scheduler.py b/dummy_commit_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f5015e1479a175081080ef8908966979c1de179
--- /dev/null
+++ b/dummy_commit_scheduler.py
@@ -0,0 +1,12 @@
+# A dummy object to fit the interface of huggingface_hub's CommitScheduler
+class DummyCommitSchedulerLock:
+ def __enter__(self):
+ return None
+
+ def __exit__(self, exception_type, exception_value, exception_traceback):
+ pass
+
+
+class DummyCommitScheduler:
+ def __init__(self):
+ self.lock = DummyCommitSchedulerLock()
diff --git a/file_storage.py b/file_storage.py
new file mode 100644
index 0000000000000000000000000000000000000000..fed947c0d8a5efde11b42c1c1f0cae05edb79f0a
--- /dev/null
+++ b/file_storage.py
@@ -0,0 +1,37 @@
+from pathlib import Path
+
+try: # absolute imports when installed
+ from trackio.utils import MEDIA_DIR
+except ImportError: # relative imports for local execution on Spaces
+ from utils import MEDIA_DIR
+
+
+class FileStorage:
+ @staticmethod
+ def get_project_media_path(
+ project: str,
+ run: str | None = None,
+ step: int | None = None,
+ filename: str | None = None,
+ ) -> Path:
+ if filename is not None and step is None:
+ raise ValueError("filename requires step")
+ if step is not None and run is None:
+ raise ValueError("step requires run")
+
+ path = MEDIA_DIR / project
+ if run:
+ path /= run
+ if step is not None:
+ path /= str(step)
+ if filename:
+ path /= filename
+ return path
+
+ @staticmethod
+ def init_project_media_path(
+ project: str, run: str | None = None, step: int | None = None
+ ) -> Path:
+ path = FileStorage.get_project_media_path(project, run, step)
+ path.mkdir(parents=True, exist_ok=True)
+ return path
diff --git a/imports.py b/imports.py
new file mode 100644
index 0000000000000000000000000000000000000000..45b0d58499976280d0b2506210d38ec692abf24d
--- /dev/null
+++ b/imports.py
@@ -0,0 +1,288 @@
+import os
+from pathlib import Path
+
+import pandas as pd
+
+from trackio import deploy, utils
+from trackio.sqlite_storage import SQLiteStorage
+
+
+def import_csv(
+ csv_path: str | Path,
+ project: str,
+ name: str | None = None,
+ space_id: str | None = None,
+ dataset_id: str | None = None,
+) -> None:
+ """
+ Imports a CSV file into a Trackio project. The CSV file must contain a `"step"`
+ column, may optionally contain a `"timestamp"` column, and any other columns will be
+ treated as metrics. It should also include a header row with the column names.
+
+ TODO: call init() and return a Run object so that the user can continue to log metrics to it.
+
+ Args:
+ csv_path (`str` or `Path`):
+ The str or Path to the CSV file to import.
+ project (`str`):
+ The name of the project to import the CSV file into. Must not be an existing
+ project.
+ name (`str` or `None`, *optional*, defaults to `None`):
+ The name of the Run to import the CSV file into. If not provided, a default
+ name will be generated.
+ name (`str` or `None`, *optional*, defaults to `None`):
+ The name of the run (if not provided, a default name will be generated).
+ space_id (`str` or `None`, *optional*, defaults to `None`):
+ If provided, the project will be logged to a Hugging Face Space instead of a
+ local directory. Should be a complete Space name like `"username/reponame"`
+ or `"orgname/reponame"`, or just `"reponame"` in which case the Space will
+ be created in the currently-logged-in Hugging Face user's namespace. If the
+ Space does not exist, it will be created. If the Space already exists, the
+ project will be logged to it.
+ dataset_id (`str` or `None`, *optional*, defaults to `None`):
+ If provided, a persistent Hugging Face Dataset will be created and the
+ metrics will be synced to it every 5 minutes. Should be a complete Dataset
+ name like `"username/datasetname"` or `"orgname/datasetname"`, or just
+ `"datasetname"` in which case the Dataset will be created in the
+ currently-logged-in Hugging Face user's namespace. If the Dataset does not
+ exist, it will be created. If the Dataset already exists, the project will
+ be appended to it. If not provided, the metrics will be logged to a local
+ SQLite database, unless a `space_id` is provided, in which case a Dataset
+ will be automatically created with the same name as the Space but with the
+ `"_dataset"` suffix.
+ """
+ if SQLiteStorage.get_runs(project):
+ raise ValueError(
+ f"Project '{project}' already exists. Cannot import CSV into existing project."
+ )
+
+ csv_path = Path(csv_path)
+ if not csv_path.exists():
+ raise FileNotFoundError(f"CSV file not found: {csv_path}")
+
+ df = pd.read_csv(csv_path)
+ if df.empty:
+ raise ValueError("CSV file is empty")
+
+ column_mapping = utils.simplify_column_names(df.columns.tolist())
+ df = df.rename(columns=column_mapping)
+
+ step_column = None
+ for col in df.columns:
+ if col.lower() == "step":
+ step_column = col
+ break
+
+ if step_column is None:
+ raise ValueError("CSV file must contain a 'step' or 'Step' column")
+
+ if name is None:
+ name = csv_path.stem
+
+ metrics_list = []
+ steps = []
+ timestamps = []
+
+ numeric_columns = []
+ for column in df.columns:
+ if column == step_column:
+ continue
+ if column == "timestamp":
+ continue
+
+ try:
+ pd.to_numeric(df[column], errors="raise")
+ numeric_columns.append(column)
+ except (ValueError, TypeError):
+ continue
+
+ for _, row in df.iterrows():
+ metrics = {}
+ for column in numeric_columns:
+ value = row[column]
+ if bool(pd.notna(value)):
+ metrics[column] = float(value)
+
+ if metrics:
+ metrics_list.append(metrics)
+ steps.append(int(row[step_column]))
+
+ if "timestamp" in df.columns and bool(pd.notna(row["timestamp"])):
+ timestamps.append(str(row["timestamp"]))
+ else:
+ timestamps.append("")
+
+ if metrics_list:
+ SQLiteStorage.bulk_log(
+ project=project,
+ run=name,
+ metrics_list=metrics_list,
+ steps=steps,
+ timestamps=timestamps,
+ )
+
+ print(
+ f"* Imported {len(metrics_list)} rows from {csv_path} into project '{project}' as run '{name}'"
+ )
+ print(f"* Metrics found: {', '.join(metrics_list[0].keys())}")
+
+ space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
+ if dataset_id is not None:
+ os.environ["TRACKIO_DATASET_ID"] = dataset_id
+ print(f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}")
+
+ if space_id is None:
+ utils.print_dashboard_instructions(project)
+ else:
+ deploy.create_space_if_not_exists(space_id, dataset_id)
+ deploy.wait_until_space_exists(space_id)
+ deploy.upload_db_to_space(project, space_id)
+ print(
+ f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}"
+ )
+
+
+def import_tf_events(
+ log_dir: str | Path,
+ project: str,
+ name: str | None = None,
+ space_id: str | None = None,
+ dataset_id: str | None = None,
+) -> None:
+ """
+ Imports TensorFlow Events files from a directory into a Trackio project. Each
+ subdirectory in the log directory will be imported as a separate run.
+
+ Args:
+ log_dir (`str` or `Path`):
+ The str or Path to the directory containing TensorFlow Events files.
+ project (`str`):
+ The name of the project to import the TensorFlow Events files into. Must not
+ be an existing project.
+ name (`str` or `None`, *optional*, defaults to `None`):
+ The name prefix for runs (if not provided, will use directory names). Each
+ subdirectory will create a separate run.
+ space_id (`str` or `None`, *optional*, defaults to `None`):
+ If provided, the project will be logged to a Hugging Face Space instead of a
+ local directory. Should be a complete Space name like `"username/reponame"`
+ or `"orgname/reponame"`, or just `"reponame"` in which case the Space will
+ be created in the currently-logged-in Hugging Face user's namespace. If the
+ Space does not exist, it will be created. If the Space already exists, the
+ project will be logged to it.
+ dataset_id (`str` or `None`, *optional*, defaults to `None`):
+ If provided, a persistent Hugging Face Dataset will be created and the
+ metrics will be synced to it every 5 minutes. Should be a complete Dataset
+ name like `"username/datasetname"` or `"orgname/datasetname"`, or just
+ `"datasetname"` in which case the Dataset will be created in the
+ currently-logged-in Hugging Face user's namespace. If the Dataset does not
+ exist, it will be created. If the Dataset already exists, the project will
+ be appended to it. If not provided, the metrics will be logged to a local
+ SQLite database, unless a `space_id` is provided, in which case a Dataset
+ will be automatically created with the same name as the Space but with the
+ `"_dataset"` suffix.
+ """
+ try:
+ from tbparse import SummaryReader
+ except ImportError:
+ raise ImportError(
+ "The `tbparse` package is not installed but is required for `import_tf_events`. Please install trackio with the `tensorboard` extra: `pip install trackio[tensorboard]`."
+ )
+
+ if SQLiteStorage.get_runs(project):
+ raise ValueError(
+ f"Project '{project}' already exists. Cannot import TF events into existing project."
+ )
+
+ path = Path(log_dir)
+ if not path.exists():
+ raise FileNotFoundError(f"TF events directory not found: {path}")
+
+ # Use tbparse to read all tfevents files in the directory structure
+ reader = SummaryReader(str(path), extra_columns={"dir_name"})
+ df = reader.scalars
+
+ if df.empty:
+ raise ValueError(f"No TensorFlow events data found in {path}")
+
+ total_imported = 0
+ imported_runs = []
+
+ # Group by dir_name to create separate runs
+ for dir_name, group_df in df.groupby("dir_name"):
+ try:
+ # Determine run name based on directory name
+ if dir_name == "":
+ run_name = "main" # For files in the root directory
+ else:
+ run_name = dir_name # Use directory name
+
+ if name:
+ run_name = f"{name}_{run_name}"
+
+ if group_df.empty:
+ print(f"* Skipping directory {dir_name}: no scalar data found")
+ continue
+
+ metrics_list = []
+ steps = []
+ timestamps = []
+
+ for _, row in group_df.iterrows():
+ # Convert row values to appropriate types
+ tag = str(row["tag"])
+ value = float(row["value"])
+ step = int(row["step"])
+
+ metrics = {tag: value}
+ metrics_list.append(metrics)
+ steps.append(step)
+
+ # Use wall_time if present, else fallback
+ if "wall_time" in group_df.columns and not bool(
+ pd.isna(row["wall_time"])
+ ):
+ timestamps.append(str(row["wall_time"]))
+ else:
+ timestamps.append("")
+
+ if metrics_list:
+ SQLiteStorage.bulk_log(
+ project=project,
+ run=str(run_name),
+ metrics_list=metrics_list,
+ steps=steps,
+ timestamps=timestamps,
+ )
+
+ total_imported += len(metrics_list)
+ imported_runs.append(run_name)
+
+ print(
+ f"* Imported {len(metrics_list)} scalar events from directory '{dir_name}' as run '{run_name}'"
+ )
+ print(f"* Metrics in this run: {', '.join(set(group_df['tag']))}")
+
+ except Exception as e:
+ print(f"* Error processing directory {dir_name}: {e}")
+ continue
+
+ if not imported_runs:
+ raise ValueError("No valid TensorFlow events data could be imported")
+
+ print(f"* Total imported events: {total_imported}")
+ print(f"* Created runs: {', '.join(imported_runs)}")
+
+ space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
+ if dataset_id is not None:
+ os.environ["TRACKIO_DATASET_ID"] = dataset_id
+ print(f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}")
+
+ if space_id is None:
+ utils.print_dashboard_instructions(project)
+ else:
+ deploy.create_space_if_not_exists(space_id, dataset_id)
+ deploy.wait_until_space_exists(space_id)
+ deploy.upload_db_to_space(project, space_id)
+ print(
+ f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}"
+ )
diff --git a/media.py b/media.py
new file mode 100644
index 0000000000000000000000000000000000000000..a557b3a4be11997c16b5b656f4b5247f84fd191b
--- /dev/null
+++ b/media.py
@@ -0,0 +1,286 @@
+import os
+import shutil
+import uuid
+from abc import ABC, abstractmethod
+from pathlib import Path
+from typing import Literal
+
+import numpy as np
+from PIL import Image as PILImage
+
+try: # absolute imports when installed
+ from trackio.file_storage import FileStorage
+ from trackio.utils import MEDIA_DIR
+ from trackio.video_writer import write_video
+except ImportError: # relative imports for local execution on Spaces
+ from file_storage import FileStorage
+ from utils import MEDIA_DIR
+ from video_writer import write_video
+
+
+class TrackioMedia(ABC):
+ """
+ Abstract base class for Trackio media objects
+ Provides shared functionality for file handling and serialization.
+ """
+
+ TYPE: str
+
+ def __init_subclass__(cls, **kwargs):
+ """Ensure subclasses define the TYPE attribute."""
+ super().__init_subclass__(**kwargs)
+ if not hasattr(cls, "TYPE") or cls.TYPE is None:
+ raise TypeError(f"Class {cls.__name__} must define TYPE attribute")
+
+ def __init__(self, value, caption: str | None = None):
+ self.caption = caption
+ self._value = value
+ self._file_path: Path | None = None
+
+ # Validate file existence for string/Path inputs
+ if isinstance(self._value, str | Path):
+ if not os.path.isfile(self._value):
+ raise ValueError(f"File not found: {self._value}")
+
+ def _file_extension(self) -> str:
+ if self._file_path:
+ return self._file_path.suffix[1:].lower()
+ if isinstance(self._value, str | Path):
+ path = Path(self._value)
+ return path.suffix[1:].lower()
+ if hasattr(self, "_format") and self._format:
+ return self._format
+ return "unknown"
+
+ def _get_relative_file_path(self) -> Path | None:
+ return self._file_path
+
+ def _get_absolute_file_path(self) -> Path | None:
+ if self._file_path:
+ return MEDIA_DIR / self._file_path
+ return None
+
+ def _save(self, project: str, run: str, step: int = 0):
+ if self._file_path:
+ return
+
+ media_dir = FileStorage.init_project_media_path(project, run, step)
+ filename = f"{uuid.uuid4()}.{self._file_extension()}"
+ file_path = media_dir / filename
+
+ # Delegate to subclass-specific save logic
+ self._save_media(file_path)
+
+ self._file_path = file_path.relative_to(MEDIA_DIR)
+
+ @abstractmethod
+ def _save_media(self, file_path: Path):
+ """
+ Performs the actual media saving logic.
+ """
+ pass
+
+ def _to_dict(self) -> dict:
+ if not self._file_path:
+ raise ValueError("Media must be saved to file before serialization")
+ return {
+ "_type": self.TYPE,
+ "file_path": str(self._get_relative_file_path()),
+ "caption": self.caption,
+ }
+
+
+TrackioImageSourceType = str | Path | np.ndarray | PILImage.Image
+
+
+class TrackioImage(TrackioMedia):
+ """
+ Initializes an Image object.
+
+ Example:
+ ```python
+ import trackio
+ import numpy as np
+ from PIL import Image
+
+ # Create an image from numpy array
+ image_data = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
+ image = trackio.Image(image_data, caption="Random image")
+ trackio.log({"my_image": image})
+
+ # Create an image from PIL Image
+ pil_image = Image.new('RGB', (100, 100), color='red')
+ image = trackio.Image(pil_image, caption="Red square")
+ trackio.log({"red_image": image})
+
+ # Create an image from file path
+ image = trackio.Image("path/to/image.jpg", caption="Photo from file")
+ trackio.log({"file_image": image})
+ ```
+
+ Args:
+ value (`str`, `Path`, `numpy.ndarray`, or `PIL.Image`, *optional*, defaults to `None`):
+ A path to an image, a PIL Image, or a numpy array of shape (height, width, channels).
+ caption (`str`, *optional*, defaults to `None`):
+ A string caption for the image.
+ """
+
+ TYPE = "trackio.image"
+
+ def __init__(self, value: TrackioImageSourceType, caption: str | None = None):
+ super().__init__(value, caption)
+ self._format: str | None = None
+
+ if (
+ isinstance(self._value, np.ndarray | PILImage.Image)
+ and self._format is None
+ ):
+ self._format = "png"
+
+ def _as_pil(self) -> PILImage.Image | None:
+ try:
+ if isinstance(self._value, np.ndarray):
+ arr = np.asarray(self._value).astype("uint8")
+ return PILImage.fromarray(arr).convert("RGBA")
+ if isinstance(self._value, PILImage.Image):
+ return self._value.convert("RGBA")
+ except Exception as e:
+ raise ValueError(f"Failed to process image data: {self._value}") from e
+ return None
+
+ def _save_media(self, file_path: Path):
+ if pil := self._as_pil():
+ pil.save(file_path, format=self._format)
+ elif isinstance(self._value, str | Path):
+ if os.path.isfile(self._value):
+ shutil.copy(self._value, file_path)
+ else:
+ raise ValueError(f"File not found: {self._value}")
+
+
+TrackioVideoSourceType = str | Path | np.ndarray
+TrackioVideoFormatType = Literal["gif", "mp4", "webm"]
+
+
+class TrackioVideo(TrackioMedia):
+ """
+ Initializes a Video object.
+
+ Example:
+ ```python
+ import trackio
+ import numpy as np
+
+ # Create a simple video from numpy array
+ frames = np.random.randint(0, 255, (10, 3, 64, 64), dtype=np.uint8)
+ video = trackio.Video(frames, caption="Random video", fps=30)
+
+ # Create a batch of videos
+ batch_frames = np.random.randint(0, 255, (3, 10, 3, 64, 64), dtype=np.uint8)
+ batch_video = trackio.Video(batch_frames, caption="Batch of videos", fps=15)
+
+ # Create video from file path
+ video = trackio.Video("path/to/video.mp4", caption="Video from file")
+ ```
+
+ Args:
+ value (`str`, `Path`, or `numpy.ndarray`, *optional*, defaults to `None`):
+ A path to a video file, or a numpy array.
+ The array should be of type `np.uint8` with RGB values in the range `[0, 255]`.
+ It is expected to have shape of either (frames, channels, height, width) or (batch, frames, channels, height, width).
+ For the latter, the videos will be tiled into a grid.
+ caption (`str`, *optional*, defaults to `None`):
+ A string caption for the video.
+ fps (`int`, *optional*, defaults to `None`):
+ Frames per second for the video. Only used when value is an ndarray. Default is `24`.
+ format (`Literal["gif", "mp4", "webm"]`, *optional*, defaults to `None`):
+ Video format ("gif", "mp4", or "webm"). Only used when value is an ndarray. Default is "gif".
+ """
+
+ TYPE = "trackio.video"
+
+ def __init__(
+ self,
+ value: TrackioVideoSourceType,
+ caption: str | None = None,
+ fps: int | None = None,
+ format: TrackioVideoFormatType | None = None,
+ ):
+ super().__init__(value, caption)
+ if isinstance(value, np.ndarray):
+ if format is None:
+ format = "gif"
+ if fps is None:
+ fps = 24
+ self._fps = fps
+ self._format = format
+
+ @property
+ def _codec(self) -> str:
+ match self._format:
+ case "gif":
+ return "gif"
+ case "mp4":
+ return "h264"
+ case "webm":
+ return "vp9"
+ case _:
+ raise ValueError(f"Unsupported format: {self._format}")
+
+ def _save_media(self, file_path: Path):
+ if isinstance(self._value, np.ndarray):
+ video = TrackioVideo._process_ndarray(self._value)
+ write_video(file_path, video, fps=self._fps, codec=self._codec)
+ elif isinstance(self._value, str | Path):
+ if os.path.isfile(self._value):
+ shutil.copy(self._value, file_path)
+ else:
+ raise ValueError(f"File not found: {self._value}")
+
+ @staticmethod
+ def _process_ndarray(value: np.ndarray) -> np.ndarray:
+ # Verify value is either 4D (single video) or 5D array (batched videos).
+ # Expected format: (frames, channels, height, width) or (batch, frames, channels, height, width)
+ if value.ndim < 4:
+ raise ValueError(
+ "Video requires at least 4 dimensions (frames, channels, height, width)"
+ )
+ if value.ndim > 5:
+ raise ValueError(
+ "Videos can have at most 5 dimensions (batch, frames, channels, height, width)"
+ )
+ if value.ndim == 4:
+ # Reshape to 5D with single batch: (1, frames, channels, height, width)
+ value = value[np.newaxis, ...]
+
+ value = TrackioVideo._tile_batched_videos(value)
+ return value
+
+ @staticmethod
+ def _tile_batched_videos(video: np.ndarray) -> np.ndarray:
+ """
+ Tiles a batch of videos into a grid of videos.
+
+ Input format: (batch, frames, channels, height, width) - original FCHW format
+ Output format: (frames, total_height, total_width, channels)
+ """
+ batch_size, frames, channels, height, width = video.shape
+
+ next_pow2 = 1 << (batch_size - 1).bit_length()
+ if batch_size != next_pow2:
+ pad_len = next_pow2 - batch_size
+ pad_shape = (pad_len, frames, channels, height, width)
+ padding = np.zeros(pad_shape, dtype=video.dtype)
+ video = np.concatenate((video, padding), axis=0)
+ batch_size = next_pow2
+
+ n_rows = 1 << ((batch_size.bit_length() - 1) // 2)
+ n_cols = batch_size // n_rows
+
+ # Reshape to grid layout: (n_rows, n_cols, frames, channels, height, width)
+ video = video.reshape(n_rows, n_cols, frames, channels, height, width)
+
+ # Rearrange dimensions to (frames, total_height, total_width, channels)
+ video = video.transpose(2, 0, 4, 1, 5, 3)
+ video = video.reshape(frames, n_rows * height, n_cols * width, channels)
+ return video
diff --git a/py.typed b/py.typed
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/run.py b/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..c165bf32a2a4170124f3680fbee8b2d6df7dbdee
--- /dev/null
+++ b/run.py
@@ -0,0 +1,156 @@
+import threading
+import time
+
+import huggingface_hub
+from gradio_client import Client, handle_file
+
+from trackio.media import TrackioMedia
+from trackio.sqlite_storage import SQLiteStorage
+from trackio.table import Table
+from trackio.typehints import LogEntry, UploadEntry
+from trackio.utils import (
+ RESERVED_KEYS,
+ fibo,
+ generate_readable_name,
+ serialize_values,
+)
+
+BATCH_SEND_INTERVAL = 0.5
+
+
+class Run:
+ def __init__(
+ self,
+ url: str,
+ project: str,
+ client: Client | None,
+ name: str | None = None,
+ config: dict | None = None,
+ space_id: str | None = None,
+ ):
+ self.url = url
+ self.project = project
+ self._client_lock = threading.Lock()
+ self._client_thread = None
+ self._client = client
+ self._space_id = space_id
+ self.name = name or generate_readable_name(
+ SQLiteStorage.get_runs(project), space_id
+ )
+ self.config = config or {}
+ self._queued_logs: list[LogEntry] = []
+ self._queued_uploads: list[UploadEntry] = []
+ self._stop_flag = threading.Event()
+
+ self._client_thread = threading.Thread(target=self._init_client_background)
+ self._client_thread.daemon = True
+ self._client_thread.start()
+
+ def _batch_sender(self):
+ """Send batched logs every BATCH_SEND_INTERVAL."""
+ while not self._stop_flag.is_set() or len(self._queued_logs) > 0:
+ # If the stop flag has been set, then just quickly send all
+ # the logs and exit.
+ if not self._stop_flag.is_set():
+ time.sleep(BATCH_SEND_INTERVAL)
+
+ with self._client_lock:
+ if self._client is None:
+ return
+ if self._queued_logs:
+ logs_to_send = self._queued_logs.copy()
+ self._queued_logs.clear()
+ self._client.predict(
+ api_name="/bulk_log",
+ logs=logs_to_send,
+ hf_token=huggingface_hub.utils.get_token(),
+ )
+ if self._queued_uploads:
+ uploads_to_send = self._queued_uploads.copy()
+ self._queued_uploads.clear()
+ self._client.predict(
+ api_name="/bulk_upload_media",
+ uploads=uploads_to_send,
+ hf_token=huggingface_hub.utils.get_token(),
+ )
+
+ def _init_client_background(self):
+ if self._client is None:
+ fib = fibo()
+ for sleep_coefficient in fib:
+ try:
+ client = Client(self.url, verbose=False)
+
+ with self._client_lock:
+ self._client = client
+ break
+ except Exception:
+ pass
+ if sleep_coefficient is not None:
+ time.sleep(0.1 * sleep_coefficient)
+
+ self._batch_sender()
+
+ def _process_media(self, metrics, step: int | None) -> dict:
+ """
+ Serialize media in metrics and upload to space if needed.
+ """
+ serializable_metrics = {}
+ if not step:
+ step = 0
+ for key, value in metrics.items():
+ if isinstance(value, TrackioMedia):
+ value._save(self.project, self.name, step)
+ serializable_metrics[key] = value._to_dict()
+ if self._space_id:
+ # Upload local media when deploying to space
+ upload_entry: UploadEntry = {
+ "project": self.project,
+ "run": self.name,
+ "step": step,
+ "uploaded_file": handle_file(value._get_absolute_file_path()),
+ }
+ with self._client_lock:
+ self._queued_uploads.append(upload_entry)
+ else:
+ serializable_metrics[key] = value
+ return serializable_metrics
+
+ @staticmethod
+ def _replace_tables(metrics):
+ for k, v in metrics.items():
+ if isinstance(v, Table):
+ metrics[k] = v._to_dict()
+
+ def log(self, metrics: dict, step: int | None = None):
+ for k in metrics.keys():
+ if k in RESERVED_KEYS or k.startswith("__"):
+ raise ValueError(
+ f"Please do not use this reserved key as a metric: {k}"
+ )
+ Run._replace_tables(metrics)
+
+ metrics = self._process_media(metrics, step)
+ metrics = serialize_values(metrics)
+ log_entry: LogEntry = {
+ "project": self.project,
+ "run": self.name,
+ "metrics": metrics,
+ "step": step,
+ }
+
+ with self._client_lock:
+ self._queued_logs.append(log_entry)
+
+ def finish(self):
+ """Cleanup when run is finished."""
+ self._stop_flag.set()
+
+ # Wait for the batch sender to finish before joining the client thread.
+ time.sleep(2 * BATCH_SEND_INTERVAL)
+
+ if self._client_thread is not None:
+ print(
+ f"* Run finished. Uploading logs to Trackio Space: {self.url} (please wait...)"
+ )
+ self._client_thread.join()
diff --git a/sqlite_storage.py b/sqlite_storage.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e6207cb643d807cbfee562a3cf1baf5be254ddf
--- /dev/null
+++ b/sqlite_storage.py
@@ -0,0 +1,440 @@
+import fcntl
+import json
+import os
+import sqlite3
+import time
+from datetime import datetime
+from pathlib import Path
+from threading import Lock
+
+import huggingface_hub as hf
+import pandas as pd
+
+try: # absolute imports when installed
+ from trackio.commit_scheduler import CommitScheduler
+ from trackio.dummy_commit_scheduler import DummyCommitScheduler
+ from trackio.utils import (
+ TRACKIO_DIR,
+ deserialize_values,
+ serialize_values,
+ )
+except Exception: # relative imports for local execution on Spaces
+ from commit_scheduler import CommitScheduler
+ from dummy_commit_scheduler import DummyCommitScheduler
+ from utils import TRACKIO_DIR, deserialize_values, serialize_values
+
+
+class ProcessLock:
+ """A simple file-based lock that works across processes."""
+
+ def __init__(self, lockfile_path: Path):
+ self.lockfile_path = lockfile_path
+ self.lockfile = None
+
+ def __enter__(self):
+ """Acquire the lock with retry logic."""
+ self.lockfile_path.parent.mkdir(parents=True, exist_ok=True)
+ self.lockfile = open(self.lockfile_path, "w")
+
+ max_retries = 100
+ for attempt in range(max_retries):
+ try:
+ fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
+ return self
+ except IOError:
+ if attempt < max_retries - 1:
+ time.sleep(0.1)
+ else:
+ raise IOError("Could not acquire database lock after 10 seconds")
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """Release the lock."""
+ if self.lockfile:
+ fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_UN)
+ self.lockfile.close()
+
+
+class SQLiteStorage:
+ _dataset_import_attempted = False
+ _current_scheduler: CommitScheduler | DummyCommitScheduler | None = None
+ _scheduler_lock = Lock()
+
+ @staticmethod
+ def _get_connection(db_path: Path) -> sqlite3.Connection:
+ conn = sqlite3.connect(str(db_path), timeout=30.0)
+ conn.execute("PRAGMA journal_mode = WAL")
+ conn.row_factory = sqlite3.Row
+ return conn
+
+ @staticmethod
+ def _get_process_lock(project: str) -> ProcessLock:
+ lockfile_path = TRACKIO_DIR / f"{project}.lock"
+ return ProcessLock(lockfile_path)
+
+ @staticmethod
+ def get_project_db_filename(project: str) -> Path:
+ """Get the database filename for a specific project."""
+ safe_project_name = "".join(
+ c for c in project if c.isalnum() or c in ("-", "_")
+ ).rstrip()
+ if not safe_project_name:
+ safe_project_name = "default"
+ return f"{safe_project_name}.db"
+
+ @staticmethod
+ def get_project_db_path(project: str) -> Path:
+ """Get the database path for a specific project."""
+ filename = SQLiteStorage.get_project_db_filename(project)
+ return TRACKIO_DIR / filename
+
+ @staticmethod
+ def init_db(project: str) -> Path:
+ """
+ Initialize the SQLite database with required tables.
+ If there is a dataset ID provided, copies from that dataset instead.
+ Returns the database path.
+ """
+ db_path = SQLiteStorage.get_project_db_path(project)
+ db_path.parent.mkdir(parents=True, exist_ok=True)
+ with SQLiteStorage._get_process_lock(project):
+ with sqlite3.connect(db_path, timeout=30.0) as conn:
+ conn.execute("PRAGMA journal_mode = WAL")
+ cursor = conn.cursor()
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS metrics (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL,
+ run_name TEXT NOT NULL,
+ step INTEGER NOT NULL,
+ metrics TEXT NOT NULL
+ )
+ """)
+ cursor.execute(
+ """
+ CREATE INDEX IF NOT EXISTS idx_metrics_run_step
+ ON metrics(run_name, step)
+ """
+ )
+ conn.commit()
+ return db_path
+
+ @staticmethod
+ def export_to_parquet():
+ """
+ Exports all projects' DB files as Parquet under the same path but with extension ".parquet".
+ """
+ # don't attempt to export (potentially wrong/blank) data before importing for the first time
+ if not SQLiteStorage._dataset_import_attempted:
+ return
+ all_paths = os.listdir(TRACKIO_DIR)
+ db_paths = [f for f in all_paths if f.endswith(".db")]
+ for db_path in db_paths:
+ db_path = TRACKIO_DIR / db_path
+ parquet_path = db_path.with_suffix(".parquet")
+ if (not parquet_path.exists()) or (
+ db_path.stat().st_mtime > parquet_path.stat().st_mtime
+ ):
+ with sqlite3.connect(db_path) as conn:
+ df = pd.read_sql("SELECT * from metrics", conn)
+ # break out the single JSON metrics column into individual columns
+ metrics = df["metrics"].copy()
+ metrics = pd.DataFrame(
+ metrics.apply(
+ lambda x: deserialize_values(json.loads(x))
+ ).values.tolist(),
+ index=df.index,
+ )
+ del df["metrics"]
+ for col in metrics.columns:
+ df[col] = metrics[col]
+ df.to_parquet(parquet_path)
+
+ @staticmethod
+ def import_from_parquet():
+ """
+ Imports to all DB files that have matching files under the same path but with extension ".parquet".
+ """
+ all_paths = os.listdir(TRACKIO_DIR)
+ parquet_paths = [f for f in all_paths if f.endswith(".parquet")]
+ for parquet_path in parquet_paths:
+ parquet_path = TRACKIO_DIR / parquet_path
+ db_path = parquet_path.with_suffix(".db")
+ df = pd.read_parquet(parquet_path)
+ with sqlite3.connect(db_path) as conn:
+ # fix up df to have a single JSON metrics column
+ if "metrics" not in df.columns:
+ # separate other columns from metrics
+ metrics = df.copy()
+ other_cols = ["id", "timestamp", "run_name", "step"]
+ df = df[other_cols]
+ for col in other_cols:
+ del metrics[col]
+ # combine them all into a single metrics col
+ metrics = json.loads(metrics.to_json(orient="records"))
+ df["metrics"] = [
+ json.dumps(serialize_values(row)) for row in metrics
+ ]
+ df.to_sql("metrics", conn, if_exists="replace", index=False)
+
+ @staticmethod
+ def get_scheduler():
+ """
+ Get the scheduler for the database based on the environment variables.
+ This applies to both local and Spaces.
+ """
+ with SQLiteStorage._scheduler_lock:
+ if SQLiteStorage._current_scheduler is not None:
+ return SQLiteStorage._current_scheduler
+ hf_token = os.environ.get("HF_TOKEN")
+ dataset_id = os.environ.get("TRACKIO_DATASET_ID")
+ space_repo_name = os.environ.get("SPACE_REPO_NAME")
+ if dataset_id is None or space_repo_name is None:
+ scheduler = DummyCommitScheduler()
+ else:
+ scheduler = CommitScheduler(
+ repo_id=dataset_id,
+ repo_type="dataset",
+ folder_path=TRACKIO_DIR,
+ private=True,
+ allow_patterns=["*.parquet", "media/**/*"],
+ squash_history=True,
+ token=hf_token,
+ on_before_commit=SQLiteStorage.export_to_parquet,
+ )
+ SQLiteStorage._current_scheduler = scheduler
+ return scheduler
+
+ @staticmethod
+ def log(project: str, run: str, metrics: dict, step: int | None = None):
+ """
+ Safely log metrics to the database. Before logging, this method will ensure the database exists
+ and is set up with the correct tables. It also uses a cross-process lock to prevent
+ database locking errors when multiple processes access the same database.
+
+ This method is not used in the latest versions of Trackio (replaced by bulk_log) but
+ is kept for backwards compatibility for users who are connecting to a newer version of
+ a Trackio Spaces dashboard with an older version of Trackio installed locally.
+ """
+ db_path = SQLiteStorage.init_db(project)
+
+ with SQLiteStorage._get_process_lock(project):
+ with SQLiteStorage._get_connection(db_path) as conn:
+ cursor = conn.cursor()
+
+ cursor.execute(
+ """
+ SELECT MAX(step)
+ FROM metrics
+ WHERE run_name = ?
+ """,
+ (run,),
+ )
+ last_step = cursor.fetchone()[0]
+ if step is None:
+ current_step = 0 if last_step is None else last_step + 1
+ else:
+ current_step = step
+
+ current_timestamp = datetime.now().isoformat()
+
+ cursor.execute(
+ """
+ INSERT INTO metrics
+ (timestamp, run_name, step, metrics)
+ VALUES (?, ?, ?, ?)
+ """,
+ (
+ current_timestamp,
+ run,
+ current_step,
+ json.dumps(serialize_values(metrics)),
+ ),
+ )
+ conn.commit()
+
+ @staticmethod
+ def bulk_log(
+ project: str,
+ run: str,
+ metrics_list: list[dict],
+ steps: list[int] | None = None,
+ timestamps: list[str] | None = None,
+ ):
+ """
+ Safely log bulk metrics to the database. Before logging, this method will ensure the database exists
+ and is set up with the correct tables. It also uses a cross-process lock to prevent
+ database locking errors when multiple processes access the same database.
+ """
+ if not metrics_list:
+ return
+
+ if timestamps is None:
+ timestamps = [datetime.now().isoformat()] * len(metrics_list)
+
+ db_path = SQLiteStorage.init_db(project)
+ with SQLiteStorage._get_process_lock(project):
+ with SQLiteStorage._get_connection(db_path) as conn:
+ cursor = conn.cursor()
+
+ if steps is None:
+ steps = list(range(len(metrics_list)))
+ elif any(s is None for s in steps):
+ cursor.execute(
+ "SELECT MAX(step) FROM metrics WHERE run_name = ?", (run,)
+ )
+ last_step = cursor.fetchone()[0]
+ current_step = 0 if last_step is None else last_step + 1
+
+ processed_steps = []
+ for step in steps:
+ if step is None:
+ processed_steps.append(current_step)
+ current_step += 1
+ else:
+ processed_steps.append(step)
+ steps = processed_steps
+
+ if len(metrics_list) != len(steps) or len(metrics_list) != len(
+ timestamps
+ ):
+ raise ValueError(
+ "metrics_list, steps, and timestamps must have the same length"
+ )
+
+ data = []
+ for i, metrics in enumerate(metrics_list):
+ data.append(
+ (
+ timestamps[i],
+ run,
+ steps[i],
+ json.dumps(serialize_values(metrics)),
+ )
+ )
+
+ cursor.executemany(
+ """
+ INSERT INTO metrics
+ (timestamp, run_name, step, metrics)
+ VALUES (?, ?, ?, ?)
+ """,
+ data,
+ )
+ conn.commit()
+
+ @staticmethod
+ def get_logs(project: str, run: str) -> list[dict]:
+ """Retrieve logs for a specific run. Logs include the step count (int) and the timestamp (datetime object)."""
+ db_path = SQLiteStorage.get_project_db_path(project)
+ if not db_path.exists():
+ return []
+
+ with SQLiteStorage._get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT timestamp, step, metrics
+ FROM metrics
+ WHERE run_name = ?
+ ORDER BY timestamp
+ """,
+ (run,),
+ )
+
+ rows = cursor.fetchall()
+ results = []
+ for row in rows:
+ metrics = json.loads(row["metrics"])
+ metrics = deserialize_values(metrics)
+ metrics["timestamp"] = row["timestamp"]
+ metrics["step"] = row["step"]
+ results.append(metrics)
+ return results
+
+ @staticmethod
+ def load_from_dataset():
+ dataset_id = os.environ.get("TRACKIO_DATASET_ID")
+ space_repo_name = os.environ.get("SPACE_REPO_NAME")
+ if dataset_id is not None and space_repo_name is not None:
+ hfapi = hf.HfApi()
+ updated = False
+ if not TRACKIO_DIR.exists():
+ TRACKIO_DIR.mkdir(parents=True, exist_ok=True)
+ with SQLiteStorage.get_scheduler().lock:
+ try:
+ files = hfapi.list_repo_files(dataset_id, repo_type="dataset")
+ for file in files:
+ # Download parquet and media assets
+ if not (file.endswith(".parquet") or file.startswith("media/")):
+ continue
+ if (TRACKIO_DIR / file).exists():
+ continue
+ hf.hf_hub_download(
+ dataset_id, file, repo_type="dataset", local_dir=TRACKIO_DIR
+ )
+ updated = True
+ except hf.errors.EntryNotFoundError:
+ pass
+ except hf.errors.RepositoryNotFoundError:
+ pass
+ if updated:
+ SQLiteStorage.import_from_parquet()
+ SQLiteStorage._dataset_import_attempted = True
+
+ @staticmethod
+ def get_projects() -> list[str]:
+ """
+ Get list of all projects by scanning the database files in the trackio directory.
+ """
+ if not SQLiteStorage._dataset_import_attempted:
+ SQLiteStorage.load_from_dataset()
+
+ projects: set[str] = set()
+ if not TRACKIO_DIR.exists():
+ return []
+
+ for db_file in TRACKIO_DIR.glob("*.db"):
+ project_name = db_file.stem
+ projects.add(project_name)
+ return sorted(projects)
+
+ @staticmethod
+ def get_runs(project: str) -> list[str]:
+ """Get list of all runs for a project."""
+ db_path = SQLiteStorage.get_project_db_path(project)
+ if not db_path.exists():
+ return []
+
+ with SQLiteStorage._get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT DISTINCT run_name FROM metrics",
+ )
+ return [row[0] for row in cursor.fetchall()]
+
+ @staticmethod
+ def get_max_steps_for_runs(project: str) -> dict[str, int]:
+ """Get the maximum step for each run in a project."""
+ db_path = SQLiteStorage.get_project_db_path(project)
+ if not db_path.exists():
+ return {}
+
+ with SQLiteStorage._get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT run_name, MAX(step) as max_step
+ FROM metrics
+ GROUP BY run_name
+ """
+ )
+
+ results = {}
+ for row in cursor.fetchall():
+ results[row["run_name"]] = row["max_step"]
+
+ return results
+
+ def finish(self):
+ """Cleanup when run is finished."""
+ pass
diff --git a/table.py b/table.py
new file mode 100644
index 0000000000000000000000000000000000000000..61b65426c64c0ea9cabe1db1cc46803abe21f90c
--- /dev/null
+++ b/table.py
@@ -0,0 +1,55 @@
+from typing import Any, Literal, Optional, Union
+
+from pandas import DataFrame
+
+
+class Table:
+ """
+ Initializes a Table object.
+
+ Args:
+ columns (`list[str]`, *optional*, defaults to `None`):
+ Names of the columns in the table. Optional if `data` is provided. Not
+ expected if `dataframe` is provided. Currently ignored.
+ data (`list[list[Any]]`, *optional*, defaults to `None`):
+ 2D row-oriented array of values.
+ dataframe (`pandas.`DataFrame``, *optional*, defaults to `None`):
+ DataFrame object used to create the table. When set, `data` and `columns`
+ arguments are ignored.
+ rows (`list[list[any]]`, *optional*, defaults to `None`):
+ Currently ignored.
+ optional (`bool` or `list[bool]`, *optional*, defaults to `True`):
+ Currently ignored.
+ allow_mixed_types (`bool`, *optional*, defaults to `False`):
+ Currently ignored.
+ log_mode: (`Literal["IMMUTABLE", "MUTABLE", "INCREMENTAL"]` or `None`, *optional*, defaults to `"IMMUTABLE"`):
+ Currently ignored.
+ """
+
+ TYPE = "trackio.table"
+
+ def __init__(
+ self,
+ columns: Optional[list[str]] = None,
+ data: Optional[list[list[Any]]] = None,
+ dataframe: Optional[DataFrame] = None,
+ rows: Optional[list[list[Any]]] = None,
+ optional: Union[bool, list[bool]] = True,
+ allow_mixed_types: bool = False,
+ log_mode: Optional[
+ Literal["IMMUTABLE", "MUTABLE", "INCREMENTAL"]
+ ] = "IMMUTABLE",
+ ):
+ # TODO: implement support for columns, dtype, optional, allow_mixed_types, and log_mode.
+ # for now (like `rows`) they are included for API compat but don't do anything.
+
+ if dataframe is None:
+ self.data = data
+ else:
+ self.data = dataframe.to_dict(orient="records")
+
+ def _to_dict(self):
+ return {
+ "_type": self.TYPE,
+ "_value": self.data,
+ }
diff --git a/typehints.py b/typehints.py
new file mode 100644
index 0000000000000000000000000000000000000000..e270424d8d7604d467ffaedf456696067139365b
--- /dev/null
+++ b/typehints.py
@@ -0,0 +1,17 @@
+from typing import Any, TypedDict
+
+from gradio import FileData
+
+
+class LogEntry(TypedDict):
+ project: str
+ run: str
+ metrics: dict[str, Any]
+ step: int | None
+
+
+class UploadEntry(TypedDict):
+ project: str
+ run: str
+ step: int | None
+ uploaded_file: FileData
diff --git a/ui.py b/ui.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b4e47a839b1c95f0919cb126226e3324c6c8822
--- /dev/null
+++ b/ui.py
@@ -0,0 +1,857 @@
+import os
+import re
+import shutil
+from dataclasses import dataclass
+from typing import Any
+
+import gradio as gr
+import huggingface_hub as hf
+import numpy as np
+import pandas as pd
+
+HfApi = hf.HfApi()
+
+try:
+ import trackio.utils as utils
+ from trackio.file_storage import FileStorage
+ from trackio.media import TrackioImage, TrackioVideo
+ from trackio.sqlite_storage import SQLiteStorage
+ from trackio.table import Table
+ from trackio.typehints import LogEntry, UploadEntry
+except: # noqa: E722
+ import utils
+ from file_storage import FileStorage
+ from media import TrackioImage, TrackioVideo
+ from sqlite_storage import SQLiteStorage
+ from table import Table
+ from typehints import LogEntry, UploadEntry
+
+
+def get_project_info() -> str | None:
+ dataset_id = os.environ.get("TRACKIO_DATASET_ID")
+ space_id = os.environ.get("SPACE_ID")
+ if utils.persistent_storage_enabled():
+ return "✨ Persistent Storage is enabled, logs are stored directly in this Space."
+ if dataset_id:
+ sync_status = utils.get_sync_status(SQLiteStorage.get_scheduler())
+ upgrade_message = f"New changes are synced every 5 min To avoid losing data between syncs, click here to open this Space's settings and add Persistent Storage. Make sure data is synced prior to enabling."
+ if sync_status is not None:
+ info = f"↻ Backed up {sync_status} min ago to {dataset_id} | {upgrade_message}"
+ else:
+ info = f"↻ Not backed up yet to {dataset_id} | {upgrade_message}"
+ return info
+ return None
+
+
+def get_projects(request: gr.Request):
+ projects = SQLiteStorage.get_projects()
+ if project := request.query_params.get("project"):
+ interactive = False
+ else:
+ interactive = True
+ project = projects[0] if projects else None
+
+ return gr.Dropdown(
+ label="Project",
+ choices=projects,
+ value=project,
+ allow_custom_value=True,
+ interactive=interactive,
+ info=get_project_info(),
+ )
+
+
+def get_runs(project) -> list[str]:
+ if not project:
+ return []
+ return SQLiteStorage.get_runs(project)
+
+
+def get_available_metrics(project: str, runs: list[str]) -> list[str]:
+ """Get all available metrics across all runs for x-axis selection."""
+ if not project or not runs:
+ return ["step", "time"]
+
+ all_metrics = set()
+ for run in runs:
+ metrics = SQLiteStorage.get_logs(project, run)
+ if metrics:
+ df = pd.DataFrame(metrics)
+ numeric_cols = df.select_dtypes(include="number").columns
+ numeric_cols = [c for c in numeric_cols if c not in utils.RESERVED_KEYS]
+ all_metrics.update(numeric_cols)
+
+ all_metrics.add("step")
+ all_metrics.add("time")
+
+ sorted_metrics = utils.sort_metrics_by_prefix(list(all_metrics))
+
+ result = ["step", "time"]
+ for metric in sorted_metrics:
+ if metric not in result:
+ result.append(metric)
+
+ return result
+
+
+@dataclass
+class MediaData:
+ caption: str | None
+ file_path: str
+
+
+def extract_media(logs: list[dict]) -> dict[str, list[MediaData]]:
+ media_by_key: dict[str, list[MediaData]] = {}
+ logs = sorted(logs, key=lambda x: x.get("step", 0))
+ for log in logs:
+ for key, value in log.items():
+ if isinstance(value, dict):
+ type = value.get("_type")
+ if type == TrackioImage.TYPE or type == TrackioVideo.TYPE:
+ if key not in media_by_key:
+ media_by_key[key] = []
+ try:
+ media_data = MediaData(
+ file_path=utils.MEDIA_DIR / value.get("file_path"),
+ caption=value.get("caption"),
+ )
+ media_by_key[key].append(media_data)
+ except Exception as e:
+ print(f"Media currently unavailable: {key}: {e}")
+ return media_by_key
+
+
+def load_run_data(
+ project: str | None,
+ run: str | None,
+ smoothing_granularity: int,
+ x_axis: str,
+ log_scale: bool = False,
+) -> tuple[pd.DataFrame, dict]:
+ if not project or not run:
+ return None, None
+
+ logs = SQLiteStorage.get_logs(project, run)
+ if not logs:
+ return None, None
+
+ media = extract_media(logs)
+ df = pd.DataFrame(logs)
+
+ if "step" not in df.columns:
+ df["step"] = range(len(df))
+
+ if x_axis == "time" and "timestamp" in df.columns:
+ df["timestamp"] = pd.to_datetime(df["timestamp"])
+ first_timestamp = df["timestamp"].min()
+ df["time"] = (df["timestamp"] - first_timestamp).dt.total_seconds()
+ x_column = "time"
+ elif x_axis == "step":
+ x_column = "step"
+ else:
+ x_column = x_axis
+
+ if log_scale and x_column in df.columns:
+ x_vals = df[x_column]
+ if (x_vals <= 0).any():
+ df[x_column] = np.log10(np.maximum(x_vals, 0) + 1)
+ else:
+ df[x_column] = np.log10(x_vals)
+
+ if smoothing_granularity > 0:
+ numeric_cols = df.select_dtypes(include="number").columns
+ numeric_cols = [c for c in numeric_cols if c not in utils.RESERVED_KEYS]
+
+ df_original = df.copy()
+ df_original["run"] = run
+ df_original["data_type"] = "original"
+
+ df_smoothed = df.copy()
+ window_size = max(3, min(smoothing_granularity, len(df)))
+ df_smoothed[numeric_cols] = (
+ df_smoothed[numeric_cols]
+ .rolling(window=window_size, center=True, min_periods=1)
+ .mean()
+ )
+ df_smoothed["run"] = f"{run}_smoothed"
+ df_smoothed["data_type"] = "smoothed"
+
+ combined_df = pd.concat([df_original, df_smoothed], ignore_index=True)
+ combined_df["x_axis"] = x_column
+ return combined_df, media
+ else:
+ df["run"] = run
+ df["data_type"] = "original"
+ df["x_axis"] = x_column
+ return df, media
+
+
+def update_runs(
+ project, filter_text, user_interacted_with_runs=False, selected_runs_from_url=None
+):
+ if project is None:
+ runs = []
+ num_runs = 0
+ else:
+ runs = get_runs(project)
+ num_runs = len(runs)
+ if filter_text:
+ runs = [r for r in runs if filter_text in r]
+
+ if not user_interacted_with_runs:
+ if selected_runs_from_url:
+ value = [r for r in runs if r in selected_runs_from_url]
+ else:
+ value = runs
+ return gr.CheckboxGroup(choices=runs, value=value), gr.Textbox(
+ label=f"Runs ({num_runs})"
+ )
+ else:
+ return gr.CheckboxGroup(choices=runs), gr.Textbox(label=f"Runs ({num_runs})")
+
+
+def filter_runs(project, filter_text):
+ runs = get_runs(project)
+ runs = [r for r in runs if filter_text in r]
+ return gr.CheckboxGroup(choices=runs, value=runs)
+
+
+def update_x_axis_choices(project, runs):
+ """Update x-axis dropdown choices based on available metrics."""
+ available_metrics = get_available_metrics(project, runs)
+ return gr.Dropdown(
+ label="X-axis",
+ choices=available_metrics,
+ value="step",
+ )
+
+
+def toggle_timer(cb_value):
+ if cb_value:
+ return gr.Timer(active=True)
+ else:
+ return gr.Timer(active=False)
+
+
+def check_auth(hf_token: str | None) -> None:
+ if os.getenv("SYSTEM") == "spaces": # if we are running in Spaces
+ # check auth token passed in
+ if hf_token is None:
+ raise PermissionError(
+ "Expected a HF_TOKEN to be provided when logging to a Space"
+ )
+ who = HfApi.whoami(hf_token)
+ access_token = who["auth"]["accessToken"]
+ owner_name = os.getenv("SPACE_AUTHOR_NAME")
+ repo_name = os.getenv("SPACE_REPO_NAME")
+ # make sure the token user is either the author of the space,
+ # or is a member of an org that is the author.
+ orgs = [o["name"] for o in who["orgs"]]
+ if owner_name != who["name"] and owner_name not in orgs:
+ raise PermissionError(
+ "Expected the provided hf_token to be the user owner of the space, or be a member of the org owner of the space"
+ )
+ # reject fine-grained tokens without specific repo access
+ if access_token["role"] == "fineGrained":
+ matched = False
+ for item in access_token["fineGrained"]["scoped"]:
+ if (
+ item["entity"]["type"] == "space"
+ and item["entity"]["name"] == f"{owner_name}/{repo_name}"
+ and "repo.write" in item["permissions"]
+ ):
+ matched = True
+ break
+ if (
+ (
+ item["entity"]["type"] == "user"
+ or item["entity"]["type"] == "org"
+ )
+ and item["entity"]["name"] == owner_name
+ and "repo.write" in item["permissions"]
+ ):
+ matched = True
+ break
+ if not matched:
+ raise PermissionError(
+ "Expected the provided hf_token with fine grained permissions to provide write access to the space"
+ )
+ # reject read-only tokens
+ elif access_token["role"] != "write":
+ raise PermissionError(
+ "Expected the provided hf_token to provide write permissions"
+ )
+
+
+def upload_db_to_space(
+ project: str, uploaded_db: gr.FileData, hf_token: str | None
+) -> None:
+ check_auth(hf_token)
+ db_project_path = SQLiteStorage.get_project_db_path(project)
+ if os.path.exists(db_project_path):
+ raise gr.Error(
+ f"Trackio database file already exists for project {project}, cannot overwrite."
+ )
+ os.makedirs(os.path.dirname(db_project_path), exist_ok=True)
+ shutil.copy(uploaded_db["path"], db_project_path)
+
+
+def bulk_upload_media(uploads: list[UploadEntry], hf_token: str | None) -> None:
+ check_auth(hf_token)
+ for upload in uploads:
+ media_path = FileStorage.init_project_media_path(
+ upload["project"], upload["run"], upload["step"]
+ )
+ shutil.copy(upload["uploaded_file"]["path"], media_path)
+
+
+def log(
+ project: str,
+ run: str,
+ metrics: dict[str, Any],
+ step: int | None,
+ hf_token: str | None,
+) -> None:
+ """
+ Note: this method is not used in the latest versions of Trackio (replaced by bulk_log) but
+ is kept for backwards compatibility for users who are connecting to a newer version of
+ a Trackio Spaces dashboard with an older version of Trackio installed locally.
+ """
+ check_auth(hf_token)
+ SQLiteStorage.log(project=project, run=run, metrics=metrics, step=step)
+
+
+def bulk_log(
+ logs: list[LogEntry],
+ hf_token: str | None,
+) -> None:
+ check_auth(hf_token)
+
+ logs_by_run = {}
+ for log_entry in logs:
+ key = (log_entry["project"], log_entry["run"])
+ if key not in logs_by_run:
+ logs_by_run[key] = {"metrics": [], "steps": []}
+ logs_by_run[key]["metrics"].append(log_entry["metrics"])
+ logs_by_run[key]["steps"].append(log_entry.get("step"))
+
+ for (project, run), data in logs_by_run.items():
+ SQLiteStorage.bulk_log(
+ project=project,
+ run=run,
+ metrics_list=data["metrics"],
+ steps=data["steps"],
+ )
+
+
+def filter_metrics_by_regex(metrics: list[str], filter_pattern: str) -> list[str]:
+ """
+ Filter metrics using regex pattern.
+
+ Args:
+ metrics: List of metric names to filter
+ filter_pattern: Regex pattern to match against metric names
+
+ Returns:
+ List of metric names that match the pattern
+ """
+ if not filter_pattern.strip():
+ return metrics
+
+ try:
+ pattern = re.compile(filter_pattern, re.IGNORECASE)
+ return [metric for metric in metrics if pattern.search(metric)]
+ except re.error:
+ return [
+ metric for metric in metrics if filter_pattern.lower() in metric.lower()
+ ]
+
+
+def configure(request: gr.Request):
+ sidebar_param = request.query_params.get("sidebar")
+ match sidebar_param:
+ case "collapsed":
+ sidebar = gr.Sidebar(open=False, visible=True)
+ case "hidden":
+ sidebar = gr.Sidebar(open=False, visible=False)
+ case _:
+ sidebar = gr.Sidebar(open=True, visible=True)
+
+ metrics_param = request.query_params.get("metrics", "")
+ runs_param = request.query_params.get("runs", "")
+ selected_runs = runs_param.split(",") if runs_param else []
+
+ return [], sidebar, metrics_param, selected_runs
+
+
+def create_media_section(media_by_run: dict[str, dict[str, list[MediaData]]]):
+ with gr.Accordion(label="media"):
+ with gr.Group(elem_classes=("media-group")):
+ for run, media_by_key in media_by_run.items():
+ with gr.Tab(label=run, elem_classes=("media-tab")):
+ for key, media_item in media_by_key.items():
+ gr.Gallery(
+ [(item.file_path, item.caption) for item in media_item],
+ label=key,
+ columns=6,
+ elem_classes=("media-gallery"),
+ )
+
+
+css = """
+#run-cb .wrap { gap: 2px; }
+#run-cb .wrap label {
+ line-height: 1;
+ padding: 6px;
+}
+.logo-light { display: block; }
+.logo-dark { display: none; }
+.dark .logo-light { display: none; }
+.dark .logo-dark { display: block; }
+.dark .caption-label { color: white; }
+
+.info-container {
+ position: relative;
+ display: inline;
+}
+.info-checkbox {
+ position: absolute;
+ opacity: 0;
+ pointer-events: none;
+}
+.info-icon {
+ border-bottom: 1px dotted;
+ cursor: pointer;
+ user-select: none;
+ color: var(--color-accent);
+}
+.info-expandable {
+ display: none;
+ opacity: 0;
+ transition: opacity 0.2s ease-in-out;
+}
+.info-checkbox:checked ~ .info-expandable {
+ display: inline;
+ opacity: 1;
+}
+.info-icon:hover { opacity: 0.8; }
+.accent-link { font-weight: bold; }
+
+.media-gallery .fixed-height { min-height: 275px; }
+.media-group, .media-group > div { background: none; }
+.media-group .tabs { padding: 0.5em; }
+.media-tab { max-height: 500px; overflow-y: scroll; }
+"""
+
+gr.set_static_paths(paths=[utils.MEDIA_DIR])
+with gr.Blocks(theme="citrus", title="Trackio Dashboard", css=css) as demo:
+ with gr.Sidebar(open=False) as sidebar:
+ logo = gr.Markdown(
+ f"""
+
+
+ """
+ )
+ project_dd = gr.Dropdown(label="Project", allow_custom_value=True)
+
+ embed_code = gr.Code(
+ label="Embed this view",
+ max_lines=2,
+ lines=2,
+ language="html",
+ visible=bool(os.environ.get("SPACE_HOST")),
+ )
+ run_tb = gr.Textbox(label="Runs", placeholder="Type to filter...")
+ run_cb = gr.CheckboxGroup(
+ label="Runs", choices=[], interactive=True, elem_id="run-cb"
+ )
+ gr.HTML("
")
+ realtime_cb = gr.Checkbox(label="Refresh metrics realtime", value=True)
+ smoothing_slider = gr.Slider(
+ label="Smoothing Factor",
+ minimum=0,
+ maximum=20,
+ value=10,
+ step=1,
+ info="0 = no smoothing",
+ )
+ x_axis_dd = gr.Dropdown(
+ label="X-axis",
+ choices=["step", "time"],
+ value="step",
+ )
+ log_scale_cb = gr.Checkbox(label="Log scale X-axis", value=False)
+ metric_filter_tb = gr.Textbox(
+ label="Metric Filter (regex)",
+ placeholder="e.g., loss|ndcg@10|gpu",
+ value="",
+ info="Filter metrics using regex patterns. Leave empty to show all metrics.",
+ )
+
+ timer = gr.Timer(value=1)
+ metrics_subset = gr.State([])
+ user_interacted_with_run_cb = gr.State(False)
+ selected_runs_from_url = gr.State([])
+
+ gr.on(
+ [demo.load],
+ fn=configure,
+ outputs=[metrics_subset, sidebar, metric_filter_tb, selected_runs_from_url],
+ queue=False,
+ api_name=False,
+ )
+ gr.on(
+ [demo.load],
+ fn=get_projects,
+ outputs=project_dd,
+ show_progress="hidden",
+ queue=False,
+ api_name=False,
+ )
+ gr.on(
+ [timer.tick],
+ fn=update_runs,
+ inputs=[
+ project_dd,
+ run_tb,
+ user_interacted_with_run_cb,
+ selected_runs_from_url,
+ ],
+ outputs=[run_cb, run_tb],
+ show_progress="hidden",
+ api_name=False,
+ )
+ gr.on(
+ [timer.tick],
+ fn=lambda: gr.Dropdown(info=get_project_info()),
+ outputs=[project_dd],
+ show_progress="hidden",
+ api_name=False,
+ )
+ gr.on(
+ [demo.load, project_dd.change],
+ fn=update_runs,
+ inputs=[project_dd, run_tb, gr.State(False), selected_runs_from_url],
+ outputs=[run_cb, run_tb],
+ show_progress="hidden",
+ queue=False,
+ api_name=False,
+ ).then(
+ fn=update_x_axis_choices,
+ inputs=[project_dd, run_cb],
+ outputs=x_axis_dd,
+ show_progress="hidden",
+ queue=False,
+ api_name=False,
+ ).then(
+ fn=utils.generate_embed_code,
+ inputs=[project_dd, metric_filter_tb, run_cb],
+ outputs=embed_code,
+ show_progress="hidden",
+ api_name=False,
+ queue=False,
+ )
+
+ gr.on(
+ [run_cb.input],
+ fn=update_x_axis_choices,
+ inputs=[project_dd, run_cb],
+ outputs=x_axis_dd,
+ show_progress="hidden",
+ queue=False,
+ api_name=False,
+ )
+ gr.on(
+ [metric_filter_tb.change, run_cb.change],
+ fn=utils.generate_embed_code,
+ inputs=[project_dd, metric_filter_tb, run_cb],
+ outputs=embed_code,
+ show_progress="hidden",
+ api_name=False,
+ queue=False,
+ )
+
+ realtime_cb.change(
+ fn=toggle_timer,
+ inputs=realtime_cb,
+ outputs=timer,
+ api_name=False,
+ queue=False,
+ )
+ run_cb.input(
+ fn=lambda: True,
+ outputs=user_interacted_with_run_cb,
+ api_name=False,
+ queue=False,
+ )
+ run_tb.input(
+ fn=filter_runs,
+ inputs=[project_dd, run_tb],
+ outputs=run_cb,
+ api_name=False,
+ queue=False,
+ )
+
+ gr.api(
+ fn=upload_db_to_space,
+ api_name="upload_db_to_space",
+ )
+ gr.api(
+ fn=bulk_upload_media,
+ api_name="bulk_upload_media",
+ )
+ gr.api(
+ fn=log,
+ api_name="log",
+ )
+ gr.api(
+ fn=bulk_log,
+ api_name="bulk_log",
+ )
+
+ x_lim = gr.State(None)
+ last_steps = gr.State({})
+
+ def update_x_lim(select_data: gr.SelectData):
+ return select_data.index
+
+ def update_last_steps(project):
+ """Check the last step for each run to detect when new data is available."""
+ if not project:
+ return {}
+ return SQLiteStorage.get_max_steps_for_runs(project)
+
+ timer.tick(
+ fn=update_last_steps,
+ inputs=[project_dd],
+ outputs=last_steps,
+ show_progress="hidden",
+ api_name=False,
+ )
+
+ @gr.render(
+ triggers=[
+ demo.load,
+ run_cb.change,
+ last_steps.change,
+ smoothing_slider.change,
+ x_lim.change,
+ x_axis_dd.change,
+ log_scale_cb.change,
+ metric_filter_tb.change,
+ ],
+ inputs=[
+ project_dd,
+ run_cb,
+ smoothing_slider,
+ metrics_subset,
+ x_lim,
+ x_axis_dd,
+ log_scale_cb,
+ metric_filter_tb,
+ ],
+ show_progress="hidden",
+ queue=False,
+ )
+ def update_dashboard(
+ project,
+ runs,
+ smoothing_granularity,
+ metrics_subset,
+ x_lim_value,
+ x_axis,
+ log_scale,
+ metric_filter,
+ ):
+ dfs = []
+ images_by_run = {}
+ original_runs = runs.copy()
+
+ for run in runs:
+ df, images_by_key = load_run_data(
+ project, run, smoothing_granularity, x_axis, log_scale
+ )
+ if df is not None:
+ dfs.append(df)
+ images_by_run[run] = images_by_key
+
+ if dfs:
+ if smoothing_granularity > 0:
+ original_dfs = []
+ smoothed_dfs = []
+ for df in dfs:
+ original_data = df[df["data_type"] == "original"]
+ smoothed_data = df[df["data_type"] == "smoothed"]
+ if not original_data.empty:
+ original_dfs.append(original_data)
+ if not smoothed_data.empty:
+ smoothed_dfs.append(smoothed_data)
+
+ all_dfs = original_dfs + smoothed_dfs
+ master_df = (
+ pd.concat(all_dfs, ignore_index=True) if all_dfs else pd.DataFrame()
+ )
+
+ else:
+ master_df = pd.concat(dfs, ignore_index=True)
+ else:
+ master_df = pd.DataFrame()
+
+ if master_df.empty:
+ return
+
+ x_column = "step"
+ if dfs and not dfs[0].empty and "x_axis" in dfs[0].columns:
+ x_column = dfs[0]["x_axis"].iloc[0]
+
+ numeric_cols = master_df.select_dtypes(include="number").columns
+ numeric_cols = [c for c in numeric_cols if c not in utils.RESERVED_KEYS]
+ if x_column and x_column in numeric_cols:
+ numeric_cols.remove(x_column)
+
+ if metrics_subset:
+ numeric_cols = [c for c in numeric_cols if c in metrics_subset]
+
+ if metric_filter and metric_filter.strip():
+ numeric_cols = filter_metrics_by_regex(list(numeric_cols), metric_filter)
+
+ nested_metric_groups = utils.group_metrics_with_subprefixes(list(numeric_cols))
+ color_map = utils.get_color_mapping(original_runs, smoothing_granularity > 0)
+
+ metric_idx = 0
+ for group_name in sorted(nested_metric_groups.keys()):
+ group_data = nested_metric_groups[group_name]
+
+ with gr.Accordion(
+ label=group_name,
+ open=True,
+ key=f"accordion-{group_name}",
+ preserved_by_key=["value", "open"],
+ ):
+ # Render direct metrics at this level
+ if group_data["direct_metrics"]:
+ with gr.Draggable(
+ key=f"row-{group_name}-direct", orientation="row"
+ ):
+ for metric_name in group_data["direct_metrics"]:
+ metric_df = master_df.dropna(subset=[metric_name])
+ color = "run" if "run" in metric_df.columns else None
+ if not metric_df.empty:
+ plot = gr.LinePlot(
+ utils.downsample(
+ metric_df,
+ x_column,
+ metric_name,
+ color,
+ x_lim_value,
+ ),
+ x=x_column,
+ y=metric_name,
+ y_title=metric_name.split("/")[-1],
+ color=color,
+ color_map=color_map,
+ title=metric_name,
+ key=f"plot-{metric_idx}",
+ preserved_by_key=None,
+ x_lim=x_lim_value,
+ show_fullscreen_button=True,
+ min_width=400,
+ )
+ plot.select(
+ update_x_lim,
+ outputs=x_lim,
+ key=f"select-{metric_idx}",
+ )
+ plot.double_click(
+ lambda: None,
+ outputs=x_lim,
+ key=f"double-{metric_idx}",
+ )
+ metric_idx += 1
+
+ # If there are subgroups, create nested accordions
+ if group_data["subgroups"]:
+ for subgroup_name in sorted(group_data["subgroups"].keys()):
+ subgroup_metrics = group_data["subgroups"][subgroup_name]
+
+ with gr.Accordion(
+ label=subgroup_name,
+ open=True,
+ key=f"accordion-{group_name}-{subgroup_name}",
+ preserved_by_key=["value", "open"],
+ ):
+ with gr.Draggable(key=f"row-{group_name}-{subgroup_name}"):
+ for metric_name in subgroup_metrics:
+ metric_df = master_df.dropna(subset=[metric_name])
+ color = (
+ "run" if "run" in metric_df.columns else None
+ )
+ if not metric_df.empty:
+ plot = gr.LinePlot(
+ utils.downsample(
+ metric_df,
+ x_column,
+ metric_name,
+ color,
+ x_lim_value,
+ ),
+ x=x_column,
+ y=metric_name,
+ y_title=metric_name.split("/")[-1],
+ color=color,
+ color_map=color_map,
+ title=metric_name,
+ key=f"plot-{metric_idx}",
+ preserved_by_key=None,
+ x_lim=x_lim_value,
+ show_fullscreen_button=True,
+ min_width=400,
+ )
+ plot.select(
+ update_x_lim,
+ outputs=x_lim,
+ key=f"select-{metric_idx}",
+ )
+ plot.double_click(
+ lambda: None,
+ outputs=x_lim,
+ key=f"double-{metric_idx}",
+ )
+ metric_idx += 1
+ if images_by_run and any(any(images) for images in images_by_run.values()):
+ create_media_section(images_by_run)
+
+ table_cols = master_df.select_dtypes(include="object").columns
+ table_cols = [c for c in table_cols if c not in utils.RESERVED_KEYS]
+ if metrics_subset:
+ table_cols = [c for c in table_cols if c in metrics_subset]
+ if metric_filter and metric_filter.strip():
+ table_cols = filter_metrics_by_regex(list(table_cols), metric_filter)
+ if len(table_cols) > 0:
+ with gr.Accordion("tables", open=True):
+ with gr.Row(key="row"):
+ for metric_idx, metric_name in enumerate(table_cols):
+ metric_df = master_df.dropna(subset=[metric_name])
+ if not metric_df.empty:
+ value = metric_df[metric_name].iloc[-1]
+ if (
+ isinstance(value, dict)
+ and "_type" in value
+ and value["_type"] == Table.TYPE
+ ):
+ try:
+ df = pd.DataFrame(value["_value"])
+ gr.DataFrame(
+ df,
+ label=f"{metric_name} (latest)",
+ key=f"table-{metric_idx}",
+ wrap=True,
+ )
+ except Exception as e:
+ gr.Warning(
+ f"Column {metric_name} failed to render as a table: {e}"
+ )
+
+
+if __name__ == "__main__":
+ demo.launch(allowed_paths=[utils.TRACKIO_LOGO_DIR], show_api=False, show_error=True)
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6675782ff7189e4471ced28b295ee53ed174f09
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,687 @@
+import math
+import os
+import re
+import sys
+import time
+from pathlib import Path
+from typing import TYPE_CHECKING
+
+import huggingface_hub
+import numpy as np
+import pandas as pd
+from huggingface_hub.constants import HF_HOME
+
+if TYPE_CHECKING:
+ from trackio.commit_scheduler import CommitScheduler
+ from trackio.dummy_commit_scheduler import DummyCommitScheduler
+
+RESERVED_KEYS = ["project", "run", "timestamp", "step", "time", "metrics"]
+
+TRACKIO_LOGO_DIR = Path(__file__).parent / "assets"
+
+
+def persistent_storage_enabled() -> bool:
+ return (
+ os.environ.get("PERSISTANT_STORAGE_ENABLED") == "true"
+ ) # typo in the name of the environment variable
+
+
+def _get_trackio_dir() -> Path:
+ if persistent_storage_enabled():
+ return Path("/data/trackio")
+ return Path(HF_HOME) / "trackio"
+
+
+TRACKIO_DIR = _get_trackio_dir()
+MEDIA_DIR = TRACKIO_DIR / "media"
+
+
+def generate_readable_name(used_names: list[str], space_id: str | None = None) -> str:
+ """
+ Generates a random, readable name like "dainty-sunset-0".
+ If space_id is provided, generates username-timestamp format instead.
+ """
+ if space_id is not None:
+ username = huggingface_hub.whoami()["name"]
+ timestamp = int(time.time())
+ return f"{username}-{timestamp}"
+ adjectives = [
+ "dainty",
+ "brave",
+ "calm",
+ "eager",
+ "fancy",
+ "gentle",
+ "happy",
+ "jolly",
+ "kind",
+ "lively",
+ "merry",
+ "nice",
+ "proud",
+ "quick",
+ "hugging",
+ "silly",
+ "tidy",
+ "witty",
+ "zealous",
+ "bright",
+ "shy",
+ "bold",
+ "clever",
+ "daring",
+ "elegant",
+ "faithful",
+ "graceful",
+ "honest",
+ "inventive",
+ "jovial",
+ "keen",
+ "lucky",
+ "modest",
+ "noble",
+ "optimistic",
+ "patient",
+ "quirky",
+ "resourceful",
+ "sincere",
+ "thoughtful",
+ "upbeat",
+ "valiant",
+ "warm",
+ "youthful",
+ "zesty",
+ "adventurous",
+ "breezy",
+ "cheerful",
+ "delightful",
+ "energetic",
+ "fearless",
+ "glad",
+ "hopeful",
+ "imaginative",
+ "joyful",
+ "kindly",
+ "luminous",
+ "mysterious",
+ "neat",
+ "outgoing",
+ "playful",
+ "radiant",
+ "spirited",
+ "tranquil",
+ "unique",
+ "vivid",
+ "wise",
+ "zany",
+ "artful",
+ "bubbly",
+ "charming",
+ "dazzling",
+ "earnest",
+ "festive",
+ "gentlemanly",
+ "hearty",
+ "intrepid",
+ "jubilant",
+ "knightly",
+ "lively",
+ "magnetic",
+ "nimble",
+ "orderly",
+ "peaceful",
+ "quick-witted",
+ "robust",
+ "sturdy",
+ "trusty",
+ "upstanding",
+ "vibrant",
+ "whimsical",
+ ]
+ nouns = [
+ "sunset",
+ "forest",
+ "river",
+ "mountain",
+ "breeze",
+ "meadow",
+ "ocean",
+ "valley",
+ "sky",
+ "field",
+ "cloud",
+ "star",
+ "rain",
+ "leaf",
+ "stone",
+ "flower",
+ "bird",
+ "tree",
+ "wave",
+ "trail",
+ "island",
+ "desert",
+ "hill",
+ "lake",
+ "pond",
+ "grove",
+ "canyon",
+ "reef",
+ "bay",
+ "peak",
+ "glade",
+ "marsh",
+ "cliff",
+ "dune",
+ "spring",
+ "brook",
+ "cave",
+ "plain",
+ "ridge",
+ "wood",
+ "blossom",
+ "petal",
+ "root",
+ "branch",
+ "seed",
+ "acorn",
+ "pine",
+ "willow",
+ "cedar",
+ "elm",
+ "falcon",
+ "eagle",
+ "sparrow",
+ "robin",
+ "owl",
+ "finch",
+ "heron",
+ "crane",
+ "duck",
+ "swan",
+ "fox",
+ "wolf",
+ "bear",
+ "deer",
+ "moose",
+ "otter",
+ "beaver",
+ "lynx",
+ "hare",
+ "badger",
+ "butterfly",
+ "bee",
+ "ant",
+ "beetle",
+ "dragonfly",
+ "firefly",
+ "ladybug",
+ "moth",
+ "spider",
+ "worm",
+ "coral",
+ "kelp",
+ "shell",
+ "pebble",
+ "face",
+ "boulder",
+ "cobble",
+ "sand",
+ "wavelet",
+ "tide",
+ "current",
+ "mist",
+ ]
+ number = 0
+ name = f"{adjectives[0]}-{nouns[0]}-{number}"
+ while name in used_names:
+ number += 1
+ adjective = adjectives[number % len(adjectives)]
+ noun = nouns[number % len(nouns)]
+ name = f"{adjective}-{noun}-{number}"
+ return name
+
+
+def block_except_in_notebook():
+ in_notebook = bool(getattr(sys, "ps1", sys.flags.interactive))
+ if in_notebook:
+ return
+ try:
+ while True:
+ time.sleep(0.1)
+ except (KeyboardInterrupt, OSError):
+ print("Keyboard interruption in main thread... closing dashboard.")
+
+
+def simplify_column_names(columns: list[str]) -> dict[str, str]:
+ """
+ Simplifies column names to first 10 alphanumeric or "/" characters with unique suffixes.
+
+ Args:
+ columns: List of original column names
+
+ Returns:
+ Dictionary mapping original column names to simplified names
+ """
+ simplified_names = {}
+ used_names = set()
+
+ for col in columns:
+ alphanumeric = re.sub(r"[^a-zA-Z0-9/]", "", col)
+ base_name = alphanumeric[:10] if alphanumeric else f"col_{len(used_names)}"
+
+ final_name = base_name
+ suffix = 1
+ while final_name in used_names:
+ final_name = f"{base_name}_{suffix}"
+ suffix += 1
+
+ simplified_names[col] = final_name
+ used_names.add(final_name)
+
+ return simplified_names
+
+
+def print_dashboard_instructions(project: str) -> None:
+ """
+ Prints instructions for viewing the Trackio dashboard.
+
+ Args:
+ project: The name of the project to show dashboard for.
+ """
+ YELLOW = "\033[93m"
+ BOLD = "\033[1m"
+ RESET = "\033[0m"
+
+ print("* View dashboard by running in your terminal:")
+ print(f'{BOLD}{YELLOW}trackio show --project "{project}"{RESET}')
+ print(f'* or by running in Python: trackio.show(project="{project}")')
+
+
+def preprocess_space_and_dataset_ids(
+ space_id: str | None, dataset_id: str | None
+) -> tuple[str | None, str | None]:
+ if space_id is not None and "/" not in space_id:
+ username = huggingface_hub.whoami()["name"]
+ space_id = f"{username}/{space_id}"
+ if dataset_id is not None and "/" not in dataset_id:
+ username = huggingface_hub.whoami()["name"]
+ dataset_id = f"{username}/{dataset_id}"
+ if space_id is not None and dataset_id is None:
+ dataset_id = f"{space_id}-dataset"
+ return space_id, dataset_id
+
+
+def fibo():
+ """Generator for Fibonacci backoff: 1, 1, 2, 3, 5, 8, ..."""
+ a, b = 1, 1
+ while True:
+ yield a
+ a, b = b, a + b
+
+
+COLOR_PALETTE = [
+ "#3B82F6",
+ "#EF4444",
+ "#10B981",
+ "#F59E0B",
+ "#8B5CF6",
+ "#EC4899",
+ "#06B6D4",
+ "#84CC16",
+ "#F97316",
+ "#6366F1",
+]
+
+
+def get_color_mapping(runs: list[str], smoothing: bool) -> dict[str, str]:
+ """Generate color mapping for runs, with transparency for original data when smoothing is enabled."""
+ color_map = {}
+
+ for i, run in enumerate(runs):
+ base_color = COLOR_PALETTE[i % len(COLOR_PALETTE)]
+
+ if smoothing:
+ color_map[run] = base_color + "4D"
+ color_map[f"{run}_smoothed"] = base_color
+ else:
+ color_map[run] = base_color
+
+ return color_map
+
+
+def downsample(
+ df: pd.DataFrame,
+ x: str,
+ y: str,
+ color: str | None,
+ x_lim: tuple[float, float] | None = None,
+) -> pd.DataFrame:
+ if df.empty:
+ return df
+
+ columns_to_keep = [x, y]
+ if color is not None and color in df.columns:
+ columns_to_keep.append(color)
+ df = df[columns_to_keep].copy()
+
+ n_bins = 100
+
+ if color is not None and color in df.columns:
+ groups = df.groupby(color)
+ else:
+ groups = [(None, df)]
+
+ downsampled_indices = []
+
+ for _, group_df in groups:
+ if group_df.empty:
+ continue
+
+ group_df = group_df.sort_values(x)
+
+ if x_lim is not None:
+ x_min, x_max = x_lim
+ before_point = group_df[group_df[x] < x_min].tail(1)
+ after_point = group_df[group_df[x] > x_max].head(1)
+ group_df = group_df[(group_df[x] >= x_min) & (group_df[x] <= x_max)]
+ else:
+ before_point = after_point = None
+ x_min = group_df[x].min()
+ x_max = group_df[x].max()
+
+ if before_point is not None and not before_point.empty:
+ downsampled_indices.extend(before_point.index.tolist())
+ if after_point is not None and not after_point.empty:
+ downsampled_indices.extend(after_point.index.tolist())
+
+ if group_df.empty:
+ continue
+
+ if x_min == x_max:
+ min_y_idx = group_df[y].idxmin()
+ max_y_idx = group_df[y].idxmax()
+ if min_y_idx != max_y_idx:
+ downsampled_indices.extend([min_y_idx, max_y_idx])
+ else:
+ downsampled_indices.append(min_y_idx)
+ continue
+
+ if len(group_df) < 500:
+ downsampled_indices.extend(group_df.index.tolist())
+ continue
+
+ bins = np.linspace(x_min, x_max, n_bins + 1)
+ group_df["bin"] = pd.cut(
+ group_df[x], bins=bins, labels=False, include_lowest=True
+ )
+
+ for bin_idx in group_df["bin"].dropna().unique():
+ bin_data = group_df[group_df["bin"] == bin_idx]
+ if bin_data.empty:
+ continue
+
+ min_y_idx = bin_data[y].idxmin()
+ max_y_idx = bin_data[y].idxmax()
+
+ downsampled_indices.append(min_y_idx)
+ if min_y_idx != max_y_idx:
+ downsampled_indices.append(max_y_idx)
+
+ unique_indices = list(set(downsampled_indices))
+
+ downsampled_df = df.loc[unique_indices].copy()
+
+ if color is not None:
+ downsampled_df = (
+ downsampled_df.groupby(color, sort=False)[downsampled_df.columns]
+ .apply(lambda group: group.sort_values(x))
+ .reset_index(drop=True)
+ )
+ else:
+ downsampled_df = downsampled_df.sort_values(x).reset_index(drop=True)
+
+ downsampled_df = downsampled_df.drop(columns=["bin"], errors="ignore")
+
+ return downsampled_df
+
+
+def sort_metrics_by_prefix(metrics: list[str]) -> list[str]:
+ """
+ Sort metrics by grouping prefixes together for dropdown/list display.
+ Metrics without prefixes come first, then grouped by prefix.
+
+ Args:
+ metrics: List of metric names
+
+ Returns:
+ List of metric names sorted by prefix
+
+ Example:
+ Input: ["train/loss", "loss", "train/acc", "val/loss"]
+ Output: ["loss", "train/acc", "train/loss", "val/loss"]
+ """
+ groups = group_metrics_by_prefix(metrics)
+ result = []
+
+ if "charts" in groups:
+ result.extend(groups["charts"])
+
+ for group_name in sorted(groups.keys()):
+ if group_name != "charts":
+ result.extend(groups[group_name])
+
+ return result
+
+
+def group_metrics_by_prefix(metrics: list[str]) -> dict[str, list[str]]:
+ """
+ Group metrics by their prefix. Metrics without prefix go to 'charts' group.
+
+ Args:
+ metrics: List of metric names
+
+ Returns:
+ Dictionary with prefix names as keys and lists of metrics as values
+
+ Example:
+ Input: ["loss", "accuracy", "train/loss", "train/acc", "val/loss"]
+ Output: {
+ "charts": ["loss", "accuracy"],
+ "train": ["train/loss", "train/acc"],
+ "val": ["val/loss"]
+ }
+ """
+ no_prefix = []
+ with_prefix = []
+
+ for metric in metrics:
+ if "/" in metric:
+ with_prefix.append(metric)
+ else:
+ no_prefix.append(metric)
+
+ no_prefix.sort()
+
+ prefix_groups = {}
+ for metric in with_prefix:
+ prefix = metric.split("/")[0]
+ if prefix not in prefix_groups:
+ prefix_groups[prefix] = []
+ prefix_groups[prefix].append(metric)
+
+ for prefix in prefix_groups:
+ prefix_groups[prefix].sort()
+
+ groups = {}
+ if no_prefix:
+ groups["charts"] = no_prefix
+
+ for prefix in sorted(prefix_groups.keys()):
+ groups[prefix] = prefix_groups[prefix]
+
+ return groups
+
+
+def group_metrics_with_subprefixes(metrics: list[str]) -> dict:
+ """
+ Group metrics with simple 2-level nested structure detection.
+
+ Returns a dictionary where each prefix group can have:
+ - direct_metrics: list of metrics at this level (e.g., "train/acc")
+ - subgroups: dict of subgroup name -> list of metrics (e.g., "loss" -> ["train/loss/norm", "train/loss/unnorm"])
+
+ Example:
+ Input: ["loss", "train/acc", "train/loss/normalized", "train/loss/unnormalized", "val/loss"]
+ Output: {
+ "charts": {
+ "direct_metrics": ["loss"],
+ "subgroups": {}
+ },
+ "train": {
+ "direct_metrics": ["train/acc"],
+ "subgroups": {
+ "loss": ["train/loss/normalized", "train/loss/unnormalized"]
+ }
+ },
+ "val": {
+ "direct_metrics": ["val/loss"],
+ "subgroups": {}
+ }
+ }
+ """
+ result = {}
+
+ for metric in metrics:
+ if "/" not in metric:
+ if "charts" not in result:
+ result["charts"] = {"direct_metrics": [], "subgroups": {}}
+ result["charts"]["direct_metrics"].append(metric)
+ else:
+ parts = metric.split("/")
+ main_prefix = parts[0]
+
+ if main_prefix not in result:
+ result[main_prefix] = {"direct_metrics": [], "subgroups": {}}
+
+ if len(parts) == 2:
+ result[main_prefix]["direct_metrics"].append(metric)
+ else:
+ subprefix = parts[1]
+ if subprefix not in result[main_prefix]["subgroups"]:
+ result[main_prefix]["subgroups"][subprefix] = []
+ result[main_prefix]["subgroups"][subprefix].append(metric)
+
+ for group_data in result.values():
+ group_data["direct_metrics"].sort()
+ for subgroup_metrics in group_data["subgroups"].values():
+ subgroup_metrics.sort()
+
+ if "charts" in result and not result["charts"]["direct_metrics"]:
+ del result["charts"]
+
+ return result
+
+
+def get_sync_status(scheduler: "CommitScheduler | DummyCommitScheduler") -> int | None:
+ """Get the sync status from the CommitScheduler in an integer number of minutes, or None if not synced yet."""
+ if getattr(
+ scheduler, "last_push_time", None
+ ): # DummyCommitScheduler doesn't have last_push_time
+ time_diff = time.time() - scheduler.last_push_time
+ return int(time_diff / 60)
+ else:
+ return None
+
+
+def generate_embed_code(project: str, metrics: str, selected_runs: list = None) -> str:
+ """Generate the embed iframe code based on current settings."""
+ space_host = os.environ.get("SPACE_HOST", "")
+ if not space_host:
+ return ""
+
+ params = []
+
+ if project:
+ params.append(f"project={project}")
+
+ if metrics and metrics.strip():
+ params.append(f"metrics={metrics}")
+
+ if selected_runs:
+ runs_param = ",".join(selected_runs)
+ params.append(f"runs={runs_param}")
+
+ params.append("sidebar=hidden")
+
+ query_string = "&".join(params)
+ embed_url = f"https://{space_host}?{query_string}"
+
+ return f''
+
+
+def serialize_values(metrics):
+ """
+ Serialize infinity and NaN values in metrics dict to make it JSON-compliant.
+ Only handles top-level float values.
+
+ Converts:
+ - float('inf') -> "Infinity"
+ - float('-inf') -> "-Infinity"
+ - float('nan') -> "NaN"
+
+ Example:
+ {"loss": float('inf'), "accuracy": 0.95} -> {"loss": "Infinity", "accuracy": 0.95}
+ """
+ if not isinstance(metrics, dict):
+ return metrics
+
+ result = {}
+ for key, value in metrics.items():
+ if isinstance(value, float):
+ if math.isinf(value):
+ result[key] = "Infinity" if value > 0 else "-Infinity"
+ elif math.isnan(value):
+ result[key] = "NaN"
+ else:
+ result[key] = value
+ elif isinstance(value, np.floating):
+ float_val = float(value)
+ if math.isinf(float_val):
+ result[key] = "Infinity" if float_val > 0 else "-Infinity"
+ elif math.isnan(float_val):
+ result[key] = "NaN"
+ else:
+ result[key] = float_val
+ else:
+ result[key] = value
+ return result
+
+
+def deserialize_values(metrics):
+ """
+ Deserialize infinity and NaN string values back to their numeric forms.
+ Only handles top-level string values.
+
+ Converts:
+ - "Infinity" -> float('inf')
+ - "-Infinity" -> float('-inf')
+ - "NaN" -> float('nan')
+
+ Example:
+ {"loss": "Infinity", "accuracy": 0.95} -> {"loss": float('inf'), "accuracy": 0.95}
+ """
+ if not isinstance(metrics, dict):
+ return metrics
+
+ result = {}
+ for key, value in metrics.items():
+ if value == "Infinity":
+ result[key] = float("inf")
+ elif value == "-Infinity":
+ result[key] = float("-inf")
+ elif value == "NaN":
+ result[key] = float("nan")
+ else:
+ result[key] = value
+ return result
diff --git a/version.txt b/version.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f64a1b563b2adff2a065b04d4d81492d237923c2
--- /dev/null
+++ b/version.txt
@@ -0,0 +1 @@
+0.3.4.dev0
diff --git a/video_writer.py b/video_writer.py
new file mode 100644
index 0000000000000000000000000000000000000000..17374944816de049899fb40cef5836ce8aa7b63a
--- /dev/null
+++ b/video_writer.py
@@ -0,0 +1,126 @@
+import shutil
+import subprocess
+from pathlib import Path
+from typing import Literal
+
+import numpy as np
+
+VideoCodec = Literal["h264", "vp9", "gif"]
+
+
+def _check_ffmpeg_installed() -> None:
+ """Raise an error if ffmpeg is not available on the system PATH."""
+ if shutil.which("ffmpeg") is None:
+ raise RuntimeError(
+ "ffmpeg is required to write video but was not found on your system. "
+ "Please install ffmpeg and ensure it is available on your PATH."
+ )
+
+
+def _check_array_format(video: np.ndarray) -> None:
+ """Raise an error if the array is not in the expected format."""
+ if not (video.ndim == 4 and video.shape[-1] == 3):
+ raise ValueError(
+ f"Expected RGB input shaped (F, H, W, 3), got {video.shape}. "
+ f"Input has {video.ndim} dimensions, expected 4."
+ )
+ if video.dtype != np.uint8:
+ raise TypeError(
+ f"Expected dtype=uint8, got {video.dtype}. "
+ "Please convert your video data to uint8 format."
+ )
+
+
+def _check_path(file_path: str | Path) -> None:
+ """Raise an error if the parent directory does not exist."""
+ file_path = Path(file_path)
+ if not file_path.parent.exists():
+ try:
+ file_path.parent.mkdir(parents=True, exist_ok=True)
+ except OSError as e:
+ raise ValueError(
+ f"Failed to create parent directory {file_path.parent}: {e}"
+ )
+
+
+def write_video(
+ file_path: str | Path, video: np.ndarray, fps: float, codec: VideoCodec
+) -> None:
+ """RGB uint8 only, shape (F, H, W, 3)."""
+ _check_ffmpeg_installed()
+ _check_path(file_path)
+
+ if codec not in {"h264", "vp9", "gif"}:
+ raise ValueError("Unsupported codec. Use h264, vp9, or gif.")
+
+ arr = np.asarray(video)
+ _check_array_format(arr)
+
+ frames = np.ascontiguousarray(arr)
+ _, height, width, _ = frames.shape
+ out_path = str(file_path)
+
+ cmd = [
+ "ffmpeg",
+ "-y",
+ "-f",
+ "rawvideo",
+ "-s",
+ f"{width}x{height}",
+ "-pix_fmt",
+ "rgb24",
+ "-r",
+ str(fps),
+ "-i",
+ "-",
+ "-an",
+ ]
+
+ if codec == "gif":
+ video_filter = "split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse"
+ cmd += [
+ "-vf",
+ video_filter,
+ "-loop",
+ "0",
+ ]
+ elif codec == "h264":
+ cmd += [
+ "-vcodec",
+ "libx264",
+ "-pix_fmt",
+ "yuv420p",
+ "-movflags",
+ "+faststart",
+ ]
+ elif codec == "vp9":
+ bpp = 0.08
+ bps = int(width * height * fps * bpp)
+ if bps >= 1_000_000:
+ bitrate = f"{round(bps / 1_000_000)}M"
+ elif bps >= 1_000:
+ bitrate = f"{round(bps / 1_000)}k"
+ else:
+ bitrate = str(max(bps, 1))
+ cmd += [
+ "-vcodec",
+ "libvpx-vp9",
+ "-b:v",
+ bitrate,
+ "-pix_fmt",
+ "yuv420p",
+ ]
+ cmd += [out_path]
+ proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.PIPE)
+ try:
+ for frame in frames:
+ proc.stdin.write(frame.tobytes())
+ finally:
+ if proc.stdin:
+ proc.stdin.close()
+ stderr = (
+ proc.stderr.read().decode("utf-8", errors="ignore") if proc.stderr else ""
+ )
+ ret = proc.wait()
+ if ret != 0:
+ raise RuntimeError(f"ffmpeg failed with code {ret}\n{stderr}")