Create hf_sync.py
Browse files- hf_sync.py +200 -0
hf_sync.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# hf_sync.py
|
2 |
+
from huggingface_hub import HfApi
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
import datetime
|
6 |
+
|
7 |
+
def manage_backups(api, repo_id, max_files=50):
|
8 |
+
"""管理备份文件,保留最新的max_files个文件"""
|
9 |
+
try:
|
10 |
+
print(f"开始管理 Hugging Face 仓库 '{repo_id}' 的备份...")
|
11 |
+
files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")
|
12 |
+
# 筛选出符合备份文件命名规则的文件
|
13 |
+
backup_files = [f for f in files if f.startswith('webui_backup_') and f.endswith('.db')]
|
14 |
+
print(f"在仓库中找到 {len(backup_files)} 个备份文件。")
|
15 |
+
|
16 |
+
# 提取日期并排序
|
17 |
+
dated_files = []
|
18 |
+
for f in backup_files:
|
19 |
+
try:
|
20 |
+
# 假设文件名格式为 webui_backup_YYYYMMDD.db
|
21 |
+
date_str = f.split('_')[-1].split('.')[0]
|
22 |
+
# 严格匹配 YYYYMMDD 格式
|
23 |
+
if len(date_str) == 8 and date_str.isdigit():
|
24 |
+
date_obj = datetime.datetime.strptime(date_str, '%Y%m%d').date()
|
25 |
+
dated_files.append((date_obj, f))
|
26 |
+
else:
|
27 |
+
print(f"警告: 文件名格式不符,跳过: {f}")
|
28 |
+
except (IndexError, ValueError) as e:
|
29 |
+
print(f"警告: 无法从文件名解析日期,跳过: {f} (错误: {e})")
|
30 |
+
continue # 跳过格式不正确的文件
|
31 |
+
|
32 |
+
# 按日期降序排序(最新的在前)
|
33 |
+
dated_files.sort(key=lambda item: item[0], reverse=True)
|
34 |
+
|
35 |
+
if len(dated_files) > max_files:
|
36 |
+
files_to_delete = [item[1] for item in dated_files[max_files:]]
|
37 |
+
print(f"找到 {len(files_to_delete)} 个需要删除的旧备份。")
|
38 |
+
for file_to_delete in files_to_delete:
|
39 |
+
try:
|
40 |
+
print(f"正在删除旧备份: {file_to_delete}")
|
41 |
+
api.delete_file(path_in_repo=file_to_delete, repo_id=repo_id, repo_type="dataset")
|
42 |
+
print(f"已删除旧备份: {file_to_delete}")
|
43 |
+
except Exception as e_del:
|
44 |
+
print(f"删除文件 {file_to_delete} 时出错: {str(e_del)}")
|
45 |
+
else:
|
46 |
+
print(f"备份文件数量 ({len(dated_files)}) 未超过限制 ({max_files}),无需删除。")
|
47 |
+
|
48 |
+
except Exception as e_list:
|
49 |
+
print(f"列出或管理备份文件时出错: {str(e_list)}")
|
50 |
+
|
51 |
+
|
52 |
+
def upload_backup(file_path, file_name, token, repo_id):
|
53 |
+
"""上传备份文件到Hugging Face"""
|
54 |
+
if not os.path.exists(file_path):
|
55 |
+
print(f"错误: 本地文件未找到,无法上传: {file_path}")
|
56 |
+
return
|
57 |
+
|
58 |
+
print(f"尝试将 {file_path} 中的文件上传为 {file_name} 到仓库 {repo_id}")
|
59 |
+
api = HfApi(token=token)
|
60 |
+
try:
|
61 |
+
# 上传新文件 (upload_file 会覆盖同名文件)
|
62 |
+
print(f"正在上传 {file_name}...")
|
63 |
+
api.upload_file(
|
64 |
+
path_or_fileobj=file_path,
|
65 |
+
path_in_repo=file_name,
|
66 |
+
repo_id=repo_id,
|
67 |
+
repo_type="dataset"
|
68 |
+
)
|
69 |
+
print(f"成功上传: {file_name}")
|
70 |
+
# 上传成功后管理备份
|
71 |
+
print("上传成功,开始管理备份...")
|
72 |
+
manage_backups(api, repo_id)
|
73 |
+
except Exception as e:
|
74 |
+
print(f"上传文件 {file_name} 失败: {str(e)}")
|
75 |
+
|
76 |
+
|
77 |
+
def download_latest_backup(token, repo_id, target_dir):
|
78 |
+
"""从Hugging Face下载最新备份"""
|
79 |
+
print(f"尝试从仓库 {repo_id} 下载最新备份到目录 {target_dir}")
|
80 |
+
api = HfApi(token=token)
|
81 |
+
try:
|
82 |
+
files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")
|
83 |
+
backup_files = [f for f in files if f.startswith('webui_backup_') and f.endswith('.db')]
|
84 |
+
|
85 |
+
if not backup_files:
|
86 |
+
print("在 Hugging Face 数据集中未找到备份文件。")
|
87 |
+
return False
|
88 |
+
|
89 |
+
# 根据文件名中的日期找到最新的文件
|
90 |
+
latest_file = None
|
91 |
+
latest_date = None
|
92 |
+
for f in backup_files:
|
93 |
+
try:
|
94 |
+
# 假设文件名格式为 webui_backup_YYYYMMDD.db
|
95 |
+
date_str = f.split('_')[-1].split('.')[0]
|
96 |
+
if len(date_str) == 8 and date_str.isdigit():
|
97 |
+
date_obj = datetime.datetime.strptime(date_str, '%Y%m%d').date()
|
98 |
+
if latest_date is None or date_obj > latest_date:
|
99 |
+
latest_date = date_obj
|
100 |
+
latest_file = f
|
101 |
+
else:
|
102 |
+
continue # 跳过格式不正确的文件
|
103 |
+
except (IndexError, ValueError):
|
104 |
+
continue # 跳过无法解析的文件
|
105 |
+
|
106 |
+
if latest_file is None:
|
107 |
+
print("无法确定最新的备份文件(可能所有文件名格式都不正确)。")
|
108 |
+
return False
|
109 |
+
|
110 |
+
print(f"找到最新的备份文件: {latest_file}")
|
111 |
+
print(f"正在下载 {latest_file}...")
|
112 |
+
# 先下载到临时路径
|
113 |
+
downloaded_path = api.hf_hub_download(
|
114 |
+
repo_id=repo_id,
|
115 |
+
filename=latest_file,
|
116 |
+
repo_type="dataset",
|
117 |
+
local_dir_use_symlinks=False, # 避免在容器中出现符号链接问题
|
118 |
+
cache_dir=os.path.join(target_dir, '.hf_cache') # 指定缓存目录,避免权限问题
|
119 |
+
)
|
120 |
+
|
121 |
+
if downloaded_path and os.path.exists(downloaded_path):
|
122 |
+
# 创建目标目录(如果不存在)
|
123 |
+
os.makedirs(target_dir, exist_ok=True)
|
124 |
+
# 目标文件路径
|
125 |
+
target_path = os.path.join(target_dir, "webui.db")
|
126 |
+
print(f"下载完成,路径: {downloaded_path},准备移动到: {target_path}")
|
127 |
+
# 使用 os.replace 进行原子性移动/覆盖,如果失败则回退到复制
|
128 |
+
try:
|
129 |
+
os.replace(downloaded_path, target_path)
|
130 |
+
print(f"成功将 {latest_file} 移动到 {target_path}")
|
131 |
+
except OSError as e_mv:
|
132 |
+
print(f"移动文件失败 (错误: {e_mv}),尝试复制...")
|
133 |
+
import shutil
|
134 |
+
try:
|
135 |
+
shutil.copy2(downloaded_path, target_path)
|
136 |
+
print(f"成功将 {latest_file} 复制到 {target_path}")
|
137 |
+
# 尝试删除临时下载文件
|
138 |
+
try:
|
139 |
+
os.remove(downloaded_path)
|
140 |
+
# 清理可能的缓存目录结构
|
141 |
+
hf_cache_dir = os.path.join(target_dir, '.hf_cache')
|
142 |
+
if os.path.exists(hf_cache_dir):
|
143 |
+
shutil.rmtree(hf_cache_dir)
|
144 |
+
except OSError as e_rm:
|
145 |
+
print(f"警告: 删除临时下载文件 {downloaded_path} 失败: {e_rm}")
|
146 |
+
except Exception as e_cp:
|
147 |
+
print(f"复制文件也失败: {e_cp}")
|
148 |
+
return False
|
149 |
+
|
150 |
+
|
151 |
+
print(f"成功从 Hugging Face 恢复: {latest_file}")
|
152 |
+
# 清理空的缓存目录(如果存在且为空)
|
153 |
+
try:
|
154 |
+
cache_root = os.path.dirname(downloaded_path)
|
155 |
+
if os.path.exists(cache_root) and not os.listdir(cache_root):
|
156 |
+
os.rmdir(cache_root)
|
157 |
+
except OSError:
|
158 |
+
pass # 忽略清理错误
|
159 |
+
return True
|
160 |
+
else:
|
161 |
+
print(f"下载失败或文件路径无效: {downloaded_path}")
|
162 |
+
return False
|
163 |
+
except Exception as e:
|
164 |
+
print(f"下载或恢复过程中发生错误: {str(e)}")
|
165 |
+
# 打印更详细的堆栈跟踪信息以便调试
|
166 |
+
import traceback
|
167 |
+
traceback.print_exc()
|
168 |
+
return False
|
169 |
+
|
170 |
+
if __name__ == "__main__":
|
171 |
+
# 基本的参数数量检查
|
172 |
+
if len(sys.argv) < 4:
|
173 |
+
print("使用方法:")
|
174 |
+
print(" 上传: python3 hf_sync.py upload <HF_TOKEN> <DATASET_ID> <FILE_PATH> <FILE_NAME_IN_REPO>")
|
175 |
+
print(" 下载: python3 hf_sync.py download <HF_TOKEN> <DATASET_ID> <TARGET_DIR>")
|
176 |
+
sys.exit(1)
|
177 |
+
|
178 |
+
action = sys.argv[1]
|
179 |
+
token = sys.argv[2]
|
180 |
+
repo_id = sys.argv[3]
|
181 |
+
|
182 |
+
if action == "upload":
|
183 |
+
if len(sys.argv) != 6:
|
184 |
+
print("上传命令参数错误。需要: upload <HF_TOKEN> <DATASET_ID> <FILE_PATH> <FILE_NAME_IN_REPO>")
|
185 |
+
sys.exit(1)
|
186 |
+
file_path = sys.argv[4]
|
187 |
+
file_name = sys.argv[5]
|
188 |
+
upload_backup(file_path, file_name, token, repo_id)
|
189 |
+
elif action == "download":
|
190 |
+
if len(sys.argv) != 5:
|
191 |
+
print("下载命令参数错误。需要: download <HF_TOKEN> <DATASET_ID> <TARGET_DIR>")
|
192 |
+
sys.exit(1)
|
193 |
+
target_dir = sys.argv[4] # 下载需要目标目录参数
|
194 |
+
# 调用下载函数,并根据返回值退出
|
195 |
+
if not download_latest_backup(token, repo_id, target_dir):
|
196 |
+
sys.exit(1) # 下载失败则脚本退出,表示恢复失败
|
197 |
+
else:
|
198 |
+
print(f"未知的操作: {action}")
|
199 |
+
sys.exit(1)
|
200 |
+
|