hf2ms / app.py
thomas-yanxin's picture
Update app.py
b340237 verified
import logging
import os
from pathlib import Path
from typing import List, Optional
os.system('pip install modelscope -U')
import gradio as gr
from huggingface_hub import HfApi
from modelscope.hub.api import HubApi
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
class HFToMSConverter:
def __init__(self, config: dict):
self.config = config
self.cache_dir = config.get('cache_dir', "hf2ms_cache")
self.local_dir = config.get('local_dir', "hf2ms_local")
self.hf_api = HfApi(token=config['hf_token'])
self.ms_api = HubApi()
self.ms_api.login(config['ms_token'])
for dir_path in [self.local_dir, self.cache_dir]:
Path(dir_path).mkdir(exist_ok=True)
def get_hf_files(self, repo_id: str, repo_type: str = "dataset") -> List[str]:
"""获取HuggingFace仓库文件列表"""
return self.hf_api.list_repo_files(repo_id=repo_id, repo_type=repo_type)
def download_file(self, repo_id: str, filename: str) -> Optional[str]:
"""从HuggingFace下载文件"""
save_path = Path(self.local_dir) / filename
if save_path.exists():
logger.warning(f"文件已存在: {filename}")
return None
try:
self.hf_api.hf_hub_download(
repo_id=repo_id,
repo_type="dataset",
filename=filename,
cache_dir=self.cache_dir,
local_dir=self.local_dir,
local_dir_use_symlinks=False
)
logger.info(f"成功下载文件: {filename}")
return str(save_path)
except Exception as e:
logger.error(f"下载失败 {filename}: {e}")
return None
def handle_file_operation(self, operation_type: str, *args) -> bool:
"""统一处理文件操作的异常"""
try:
if operation_type == "move":
src, dst = args
dst.parent.mkdir(parents=True, exist_ok=True)
src.rename(dst)
logger.info(f"移动文件成功: {src.name}")
elif operation_type == "push":
ms_repo_id, clone_dir = args
logger.info(f"开始推送文件夹: {clone_dir}")
self.ms_api.upload_folder(
repo_id=f"{ms_repo_id}",
folder_path=str(clone_dir), # 确保路径是字符串
commit_message='upload dataset folder',
repo_type='dataset'
)
logger.info(f"推送文件夹成功: {clone_dir}")
return True
except Exception as e:
logger.error(f"{operation_type}操作失败: {e}")
return False
def process_files(self, hf_repo: str, ms_repo: str, files: List[str]) -> bool:
"""处理所有文件的完整流程"""
try:
# 获取绝对路径并创建目录
clone_dir = Path(os.path.abspath('.')) / ms_repo.split("/")[-1]
clone_dir.mkdir(parents=True, exist_ok=True)
# 下载并移动所有文件
for filename in files:
if not all([
self.download_file(hf_repo, filename),
self.move_file(filename, str(clone_dir))
]):
return False
# 统一推送整个文件夹
return self.push_to_ms(ms_repo, clone_dir)
except Exception as e:
logger.error(f"处理文件失败: {e}")
return False
finally:
# 清理临时文件
if clone_dir.exists():
import shutil
shutil.rmtree(clone_dir)
def move_file(self, filename: str, clone_dir: str) -> bool:
"""移动文件到目标目录"""
return self.handle_file_operation(
"move",
Path(self.local_dir) / filename,
Path(clone_dir) / filename
)
def push_to_ms(self, ms_repo_id: str, clone_dir: str) -> bool:
"""推送到ModelScope"""
return self.handle_file_operation("push", ms_repo_id, clone_dir)
def create_ui() -> gr.Blocks:
"""创建Gradio界面"""
with gr.Blocks() as demo:
gr.Markdown(
"""
# HuggingFace to ModelScope 数据迁移工具
请确保您拥有相应仓库的权限。
"""
)
with gr.Row():
hf_token = gr.Textbox(label="HuggingFace Token")
ms_token = gr.Textbox(label="ModelScope Token")
with gr.Row():
repo_type = gr.Textbox(label="仓库类型", value="dataset")
hf_repo = gr.Textbox(label="HuggingFace仓库")
ms_repo = gr.Textbox(label="ModelScope仓库")
with gr.Row():
submit = gr.Button("开始迁移", variant="primary")
clear = gr.Button("清除")
def handle_submit(hf_token, ms_token, repo_type, hf_repo, ms_repo):
config = {
'hf_token': hf_token,
'ms_token': ms_token,
'username': "thomas",
'email': "[email protected]",
}
converter = HFToMSConverter(config)
files = converter.get_hf_files(hf_repo, repo_type)
converter.process_files(hf_repo, ms_repo, files)
submit.click(
handle_submit,
inputs=[hf_token, ms_token, repo_type, hf_repo, ms_repo],
)
return demo
if __name__ == "__main__":
demo = create_ui()
demo.queue(max_size=1)
demo.launch(share=False, max_threads=1)