diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..df573b28e8a840f0092cd0e97609ace2c73500a4 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/trackio_logo_old.png filter=lfs diff=lfs merge=lfs -text diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a2d95e45505954e32011da6f18f3f47f844beb2 --- /dev/null +++ b/__init__.py @@ -0,0 +1,264 @@ +import hashlib +import os +import warnings +import webbrowser +from pathlib import Path +from typing import Any + +from gradio.blocks import BUILT_IN_THEMES +from gradio.themes import Default as DefaultTheme +from gradio.themes import ThemeClass +from gradio_client import Client +from huggingface_hub import SpaceStorage + +from trackio import context_vars, deploy, utils +from trackio.imports import import_csv, import_tf_events +from trackio.media import TrackioImage, TrackioVideo +from trackio.run import Run +from trackio.sqlite_storage import SQLiteStorage +from trackio.table import Table +from trackio.ui import demo +from trackio.utils import TRACKIO_DIR, TRACKIO_LOGO_DIR + +__version__ = Path(__file__).parent.joinpath("version.txt").read_text().strip() + +__all__ = [ + "init", + "log", + "finish", + "show", + "import_csv", + "import_tf_events", + "Image", + "Video", + "Table", +] + +Image = TrackioImage +Video = TrackioVideo + + +config = {} + +DEFAULT_THEME = "citrus" + + +def init( + project: str, + name: str | None = None, + space_id: str | None = None, + space_storage: SpaceStorage | None = None, + dataset_id: str | None = None, + config: dict | None = None, + resume: str = "never", + settings: Any = None, +) -> Run: + """ + Creates a new Trackio project and returns a [`Run`] object. + + Args: + project (`str`): + The name of the project (can be an existing project to continue tracking or + a new project to start tracking from scratch). + name (`str` or `None`, *optional*, defaults to `None`): + The name of the run (if not provided, a default name will be generated). + space_id (`str` or `None`, *optional*, defaults to `None`): + If provided, the project will be logged to a Hugging Face Space instead of + a local directory. Should be a complete Space name like + `"username/reponame"` or `"orgname/reponame"`, or just `"reponame"` in which + case the Space will be created in the currently-logged-in Hugging Face + user's namespace. If the Space does not exist, it will be created. If the + Space already exists, the project will be logged to it. + space_storage ([`~huggingface_hub.SpaceStorage`] or `None`, *optional*, defaults to `None`): + Choice of persistent storage tier. + dataset_id (`str` or `None`, *optional*, defaults to `None`): + If a `space_id` is provided, a persistent Hugging Face Dataset will be + created and the metrics will be synced to it every 5 minutes. Specify a + Dataset with name like `"username/datasetname"` or `"orgname/datasetname"`, + or `"datasetname"` (uses currently-logged-in Hugging Face user's namespace), + or `None` (uses the same name as the Space but with the `"_dataset"` + suffix). If the Dataset does not exist, it will be created. If the Dataset + already exists, the project will be appended to it. + config (`dict` or `None`, *optional*, defaults to `None`): + A dictionary of configuration options. Provided for compatibility with + `wandb.init()`. + resume (`str`, *optional*, defaults to `"never"`): + Controls how to handle resuming a run. Can be one of: + + - `"must"`: Must resume the run with the given name, raises error if run + doesn't exist + - `"allow"`: Resume the run if it exists, otherwise create a new run + - `"never"`: Never resume a run, always create a new one + settings (`Any`, *optional*, defaults to `None`): + Not used. Provided for compatibility with `wandb.init()`. + + Returns: + `Run`: A [`Run`] object that can be used to log metrics and finish the run. + """ + if settings is not None: + warnings.warn( + "* Warning: settings is not used. Provided for compatibility with wandb.init(). Please create an issue at: https://github.com/gradio-app/trackio/issues if you need a specific feature implemented." + ) + + if space_id is None and dataset_id is not None: + raise ValueError("Must provide a `space_id` when `dataset_id` is provided.") + space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id) + url = context_vars.current_server.get() + + if url is None: + if space_id is None: + _, url, _ = demo.launch( + show_api=False, + inline=False, + quiet=True, + prevent_thread_lock=True, + show_error=True, + ) + else: + url = space_id + context_vars.current_server.set(url) + + if ( + context_vars.current_project.get() is None + or context_vars.current_project.get() != project + ): + print(f"* Trackio project initialized: {project}") + + if dataset_id is not None: + os.environ["TRACKIO_DATASET_ID"] = dataset_id + print( + f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}" + ) + if space_id is None: + print(f"* Trackio metrics logged to: {TRACKIO_DIR}") + utils.print_dashboard_instructions(project) + else: + deploy.create_space_if_not_exists(space_id, space_storage, dataset_id) + print( + f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}" + ) + context_vars.current_project.set(project) + + client = None + if not space_id: + client = Client(url, verbose=False) + + if resume == "must": + if name is None: + raise ValueError("Must provide a run name when resume='must'") + if name not in SQLiteStorage.get_runs(project): + raise ValueError(f"Run '{name}' does not exist in project '{project}'") + resumed = True + elif resume == "allow": + resumed = name is not None and name in SQLiteStorage.get_runs(project) + elif resume == "never": + if name is not None and name in SQLiteStorage.get_runs(project): + warnings.warn( + f"* Warning: resume='never' but a run '{name}' already exists in " + f"project '{project}'. Generating a new name and instead. If you want " + "to resume this run, call init() with resume='must' or resume='allow'." + ) + name = None + resumed = False + else: + raise ValueError("resume must be one of: 'must', 'allow', or 'never'") + + run = Run( + url=url, + project=project, + client=client, + name=name, + config=config, + space_id=space_id, + ) + + if resumed: + print(f"* Resumed existing run: {run.name}") + else: + print(f"* Created new run: {run.name}") + + context_vars.current_run.set(run) + globals()["config"] = run.config + return run + + +def log(metrics: dict, step: int | None = None) -> None: + """ + Logs metrics to the current run. + + Args: + metrics (`dict`): + A dictionary of metrics to log. + step (`int` or `None`, *optional*, defaults to `None`): + The step number. If not provided, the step will be incremented + automatically. + """ + run = context_vars.current_run.get() + if run is None: + raise RuntimeError("Call trackio.init() before trackio.log().") + run.log( + metrics=metrics, + step=step, + ) + + +def finish(): + """ + Finishes the current run. + """ + run = context_vars.current_run.get() + if run is None: + raise RuntimeError("Call trackio.init() before trackio.finish().") + run.finish() + + +def show(project: str | None = None, theme: str | ThemeClass = DEFAULT_THEME): + """ + Launches the Trackio dashboard. + + Args: + project (`str` or `None`, *optional*, defaults to `None`): + The name of the project whose runs to show. If not provided, all projects + will be shown and the user can select one. + theme (`str` or `ThemeClass`, *optional*, defaults to `"citrus"`): + A Gradio Theme to use for the dashboard instead of the default `"citrus"`, + can be a built-in theme (e.g. `'soft'`, `'default'`), a theme from the Hub + (e.g. `"gstaff/xkcd"`), or a custom Theme class. + """ + if theme != DEFAULT_THEME: + # TODO: It's a little hacky to reproduce this theme-setting logic from Gradio Blocks, + # but in Gradio 6.0, the theme will be set in `launch()` instead, which means that we + # will be able to remove this code. + if isinstance(theme, str): + if theme.lower() in BUILT_IN_THEMES: + theme = BUILT_IN_THEMES[theme.lower()] + else: + try: + theme = ThemeClass.from_hub(theme) + except Exception as e: + warnings.warn(f"Cannot load {theme}. Caught Exception: {str(e)}") + theme = DefaultTheme() + if not isinstance(theme, ThemeClass): + warnings.warn("Theme should be a class loaded from gradio.themes") + theme = DefaultTheme() + demo.theme: ThemeClass = theme + demo.theme_css = theme._get_theme_css() + demo.stylesheets = theme._stylesheets + theme_hasher = hashlib.sha256() + theme_hasher.update(demo.theme_css.encode("utf-8")) + demo.theme_hash = theme_hasher.hexdigest() + + _, url, share_url = demo.launch( + show_api=False, + quiet=True, + inline=False, + prevent_thread_lock=True, + favicon_path=TRACKIO_LOGO_DIR / "trackio_logo_light.png", + allowed_paths=[TRACKIO_LOGO_DIR], + ) + + base_url = share_url + "/" if share_url else url + dashboard_url = base_url + f"?project={project}" if project else base_url + print(f"* Trackio UI launched at: {dashboard_url}") + webbrowser.open(dashboard_url) + utils.block_except_in_notebook() diff --git a/__pycache__/__init__.cpython-310.pyc b/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fb771bfb289af7c5ed3fc544b23ec3cd4695688 Binary files /dev/null and b/__pycache__/__init__.cpython-310.pyc differ diff --git a/__pycache__/__init__.cpython-311.pyc b/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0686ad1db33906c9e968b929f69d1eaf367e2e0 Binary files /dev/null and b/__pycache__/__init__.cpython-311.pyc differ diff --git a/__pycache__/__init__.cpython-312.pyc b/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bad3227e0f9a57894354f6fdf86b63f5321d8d1c Binary files /dev/null and b/__pycache__/__init__.cpython-312.pyc differ diff --git a/__pycache__/api.cpython-312.pyc b/__pycache__/api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db895ead0b2b19e5f124e694a00d930d2b875ef4 Binary files /dev/null and b/__pycache__/api.cpython-312.pyc differ diff --git a/__pycache__/cli.cpython-311.pyc b/__pycache__/cli.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a11d442256198bc11f985299c4245b42138cc6f5 Binary files /dev/null and b/__pycache__/cli.cpython-311.pyc differ diff --git a/__pycache__/cli.cpython-312.pyc b/__pycache__/cli.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e81c2e2075bda7d412fbb5b472655a6b41e5830e Binary files /dev/null and b/__pycache__/cli.cpython-312.pyc differ diff --git a/__pycache__/commit_scheduler.cpython-311.pyc b/__pycache__/commit_scheduler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d3b72648ed7a92505efe3c5c594e6e2c057b219 Binary files /dev/null and b/__pycache__/commit_scheduler.cpython-311.pyc differ diff --git a/__pycache__/commit_scheduler.cpython-312.pyc b/__pycache__/commit_scheduler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c496df12a82ee73ba1b8614d32eddb694babb74 Binary files /dev/null and b/__pycache__/commit_scheduler.cpython-312.pyc differ diff --git a/__pycache__/context.cpython-312.pyc b/__pycache__/context.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a215aca5b91bb762b70c64bc5cc419f6c610e05 Binary files /dev/null and b/__pycache__/context.cpython-312.pyc differ diff --git a/__pycache__/context_vars.cpython-311.pyc b/__pycache__/context_vars.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9f0e81ee10adda3c17c222bc516c1e79d5badcc Binary files /dev/null and b/__pycache__/context_vars.cpython-311.pyc differ diff --git a/__pycache__/context_vars.cpython-312.pyc b/__pycache__/context_vars.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96702ea9a765374170b36ac83b928a6b5e8a7309 Binary files /dev/null and b/__pycache__/context_vars.cpython-312.pyc differ diff --git a/__pycache__/deploy.cpython-310.pyc b/__pycache__/deploy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23e835fb9694a9735f3bf34bc7a454ee4f1a9dc3 Binary files /dev/null and b/__pycache__/deploy.cpython-310.pyc differ diff --git a/__pycache__/deploy.cpython-311.pyc b/__pycache__/deploy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb048416cd4f4f7cf33724e1459d63fb88e51141 Binary files /dev/null and b/__pycache__/deploy.cpython-311.pyc differ diff --git a/__pycache__/deploy.cpython-312.pyc b/__pycache__/deploy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e25ac39ee03ff567b86e7e1b4ca2500ca48f41de Binary files /dev/null and b/__pycache__/deploy.cpython-312.pyc differ diff --git a/__pycache__/dummy_commit_scheduler.cpython-310.pyc b/__pycache__/dummy_commit_scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7465aad3c0d39db51f75ad0762999e10744537d3 Binary files /dev/null and b/__pycache__/dummy_commit_scheduler.cpython-310.pyc differ diff --git a/__pycache__/dummy_commit_scheduler.cpython-311.pyc b/__pycache__/dummy_commit_scheduler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..517f6e2f4d1dfcfcc6a695b26cfb5b116e503b21 Binary files /dev/null and b/__pycache__/dummy_commit_scheduler.cpython-311.pyc differ diff --git a/__pycache__/dummy_commit_scheduler.cpython-312.pyc b/__pycache__/dummy_commit_scheduler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40ae2a8fff8f37a435732b9f6964ebac3b90ee41 Binary files /dev/null and b/__pycache__/dummy_commit_scheduler.cpython-312.pyc differ diff --git a/__pycache__/file_storage.cpython-311.pyc b/__pycache__/file_storage.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c42edf83e3861ea81f56c71c37959907e07051d Binary files /dev/null and b/__pycache__/file_storage.cpython-311.pyc differ diff --git a/__pycache__/file_storage.cpython-312.pyc b/__pycache__/file_storage.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f26e69fe8517b1efafe61df875f0b6994e79e04d Binary files /dev/null and b/__pycache__/file_storage.cpython-312.pyc differ diff --git a/__pycache__/imports.cpython-311.pyc b/__pycache__/imports.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51de8c063b2d01f3f69ddc5dd0338a91def55a0a Binary files /dev/null and b/__pycache__/imports.cpython-311.pyc differ diff --git a/__pycache__/imports.cpython-312.pyc b/__pycache__/imports.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f67cb1f5665ef3b833ee4738e2310063f248888b Binary files /dev/null and b/__pycache__/imports.cpython-312.pyc differ diff --git a/__pycache__/media.cpython-311.pyc b/__pycache__/media.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f57dfc7bf64e11ec7b6a71863da1c61a4d6dd885 Binary files /dev/null and b/__pycache__/media.cpython-311.pyc differ diff --git a/__pycache__/media.cpython-312.pyc b/__pycache__/media.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5278976200978981f4f7e5a6d840c24e6dd4486b Binary files /dev/null and b/__pycache__/media.cpython-312.pyc differ diff --git a/__pycache__/run.cpython-310.pyc b/__pycache__/run.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce986079a4e824bd78d2c5fcdf47bc5a030b4193 Binary files /dev/null and b/__pycache__/run.cpython-310.pyc differ diff --git a/__pycache__/run.cpython-311.pyc b/__pycache__/run.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a62b4b743094cf51c093228519c668b5628aeea Binary files /dev/null and b/__pycache__/run.cpython-311.pyc differ diff --git a/__pycache__/run.cpython-312.pyc b/__pycache__/run.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cec2e7ea8b7c22d1f1d3ea35a3fb76e050dbb3a Binary files /dev/null and b/__pycache__/run.cpython-312.pyc differ diff --git a/__pycache__/sqlite_storage.cpython-310.pyc b/__pycache__/sqlite_storage.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..111bbf14ce395fe05202b818ae4de3587ed92e6e Binary files /dev/null and b/__pycache__/sqlite_storage.cpython-310.pyc differ diff --git a/__pycache__/sqlite_storage.cpython-311.pyc b/__pycache__/sqlite_storage.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcc6e812de7d24a968d7e6739cc477dcf595fafa Binary files /dev/null and b/__pycache__/sqlite_storage.cpython-311.pyc differ diff --git a/__pycache__/sqlite_storage.cpython-312.pyc b/__pycache__/sqlite_storage.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6584e0712a58c6538cc804719ee00f30f3d4ae9f Binary files /dev/null and b/__pycache__/sqlite_storage.cpython-312.pyc differ diff --git a/__pycache__/storage.cpython-312.pyc b/__pycache__/storage.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86681eb1df8341239c0e0bb8e03d836873807911 Binary files /dev/null and b/__pycache__/storage.cpython-312.pyc differ diff --git a/__pycache__/table.cpython-311.pyc b/__pycache__/table.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77ceec71354cae777b968c220a24c5cee957473c Binary files /dev/null and b/__pycache__/table.cpython-311.pyc differ diff --git a/__pycache__/table.cpython-312.pyc b/__pycache__/table.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fa772de941751e68a5198af82db3dd2eda79875 Binary files /dev/null and b/__pycache__/table.cpython-312.pyc differ diff --git a/__pycache__/typehints.cpython-311.pyc b/__pycache__/typehints.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9073d269f19ed295d2f3d67556b91b5ae42bd34d Binary files /dev/null and b/__pycache__/typehints.cpython-311.pyc differ diff --git a/__pycache__/typehints.cpython-312.pyc b/__pycache__/typehints.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7893460bed4c2b55f36bc2b509420a4026d337f1 Binary files /dev/null and b/__pycache__/typehints.cpython-312.pyc differ diff --git a/__pycache__/types.cpython-312.pyc b/__pycache__/types.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b17d58d91e97665a9cbe3e275eb5467dca73707 Binary files /dev/null and b/__pycache__/types.cpython-312.pyc differ diff --git a/__pycache__/ui.cpython-310.pyc b/__pycache__/ui.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08551e84ddeb0c06d5869cc556e40b2be889a702 Binary files /dev/null and b/__pycache__/ui.cpython-310.pyc differ diff --git a/__pycache__/ui.cpython-311.pyc b/__pycache__/ui.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c56bc716811da81a97acca76395f204044c64c88 Binary files /dev/null and b/__pycache__/ui.cpython-311.pyc differ diff --git a/__pycache__/ui.cpython-312.pyc b/__pycache__/ui.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05196127a497adffeb50c2673b4eeb5275d24b62 Binary files /dev/null and b/__pycache__/ui.cpython-312.pyc differ diff --git a/__pycache__/utils.cpython-310.pyc b/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f54a74128e47f630e8b216c14e355f571fec93d Binary files /dev/null and b/__pycache__/utils.cpython-310.pyc differ diff --git a/__pycache__/utils.cpython-311.pyc b/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8c887cd370c241ff00d7c6fec7f41724a3591dc Binary files /dev/null and b/__pycache__/utils.cpython-311.pyc differ diff --git a/__pycache__/utils.cpython-312.pyc b/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6c274b4d5e1b3d79bfbe46853acf0216567b501 Binary files /dev/null and b/__pycache__/utils.cpython-312.pyc differ diff --git a/__pycache__/video_writer.cpython-311.pyc b/__pycache__/video_writer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49e7f763939485f74379c2804bb9c9323a024f12 Binary files /dev/null and b/__pycache__/video_writer.cpython-311.pyc differ diff --git a/__pycache__/video_writer.cpython-312.pyc b/__pycache__/video_writer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91a7849164ff42355a2ea004b793dfa51e4e9390 Binary files /dev/null and b/__pycache__/video_writer.cpython-312.pyc differ diff --git a/assets/trackio_logo_dark.png b/assets/trackio_logo_dark.png new file mode 100644 index 0000000000000000000000000000000000000000..5c5c08b2387d23599f177477ef7482ff6a601df3 Binary files /dev/null and b/assets/trackio_logo_dark.png differ diff --git a/assets/trackio_logo_light.png b/assets/trackio_logo_light.png new file mode 100644 index 0000000000000000000000000000000000000000..b3438262c61989e6c6d16df4801a8935136115b3 Binary files /dev/null and b/assets/trackio_logo_light.png differ diff --git a/assets/trackio_logo_old.png b/assets/trackio_logo_old.png new file mode 100644 index 0000000000000000000000000000000000000000..48a07d40b83e23c9cc9dc0cb6544a3c6ad62b57f --- /dev/null +++ b/assets/trackio_logo_old.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3922c4d1e465270ad4d8abb12023f3beed5d9f7f338528a4c0ac21dcf358a1c8 +size 487101 diff --git a/assets/trackio_logo_type_dark.png b/assets/trackio_logo_type_dark.png new file mode 100644 index 0000000000000000000000000000000000000000..6f80a3191e514a8a0beaa6ab2011e5baf8df5eda Binary files /dev/null and b/assets/trackio_logo_type_dark.png differ diff --git a/assets/trackio_logo_type_dark_transparent.png b/assets/trackio_logo_type_dark_transparent.png new file mode 100644 index 0000000000000000000000000000000000000000..95b2c1f3499c502a81f2ec1094c0e09f827fb1fa Binary files /dev/null and b/assets/trackio_logo_type_dark_transparent.png differ diff --git a/assets/trackio_logo_type_light.png b/assets/trackio_logo_type_light.png new file mode 100644 index 0000000000000000000000000000000000000000..f07866d245ea897b9aba417b29403f7f91cc8bbd Binary files /dev/null and b/assets/trackio_logo_type_light.png differ diff --git a/assets/trackio_logo_type_light_transparent.png b/assets/trackio_logo_type_light_transparent.png new file mode 100644 index 0000000000000000000000000000000000000000..a20b4d5e64c61c91546577645310593fe3493508 Binary files /dev/null and b/assets/trackio_logo_type_light_transparent.png differ diff --git a/cli.py b/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..76be544f3d03671ce37cae802fc068f70d2c61e7 --- /dev/null +++ b/cli.py @@ -0,0 +1,32 @@ +import argparse + +from trackio import show + + +def main(): + parser = argparse.ArgumentParser(description="Trackio CLI") + subparsers = parser.add_subparsers(dest="command") + + ui_parser = subparsers.add_parser( + "show", help="Show the Trackio dashboard UI for a project" + ) + ui_parser.add_argument( + "--project", required=False, help="Project name to show in the dashboard" + ) + ui_parser.add_argument( + "--theme", + required=False, + default="citrus", + help="A Gradio Theme to use for the dashboard instead of the default 'citrus', can be a built-in theme (e.g. 'soft', 'default'), a theme from the Hub (e.g. 'gstaff/xkcd').", + ) + + args = parser.parse_args() + + if args.command == "show": + show(args.project, args.theme) + else: + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/commit_scheduler.py b/commit_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..ffa9dfab146ef594450e0ab0e5a25169abf1c07b --- /dev/null +++ b/commit_scheduler.py @@ -0,0 +1,391 @@ +# Originally copied from https://github.com/huggingface/huggingface_hub/blob/d0a948fc2a32ed6e557042a95ef3e4af97ec4a7c/src/huggingface_hub/_commit_scheduler.py + +import atexit +import logging +import os +import time +from concurrent.futures import Future +from dataclasses import dataclass +from io import SEEK_END, SEEK_SET, BytesIO +from pathlib import Path +from threading import Lock, Thread +from typing import Callable, Dict, List, Optional, Union + +from huggingface_hub.hf_api import ( + DEFAULT_IGNORE_PATTERNS, + CommitInfo, + CommitOperationAdd, + HfApi, +) +from huggingface_hub.utils import filter_repo_objects + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class _FileToUpload: + """Temporary dataclass to store info about files to upload. Not meant to be used directly.""" + + local_path: Path + path_in_repo: str + size_limit: int + last_modified: float + + +class CommitScheduler: + """ + Scheduler to upload a local folder to the Hub at regular intervals (e.g. push to hub every 5 minutes). + + The recommended way to use the scheduler is to use it as a context manager. This ensures that the scheduler is + properly stopped and the last commit is triggered when the script ends. The scheduler can also be stopped manually + with the `stop` method. Checkout the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#scheduled-uploads) + to learn more about how to use it. + + Args: + repo_id (`str`): + The id of the repo to commit to. + folder_path (`str` or `Path`): + Path to the local folder to upload regularly. + every (`int` or `float`, *optional*): + The number of minutes between each commit. Defaults to 5 minutes. + path_in_repo (`str`, *optional*): + Relative path of the directory in the repo, for example: `"checkpoints/"`. Defaults to the root folder + of the repository. + repo_type (`str`, *optional*): + The type of the repo to commit to. Defaults to `model`. + revision (`str`, *optional*): + The revision of the repo to commit to. Defaults to `main`. + private (`bool`, *optional*): + Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists. + token (`str`, *optional*): + The token to use to commit to the repo. Defaults to the token saved on the machine. + allow_patterns (`List[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are uploaded. + ignore_patterns (`List[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not uploaded. + squash_history (`bool`, *optional*): + Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is + useful to avoid degraded performances on the repo when it grows too large. + hf_api (`HfApi`, *optional*): + The [`HfApi`] client to use to commit to the Hub. Can be set with custom settings (user agent, token,...). + on_before_commit (`Callable[[], None]`, *optional*): + If specified, a function that will be called before the CommitScheduler lists files to create a commit. + + Example: + ```py + >>> from pathlib import Path + >>> from huggingface_hub import CommitScheduler + + # Scheduler uploads every 10 minutes + >>> csv_path = Path("watched_folder/data.csv") + >>> CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path=csv_path.parent, every=10) + + >>> with csv_path.open("a") as f: + ... f.write("first line") + + # Some time later (...) + >>> with csv_path.open("a") as f: + ... f.write("second line") + ``` + + Example using a context manager: + ```py + >>> from pathlib import Path + >>> from huggingface_hub import CommitScheduler + + >>> with CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path="watched_folder", every=10) as scheduler: + ... csv_path = Path("watched_folder/data.csv") + ... with csv_path.open("a") as f: + ... f.write("first line") + ... (...) + ... with csv_path.open("a") as f: + ... f.write("second line") + + # Scheduler is now stopped and last commit have been triggered + ``` + """ + + def __init__( + self, + *, + repo_id: str, + folder_path: Union[str, Path], + every: Union[int, float] = 5, + path_in_repo: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + private: Optional[bool] = None, + token: Optional[str] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + squash_history: bool = False, + hf_api: Optional["HfApi"] = None, + on_before_commit: Optional[Callable[[], None]] = None, + ) -> None: + self.api = hf_api or HfApi(token=token) + self.on_before_commit = on_before_commit + + # Folder + self.folder_path = Path(folder_path).expanduser().resolve() + self.path_in_repo = path_in_repo or "" + self.allow_patterns = allow_patterns + + if ignore_patterns is None: + ignore_patterns = [] + elif isinstance(ignore_patterns, str): + ignore_patterns = [ignore_patterns] + self.ignore_patterns = ignore_patterns + DEFAULT_IGNORE_PATTERNS + + if self.folder_path.is_file(): + raise ValueError( + f"'folder_path' must be a directory, not a file: '{self.folder_path}'." + ) + self.folder_path.mkdir(parents=True, exist_ok=True) + + # Repository + repo_url = self.api.create_repo( + repo_id=repo_id, private=private, repo_type=repo_type, exist_ok=True + ) + self.repo_id = repo_url.repo_id + self.repo_type = repo_type + self.revision = revision + self.token = token + + self.last_uploaded: Dict[Path, float] = {} + self.last_push_time: float | None = None + + if not every > 0: + raise ValueError(f"'every' must be a positive integer, not '{every}'.") + self.lock = Lock() + self.every = every + self.squash_history = squash_history + + logger.info( + f"Scheduled job to push '{self.folder_path}' to '{self.repo_id}' every {self.every} minutes." + ) + self._scheduler_thread = Thread(target=self._run_scheduler, daemon=True) + self._scheduler_thread.start() + atexit.register(self._push_to_hub) + + self.__stopped = False + + def stop(self) -> None: + """Stop the scheduler. + + A stopped scheduler cannot be restarted. Mostly for tests purposes. + """ + self.__stopped = True + + def __enter__(self) -> "CommitScheduler": + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + # Upload last changes before exiting + self.trigger().result() + self.stop() + return + + def _run_scheduler(self) -> None: + """Dumb thread waiting between each scheduled push to Hub.""" + while True: + self.last_future = self.trigger() + time.sleep(self.every * 60) + if self.__stopped: + break + + def trigger(self) -> Future: + """Trigger a `push_to_hub` and return a future. + + This method is automatically called every `every` minutes. You can also call it manually to trigger a commit + immediately, without waiting for the next scheduled commit. + """ + return self.api.run_as_future(self._push_to_hub) + + def _push_to_hub(self) -> Optional[CommitInfo]: + if self.__stopped: # If stopped, already scheduled commits are ignored + return None + + logger.info("(Background) scheduled commit triggered.") + try: + value = self.push_to_hub() + if self.squash_history: + logger.info("(Background) squashing repo history.") + self.api.super_squash_history( + repo_id=self.repo_id, repo_type=self.repo_type, branch=self.revision + ) + return value + except Exception as e: + logger.error( + f"Error while pushing to Hub: {e}" + ) # Depending on the setup, error might be silenced + raise + + def push_to_hub(self) -> Optional[CommitInfo]: + """ + Push folder to the Hub and return the commit info. + + + + This method is not meant to be called directly. It is run in the background by the scheduler, respecting a + queue mechanism to avoid concurrent commits. Making a direct call to the method might lead to concurrency + issues. + + + + The default behavior of `push_to_hub` is to assume an append-only folder. It lists all files in the folder and + uploads only changed files. If no changes are found, the method returns without committing anything. If you want + to change this behavior, you can inherit from [`CommitScheduler`] and override this method. This can be useful + for example to compress data together in a single file before committing. For more details and examples, check + out our [integration guide](https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#scheduled-uploads). + """ + # Check files to upload (with lock) + with self.lock: + if self.on_before_commit is not None: + self.on_before_commit() + + logger.debug("Listing files to upload for scheduled commit.") + + # List files from folder (taken from `_prepare_upload_folder_additions`) + relpath_to_abspath = { + path.relative_to(self.folder_path).as_posix(): path + for path in sorted( + self.folder_path.glob("**/*") + ) # sorted to be deterministic + if path.is_file() + } + prefix = f"{self.path_in_repo.strip('/')}/" if self.path_in_repo else "" + + # Filter with pattern + filter out unchanged files + retrieve current file size + files_to_upload: List[_FileToUpload] = [] + for relpath in filter_repo_objects( + relpath_to_abspath.keys(), + allow_patterns=self.allow_patterns, + ignore_patterns=self.ignore_patterns, + ): + local_path = relpath_to_abspath[relpath] + stat = local_path.stat() + if ( + self.last_uploaded.get(local_path) is None + or self.last_uploaded[local_path] != stat.st_mtime + ): + files_to_upload.append( + _FileToUpload( + local_path=local_path, + path_in_repo=prefix + relpath, + size_limit=stat.st_size, + last_modified=stat.st_mtime, + ) + ) + + # Return if nothing to upload + if len(files_to_upload) == 0: + logger.debug("Dropping schedule commit: no changed file to upload.") + return None + + # Convert `_FileToUpload` as `CommitOperationAdd` (=> compute file shas + limit to file size) + logger.debug("Removing unchanged files since previous scheduled commit.") + add_operations = [ + CommitOperationAdd( + # TODO: Cap the file to its current size, even if the user append data to it while a scheduled commit is happening + # (requires an upstream fix for XET-535: `hf_xet` should support `BinaryIO` for upload) + path_or_fileobj=file_to_upload.local_path, + path_in_repo=file_to_upload.path_in_repo, + ) + for file_to_upload in files_to_upload + ] + + # Upload files (append mode expected - no need for lock) + logger.debug("Uploading files for scheduled commit.") + commit_info = self.api.create_commit( + repo_id=self.repo_id, + repo_type=self.repo_type, + operations=add_operations, + commit_message="Scheduled Commit", + revision=self.revision, + ) + + for file in files_to_upload: + self.last_uploaded[file.local_path] = file.last_modified + + self.last_push_time = time.time() + + return commit_info + + +class PartialFileIO(BytesIO): + """A file-like object that reads only the first part of a file. + + Useful to upload a file to the Hub when the user might still be appending data to it. Only the first part of the + file is uploaded (i.e. the part that was available when the filesystem was first scanned). + + In practice, only used internally by the CommitScheduler to regularly push a folder to the Hub with minimal + disturbance for the user. The object is passed to `CommitOperationAdd`. + + Only supports `read`, `tell` and `seek` methods. + + Args: + file_path (`str` or `Path`): + Path to the file to read. + size_limit (`int`): + The maximum number of bytes to read from the file. If the file is larger than this, only the first part + will be read (and uploaded). + """ + + def __init__(self, file_path: Union[str, Path], size_limit: int) -> None: + self._file_path = Path(file_path) + self._file = self._file_path.open("rb") + self._size_limit = min(size_limit, os.fstat(self._file.fileno()).st_size) + + def __del__(self) -> None: + self._file.close() + return super().__del__() + + def __repr__(self) -> str: + return ( + f"" + ) + + def __len__(self) -> int: + return self._size_limit + + def __getattribute__(self, name: str): + if name.startswith("_") or name in ( + "read", + "tell", + "seek", + ): # only 3 public methods supported + return super().__getattribute__(name) + raise NotImplementedError(f"PartialFileIO does not support '{name}'.") + + def tell(self) -> int: + """Return the current file position.""" + return self._file.tell() + + def seek(self, __offset: int, __whence: int = SEEK_SET) -> int: + """Change the stream position to the given offset. + + Behavior is the same as a regular file, except that the position is capped to the size limit. + """ + if __whence == SEEK_END: + # SEEK_END => set from the truncated end + __offset = len(self) + __offset + __whence = SEEK_SET + + pos = self._file.seek(__offset, __whence) + if pos > self._size_limit: + return self._file.seek(self._size_limit) + return pos + + def read(self, __size: Optional[int] = -1) -> bytes: + """Read at most `__size` bytes from the file. + + Behavior is the same as a regular file, except that it is capped to the size limit. + """ + current = self._file.tell() + if __size is None or __size < 0: + # Read until file limit + truncated_size = self._size_limit - current + else: + # Read until file limit or __size + truncated_size = min(__size, self._size_limit - current) + return self._file.read(truncated_size) diff --git a/context_vars.py b/context_vars.py new file mode 100644 index 0000000000000000000000000000000000000000..0dab719f76cdcdf5e173bebea47666ca1b016694 --- /dev/null +++ b/context_vars.py @@ -0,0 +1,15 @@ +import contextvars +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from trackio.run import Run + +current_run: contextvars.ContextVar["Run | None"] = contextvars.ContextVar( + "current_run", default=None +) +current_project: contextvars.ContextVar[str | None] = contextvars.ContextVar( + "current_project", default=None +) +current_server: contextvars.ContextVar[str | None] = contextvars.ContextVar( + "current_server", default=None +) diff --git a/data_types/__pycache__/__init__.cpython-311.pyc b/data_types/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79597e766b34941f993cfb0bc47e86f4f12abe1a Binary files /dev/null and b/data_types/__pycache__/__init__.cpython-311.pyc differ diff --git a/data_types/__pycache__/__init__.cpython-312.pyc b/data_types/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38a9cb512c28ad442010e27a653093f71e40baa6 Binary files /dev/null and b/data_types/__pycache__/__init__.cpython-312.pyc differ diff --git a/data_types/__pycache__/table.cpython-311.pyc b/data_types/__pycache__/table.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce6bea364a885f347b02772a9b5d722f412ebcfa Binary files /dev/null and b/data_types/__pycache__/table.cpython-311.pyc differ diff --git a/data_types/__pycache__/table.cpython-312.pyc b/data_types/__pycache__/table.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16aca75496bd22d60e7b213fb799873e9744c1f8 Binary files /dev/null and b/data_types/__pycache__/table.cpython-312.pyc differ diff --git a/deploy.py b/deploy.py new file mode 100644 index 0000000000000000000000000000000000000000..2df76b131abe0b0fbf76f1743f859644f118ea5c --- /dev/null +++ b/deploy.py @@ -0,0 +1,217 @@ +import importlib.metadata +import io +import os +import time +from importlib.resources import files +from pathlib import Path + +import gradio +import huggingface_hub +from gradio_client import Client, handle_file +from httpx import ReadTimeout +from huggingface_hub.errors import RepositoryNotFoundError +from requests import HTTPError + +import trackio +from trackio.sqlite_storage import SQLiteStorage + +SPACE_URL = "https://huggingface.co/spaces/{space_id}" + + +def _is_trackio_installed_from_source() -> bool: + """Check if trackio is installed from source/editable install vs PyPI.""" + try: + trackio_file = trackio.__file__ + if "site-packages" not in trackio_file: + return True + + dist = importlib.metadata.distribution("trackio") + if dist.files: + files = list(dist.files) + has_pth = any(".pth" in str(f) for f in files) + if has_pth: + return True + + return False + except ( + AttributeError, + importlib.metadata.PackageNotFoundError, + importlib.metadata.MetadataError, + ValueError, + TypeError, + ): + return True + + +def deploy_as_space( + space_id: str, + space_storage: huggingface_hub.SpaceStorage | None = None, + dataset_id: str | None = None, +): + if ( + os.getenv("SYSTEM") == "spaces" + ): # in case a repo with this function is uploaded to spaces + return + + trackio_path = files("trackio") + + hf_api = huggingface_hub.HfApi() + + try: + huggingface_hub.create_repo( + space_id, + space_sdk="gradio", + space_storage=space_storage, + repo_type="space", + exist_ok=True, + ) + except HTTPError as e: + if e.response.status_code in [401, 403]: # unauthorized or forbidden + print("Need 'write' access token to create a Spaces repo.") + huggingface_hub.login(add_to_git_credential=False) + huggingface_hub.create_repo( + space_id, + space_sdk="gradio", + space_storage=space_storage, + repo_type="space", + exist_ok=True, + ) + else: + raise ValueError(f"Failed to create Space: {e}") + + with open(Path(trackio_path, "README.md"), "r") as f: + readme_content = f.read() + readme_content = readme_content.replace("{GRADIO_VERSION}", gradio.__version__) + readme_buffer = io.BytesIO(readme_content.encode("utf-8")) + hf_api.upload_file( + path_or_fileobj=readme_buffer, + path_in_repo="README.md", + repo_id=space_id, + repo_type="space", + ) + + # We can assume pandas, gradio, and huggingface-hub are already installed in a Gradio Space. + # Make sure necessary dependencies are installed by creating a requirements.txt. + is_source_install = _is_trackio_installed_from_source() + + if is_source_install: + requirements_content = """pyarrow>=21.0""" + else: + requirements_content = f"""pyarrow>=21.0 +trackio=={trackio.__version__}""" + + requirements_buffer = io.BytesIO(requirements_content.encode("utf-8")) + hf_api.upload_file( + path_or_fileobj=requirements_buffer, + path_in_repo="requirements.txt", + repo_id=space_id, + repo_type="space", + ) + + huggingface_hub.utils.disable_progress_bars() + + if is_source_install: + hf_api.upload_folder( + repo_id=space_id, + repo_type="space", + folder_path=trackio_path, + ignore_patterns=["README.md"], + ) + else: + app_file_content = """import trackio +trackio.show()""" + app_file_buffer = io.BytesIO(app_file_content.encode("utf-8")) + hf_api.upload_file( + path_or_fileobj=app_file_buffer, + path_in_repo="ui.py", + repo_id=space_id, + repo_type="space", + ) + + if hf_token := huggingface_hub.utils.get_token(): + huggingface_hub.add_space_secret(space_id, "HF_TOKEN", hf_token) + if dataset_id is not None: + huggingface_hub.add_space_variable(space_id, "TRACKIO_DATASET_ID", dataset_id) + + +def create_space_if_not_exists( + space_id: str, + space_storage: huggingface_hub.SpaceStorage | None = None, + dataset_id: str | None = None, +) -> None: + """ + Creates a new Hugging Face Space if it does not exist. If a dataset_id is provided, it will be added as a space variable. + + Args: + space_id: The ID of the Space to create. + dataset_id: The ID of the Dataset to add to the Space. + """ + if "/" not in space_id: + raise ValueError( + f"Invalid space ID: {space_id}. Must be in the format: username/reponame or orgname/reponame." + ) + if dataset_id is not None and "/" not in dataset_id: + raise ValueError( + f"Invalid dataset ID: {dataset_id}. Must be in the format: username/datasetname or orgname/datasetname." + ) + try: + huggingface_hub.repo_info(space_id, repo_type="space") + print(f"* Found existing space: {SPACE_URL.format(space_id=space_id)}") + if dataset_id is not None: + huggingface_hub.add_space_variable( + space_id, "TRACKIO_DATASET_ID", dataset_id + ) + return + except RepositoryNotFoundError: + pass + except HTTPError as e: + if e.response.status_code in [401, 403]: # unauthorized or forbidden + print("Need 'write' access token to create a Spaces repo.") + huggingface_hub.login(add_to_git_credential=False) + huggingface_hub.add_space_variable( + space_id, "TRACKIO_DATASET_ID", dataset_id + ) + else: + raise ValueError(f"Failed to create Space: {e}") + + print(f"* Creating new space: {SPACE_URL.format(space_id=space_id)}") + deploy_as_space(space_id, space_storage, dataset_id) + + +def wait_until_space_exists( + space_id: str, +) -> None: + """ + Blocks the current thread until the space exists. + May raise a TimeoutError if this takes quite a while. + + Args: + space_id: The ID of the Space to wait for. + """ + delay = 1 + for _ in range(10): + try: + Client(space_id, verbose=False) + return + except (ReadTimeout, ValueError): + time.sleep(delay) + delay = min(delay * 2, 30) + raise TimeoutError("Waiting for space to exist took longer than expected") + + +def upload_db_to_space(project: str, space_id: str) -> None: + """ + Uploads the database of a local Trackio project to a Hugging Face Space. + + Args: + project: The name of the project to upload. + space_id: The ID of the Space to upload to. + """ + db_path = SQLiteStorage.get_project_db_path(project) + client = Client(space_id, verbose=False) + client.predict( + api_name="/upload_db_to_space", + project=project, + uploaded_db=handle_file(db_path), + hf_token=huggingface_hub.utils.get_token(), + ) diff --git a/dummy_commit_scheduler.py b/dummy_commit_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..0f5015e1479a175081080ef8908966979c1de179 --- /dev/null +++ b/dummy_commit_scheduler.py @@ -0,0 +1,12 @@ +# A dummy object to fit the interface of huggingface_hub's CommitScheduler +class DummyCommitSchedulerLock: + def __enter__(self): + return None + + def __exit__(self, exception_type, exception_value, exception_traceback): + pass + + +class DummyCommitScheduler: + def __init__(self): + self.lock = DummyCommitSchedulerLock() diff --git a/file_storage.py b/file_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..fed947c0d8a5efde11b42c1c1f0cae05edb79f0a --- /dev/null +++ b/file_storage.py @@ -0,0 +1,37 @@ +from pathlib import Path + +try: # absolute imports when installed + from trackio.utils import MEDIA_DIR +except ImportError: # relative imports for local execution on Spaces + from utils import MEDIA_DIR + + +class FileStorage: + @staticmethod + def get_project_media_path( + project: str, + run: str | None = None, + step: int | None = None, + filename: str | None = None, + ) -> Path: + if filename is not None and step is None: + raise ValueError("filename requires step") + if step is not None and run is None: + raise ValueError("step requires run") + + path = MEDIA_DIR / project + if run: + path /= run + if step is not None: + path /= str(step) + if filename: + path /= filename + return path + + @staticmethod + def init_project_media_path( + project: str, run: str | None = None, step: int | None = None + ) -> Path: + path = FileStorage.get_project_media_path(project, run, step) + path.mkdir(parents=True, exist_ok=True) + return path diff --git a/imports.py b/imports.py new file mode 100644 index 0000000000000000000000000000000000000000..45b0d58499976280d0b2506210d38ec692abf24d --- /dev/null +++ b/imports.py @@ -0,0 +1,288 @@ +import os +from pathlib import Path + +import pandas as pd + +from trackio import deploy, utils +from trackio.sqlite_storage import SQLiteStorage + + +def import_csv( + csv_path: str | Path, + project: str, + name: str | None = None, + space_id: str | None = None, + dataset_id: str | None = None, +) -> None: + """ + Imports a CSV file into a Trackio project. The CSV file must contain a `"step"` + column, may optionally contain a `"timestamp"` column, and any other columns will be + treated as metrics. It should also include a header row with the column names. + + TODO: call init() and return a Run object so that the user can continue to log metrics to it. + + Args: + csv_path (`str` or `Path`): + The str or Path to the CSV file to import. + project (`str`): + The name of the project to import the CSV file into. Must not be an existing + project. + name (`str` or `None`, *optional*, defaults to `None`): + The name of the Run to import the CSV file into. If not provided, a default + name will be generated. + name (`str` or `None`, *optional*, defaults to `None`): + The name of the run (if not provided, a default name will be generated). + space_id (`str` or `None`, *optional*, defaults to `None`): + If provided, the project will be logged to a Hugging Face Space instead of a + local directory. Should be a complete Space name like `"username/reponame"` + or `"orgname/reponame"`, or just `"reponame"` in which case the Space will + be created in the currently-logged-in Hugging Face user's namespace. If the + Space does not exist, it will be created. If the Space already exists, the + project will be logged to it. + dataset_id (`str` or `None`, *optional*, defaults to `None`): + If provided, a persistent Hugging Face Dataset will be created and the + metrics will be synced to it every 5 minutes. Should be a complete Dataset + name like `"username/datasetname"` or `"orgname/datasetname"`, or just + `"datasetname"` in which case the Dataset will be created in the + currently-logged-in Hugging Face user's namespace. If the Dataset does not + exist, it will be created. If the Dataset already exists, the project will + be appended to it. If not provided, the metrics will be logged to a local + SQLite database, unless a `space_id` is provided, in which case a Dataset + will be automatically created with the same name as the Space but with the + `"_dataset"` suffix. + """ + if SQLiteStorage.get_runs(project): + raise ValueError( + f"Project '{project}' already exists. Cannot import CSV into existing project." + ) + + csv_path = Path(csv_path) + if not csv_path.exists(): + raise FileNotFoundError(f"CSV file not found: {csv_path}") + + df = pd.read_csv(csv_path) + if df.empty: + raise ValueError("CSV file is empty") + + column_mapping = utils.simplify_column_names(df.columns.tolist()) + df = df.rename(columns=column_mapping) + + step_column = None + for col in df.columns: + if col.lower() == "step": + step_column = col + break + + if step_column is None: + raise ValueError("CSV file must contain a 'step' or 'Step' column") + + if name is None: + name = csv_path.stem + + metrics_list = [] + steps = [] + timestamps = [] + + numeric_columns = [] + for column in df.columns: + if column == step_column: + continue + if column == "timestamp": + continue + + try: + pd.to_numeric(df[column], errors="raise") + numeric_columns.append(column) + except (ValueError, TypeError): + continue + + for _, row in df.iterrows(): + metrics = {} + for column in numeric_columns: + value = row[column] + if bool(pd.notna(value)): + metrics[column] = float(value) + + if metrics: + metrics_list.append(metrics) + steps.append(int(row[step_column])) + + if "timestamp" in df.columns and bool(pd.notna(row["timestamp"])): + timestamps.append(str(row["timestamp"])) + else: + timestamps.append("") + + if metrics_list: + SQLiteStorage.bulk_log( + project=project, + run=name, + metrics_list=metrics_list, + steps=steps, + timestamps=timestamps, + ) + + print( + f"* Imported {len(metrics_list)} rows from {csv_path} into project '{project}' as run '{name}'" + ) + print(f"* Metrics found: {', '.join(metrics_list[0].keys())}") + + space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id) + if dataset_id is not None: + os.environ["TRACKIO_DATASET_ID"] = dataset_id + print(f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}") + + if space_id is None: + utils.print_dashboard_instructions(project) + else: + deploy.create_space_if_not_exists(space_id, dataset_id) + deploy.wait_until_space_exists(space_id) + deploy.upload_db_to_space(project, space_id) + print( + f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}" + ) + + +def import_tf_events( + log_dir: str | Path, + project: str, + name: str | None = None, + space_id: str | None = None, + dataset_id: str | None = None, +) -> None: + """ + Imports TensorFlow Events files from a directory into a Trackio project. Each + subdirectory in the log directory will be imported as a separate run. + + Args: + log_dir (`str` or `Path`): + The str or Path to the directory containing TensorFlow Events files. + project (`str`): + The name of the project to import the TensorFlow Events files into. Must not + be an existing project. + name (`str` or `None`, *optional*, defaults to `None`): + The name prefix for runs (if not provided, will use directory names). Each + subdirectory will create a separate run. + space_id (`str` or `None`, *optional*, defaults to `None`): + If provided, the project will be logged to a Hugging Face Space instead of a + local directory. Should be a complete Space name like `"username/reponame"` + or `"orgname/reponame"`, or just `"reponame"` in which case the Space will + be created in the currently-logged-in Hugging Face user's namespace. If the + Space does not exist, it will be created. If the Space already exists, the + project will be logged to it. + dataset_id (`str` or `None`, *optional*, defaults to `None`): + If provided, a persistent Hugging Face Dataset will be created and the + metrics will be synced to it every 5 minutes. Should be a complete Dataset + name like `"username/datasetname"` or `"orgname/datasetname"`, or just + `"datasetname"` in which case the Dataset will be created in the + currently-logged-in Hugging Face user's namespace. If the Dataset does not + exist, it will be created. If the Dataset already exists, the project will + be appended to it. If not provided, the metrics will be logged to a local + SQLite database, unless a `space_id` is provided, in which case a Dataset + will be automatically created with the same name as the Space but with the + `"_dataset"` suffix. + """ + try: + from tbparse import SummaryReader + except ImportError: + raise ImportError( + "The `tbparse` package is not installed but is required for `import_tf_events`. Please install trackio with the `tensorboard` extra: `pip install trackio[tensorboard]`." + ) + + if SQLiteStorage.get_runs(project): + raise ValueError( + f"Project '{project}' already exists. Cannot import TF events into existing project." + ) + + path = Path(log_dir) + if not path.exists(): + raise FileNotFoundError(f"TF events directory not found: {path}") + + # Use tbparse to read all tfevents files in the directory structure + reader = SummaryReader(str(path), extra_columns={"dir_name"}) + df = reader.scalars + + if df.empty: + raise ValueError(f"No TensorFlow events data found in {path}") + + total_imported = 0 + imported_runs = [] + + # Group by dir_name to create separate runs + for dir_name, group_df in df.groupby("dir_name"): + try: + # Determine run name based on directory name + if dir_name == "": + run_name = "main" # For files in the root directory + else: + run_name = dir_name # Use directory name + + if name: + run_name = f"{name}_{run_name}" + + if group_df.empty: + print(f"* Skipping directory {dir_name}: no scalar data found") + continue + + metrics_list = [] + steps = [] + timestamps = [] + + for _, row in group_df.iterrows(): + # Convert row values to appropriate types + tag = str(row["tag"]) + value = float(row["value"]) + step = int(row["step"]) + + metrics = {tag: value} + metrics_list.append(metrics) + steps.append(step) + + # Use wall_time if present, else fallback + if "wall_time" in group_df.columns and not bool( + pd.isna(row["wall_time"]) + ): + timestamps.append(str(row["wall_time"])) + else: + timestamps.append("") + + if metrics_list: + SQLiteStorage.bulk_log( + project=project, + run=str(run_name), + metrics_list=metrics_list, + steps=steps, + timestamps=timestamps, + ) + + total_imported += len(metrics_list) + imported_runs.append(run_name) + + print( + f"* Imported {len(metrics_list)} scalar events from directory '{dir_name}' as run '{run_name}'" + ) + print(f"* Metrics in this run: {', '.join(set(group_df['tag']))}") + + except Exception as e: + print(f"* Error processing directory {dir_name}: {e}") + continue + + if not imported_runs: + raise ValueError("No valid TensorFlow events data could be imported") + + print(f"* Total imported events: {total_imported}") + print(f"* Created runs: {', '.join(imported_runs)}") + + space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id) + if dataset_id is not None: + os.environ["TRACKIO_DATASET_ID"] = dataset_id + print(f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}") + + if space_id is None: + utils.print_dashboard_instructions(project) + else: + deploy.create_space_if_not_exists(space_id, dataset_id) + deploy.wait_until_space_exists(space_id) + deploy.upload_db_to_space(project, space_id) + print( + f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}" + ) diff --git a/media.py b/media.py new file mode 100644 index 0000000000000000000000000000000000000000..a557b3a4be11997c16b5b656f4b5247f84fd191b --- /dev/null +++ b/media.py @@ -0,0 +1,286 @@ +import os +import shutil +import uuid +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Literal + +import numpy as np +from PIL import Image as PILImage + +try: # absolute imports when installed + from trackio.file_storage import FileStorage + from trackio.utils import MEDIA_DIR + from trackio.video_writer import write_video +except ImportError: # relative imports for local execution on Spaces + from file_storage import FileStorage + from utils import MEDIA_DIR + from video_writer import write_video + + +class TrackioMedia(ABC): + """ + Abstract base class for Trackio media objects + Provides shared functionality for file handling and serialization. + """ + + TYPE: str + + def __init_subclass__(cls, **kwargs): + """Ensure subclasses define the TYPE attribute.""" + super().__init_subclass__(**kwargs) + if not hasattr(cls, "TYPE") or cls.TYPE is None: + raise TypeError(f"Class {cls.__name__} must define TYPE attribute") + + def __init__(self, value, caption: str | None = None): + self.caption = caption + self._value = value + self._file_path: Path | None = None + + # Validate file existence for string/Path inputs + if isinstance(self._value, str | Path): + if not os.path.isfile(self._value): + raise ValueError(f"File not found: {self._value}") + + def _file_extension(self) -> str: + if self._file_path: + return self._file_path.suffix[1:].lower() + if isinstance(self._value, str | Path): + path = Path(self._value) + return path.suffix[1:].lower() + if hasattr(self, "_format") and self._format: + return self._format + return "unknown" + + def _get_relative_file_path(self) -> Path | None: + return self._file_path + + def _get_absolute_file_path(self) -> Path | None: + if self._file_path: + return MEDIA_DIR / self._file_path + return None + + def _save(self, project: str, run: str, step: int = 0): + if self._file_path: + return + + media_dir = FileStorage.init_project_media_path(project, run, step) + filename = f"{uuid.uuid4()}.{self._file_extension()}" + file_path = media_dir / filename + + # Delegate to subclass-specific save logic + self._save_media(file_path) + + self._file_path = file_path.relative_to(MEDIA_DIR) + + @abstractmethod + def _save_media(self, file_path: Path): + """ + Performs the actual media saving logic. + """ + pass + + def _to_dict(self) -> dict: + if not self._file_path: + raise ValueError("Media must be saved to file before serialization") + return { + "_type": self.TYPE, + "file_path": str(self._get_relative_file_path()), + "caption": self.caption, + } + + +TrackioImageSourceType = str | Path | np.ndarray | PILImage.Image + + +class TrackioImage(TrackioMedia): + """ + Initializes an Image object. + + Example: + ```python + import trackio + import numpy as np + from PIL import Image + + # Create an image from numpy array + image_data = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + image = trackio.Image(image_data, caption="Random image") + trackio.log({"my_image": image}) + + # Create an image from PIL Image + pil_image = Image.new('RGB', (100, 100), color='red') + image = trackio.Image(pil_image, caption="Red square") + trackio.log({"red_image": image}) + + # Create an image from file path + image = trackio.Image("path/to/image.jpg", caption="Photo from file") + trackio.log({"file_image": image}) + ``` + + Args: + value (`str`, `Path`, `numpy.ndarray`, or `PIL.Image`, *optional*, defaults to `None`): + A path to an image, a PIL Image, or a numpy array of shape (height, width, channels). + caption (`str`, *optional*, defaults to `None`): + A string caption for the image. + """ + + TYPE = "trackio.image" + + def __init__(self, value: TrackioImageSourceType, caption: str | None = None): + super().__init__(value, caption) + self._format: str | None = None + + if ( + isinstance(self._value, np.ndarray | PILImage.Image) + and self._format is None + ): + self._format = "png" + + def _as_pil(self) -> PILImage.Image | None: + try: + if isinstance(self._value, np.ndarray): + arr = np.asarray(self._value).astype("uint8") + return PILImage.fromarray(arr).convert("RGBA") + if isinstance(self._value, PILImage.Image): + return self._value.convert("RGBA") + except Exception as e: + raise ValueError(f"Failed to process image data: {self._value}") from e + return None + + def _save_media(self, file_path: Path): + if pil := self._as_pil(): + pil.save(file_path, format=self._format) + elif isinstance(self._value, str | Path): + if os.path.isfile(self._value): + shutil.copy(self._value, file_path) + else: + raise ValueError(f"File not found: {self._value}") + + +TrackioVideoSourceType = str | Path | np.ndarray +TrackioVideoFormatType = Literal["gif", "mp4", "webm"] + + +class TrackioVideo(TrackioMedia): + """ + Initializes a Video object. + + Example: + ```python + import trackio + import numpy as np + + # Create a simple video from numpy array + frames = np.random.randint(0, 255, (10, 3, 64, 64), dtype=np.uint8) + video = trackio.Video(frames, caption="Random video", fps=30) + + # Create a batch of videos + batch_frames = np.random.randint(0, 255, (3, 10, 3, 64, 64), dtype=np.uint8) + batch_video = trackio.Video(batch_frames, caption="Batch of videos", fps=15) + + # Create video from file path + video = trackio.Video("path/to/video.mp4", caption="Video from file") + ``` + + Args: + value (`str`, `Path`, or `numpy.ndarray`, *optional*, defaults to `None`): + A path to a video file, or a numpy array. + The array should be of type `np.uint8` with RGB values in the range `[0, 255]`. + It is expected to have shape of either (frames, channels, height, width) or (batch, frames, channels, height, width). + For the latter, the videos will be tiled into a grid. + caption (`str`, *optional*, defaults to `None`): + A string caption for the video. + fps (`int`, *optional*, defaults to `None`): + Frames per second for the video. Only used when value is an ndarray. Default is `24`. + format (`Literal["gif", "mp4", "webm"]`, *optional*, defaults to `None`): + Video format ("gif", "mp4", or "webm"). Only used when value is an ndarray. Default is "gif". + """ + + TYPE = "trackio.video" + + def __init__( + self, + value: TrackioVideoSourceType, + caption: str | None = None, + fps: int | None = None, + format: TrackioVideoFormatType | None = None, + ): + super().__init__(value, caption) + if isinstance(value, np.ndarray): + if format is None: + format = "gif" + if fps is None: + fps = 24 + self._fps = fps + self._format = format + + @property + def _codec(self) -> str: + match self._format: + case "gif": + return "gif" + case "mp4": + return "h264" + case "webm": + return "vp9" + case _: + raise ValueError(f"Unsupported format: {self._format}") + + def _save_media(self, file_path: Path): + if isinstance(self._value, np.ndarray): + video = TrackioVideo._process_ndarray(self._value) + write_video(file_path, video, fps=self._fps, codec=self._codec) + elif isinstance(self._value, str | Path): + if os.path.isfile(self._value): + shutil.copy(self._value, file_path) + else: + raise ValueError(f"File not found: {self._value}") + + @staticmethod + def _process_ndarray(value: np.ndarray) -> np.ndarray: + # Verify value is either 4D (single video) or 5D array (batched videos). + # Expected format: (frames, channels, height, width) or (batch, frames, channels, height, width) + if value.ndim < 4: + raise ValueError( + "Video requires at least 4 dimensions (frames, channels, height, width)" + ) + if value.ndim > 5: + raise ValueError( + "Videos can have at most 5 dimensions (batch, frames, channels, height, width)" + ) + if value.ndim == 4: + # Reshape to 5D with single batch: (1, frames, channels, height, width) + value = value[np.newaxis, ...] + + value = TrackioVideo._tile_batched_videos(value) + return value + + @staticmethod + def _tile_batched_videos(video: np.ndarray) -> np.ndarray: + """ + Tiles a batch of videos into a grid of videos. + + Input format: (batch, frames, channels, height, width) - original FCHW format + Output format: (frames, total_height, total_width, channels) + """ + batch_size, frames, channels, height, width = video.shape + + next_pow2 = 1 << (batch_size - 1).bit_length() + if batch_size != next_pow2: + pad_len = next_pow2 - batch_size + pad_shape = (pad_len, frames, channels, height, width) + padding = np.zeros(pad_shape, dtype=video.dtype) + video = np.concatenate((video, padding), axis=0) + batch_size = next_pow2 + + n_rows = 1 << ((batch_size.bit_length() - 1) // 2) + n_cols = batch_size // n_rows + + # Reshape to grid layout: (n_rows, n_cols, frames, channels, height, width) + video = video.reshape(n_rows, n_cols, frames, channels, height, width) + + # Rearrange dimensions to (frames, total_height, total_width, channels) + video = video.transpose(2, 0, 4, 1, 5, 3) + video = video.reshape(frames, n_rows * height, n_cols * width, channels) + return video diff --git a/py.typed b/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/run.py b/run.py new file mode 100644 index 0000000000000000000000000000000000000000..c165bf32a2a4170124f3680fbee8b2d6df7dbdee --- /dev/null +++ b/run.py @@ -0,0 +1,156 @@ +import threading +import time + +import huggingface_hub +from gradio_client import Client, handle_file + +from trackio.media import TrackioMedia +from trackio.sqlite_storage import SQLiteStorage +from trackio.table import Table +from trackio.typehints import LogEntry, UploadEntry +from trackio.utils import ( + RESERVED_KEYS, + fibo, + generate_readable_name, + serialize_values, +) + +BATCH_SEND_INTERVAL = 0.5 + + +class Run: + def __init__( + self, + url: str, + project: str, + client: Client | None, + name: str | None = None, + config: dict | None = None, + space_id: str | None = None, + ): + self.url = url + self.project = project + self._client_lock = threading.Lock() + self._client_thread = None + self._client = client + self._space_id = space_id + self.name = name or generate_readable_name( + SQLiteStorage.get_runs(project), space_id + ) + self.config = config or {} + self._queued_logs: list[LogEntry] = [] + self._queued_uploads: list[UploadEntry] = [] + self._stop_flag = threading.Event() + + self._client_thread = threading.Thread(target=self._init_client_background) + self._client_thread.daemon = True + self._client_thread.start() + + def _batch_sender(self): + """Send batched logs every BATCH_SEND_INTERVAL.""" + while not self._stop_flag.is_set() or len(self._queued_logs) > 0: + # If the stop flag has been set, then just quickly send all + # the logs and exit. + if not self._stop_flag.is_set(): + time.sleep(BATCH_SEND_INTERVAL) + + with self._client_lock: + if self._client is None: + return + if self._queued_logs: + logs_to_send = self._queued_logs.copy() + self._queued_logs.clear() + self._client.predict( + api_name="/bulk_log", + logs=logs_to_send, + hf_token=huggingface_hub.utils.get_token(), + ) + if self._queued_uploads: + uploads_to_send = self._queued_uploads.copy() + self._queued_uploads.clear() + self._client.predict( + api_name="/bulk_upload_media", + uploads=uploads_to_send, + hf_token=huggingface_hub.utils.get_token(), + ) + + def _init_client_background(self): + if self._client is None: + fib = fibo() + for sleep_coefficient in fib: + try: + client = Client(self.url, verbose=False) + + with self._client_lock: + self._client = client + break + except Exception: + pass + if sleep_coefficient is not None: + time.sleep(0.1 * sleep_coefficient) + + self._batch_sender() + + def _process_media(self, metrics, step: int | None) -> dict: + """ + Serialize media in metrics and upload to space if needed. + """ + serializable_metrics = {} + if not step: + step = 0 + for key, value in metrics.items(): + if isinstance(value, TrackioMedia): + value._save(self.project, self.name, step) + serializable_metrics[key] = value._to_dict() + if self._space_id: + # Upload local media when deploying to space + upload_entry: UploadEntry = { + "project": self.project, + "run": self.name, + "step": step, + "uploaded_file": handle_file(value._get_absolute_file_path()), + } + with self._client_lock: + self._queued_uploads.append(upload_entry) + else: + serializable_metrics[key] = value + return serializable_metrics + + @staticmethod + def _replace_tables(metrics): + for k, v in metrics.items(): + if isinstance(v, Table): + metrics[k] = v._to_dict() + + def log(self, metrics: dict, step: int | None = None): + for k in metrics.keys(): + if k in RESERVED_KEYS or k.startswith("__"): + raise ValueError( + f"Please do not use this reserved key as a metric: {k}" + ) + Run._replace_tables(metrics) + + metrics = self._process_media(metrics, step) + metrics = serialize_values(metrics) + log_entry: LogEntry = { + "project": self.project, + "run": self.name, + "metrics": metrics, + "step": step, + } + + with self._client_lock: + self._queued_logs.append(log_entry) + + def finish(self): + """Cleanup when run is finished.""" + self._stop_flag.set() + + # Wait for the batch sender to finish before joining the client thread. + time.sleep(2 * BATCH_SEND_INTERVAL) + + if self._client_thread is not None: + print( + f"* Run finished. Uploading logs to Trackio Space: {self.url} (please wait...)" + ) + self._client_thread.join() diff --git a/sqlite_storage.py b/sqlite_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6207cb643d807cbfee562a3cf1baf5be254ddf --- /dev/null +++ b/sqlite_storage.py @@ -0,0 +1,440 @@ +import fcntl +import json +import os +import sqlite3 +import time +from datetime import datetime +from pathlib import Path +from threading import Lock + +import huggingface_hub as hf +import pandas as pd + +try: # absolute imports when installed + from trackio.commit_scheduler import CommitScheduler + from trackio.dummy_commit_scheduler import DummyCommitScheduler + from trackio.utils import ( + TRACKIO_DIR, + deserialize_values, + serialize_values, + ) +except Exception: # relative imports for local execution on Spaces + from commit_scheduler import CommitScheduler + from dummy_commit_scheduler import DummyCommitScheduler + from utils import TRACKIO_DIR, deserialize_values, serialize_values + + +class ProcessLock: + """A simple file-based lock that works across processes.""" + + def __init__(self, lockfile_path: Path): + self.lockfile_path = lockfile_path + self.lockfile = None + + def __enter__(self): + """Acquire the lock with retry logic.""" + self.lockfile_path.parent.mkdir(parents=True, exist_ok=True) + self.lockfile = open(self.lockfile_path, "w") + + max_retries = 100 + for attempt in range(max_retries): + try: + fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) + return self + except IOError: + if attempt < max_retries - 1: + time.sleep(0.1) + else: + raise IOError("Could not acquire database lock after 10 seconds") + + def __exit__(self, exc_type, exc_val, exc_tb): + """Release the lock.""" + if self.lockfile: + fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_UN) + self.lockfile.close() + + +class SQLiteStorage: + _dataset_import_attempted = False + _current_scheduler: CommitScheduler | DummyCommitScheduler | None = None + _scheduler_lock = Lock() + + @staticmethod + def _get_connection(db_path: Path) -> sqlite3.Connection: + conn = sqlite3.connect(str(db_path), timeout=30.0) + conn.execute("PRAGMA journal_mode = WAL") + conn.row_factory = sqlite3.Row + return conn + + @staticmethod + def _get_process_lock(project: str) -> ProcessLock: + lockfile_path = TRACKIO_DIR / f"{project}.lock" + return ProcessLock(lockfile_path) + + @staticmethod + def get_project_db_filename(project: str) -> Path: + """Get the database filename for a specific project.""" + safe_project_name = "".join( + c for c in project if c.isalnum() or c in ("-", "_") + ).rstrip() + if not safe_project_name: + safe_project_name = "default" + return f"{safe_project_name}.db" + + @staticmethod + def get_project_db_path(project: str) -> Path: + """Get the database path for a specific project.""" + filename = SQLiteStorage.get_project_db_filename(project) + return TRACKIO_DIR / filename + + @staticmethod + def init_db(project: str) -> Path: + """ + Initialize the SQLite database with required tables. + If there is a dataset ID provided, copies from that dataset instead. + Returns the database path. + """ + db_path = SQLiteStorage.get_project_db_path(project) + db_path.parent.mkdir(parents=True, exist_ok=True) + with SQLiteStorage._get_process_lock(project): + with sqlite3.connect(db_path, timeout=30.0) as conn: + conn.execute("PRAGMA journal_mode = WAL") + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS metrics ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL, + run_name TEXT NOT NULL, + step INTEGER NOT NULL, + metrics TEXT NOT NULL + ) + """) + cursor.execute( + """ + CREATE INDEX IF NOT EXISTS idx_metrics_run_step + ON metrics(run_name, step) + """ + ) + conn.commit() + return db_path + + @staticmethod + def export_to_parquet(): + """ + Exports all projects' DB files as Parquet under the same path but with extension ".parquet". + """ + # don't attempt to export (potentially wrong/blank) data before importing for the first time + if not SQLiteStorage._dataset_import_attempted: + return + all_paths = os.listdir(TRACKIO_DIR) + db_paths = [f for f in all_paths if f.endswith(".db")] + for db_path in db_paths: + db_path = TRACKIO_DIR / db_path + parquet_path = db_path.with_suffix(".parquet") + if (not parquet_path.exists()) or ( + db_path.stat().st_mtime > parquet_path.stat().st_mtime + ): + with sqlite3.connect(db_path) as conn: + df = pd.read_sql("SELECT * from metrics", conn) + # break out the single JSON metrics column into individual columns + metrics = df["metrics"].copy() + metrics = pd.DataFrame( + metrics.apply( + lambda x: deserialize_values(json.loads(x)) + ).values.tolist(), + index=df.index, + ) + del df["metrics"] + for col in metrics.columns: + df[col] = metrics[col] + df.to_parquet(parquet_path) + + @staticmethod + def import_from_parquet(): + """ + Imports to all DB files that have matching files under the same path but with extension ".parquet". + """ + all_paths = os.listdir(TRACKIO_DIR) + parquet_paths = [f for f in all_paths if f.endswith(".parquet")] + for parquet_path in parquet_paths: + parquet_path = TRACKIO_DIR / parquet_path + db_path = parquet_path.with_suffix(".db") + df = pd.read_parquet(parquet_path) + with sqlite3.connect(db_path) as conn: + # fix up df to have a single JSON metrics column + if "metrics" not in df.columns: + # separate other columns from metrics + metrics = df.copy() + other_cols = ["id", "timestamp", "run_name", "step"] + df = df[other_cols] + for col in other_cols: + del metrics[col] + # combine them all into a single metrics col + metrics = json.loads(metrics.to_json(orient="records")) + df["metrics"] = [ + json.dumps(serialize_values(row)) for row in metrics + ] + df.to_sql("metrics", conn, if_exists="replace", index=False) + + @staticmethod + def get_scheduler(): + """ + Get the scheduler for the database based on the environment variables. + This applies to both local and Spaces. + """ + with SQLiteStorage._scheduler_lock: + if SQLiteStorage._current_scheduler is not None: + return SQLiteStorage._current_scheduler + hf_token = os.environ.get("HF_TOKEN") + dataset_id = os.environ.get("TRACKIO_DATASET_ID") + space_repo_name = os.environ.get("SPACE_REPO_NAME") + if dataset_id is None or space_repo_name is None: + scheduler = DummyCommitScheduler() + else: + scheduler = CommitScheduler( + repo_id=dataset_id, + repo_type="dataset", + folder_path=TRACKIO_DIR, + private=True, + allow_patterns=["*.parquet", "media/**/*"], + squash_history=True, + token=hf_token, + on_before_commit=SQLiteStorage.export_to_parquet, + ) + SQLiteStorage._current_scheduler = scheduler + return scheduler + + @staticmethod + def log(project: str, run: str, metrics: dict, step: int | None = None): + """ + Safely log metrics to the database. Before logging, this method will ensure the database exists + and is set up with the correct tables. It also uses a cross-process lock to prevent + database locking errors when multiple processes access the same database. + + This method is not used in the latest versions of Trackio (replaced by bulk_log) but + is kept for backwards compatibility for users who are connecting to a newer version of + a Trackio Spaces dashboard with an older version of Trackio installed locally. + """ + db_path = SQLiteStorage.init_db(project) + + with SQLiteStorage._get_process_lock(project): + with SQLiteStorage._get_connection(db_path) as conn: + cursor = conn.cursor() + + cursor.execute( + """ + SELECT MAX(step) + FROM metrics + WHERE run_name = ? + """, + (run,), + ) + last_step = cursor.fetchone()[0] + if step is None: + current_step = 0 if last_step is None else last_step + 1 + else: + current_step = step + + current_timestamp = datetime.now().isoformat() + + cursor.execute( + """ + INSERT INTO metrics + (timestamp, run_name, step, metrics) + VALUES (?, ?, ?, ?) + """, + ( + current_timestamp, + run, + current_step, + json.dumps(serialize_values(metrics)), + ), + ) + conn.commit() + + @staticmethod + def bulk_log( + project: str, + run: str, + metrics_list: list[dict], + steps: list[int] | None = None, + timestamps: list[str] | None = None, + ): + """ + Safely log bulk metrics to the database. Before logging, this method will ensure the database exists + and is set up with the correct tables. It also uses a cross-process lock to prevent + database locking errors when multiple processes access the same database. + """ + if not metrics_list: + return + + if timestamps is None: + timestamps = [datetime.now().isoformat()] * len(metrics_list) + + db_path = SQLiteStorage.init_db(project) + with SQLiteStorage._get_process_lock(project): + with SQLiteStorage._get_connection(db_path) as conn: + cursor = conn.cursor() + + if steps is None: + steps = list(range(len(metrics_list))) + elif any(s is None for s in steps): + cursor.execute( + "SELECT MAX(step) FROM metrics WHERE run_name = ?", (run,) + ) + last_step = cursor.fetchone()[0] + current_step = 0 if last_step is None else last_step + 1 + + processed_steps = [] + for step in steps: + if step is None: + processed_steps.append(current_step) + current_step += 1 + else: + processed_steps.append(step) + steps = processed_steps + + if len(metrics_list) != len(steps) or len(metrics_list) != len( + timestamps + ): + raise ValueError( + "metrics_list, steps, and timestamps must have the same length" + ) + + data = [] + for i, metrics in enumerate(metrics_list): + data.append( + ( + timestamps[i], + run, + steps[i], + json.dumps(serialize_values(metrics)), + ) + ) + + cursor.executemany( + """ + INSERT INTO metrics + (timestamp, run_name, step, metrics) + VALUES (?, ?, ?, ?) + """, + data, + ) + conn.commit() + + @staticmethod + def get_logs(project: str, run: str) -> list[dict]: + """Retrieve logs for a specific run. Logs include the step count (int) and the timestamp (datetime object).""" + db_path = SQLiteStorage.get_project_db_path(project) + if not db_path.exists(): + return [] + + with SQLiteStorage._get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT timestamp, step, metrics + FROM metrics + WHERE run_name = ? + ORDER BY timestamp + """, + (run,), + ) + + rows = cursor.fetchall() + results = [] + for row in rows: + metrics = json.loads(row["metrics"]) + metrics = deserialize_values(metrics) + metrics["timestamp"] = row["timestamp"] + metrics["step"] = row["step"] + results.append(metrics) + return results + + @staticmethod + def load_from_dataset(): + dataset_id = os.environ.get("TRACKIO_DATASET_ID") + space_repo_name = os.environ.get("SPACE_REPO_NAME") + if dataset_id is not None and space_repo_name is not None: + hfapi = hf.HfApi() + updated = False + if not TRACKIO_DIR.exists(): + TRACKIO_DIR.mkdir(parents=True, exist_ok=True) + with SQLiteStorage.get_scheduler().lock: + try: + files = hfapi.list_repo_files(dataset_id, repo_type="dataset") + for file in files: + # Download parquet and media assets + if not (file.endswith(".parquet") or file.startswith("media/")): + continue + if (TRACKIO_DIR / file).exists(): + continue + hf.hf_hub_download( + dataset_id, file, repo_type="dataset", local_dir=TRACKIO_DIR + ) + updated = True + except hf.errors.EntryNotFoundError: + pass + except hf.errors.RepositoryNotFoundError: + pass + if updated: + SQLiteStorage.import_from_parquet() + SQLiteStorage._dataset_import_attempted = True + + @staticmethod + def get_projects() -> list[str]: + """ + Get list of all projects by scanning the database files in the trackio directory. + """ + if not SQLiteStorage._dataset_import_attempted: + SQLiteStorage.load_from_dataset() + + projects: set[str] = set() + if not TRACKIO_DIR.exists(): + return [] + + for db_file in TRACKIO_DIR.glob("*.db"): + project_name = db_file.stem + projects.add(project_name) + return sorted(projects) + + @staticmethod + def get_runs(project: str) -> list[str]: + """Get list of all runs for a project.""" + db_path = SQLiteStorage.get_project_db_path(project) + if not db_path.exists(): + return [] + + with SQLiteStorage._get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT DISTINCT run_name FROM metrics", + ) + return [row[0] for row in cursor.fetchall()] + + @staticmethod + def get_max_steps_for_runs(project: str) -> dict[str, int]: + """Get the maximum step for each run in a project.""" + db_path = SQLiteStorage.get_project_db_path(project) + if not db_path.exists(): + return {} + + with SQLiteStorage._get_connection(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT run_name, MAX(step) as max_step + FROM metrics + GROUP BY run_name + """ + ) + + results = {} + for row in cursor.fetchall(): + results[row["run_name"]] = row["max_step"] + + return results + + def finish(self): + """Cleanup when run is finished.""" + pass diff --git a/table.py b/table.py new file mode 100644 index 0000000000000000000000000000000000000000..61b65426c64c0ea9cabe1db1cc46803abe21f90c --- /dev/null +++ b/table.py @@ -0,0 +1,55 @@ +from typing import Any, Literal, Optional, Union + +from pandas import DataFrame + + +class Table: + """ + Initializes a Table object. + + Args: + columns (`list[str]`, *optional*, defaults to `None`): + Names of the columns in the table. Optional if `data` is provided. Not + expected if `dataframe` is provided. Currently ignored. + data (`list[list[Any]]`, *optional*, defaults to `None`): + 2D row-oriented array of values. + dataframe (`pandas.`DataFrame``, *optional*, defaults to `None`): + DataFrame object used to create the table. When set, `data` and `columns` + arguments are ignored. + rows (`list[list[any]]`, *optional*, defaults to `None`): + Currently ignored. + optional (`bool` or `list[bool]`, *optional*, defaults to `True`): + Currently ignored. + allow_mixed_types (`bool`, *optional*, defaults to `False`): + Currently ignored. + log_mode: (`Literal["IMMUTABLE", "MUTABLE", "INCREMENTAL"]` or `None`, *optional*, defaults to `"IMMUTABLE"`): + Currently ignored. + """ + + TYPE = "trackio.table" + + def __init__( + self, + columns: Optional[list[str]] = None, + data: Optional[list[list[Any]]] = None, + dataframe: Optional[DataFrame] = None, + rows: Optional[list[list[Any]]] = None, + optional: Union[bool, list[bool]] = True, + allow_mixed_types: bool = False, + log_mode: Optional[ + Literal["IMMUTABLE", "MUTABLE", "INCREMENTAL"] + ] = "IMMUTABLE", + ): + # TODO: implement support for columns, dtype, optional, allow_mixed_types, and log_mode. + # for now (like `rows`) they are included for API compat but don't do anything. + + if dataframe is None: + self.data = data + else: + self.data = dataframe.to_dict(orient="records") + + def _to_dict(self): + return { + "_type": self.TYPE, + "_value": self.data, + } diff --git a/typehints.py b/typehints.py new file mode 100644 index 0000000000000000000000000000000000000000..e270424d8d7604d467ffaedf456696067139365b --- /dev/null +++ b/typehints.py @@ -0,0 +1,17 @@ +from typing import Any, TypedDict + +from gradio import FileData + + +class LogEntry(TypedDict): + project: str + run: str + metrics: dict[str, Any] + step: int | None + + +class UploadEntry(TypedDict): + project: str + run: str + step: int | None + uploaded_file: FileData diff --git a/ui.py b/ui.py new file mode 100644 index 0000000000000000000000000000000000000000..4b4e47a839b1c95f0919cb126226e3324c6c8822 --- /dev/null +++ b/ui.py @@ -0,0 +1,857 @@ +import os +import re +import shutil +from dataclasses import dataclass +from typing import Any + +import gradio as gr +import huggingface_hub as hf +import numpy as np +import pandas as pd + +HfApi = hf.HfApi() + +try: + import trackio.utils as utils + from trackio.file_storage import FileStorage + from trackio.media import TrackioImage, TrackioVideo + from trackio.sqlite_storage import SQLiteStorage + from trackio.table import Table + from trackio.typehints import LogEntry, UploadEntry +except: # noqa: E722 + import utils + from file_storage import FileStorage + from media import TrackioImage, TrackioVideo + from sqlite_storage import SQLiteStorage + from table import Table + from typehints import LogEntry, UploadEntry + + +def get_project_info() -> str | None: + dataset_id = os.environ.get("TRACKIO_DATASET_ID") + space_id = os.environ.get("SPACE_ID") + if utils.persistent_storage_enabled(): + return "✨ Persistent Storage is enabled, logs are stored directly in this Space." + if dataset_id: + sync_status = utils.get_sync_status(SQLiteStorage.get_scheduler()) + upgrade_message = f"New changes are synced every 5 min To avoid losing data between syncs, click here to open this Space's settings and add Persistent Storage. Make sure data is synced prior to enabling." + if sync_status is not None: + info = f"↻ Backed up {sync_status} min ago to {dataset_id} | {upgrade_message}" + else: + info = f"↻ Not backed up yet to {dataset_id} | {upgrade_message}" + return info + return None + + +def get_projects(request: gr.Request): + projects = SQLiteStorage.get_projects() + if project := request.query_params.get("project"): + interactive = False + else: + interactive = True + project = projects[0] if projects else None + + return gr.Dropdown( + label="Project", + choices=projects, + value=project, + allow_custom_value=True, + interactive=interactive, + info=get_project_info(), + ) + + +def get_runs(project) -> list[str]: + if not project: + return [] + return SQLiteStorage.get_runs(project) + + +def get_available_metrics(project: str, runs: list[str]) -> list[str]: + """Get all available metrics across all runs for x-axis selection.""" + if not project or not runs: + return ["step", "time"] + + all_metrics = set() + for run in runs: + metrics = SQLiteStorage.get_logs(project, run) + if metrics: + df = pd.DataFrame(metrics) + numeric_cols = df.select_dtypes(include="number").columns + numeric_cols = [c for c in numeric_cols if c not in utils.RESERVED_KEYS] + all_metrics.update(numeric_cols) + + all_metrics.add("step") + all_metrics.add("time") + + sorted_metrics = utils.sort_metrics_by_prefix(list(all_metrics)) + + result = ["step", "time"] + for metric in sorted_metrics: + if metric not in result: + result.append(metric) + + return result + + +@dataclass +class MediaData: + caption: str | None + file_path: str + + +def extract_media(logs: list[dict]) -> dict[str, list[MediaData]]: + media_by_key: dict[str, list[MediaData]] = {} + logs = sorted(logs, key=lambda x: x.get("step", 0)) + for log in logs: + for key, value in log.items(): + if isinstance(value, dict): + type = value.get("_type") + if type == TrackioImage.TYPE or type == TrackioVideo.TYPE: + if key not in media_by_key: + media_by_key[key] = [] + try: + media_data = MediaData( + file_path=utils.MEDIA_DIR / value.get("file_path"), + caption=value.get("caption"), + ) + media_by_key[key].append(media_data) + except Exception as e: + print(f"Media currently unavailable: {key}: {e}") + return media_by_key + + +def load_run_data( + project: str | None, + run: str | None, + smoothing_granularity: int, + x_axis: str, + log_scale: bool = False, +) -> tuple[pd.DataFrame, dict]: + if not project or not run: + return None, None + + logs = SQLiteStorage.get_logs(project, run) + if not logs: + return None, None + + media = extract_media(logs) + df = pd.DataFrame(logs) + + if "step" not in df.columns: + df["step"] = range(len(df)) + + if x_axis == "time" and "timestamp" in df.columns: + df["timestamp"] = pd.to_datetime(df["timestamp"]) + first_timestamp = df["timestamp"].min() + df["time"] = (df["timestamp"] - first_timestamp).dt.total_seconds() + x_column = "time" + elif x_axis == "step": + x_column = "step" + else: + x_column = x_axis + + if log_scale and x_column in df.columns: + x_vals = df[x_column] + if (x_vals <= 0).any(): + df[x_column] = np.log10(np.maximum(x_vals, 0) + 1) + else: + df[x_column] = np.log10(x_vals) + + if smoothing_granularity > 0: + numeric_cols = df.select_dtypes(include="number").columns + numeric_cols = [c for c in numeric_cols if c not in utils.RESERVED_KEYS] + + df_original = df.copy() + df_original["run"] = run + df_original["data_type"] = "original" + + df_smoothed = df.copy() + window_size = max(3, min(smoothing_granularity, len(df))) + df_smoothed[numeric_cols] = ( + df_smoothed[numeric_cols] + .rolling(window=window_size, center=True, min_periods=1) + .mean() + ) + df_smoothed["run"] = f"{run}_smoothed" + df_smoothed["data_type"] = "smoothed" + + combined_df = pd.concat([df_original, df_smoothed], ignore_index=True) + combined_df["x_axis"] = x_column + return combined_df, media + else: + df["run"] = run + df["data_type"] = "original" + df["x_axis"] = x_column + return df, media + + +def update_runs( + project, filter_text, user_interacted_with_runs=False, selected_runs_from_url=None +): + if project is None: + runs = [] + num_runs = 0 + else: + runs = get_runs(project) + num_runs = len(runs) + if filter_text: + runs = [r for r in runs if filter_text in r] + + if not user_interacted_with_runs: + if selected_runs_from_url: + value = [r for r in runs if r in selected_runs_from_url] + else: + value = runs + return gr.CheckboxGroup(choices=runs, value=value), gr.Textbox( + label=f"Runs ({num_runs})" + ) + else: + return gr.CheckboxGroup(choices=runs), gr.Textbox(label=f"Runs ({num_runs})") + + +def filter_runs(project, filter_text): + runs = get_runs(project) + runs = [r for r in runs if filter_text in r] + return gr.CheckboxGroup(choices=runs, value=runs) + + +def update_x_axis_choices(project, runs): + """Update x-axis dropdown choices based on available metrics.""" + available_metrics = get_available_metrics(project, runs) + return gr.Dropdown( + label="X-axis", + choices=available_metrics, + value="step", + ) + + +def toggle_timer(cb_value): + if cb_value: + return gr.Timer(active=True) + else: + return gr.Timer(active=False) + + +def check_auth(hf_token: str | None) -> None: + if os.getenv("SYSTEM") == "spaces": # if we are running in Spaces + # check auth token passed in + if hf_token is None: + raise PermissionError( + "Expected a HF_TOKEN to be provided when logging to a Space" + ) + who = HfApi.whoami(hf_token) + access_token = who["auth"]["accessToken"] + owner_name = os.getenv("SPACE_AUTHOR_NAME") + repo_name = os.getenv("SPACE_REPO_NAME") + # make sure the token user is either the author of the space, + # or is a member of an org that is the author. + orgs = [o["name"] for o in who["orgs"]] + if owner_name != who["name"] and owner_name not in orgs: + raise PermissionError( + "Expected the provided hf_token to be the user owner of the space, or be a member of the org owner of the space" + ) + # reject fine-grained tokens without specific repo access + if access_token["role"] == "fineGrained": + matched = False + for item in access_token["fineGrained"]["scoped"]: + if ( + item["entity"]["type"] == "space" + and item["entity"]["name"] == f"{owner_name}/{repo_name}" + and "repo.write" in item["permissions"] + ): + matched = True + break + if ( + ( + item["entity"]["type"] == "user" + or item["entity"]["type"] == "org" + ) + and item["entity"]["name"] == owner_name + and "repo.write" in item["permissions"] + ): + matched = True + break + if not matched: + raise PermissionError( + "Expected the provided hf_token with fine grained permissions to provide write access to the space" + ) + # reject read-only tokens + elif access_token["role"] != "write": + raise PermissionError( + "Expected the provided hf_token to provide write permissions" + ) + + +def upload_db_to_space( + project: str, uploaded_db: gr.FileData, hf_token: str | None +) -> None: + check_auth(hf_token) + db_project_path = SQLiteStorage.get_project_db_path(project) + if os.path.exists(db_project_path): + raise gr.Error( + f"Trackio database file already exists for project {project}, cannot overwrite." + ) + os.makedirs(os.path.dirname(db_project_path), exist_ok=True) + shutil.copy(uploaded_db["path"], db_project_path) + + +def bulk_upload_media(uploads: list[UploadEntry], hf_token: str | None) -> None: + check_auth(hf_token) + for upload in uploads: + media_path = FileStorage.init_project_media_path( + upload["project"], upload["run"], upload["step"] + ) + shutil.copy(upload["uploaded_file"]["path"], media_path) + + +def log( + project: str, + run: str, + metrics: dict[str, Any], + step: int | None, + hf_token: str | None, +) -> None: + """ + Note: this method is not used in the latest versions of Trackio (replaced by bulk_log) but + is kept for backwards compatibility for users who are connecting to a newer version of + a Trackio Spaces dashboard with an older version of Trackio installed locally. + """ + check_auth(hf_token) + SQLiteStorage.log(project=project, run=run, metrics=metrics, step=step) + + +def bulk_log( + logs: list[LogEntry], + hf_token: str | None, +) -> None: + check_auth(hf_token) + + logs_by_run = {} + for log_entry in logs: + key = (log_entry["project"], log_entry["run"]) + if key not in logs_by_run: + logs_by_run[key] = {"metrics": [], "steps": []} + logs_by_run[key]["metrics"].append(log_entry["metrics"]) + logs_by_run[key]["steps"].append(log_entry.get("step")) + + for (project, run), data in logs_by_run.items(): + SQLiteStorage.bulk_log( + project=project, + run=run, + metrics_list=data["metrics"], + steps=data["steps"], + ) + + +def filter_metrics_by_regex(metrics: list[str], filter_pattern: str) -> list[str]: + """ + Filter metrics using regex pattern. + + Args: + metrics: List of metric names to filter + filter_pattern: Regex pattern to match against metric names + + Returns: + List of metric names that match the pattern + """ + if not filter_pattern.strip(): + return metrics + + try: + pattern = re.compile(filter_pattern, re.IGNORECASE) + return [metric for metric in metrics if pattern.search(metric)] + except re.error: + return [ + metric for metric in metrics if filter_pattern.lower() in metric.lower() + ] + + +def configure(request: gr.Request): + sidebar_param = request.query_params.get("sidebar") + match sidebar_param: + case "collapsed": + sidebar = gr.Sidebar(open=False, visible=True) + case "hidden": + sidebar = gr.Sidebar(open=False, visible=False) + case _: + sidebar = gr.Sidebar(open=True, visible=True) + + metrics_param = request.query_params.get("metrics", "") + runs_param = request.query_params.get("runs", "") + selected_runs = runs_param.split(",") if runs_param else [] + + return [], sidebar, metrics_param, selected_runs + + +def create_media_section(media_by_run: dict[str, dict[str, list[MediaData]]]): + with gr.Accordion(label="media"): + with gr.Group(elem_classes=("media-group")): + for run, media_by_key in media_by_run.items(): + with gr.Tab(label=run, elem_classes=("media-tab")): + for key, media_item in media_by_key.items(): + gr.Gallery( + [(item.file_path, item.caption) for item in media_item], + label=key, + columns=6, + elem_classes=("media-gallery"), + ) + + +css = """ +#run-cb .wrap { gap: 2px; } +#run-cb .wrap label { + line-height: 1; + padding: 6px; +} +.logo-light { display: block; } +.logo-dark { display: none; } +.dark .logo-light { display: none; } +.dark .logo-dark { display: block; } +.dark .caption-label { color: white; } + +.info-container { + position: relative; + display: inline; +} +.info-checkbox { + position: absolute; + opacity: 0; + pointer-events: none; +} +.info-icon { + border-bottom: 1px dotted; + cursor: pointer; + user-select: none; + color: var(--color-accent); +} +.info-expandable { + display: none; + opacity: 0; + transition: opacity 0.2s ease-in-out; +} +.info-checkbox:checked ~ .info-expandable { + display: inline; + opacity: 1; +} +.info-icon:hover { opacity: 0.8; } +.accent-link { font-weight: bold; } + +.media-gallery .fixed-height { min-height: 275px; } +.media-group, .media-group > div { background: none; } +.media-group .tabs { padding: 0.5em; } +.media-tab { max-height: 500px; overflow-y: scroll; } +""" + +gr.set_static_paths(paths=[utils.MEDIA_DIR]) +with gr.Blocks(theme="citrus", title="Trackio Dashboard", css=css) as demo: + with gr.Sidebar(open=False) as sidebar: + logo = gr.Markdown( + f""" + + + """ + ) + project_dd = gr.Dropdown(label="Project", allow_custom_value=True) + + embed_code = gr.Code( + label="Embed this view", + max_lines=2, + lines=2, + language="html", + visible=bool(os.environ.get("SPACE_HOST")), + ) + run_tb = gr.Textbox(label="Runs", placeholder="Type to filter...") + run_cb = gr.CheckboxGroup( + label="Runs", choices=[], interactive=True, elem_id="run-cb" + ) + gr.HTML("
") + realtime_cb = gr.Checkbox(label="Refresh metrics realtime", value=True) + smoothing_slider = gr.Slider( + label="Smoothing Factor", + minimum=0, + maximum=20, + value=10, + step=1, + info="0 = no smoothing", + ) + x_axis_dd = gr.Dropdown( + label="X-axis", + choices=["step", "time"], + value="step", + ) + log_scale_cb = gr.Checkbox(label="Log scale X-axis", value=False) + metric_filter_tb = gr.Textbox( + label="Metric Filter (regex)", + placeholder="e.g., loss|ndcg@10|gpu", + value="", + info="Filter metrics using regex patterns. Leave empty to show all metrics.", + ) + + timer = gr.Timer(value=1) + metrics_subset = gr.State([]) + user_interacted_with_run_cb = gr.State(False) + selected_runs_from_url = gr.State([]) + + gr.on( + [demo.load], + fn=configure, + outputs=[metrics_subset, sidebar, metric_filter_tb, selected_runs_from_url], + queue=False, + api_name=False, + ) + gr.on( + [demo.load], + fn=get_projects, + outputs=project_dd, + show_progress="hidden", + queue=False, + api_name=False, + ) + gr.on( + [timer.tick], + fn=update_runs, + inputs=[ + project_dd, + run_tb, + user_interacted_with_run_cb, + selected_runs_from_url, + ], + outputs=[run_cb, run_tb], + show_progress="hidden", + api_name=False, + ) + gr.on( + [timer.tick], + fn=lambda: gr.Dropdown(info=get_project_info()), + outputs=[project_dd], + show_progress="hidden", + api_name=False, + ) + gr.on( + [demo.load, project_dd.change], + fn=update_runs, + inputs=[project_dd, run_tb, gr.State(False), selected_runs_from_url], + outputs=[run_cb, run_tb], + show_progress="hidden", + queue=False, + api_name=False, + ).then( + fn=update_x_axis_choices, + inputs=[project_dd, run_cb], + outputs=x_axis_dd, + show_progress="hidden", + queue=False, + api_name=False, + ).then( + fn=utils.generate_embed_code, + inputs=[project_dd, metric_filter_tb, run_cb], + outputs=embed_code, + show_progress="hidden", + api_name=False, + queue=False, + ) + + gr.on( + [run_cb.input], + fn=update_x_axis_choices, + inputs=[project_dd, run_cb], + outputs=x_axis_dd, + show_progress="hidden", + queue=False, + api_name=False, + ) + gr.on( + [metric_filter_tb.change, run_cb.change], + fn=utils.generate_embed_code, + inputs=[project_dd, metric_filter_tb, run_cb], + outputs=embed_code, + show_progress="hidden", + api_name=False, + queue=False, + ) + + realtime_cb.change( + fn=toggle_timer, + inputs=realtime_cb, + outputs=timer, + api_name=False, + queue=False, + ) + run_cb.input( + fn=lambda: True, + outputs=user_interacted_with_run_cb, + api_name=False, + queue=False, + ) + run_tb.input( + fn=filter_runs, + inputs=[project_dd, run_tb], + outputs=run_cb, + api_name=False, + queue=False, + ) + + gr.api( + fn=upload_db_to_space, + api_name="upload_db_to_space", + ) + gr.api( + fn=bulk_upload_media, + api_name="bulk_upload_media", + ) + gr.api( + fn=log, + api_name="log", + ) + gr.api( + fn=bulk_log, + api_name="bulk_log", + ) + + x_lim = gr.State(None) + last_steps = gr.State({}) + + def update_x_lim(select_data: gr.SelectData): + return select_data.index + + def update_last_steps(project): + """Check the last step for each run to detect when new data is available.""" + if not project: + return {} + return SQLiteStorage.get_max_steps_for_runs(project) + + timer.tick( + fn=update_last_steps, + inputs=[project_dd], + outputs=last_steps, + show_progress="hidden", + api_name=False, + ) + + @gr.render( + triggers=[ + demo.load, + run_cb.change, + last_steps.change, + smoothing_slider.change, + x_lim.change, + x_axis_dd.change, + log_scale_cb.change, + metric_filter_tb.change, + ], + inputs=[ + project_dd, + run_cb, + smoothing_slider, + metrics_subset, + x_lim, + x_axis_dd, + log_scale_cb, + metric_filter_tb, + ], + show_progress="hidden", + queue=False, + ) + def update_dashboard( + project, + runs, + smoothing_granularity, + metrics_subset, + x_lim_value, + x_axis, + log_scale, + metric_filter, + ): + dfs = [] + images_by_run = {} + original_runs = runs.copy() + + for run in runs: + df, images_by_key = load_run_data( + project, run, smoothing_granularity, x_axis, log_scale + ) + if df is not None: + dfs.append(df) + images_by_run[run] = images_by_key + + if dfs: + if smoothing_granularity > 0: + original_dfs = [] + smoothed_dfs = [] + for df in dfs: + original_data = df[df["data_type"] == "original"] + smoothed_data = df[df["data_type"] == "smoothed"] + if not original_data.empty: + original_dfs.append(original_data) + if not smoothed_data.empty: + smoothed_dfs.append(smoothed_data) + + all_dfs = original_dfs + smoothed_dfs + master_df = ( + pd.concat(all_dfs, ignore_index=True) if all_dfs else pd.DataFrame() + ) + + else: + master_df = pd.concat(dfs, ignore_index=True) + else: + master_df = pd.DataFrame() + + if master_df.empty: + return + + x_column = "step" + if dfs and not dfs[0].empty and "x_axis" in dfs[0].columns: + x_column = dfs[0]["x_axis"].iloc[0] + + numeric_cols = master_df.select_dtypes(include="number").columns + numeric_cols = [c for c in numeric_cols if c not in utils.RESERVED_KEYS] + if x_column and x_column in numeric_cols: + numeric_cols.remove(x_column) + + if metrics_subset: + numeric_cols = [c for c in numeric_cols if c in metrics_subset] + + if metric_filter and metric_filter.strip(): + numeric_cols = filter_metrics_by_regex(list(numeric_cols), metric_filter) + + nested_metric_groups = utils.group_metrics_with_subprefixes(list(numeric_cols)) + color_map = utils.get_color_mapping(original_runs, smoothing_granularity > 0) + + metric_idx = 0 + for group_name in sorted(nested_metric_groups.keys()): + group_data = nested_metric_groups[group_name] + + with gr.Accordion( + label=group_name, + open=True, + key=f"accordion-{group_name}", + preserved_by_key=["value", "open"], + ): + # Render direct metrics at this level + if group_data["direct_metrics"]: + with gr.Draggable( + key=f"row-{group_name}-direct", orientation="row" + ): + for metric_name in group_data["direct_metrics"]: + metric_df = master_df.dropna(subset=[metric_name]) + color = "run" if "run" in metric_df.columns else None + if not metric_df.empty: + plot = gr.LinePlot( + utils.downsample( + metric_df, + x_column, + metric_name, + color, + x_lim_value, + ), + x=x_column, + y=metric_name, + y_title=metric_name.split("/")[-1], + color=color, + color_map=color_map, + title=metric_name, + key=f"plot-{metric_idx}", + preserved_by_key=None, + x_lim=x_lim_value, + show_fullscreen_button=True, + min_width=400, + ) + plot.select( + update_x_lim, + outputs=x_lim, + key=f"select-{metric_idx}", + ) + plot.double_click( + lambda: None, + outputs=x_lim, + key=f"double-{metric_idx}", + ) + metric_idx += 1 + + # If there are subgroups, create nested accordions + if group_data["subgroups"]: + for subgroup_name in sorted(group_data["subgroups"].keys()): + subgroup_metrics = group_data["subgroups"][subgroup_name] + + with gr.Accordion( + label=subgroup_name, + open=True, + key=f"accordion-{group_name}-{subgroup_name}", + preserved_by_key=["value", "open"], + ): + with gr.Draggable(key=f"row-{group_name}-{subgroup_name}"): + for metric_name in subgroup_metrics: + metric_df = master_df.dropna(subset=[metric_name]) + color = ( + "run" if "run" in metric_df.columns else None + ) + if not metric_df.empty: + plot = gr.LinePlot( + utils.downsample( + metric_df, + x_column, + metric_name, + color, + x_lim_value, + ), + x=x_column, + y=metric_name, + y_title=metric_name.split("/")[-1], + color=color, + color_map=color_map, + title=metric_name, + key=f"plot-{metric_idx}", + preserved_by_key=None, + x_lim=x_lim_value, + show_fullscreen_button=True, + min_width=400, + ) + plot.select( + update_x_lim, + outputs=x_lim, + key=f"select-{metric_idx}", + ) + plot.double_click( + lambda: None, + outputs=x_lim, + key=f"double-{metric_idx}", + ) + metric_idx += 1 + if images_by_run and any(any(images) for images in images_by_run.values()): + create_media_section(images_by_run) + + table_cols = master_df.select_dtypes(include="object").columns + table_cols = [c for c in table_cols if c not in utils.RESERVED_KEYS] + if metrics_subset: + table_cols = [c for c in table_cols if c in metrics_subset] + if metric_filter and metric_filter.strip(): + table_cols = filter_metrics_by_regex(list(table_cols), metric_filter) + if len(table_cols) > 0: + with gr.Accordion("tables", open=True): + with gr.Row(key="row"): + for metric_idx, metric_name in enumerate(table_cols): + metric_df = master_df.dropna(subset=[metric_name]) + if not metric_df.empty: + value = metric_df[metric_name].iloc[-1] + if ( + isinstance(value, dict) + and "_type" in value + and value["_type"] == Table.TYPE + ): + try: + df = pd.DataFrame(value["_value"]) + gr.DataFrame( + df, + label=f"{metric_name} (latest)", + key=f"table-{metric_idx}", + wrap=True, + ) + except Exception as e: + gr.Warning( + f"Column {metric_name} failed to render as a table: {e}" + ) + + +if __name__ == "__main__": + demo.launch(allowed_paths=[utils.TRACKIO_LOGO_DIR], show_api=False, show_error=True) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b6675782ff7189e4471ced28b295ee53ed174f09 --- /dev/null +++ b/utils.py @@ -0,0 +1,687 @@ +import math +import os +import re +import sys +import time +from pathlib import Path +from typing import TYPE_CHECKING + +import huggingface_hub +import numpy as np +import pandas as pd +from huggingface_hub.constants import HF_HOME + +if TYPE_CHECKING: + from trackio.commit_scheduler import CommitScheduler + from trackio.dummy_commit_scheduler import DummyCommitScheduler + +RESERVED_KEYS = ["project", "run", "timestamp", "step", "time", "metrics"] + +TRACKIO_LOGO_DIR = Path(__file__).parent / "assets" + + +def persistent_storage_enabled() -> bool: + return ( + os.environ.get("PERSISTANT_STORAGE_ENABLED") == "true" + ) # typo in the name of the environment variable + + +def _get_trackio_dir() -> Path: + if persistent_storage_enabled(): + return Path("/data/trackio") + return Path(HF_HOME) / "trackio" + + +TRACKIO_DIR = _get_trackio_dir() +MEDIA_DIR = TRACKIO_DIR / "media" + + +def generate_readable_name(used_names: list[str], space_id: str | None = None) -> str: + """ + Generates a random, readable name like "dainty-sunset-0". + If space_id is provided, generates username-timestamp format instead. + """ + if space_id is not None: + username = huggingface_hub.whoami()["name"] + timestamp = int(time.time()) + return f"{username}-{timestamp}" + adjectives = [ + "dainty", + "brave", + "calm", + "eager", + "fancy", + "gentle", + "happy", + "jolly", + "kind", + "lively", + "merry", + "nice", + "proud", + "quick", + "hugging", + "silly", + "tidy", + "witty", + "zealous", + "bright", + "shy", + "bold", + "clever", + "daring", + "elegant", + "faithful", + "graceful", + "honest", + "inventive", + "jovial", + "keen", + "lucky", + "modest", + "noble", + "optimistic", + "patient", + "quirky", + "resourceful", + "sincere", + "thoughtful", + "upbeat", + "valiant", + "warm", + "youthful", + "zesty", + "adventurous", + "breezy", + "cheerful", + "delightful", + "energetic", + "fearless", + "glad", + "hopeful", + "imaginative", + "joyful", + "kindly", + "luminous", + "mysterious", + "neat", + "outgoing", + "playful", + "radiant", + "spirited", + "tranquil", + "unique", + "vivid", + "wise", + "zany", + "artful", + "bubbly", + "charming", + "dazzling", + "earnest", + "festive", + "gentlemanly", + "hearty", + "intrepid", + "jubilant", + "knightly", + "lively", + "magnetic", + "nimble", + "orderly", + "peaceful", + "quick-witted", + "robust", + "sturdy", + "trusty", + "upstanding", + "vibrant", + "whimsical", + ] + nouns = [ + "sunset", + "forest", + "river", + "mountain", + "breeze", + "meadow", + "ocean", + "valley", + "sky", + "field", + "cloud", + "star", + "rain", + "leaf", + "stone", + "flower", + "bird", + "tree", + "wave", + "trail", + "island", + "desert", + "hill", + "lake", + "pond", + "grove", + "canyon", + "reef", + "bay", + "peak", + "glade", + "marsh", + "cliff", + "dune", + "spring", + "brook", + "cave", + "plain", + "ridge", + "wood", + "blossom", + "petal", + "root", + "branch", + "seed", + "acorn", + "pine", + "willow", + "cedar", + "elm", + "falcon", + "eagle", + "sparrow", + "robin", + "owl", + "finch", + "heron", + "crane", + "duck", + "swan", + "fox", + "wolf", + "bear", + "deer", + "moose", + "otter", + "beaver", + "lynx", + "hare", + "badger", + "butterfly", + "bee", + "ant", + "beetle", + "dragonfly", + "firefly", + "ladybug", + "moth", + "spider", + "worm", + "coral", + "kelp", + "shell", + "pebble", + "face", + "boulder", + "cobble", + "sand", + "wavelet", + "tide", + "current", + "mist", + ] + number = 0 + name = f"{adjectives[0]}-{nouns[0]}-{number}" + while name in used_names: + number += 1 + adjective = adjectives[number % len(adjectives)] + noun = nouns[number % len(nouns)] + name = f"{adjective}-{noun}-{number}" + return name + + +def block_except_in_notebook(): + in_notebook = bool(getattr(sys, "ps1", sys.flags.interactive)) + if in_notebook: + return + try: + while True: + time.sleep(0.1) + except (KeyboardInterrupt, OSError): + print("Keyboard interruption in main thread... closing dashboard.") + + +def simplify_column_names(columns: list[str]) -> dict[str, str]: + """ + Simplifies column names to first 10 alphanumeric or "/" characters with unique suffixes. + + Args: + columns: List of original column names + + Returns: + Dictionary mapping original column names to simplified names + """ + simplified_names = {} + used_names = set() + + for col in columns: + alphanumeric = re.sub(r"[^a-zA-Z0-9/]", "", col) + base_name = alphanumeric[:10] if alphanumeric else f"col_{len(used_names)}" + + final_name = base_name + suffix = 1 + while final_name in used_names: + final_name = f"{base_name}_{suffix}" + suffix += 1 + + simplified_names[col] = final_name + used_names.add(final_name) + + return simplified_names + + +def print_dashboard_instructions(project: str) -> None: + """ + Prints instructions for viewing the Trackio dashboard. + + Args: + project: The name of the project to show dashboard for. + """ + YELLOW = "\033[93m" + BOLD = "\033[1m" + RESET = "\033[0m" + + print("* View dashboard by running in your terminal:") + print(f'{BOLD}{YELLOW}trackio show --project "{project}"{RESET}') + print(f'* or by running in Python: trackio.show(project="{project}")') + + +def preprocess_space_and_dataset_ids( + space_id: str | None, dataset_id: str | None +) -> tuple[str | None, str | None]: + if space_id is not None and "/" not in space_id: + username = huggingface_hub.whoami()["name"] + space_id = f"{username}/{space_id}" + if dataset_id is not None and "/" not in dataset_id: + username = huggingface_hub.whoami()["name"] + dataset_id = f"{username}/{dataset_id}" + if space_id is not None and dataset_id is None: + dataset_id = f"{space_id}-dataset" + return space_id, dataset_id + + +def fibo(): + """Generator for Fibonacci backoff: 1, 1, 2, 3, 5, 8, ...""" + a, b = 1, 1 + while True: + yield a + a, b = b, a + b + + +COLOR_PALETTE = [ + "#3B82F6", + "#EF4444", + "#10B981", + "#F59E0B", + "#8B5CF6", + "#EC4899", + "#06B6D4", + "#84CC16", + "#F97316", + "#6366F1", +] + + +def get_color_mapping(runs: list[str], smoothing: bool) -> dict[str, str]: + """Generate color mapping for runs, with transparency for original data when smoothing is enabled.""" + color_map = {} + + for i, run in enumerate(runs): + base_color = COLOR_PALETTE[i % len(COLOR_PALETTE)] + + if smoothing: + color_map[run] = base_color + "4D" + color_map[f"{run}_smoothed"] = base_color + else: + color_map[run] = base_color + + return color_map + + +def downsample( + df: pd.DataFrame, + x: str, + y: str, + color: str | None, + x_lim: tuple[float, float] | None = None, +) -> pd.DataFrame: + if df.empty: + return df + + columns_to_keep = [x, y] + if color is not None and color in df.columns: + columns_to_keep.append(color) + df = df[columns_to_keep].copy() + + n_bins = 100 + + if color is not None and color in df.columns: + groups = df.groupby(color) + else: + groups = [(None, df)] + + downsampled_indices = [] + + for _, group_df in groups: + if group_df.empty: + continue + + group_df = group_df.sort_values(x) + + if x_lim is not None: + x_min, x_max = x_lim + before_point = group_df[group_df[x] < x_min].tail(1) + after_point = group_df[group_df[x] > x_max].head(1) + group_df = group_df[(group_df[x] >= x_min) & (group_df[x] <= x_max)] + else: + before_point = after_point = None + x_min = group_df[x].min() + x_max = group_df[x].max() + + if before_point is not None and not before_point.empty: + downsampled_indices.extend(before_point.index.tolist()) + if after_point is not None and not after_point.empty: + downsampled_indices.extend(after_point.index.tolist()) + + if group_df.empty: + continue + + if x_min == x_max: + min_y_idx = group_df[y].idxmin() + max_y_idx = group_df[y].idxmax() + if min_y_idx != max_y_idx: + downsampled_indices.extend([min_y_idx, max_y_idx]) + else: + downsampled_indices.append(min_y_idx) + continue + + if len(group_df) < 500: + downsampled_indices.extend(group_df.index.tolist()) + continue + + bins = np.linspace(x_min, x_max, n_bins + 1) + group_df["bin"] = pd.cut( + group_df[x], bins=bins, labels=False, include_lowest=True + ) + + for bin_idx in group_df["bin"].dropna().unique(): + bin_data = group_df[group_df["bin"] == bin_idx] + if bin_data.empty: + continue + + min_y_idx = bin_data[y].idxmin() + max_y_idx = bin_data[y].idxmax() + + downsampled_indices.append(min_y_idx) + if min_y_idx != max_y_idx: + downsampled_indices.append(max_y_idx) + + unique_indices = list(set(downsampled_indices)) + + downsampled_df = df.loc[unique_indices].copy() + + if color is not None: + downsampled_df = ( + downsampled_df.groupby(color, sort=False)[downsampled_df.columns] + .apply(lambda group: group.sort_values(x)) + .reset_index(drop=True) + ) + else: + downsampled_df = downsampled_df.sort_values(x).reset_index(drop=True) + + downsampled_df = downsampled_df.drop(columns=["bin"], errors="ignore") + + return downsampled_df + + +def sort_metrics_by_prefix(metrics: list[str]) -> list[str]: + """ + Sort metrics by grouping prefixes together for dropdown/list display. + Metrics without prefixes come first, then grouped by prefix. + + Args: + metrics: List of metric names + + Returns: + List of metric names sorted by prefix + + Example: + Input: ["train/loss", "loss", "train/acc", "val/loss"] + Output: ["loss", "train/acc", "train/loss", "val/loss"] + """ + groups = group_metrics_by_prefix(metrics) + result = [] + + if "charts" in groups: + result.extend(groups["charts"]) + + for group_name in sorted(groups.keys()): + if group_name != "charts": + result.extend(groups[group_name]) + + return result + + +def group_metrics_by_prefix(metrics: list[str]) -> dict[str, list[str]]: + """ + Group metrics by their prefix. Metrics without prefix go to 'charts' group. + + Args: + metrics: List of metric names + + Returns: + Dictionary with prefix names as keys and lists of metrics as values + + Example: + Input: ["loss", "accuracy", "train/loss", "train/acc", "val/loss"] + Output: { + "charts": ["loss", "accuracy"], + "train": ["train/loss", "train/acc"], + "val": ["val/loss"] + } + """ + no_prefix = [] + with_prefix = [] + + for metric in metrics: + if "/" in metric: + with_prefix.append(metric) + else: + no_prefix.append(metric) + + no_prefix.sort() + + prefix_groups = {} + for metric in with_prefix: + prefix = metric.split("/")[0] + if prefix not in prefix_groups: + prefix_groups[prefix] = [] + prefix_groups[prefix].append(metric) + + for prefix in prefix_groups: + prefix_groups[prefix].sort() + + groups = {} + if no_prefix: + groups["charts"] = no_prefix + + for prefix in sorted(prefix_groups.keys()): + groups[prefix] = prefix_groups[prefix] + + return groups + + +def group_metrics_with_subprefixes(metrics: list[str]) -> dict: + """ + Group metrics with simple 2-level nested structure detection. + + Returns a dictionary where each prefix group can have: + - direct_metrics: list of metrics at this level (e.g., "train/acc") + - subgroups: dict of subgroup name -> list of metrics (e.g., "loss" -> ["train/loss/norm", "train/loss/unnorm"]) + + Example: + Input: ["loss", "train/acc", "train/loss/normalized", "train/loss/unnormalized", "val/loss"] + Output: { + "charts": { + "direct_metrics": ["loss"], + "subgroups": {} + }, + "train": { + "direct_metrics": ["train/acc"], + "subgroups": { + "loss": ["train/loss/normalized", "train/loss/unnormalized"] + } + }, + "val": { + "direct_metrics": ["val/loss"], + "subgroups": {} + } + } + """ + result = {} + + for metric in metrics: + if "/" not in metric: + if "charts" not in result: + result["charts"] = {"direct_metrics": [], "subgroups": {}} + result["charts"]["direct_metrics"].append(metric) + else: + parts = metric.split("/") + main_prefix = parts[0] + + if main_prefix not in result: + result[main_prefix] = {"direct_metrics": [], "subgroups": {}} + + if len(parts) == 2: + result[main_prefix]["direct_metrics"].append(metric) + else: + subprefix = parts[1] + if subprefix not in result[main_prefix]["subgroups"]: + result[main_prefix]["subgroups"][subprefix] = [] + result[main_prefix]["subgroups"][subprefix].append(metric) + + for group_data in result.values(): + group_data["direct_metrics"].sort() + for subgroup_metrics in group_data["subgroups"].values(): + subgroup_metrics.sort() + + if "charts" in result and not result["charts"]["direct_metrics"]: + del result["charts"] + + return result + + +def get_sync_status(scheduler: "CommitScheduler | DummyCommitScheduler") -> int | None: + """Get the sync status from the CommitScheduler in an integer number of minutes, or None if not synced yet.""" + if getattr( + scheduler, "last_push_time", None + ): # DummyCommitScheduler doesn't have last_push_time + time_diff = time.time() - scheduler.last_push_time + return int(time_diff / 60) + else: + return None + + +def generate_embed_code(project: str, metrics: str, selected_runs: list = None) -> str: + """Generate the embed iframe code based on current settings.""" + space_host = os.environ.get("SPACE_HOST", "") + if not space_host: + return "" + + params = [] + + if project: + params.append(f"project={project}") + + if metrics and metrics.strip(): + params.append(f"metrics={metrics}") + + if selected_runs: + runs_param = ",".join(selected_runs) + params.append(f"runs={runs_param}") + + params.append("sidebar=hidden") + + query_string = "&".join(params) + embed_url = f"https://{space_host}?{query_string}" + + return f'' + + +def serialize_values(metrics): + """ + Serialize infinity and NaN values in metrics dict to make it JSON-compliant. + Only handles top-level float values. + + Converts: + - float('inf') -> "Infinity" + - float('-inf') -> "-Infinity" + - float('nan') -> "NaN" + + Example: + {"loss": float('inf'), "accuracy": 0.95} -> {"loss": "Infinity", "accuracy": 0.95} + """ + if not isinstance(metrics, dict): + return metrics + + result = {} + for key, value in metrics.items(): + if isinstance(value, float): + if math.isinf(value): + result[key] = "Infinity" if value > 0 else "-Infinity" + elif math.isnan(value): + result[key] = "NaN" + else: + result[key] = value + elif isinstance(value, np.floating): + float_val = float(value) + if math.isinf(float_val): + result[key] = "Infinity" if float_val > 0 else "-Infinity" + elif math.isnan(float_val): + result[key] = "NaN" + else: + result[key] = float_val + else: + result[key] = value + return result + + +def deserialize_values(metrics): + """ + Deserialize infinity and NaN string values back to their numeric forms. + Only handles top-level string values. + + Converts: + - "Infinity" -> float('inf') + - "-Infinity" -> float('-inf') + - "NaN" -> float('nan') + + Example: + {"loss": "Infinity", "accuracy": 0.95} -> {"loss": float('inf'), "accuracy": 0.95} + """ + if not isinstance(metrics, dict): + return metrics + + result = {} + for key, value in metrics.items(): + if value == "Infinity": + result[key] = float("inf") + elif value == "-Infinity": + result[key] = float("-inf") + elif value == "NaN": + result[key] = float("nan") + else: + result[key] = value + return result diff --git a/version.txt b/version.txt new file mode 100644 index 0000000000000000000000000000000000000000..f64a1b563b2adff2a065b04d4d81492d237923c2 --- /dev/null +++ b/version.txt @@ -0,0 +1 @@ +0.3.4.dev0 diff --git a/video_writer.py b/video_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..17374944816de049899fb40cef5836ce8aa7b63a --- /dev/null +++ b/video_writer.py @@ -0,0 +1,126 @@ +import shutil +import subprocess +from pathlib import Path +from typing import Literal + +import numpy as np + +VideoCodec = Literal["h264", "vp9", "gif"] + + +def _check_ffmpeg_installed() -> None: + """Raise an error if ffmpeg is not available on the system PATH.""" + if shutil.which("ffmpeg") is None: + raise RuntimeError( + "ffmpeg is required to write video but was not found on your system. " + "Please install ffmpeg and ensure it is available on your PATH." + ) + + +def _check_array_format(video: np.ndarray) -> None: + """Raise an error if the array is not in the expected format.""" + if not (video.ndim == 4 and video.shape[-1] == 3): + raise ValueError( + f"Expected RGB input shaped (F, H, W, 3), got {video.shape}. " + f"Input has {video.ndim} dimensions, expected 4." + ) + if video.dtype != np.uint8: + raise TypeError( + f"Expected dtype=uint8, got {video.dtype}. " + "Please convert your video data to uint8 format." + ) + + +def _check_path(file_path: str | Path) -> None: + """Raise an error if the parent directory does not exist.""" + file_path = Path(file_path) + if not file_path.parent.exists(): + try: + file_path.parent.mkdir(parents=True, exist_ok=True) + except OSError as e: + raise ValueError( + f"Failed to create parent directory {file_path.parent}: {e}" + ) + + +def write_video( + file_path: str | Path, video: np.ndarray, fps: float, codec: VideoCodec +) -> None: + """RGB uint8 only, shape (F, H, W, 3).""" + _check_ffmpeg_installed() + _check_path(file_path) + + if codec not in {"h264", "vp9", "gif"}: + raise ValueError("Unsupported codec. Use h264, vp9, or gif.") + + arr = np.asarray(video) + _check_array_format(arr) + + frames = np.ascontiguousarray(arr) + _, height, width, _ = frames.shape + out_path = str(file_path) + + cmd = [ + "ffmpeg", + "-y", + "-f", + "rawvideo", + "-s", + f"{width}x{height}", + "-pix_fmt", + "rgb24", + "-r", + str(fps), + "-i", + "-", + "-an", + ] + + if codec == "gif": + video_filter = "split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse" + cmd += [ + "-vf", + video_filter, + "-loop", + "0", + ] + elif codec == "h264": + cmd += [ + "-vcodec", + "libx264", + "-pix_fmt", + "yuv420p", + "-movflags", + "+faststart", + ] + elif codec == "vp9": + bpp = 0.08 + bps = int(width * height * fps * bpp) + if bps >= 1_000_000: + bitrate = f"{round(bps / 1_000_000)}M" + elif bps >= 1_000: + bitrate = f"{round(bps / 1_000)}k" + else: + bitrate = str(max(bps, 1)) + cmd += [ + "-vcodec", + "libvpx-vp9", + "-b:v", + bitrate, + "-pix_fmt", + "yuv420p", + ] + cmd += [out_path] + proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.PIPE) + try: + for frame in frames: + proc.stdin.write(frame.tobytes()) + finally: + if proc.stdin: + proc.stdin.close() + stderr = ( + proc.stderr.read().decode("utf-8", errors="ignore") if proc.stderr else "" + ) + ret = proc.wait() + if ret != 0: + raise RuntimeError(f"ffmpeg failed with code {ret}\n{stderr}")