File size: 3,291 Bytes
241b6a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from pathlib import Path
from typing import Union

import gdown
import pandas as pd
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file


def download_file_from_google_drive(file_id: str, destination: Path) -> None:
    """
    Downloads a file from Google Drive and saves it at the given destination using gdown.

    Args:
        file_id (str): The ID of the file on Google Drive.
        destination (Path): The local path where the file should be saved.
    """
    url = f"https://drive.google.com/uc?id={file_id}"
    gdown.download(url, str(destination), quiet=False)


def download_file_from_hugging_face(destination: Path) -> None:
    """
    Downloads a file from Hugging Face and saves it at the given destination using hf_hub_download.
    Loads the resulting safetensors file and saves it as a PyTorch model state for compatibility with the rest of the codebase.

    Args:
        file_id (str): The ID of the file on Hugging Face.
        destination (Path): The local path where the file should be saved.
    """
    file_name = destination.stem
    safetensor_path = hf_hub_download(
        repo_id="NickWright/OmniCloudMask",
        filename=f"{file_name}.safetensors",
        force_download=True,
        cache_dir=destination.parent,
    )
    model_state = load_file(safetensor_path)
    torch.save(model_state, destination)


def download_file(file_id: str, destination: Path, source: str) -> None:
    if source == "google_drive":
        download_file_from_google_drive(file_id, destination)
    elif source == "hugging_face":
        download_file_from_hugging_face(destination)
    else:
        raise ValueError(
            "Invalid source. Supported sources are 'google_drive' and 'hugging_face'."
        )


def get_models(
    force_download: bool = False,
    model_dir: Union[str, Path, None] = None,
    source: str = "google_drive",
) -> list[dict]:
    """
    Downloads the model weights from Google Drive and saves them locally.

    Args:
        force_download (bool): Whether to force download the model weights even if they already exist locally.
        model_dir (Union[str, Path, None]): The directory where the model weights should be saved.
        source (str): The source from which the model weights should be downloaded. Currently, only "google_drive" or "hugging_face" are supported.
    """

    df = pd.read_csv(
        Path(__file__).resolve().parent / "models/model_download_links.csv"
    )
    model_paths = []

    for _, row in df.iterrows():
        file_id = str(row["google_drive_id"])

        if model_dir is not None:
            model_dir = Path(model_dir)
        else:
            model_dir = Path(__file__).resolve().parent / "models"

        model_dir.mkdir(exist_ok=True)
        destination = model_dir / str(row["file_name"])
        timm_model_name = row["timm_model_name"]

        if not destination.exists() or force_download:
            download_file(file_id=file_id, destination=destination, source=source)

        elif destination.stat().st_size <= 1024 * 1024:
            download_file(file_id=file_id, destination=destination, source=source)

        model_paths.append({"Path": destination, "timm_model_name": timm_model_name})
    return model_paths