danhtran2mind commited on
Commit
0b2412b
·
verified ·
1 Parent(s): 4dca5ff

Delete scripts/old2-download_ckpts.py

Browse files
Files changed (1) hide show
  1. scripts/old2-download_ckpts.py +0 -80
scripts/old2-download_ckpts.py DELETED
@@ -1,80 +0,0 @@
1
- import argparse
2
- import yaml
3
- import os
4
- import requests
5
- from huggingface_hub import snapshot_download, hf_hub_download
6
-
7
- def load_config(config_path):
8
- with open(config_path, 'r') as file:
9
- return yaml.safe_load(file)
10
-
11
- def download_model(model_config, full_ckpts=False):
12
- model_id = model_config['model_id']
13
- local_dir = model_config['local_dir']
14
- platform = model_config['platform']
15
- url = model_config.get('url') # Get URL if it exists, None otherwise
16
- filename = model_config.get('filename')
17
-
18
- # Ensure the local directory exists
19
- os.makedirs(local_dir, exist_ok=True)
20
-
21
- if platform == "HuggingFace":
22
- if full_ckpts:
23
- print(f"Downloading full model {model_id} from HuggingFace to {local_dir}")
24
- snapshot_download(
25
- repo_id=model_id,
26
- local_dir=local_dir,
27
- local_dir_use_symlinks=False,
28
- allow_patterns=["*.pth", "*.bin", "*.json"], # Common model file extensions
29
- ignore_patterns=["*.md", "*.txt"], # Ignore non-model files
30
- )
31
- print(f"Successfully downloaded {model_id} to {local_dir}")
32
- else:
33
- if not filename:
34
- raise ValueError(f"No filename provided for model: {model_id}")
35
- print(f"Downloading file {filename} for model {model_id} from HuggingFace to {local_dir}")
36
- hf_hub_download(
37
- repo_id=model_id,
38
- filename=filename,
39
- local_dir=local_dir,
40
- )
41
- print(f"Successfully downloaded {filename} to {local_dir}")
42
- elif platform == "GitHub":
43
- if not url:
44
- raise ValueError(f"No URL provided for GitHub model: {model_id}")
45
- if not filename:
46
- filename = os.path.basename(url)
47
- full_path = os.path.join(local_dir, filename)
48
- print(f"Downloading model {model_id} from GitHub URL {url} to {full_path}")
49
- response = requests.get(url, stream=True)
50
- if response.status_code == 200:
51
- with open(full_path, 'wb') as f:
52
- for chunk in response.iter_content(chunk_size=8192):
53
- if chunk:
54
- f.write(chunk)
55
- print(f"Successfully downloaded {model_id} to {full_path}")
56
- else:
57
- raise ValueError(f"Failed to download {model_id} from {url}: HTTP {response.status_code}")
58
- else:
59
- raise ValueError(f"Unsupported platform: {platform}")
60
-
61
- if __name__ == "__main__":
62
- parser = argparse.ArgumentParser(description="Download model checkpoints from HuggingFace or GitHub.")
63
- parser.add_argument('--config', type=str, default="configs/model_ckpts.yaml",
64
- help="Path to the YAML configuration file")
65
- parser.add_argument('--full_ckpts', action='store_true',
66
- help="if true download all models using snapdownload, else just download model with for_inference in yaml")
67
- parser.add_argument('--include_base_model', action='store_true',
68
- help="if true download all model base_model true and false, else just download base_model is false")
69
- args = parser.parse_args()
70
-
71
- # Load the YAML configuration
72
- config = load_config(args.config)
73
-
74
- # Iterate through models in the config
75
- for model_config in config:
76
- if not args.full_ckpts and not model_config.get('for_inference', False):
77
- continue
78
- if not args.include_base_model and model_config.get('base_model', False):
79
- continue
80
- download_model(model_config, full_ckpts=args.full_ckpts)