nbugs commited on
Commit
d696836
·
verified ·
1 Parent(s): 51af98b

Create hf_sync.py

Browse files
Files changed (1) hide show
  1. hf_sync.py +200 -0
hf_sync.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # hf_sync.py
2
+ from huggingface_hub import HfApi
3
+ import sys
4
+ import os
5
+ import datetime
6
+
7
+ def manage_backups(api, repo_id, max_files=50):
8
+ """管理备份文件,保留最新的max_files个文件"""
9
+ try:
10
+ print(f"开始管理 Hugging Face 仓库 '{repo_id}' 的备份...")
11
+ files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")
12
+ # 筛选出符合备份文件命名规则的文件
13
+ backup_files = [f for f in files if f.startswith('webui_backup_') and f.endswith('.db')]
14
+ print(f"在仓库中找到 {len(backup_files)} 个备份文件。")
15
+
16
+ # 提取日期并排序
17
+ dated_files = []
18
+ for f in backup_files:
19
+ try:
20
+ # 假设文件名格式为 webui_backup_YYYYMMDD.db
21
+ date_str = f.split('_')[-1].split('.')[0]
22
+ # 严格匹配 YYYYMMDD 格式
23
+ if len(date_str) == 8 and date_str.isdigit():
24
+ date_obj = datetime.datetime.strptime(date_str, '%Y%m%d').date()
25
+ dated_files.append((date_obj, f))
26
+ else:
27
+ print(f"警告: 文件名格式不符,跳过: {f}")
28
+ except (IndexError, ValueError) as e:
29
+ print(f"警告: 无法从文件名解析日期,跳过: {f} (错误: {e})")
30
+ continue # 跳过格式不正确的文件
31
+
32
+ # 按日期降序排序(最新的在前)
33
+ dated_files.sort(key=lambda item: item[0], reverse=True)
34
+
35
+ if len(dated_files) > max_files:
36
+ files_to_delete = [item[1] for item in dated_files[max_files:]]
37
+ print(f"找到 {len(files_to_delete)} 个需要删除的旧备份。")
38
+ for file_to_delete in files_to_delete:
39
+ try:
40
+ print(f"正在删除旧备份: {file_to_delete}")
41
+ api.delete_file(path_in_repo=file_to_delete, repo_id=repo_id, repo_type="dataset")
42
+ print(f"已删除旧备份: {file_to_delete}")
43
+ except Exception as e_del:
44
+ print(f"删除文件 {file_to_delete} 时出错: {str(e_del)}")
45
+ else:
46
+ print(f"备份文件数量 ({len(dated_files)}) 未超过限制 ({max_files}),无需删除。")
47
+
48
+ except Exception as e_list:
49
+ print(f"列出或管理备份文件时出错: {str(e_list)}")
50
+
51
+
52
+ def upload_backup(file_path, file_name, token, repo_id):
53
+ """上传备份文件到Hugging Face"""
54
+ if not os.path.exists(file_path):
55
+ print(f"错误: 本地文件未找到,无法上传: {file_path}")
56
+ return
57
+
58
+ print(f"尝试将 {file_path} 中的文件上传为 {file_name} 到仓库 {repo_id}")
59
+ api = HfApi(token=token)
60
+ try:
61
+ # 上传新文件 (upload_file 会覆盖同名文件)
62
+ print(f"正在上传 {file_name}...")
63
+ api.upload_file(
64
+ path_or_fileobj=file_path,
65
+ path_in_repo=file_name,
66
+ repo_id=repo_id,
67
+ repo_type="dataset"
68
+ )
69
+ print(f"成功上传: {file_name}")
70
+ # 上传成功后管理备份
71
+ print("上传成功,开始管理备份...")
72
+ manage_backups(api, repo_id)
73
+ except Exception as e:
74
+ print(f"上传文件 {file_name} 失败: {str(e)}")
75
+
76
+
77
+ def download_latest_backup(token, repo_id, target_dir):
78
+ """从Hugging Face下载最新备份"""
79
+ print(f"尝试从仓库 {repo_id} 下载最新备份到目录 {target_dir}")
80
+ api = HfApi(token=token)
81
+ try:
82
+ files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")
83
+ backup_files = [f for f in files if f.startswith('webui_backup_') and f.endswith('.db')]
84
+
85
+ if not backup_files:
86
+ print("在 Hugging Face 数据集中未找到备份文件。")
87
+ return False
88
+
89
+ # 根据文件名中的日期找到最新的文件
90
+ latest_file = None
91
+ latest_date = None
92
+ for f in backup_files:
93
+ try:
94
+ # 假设文件名格式为 webui_backup_YYYYMMDD.db
95
+ date_str = f.split('_')[-1].split('.')[0]
96
+ if len(date_str) == 8 and date_str.isdigit():
97
+ date_obj = datetime.datetime.strptime(date_str, '%Y%m%d').date()
98
+ if latest_date is None or date_obj > latest_date:
99
+ latest_date = date_obj
100
+ latest_file = f
101
+ else:
102
+ continue # 跳过格式不正确的文件
103
+ except (IndexError, ValueError):
104
+ continue # 跳过无法解析的文件
105
+
106
+ if latest_file is None:
107
+ print("无法确定最新的备份文件(可能所有文件名格式都不正确)。")
108
+ return False
109
+
110
+ print(f"找到最新的备份文件: {latest_file}")
111
+ print(f"正在下载 {latest_file}...")
112
+ # 先下载到临时路径
113
+ downloaded_path = api.hf_hub_download(
114
+ repo_id=repo_id,
115
+ filename=latest_file,
116
+ repo_type="dataset",
117
+ local_dir_use_symlinks=False, # 避免在容器中出现符号链接问题
118
+ cache_dir=os.path.join(target_dir, '.hf_cache') # 指定缓存目录,避免权限问题
119
+ )
120
+
121
+ if downloaded_path and os.path.exists(downloaded_path):
122
+ # 创建目标目录(如果不存在)
123
+ os.makedirs(target_dir, exist_ok=True)
124
+ # 目标文件路径
125
+ target_path = os.path.join(target_dir, "webui.db")
126
+ print(f"下载完成,路径: {downloaded_path},准备移动到: {target_path}")
127
+ # 使用 os.replace 进行原子性移动/覆盖,如果失败则回退到复制
128
+ try:
129
+ os.replace(downloaded_path, target_path)
130
+ print(f"成功将 {latest_file} 移动到 {target_path}")
131
+ except OSError as e_mv:
132
+ print(f"移动文件失败 (错误: {e_mv}),尝试复制...")
133
+ import shutil
134
+ try:
135
+ shutil.copy2(downloaded_path, target_path)
136
+ print(f"成功将 {latest_file} 复制到 {target_path}")
137
+ # 尝试删除临时下载文件
138
+ try:
139
+ os.remove(downloaded_path)
140
+ # 清理可能的缓存目录结构
141
+ hf_cache_dir = os.path.join(target_dir, '.hf_cache')
142
+ if os.path.exists(hf_cache_dir):
143
+ shutil.rmtree(hf_cache_dir)
144
+ except OSError as e_rm:
145
+ print(f"警告: 删除临时下载文件 {downloaded_path} 失败: {e_rm}")
146
+ except Exception as e_cp:
147
+ print(f"复制文件也失败: {e_cp}")
148
+ return False
149
+
150
+
151
+ print(f"成功从 Hugging Face 恢复: {latest_file}")
152
+ # 清理空的缓存目录(如果存在且为空)
153
+ try:
154
+ cache_root = os.path.dirname(downloaded_path)
155
+ if os.path.exists(cache_root) and not os.listdir(cache_root):
156
+ os.rmdir(cache_root)
157
+ except OSError:
158
+ pass # 忽略清理错误
159
+ return True
160
+ else:
161
+ print(f"下载失败或文件路径无效: {downloaded_path}")
162
+ return False
163
+ except Exception as e:
164
+ print(f"下载或恢复过程中发生错误: {str(e)}")
165
+ # 打印更详细的堆栈跟踪信息以便调试
166
+ import traceback
167
+ traceback.print_exc()
168
+ return False
169
+
170
+ if __name__ == "__main__":
171
+ # 基本的参数数量检查
172
+ if len(sys.argv) < 4:
173
+ print("使用方法:")
174
+ print(" 上传: python3 hf_sync.py upload <HF_TOKEN> <DATASET_ID> <FILE_PATH> <FILE_NAME_IN_REPO>")
175
+ print(" 下载: python3 hf_sync.py download <HF_TOKEN> <DATASET_ID> <TARGET_DIR>")
176
+ sys.exit(1)
177
+
178
+ action = sys.argv[1]
179
+ token = sys.argv[2]
180
+ repo_id = sys.argv[3]
181
+
182
+ if action == "upload":
183
+ if len(sys.argv) != 6:
184
+ print("上传命令参数错误。需要: upload <HF_TOKEN> <DATASET_ID> <FILE_PATH> <FILE_NAME_IN_REPO>")
185
+ sys.exit(1)
186
+ file_path = sys.argv[4]
187
+ file_name = sys.argv[5]
188
+ upload_backup(file_path, file_name, token, repo_id)
189
+ elif action == "download":
190
+ if len(sys.argv) != 5:
191
+ print("下载命令参数错误。需要: download <HF_TOKEN> <DATASET_ID> <TARGET_DIR>")
192
+ sys.exit(1)
193
+ target_dir = sys.argv[4] # 下载需要目标目录参数
194
+ # 调用下载函数,并根据返回值退出
195
+ if not download_latest_backup(token, repo_id, target_dir):
196
+ sys.exit(1) # 下载失败则脚本退出,表示恢复失败
197
+ else:
198
+ print(f"未知的操作: {action}")
199
+ sys.exit(1)
200
+