Spaces:
Sleeping
Sleeping
import logging | |
import os | |
from pathlib import Path | |
from typing import List, Optional | |
os.system('pip install modelscope -U') | |
import gradio as gr | |
from huggingface_hub import HfApi | |
from modelscope.hub.api import HubApi | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[logging.StreamHandler()] | |
) | |
logger = logging.getLogger(__name__) | |
class HFToMSConverter: | |
def __init__(self, config: dict): | |
self.config = config | |
self.cache_dir = config.get('cache_dir', "hf2ms_cache") | |
self.local_dir = config.get('local_dir', "hf2ms_local") | |
self.hf_api = HfApi(token=config['hf_token']) | |
self.ms_api = HubApi() | |
self.ms_api.login(config['ms_token']) | |
for dir_path in [self.local_dir, self.cache_dir]: | |
Path(dir_path).mkdir(exist_ok=True) | |
def get_hf_files(self, repo_id: str, repo_type: str = "dataset") -> List[str]: | |
"""获取HuggingFace仓库文件列表""" | |
return self.hf_api.list_repo_files(repo_id=repo_id, repo_type=repo_type) | |
def download_file(self, repo_id: str, filename: str) -> Optional[str]: | |
"""从HuggingFace下载文件""" | |
save_path = Path(self.local_dir) / filename | |
if save_path.exists(): | |
logger.warning(f"文件已存在: {filename}") | |
return None | |
try: | |
self.hf_api.hf_hub_download( | |
repo_id=repo_id, | |
repo_type="dataset", | |
filename=filename, | |
cache_dir=self.cache_dir, | |
local_dir=self.local_dir, | |
local_dir_use_symlinks=False | |
) | |
logger.info(f"成功下载文件: {filename}") | |
return str(save_path) | |
except Exception as e: | |
logger.error(f"下载失败 {filename}: {e}") | |
return None | |
def handle_file_operation(self, operation_type: str, *args) -> bool: | |
"""统一处理文件操作的异常""" | |
try: | |
if operation_type == "move": | |
src, dst = args | |
dst.parent.mkdir(parents=True, exist_ok=True) | |
src.rename(dst) | |
logger.info(f"移动文件成功: {src.name}") | |
elif operation_type == "push": | |
ms_repo_id, clone_dir = args | |
logger.info(f"开始推送文件夹: {clone_dir}") | |
self.ms_api.upload_folder( | |
repo_id=f"{ms_repo_id}", | |
folder_path=str(clone_dir), # 确保路径是字符串 | |
commit_message='upload dataset folder', | |
repo_type='dataset' | |
) | |
logger.info(f"推送文件夹成功: {clone_dir}") | |
return True | |
except Exception as e: | |
logger.error(f"{operation_type}操作失败: {e}") | |
return False | |
def process_files(self, hf_repo: str, ms_repo: str, files: List[str]) -> bool: | |
"""处理所有文件的完整流程""" | |
try: | |
# 获取绝对路径并创建目录 | |
clone_dir = Path(os.path.abspath('.')) / ms_repo.split("/")[-1] | |
clone_dir.mkdir(parents=True, exist_ok=True) | |
# 下载并移动所有文件 | |
for filename in files: | |
if not all([ | |
self.download_file(hf_repo, filename), | |
self.move_file(filename, str(clone_dir)) | |
]): | |
return False | |
# 统一推送整个文件夹 | |
return self.push_to_ms(ms_repo, clone_dir) | |
except Exception as e: | |
logger.error(f"处理文件失败: {e}") | |
return False | |
finally: | |
# 清理临时文件 | |
if clone_dir.exists(): | |
import shutil | |
shutil.rmtree(clone_dir) | |
def move_file(self, filename: str, clone_dir: str) -> bool: | |
"""移动文件到目标目录""" | |
return self.handle_file_operation( | |
"move", | |
Path(self.local_dir) / filename, | |
Path(clone_dir) / filename | |
) | |
def push_to_ms(self, ms_repo_id: str, clone_dir: str) -> bool: | |
"""推送到ModelScope""" | |
return self.handle_file_operation("push", ms_repo_id, clone_dir) | |
def create_ui() -> gr.Blocks: | |
"""创建Gradio界面""" | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# HuggingFace to ModelScope 数据迁移工具 | |
请确保您拥有相应仓库的权限。 | |
""" | |
) | |
with gr.Row(): | |
hf_token = gr.Textbox(label="HuggingFace Token") | |
ms_token = gr.Textbox(label="ModelScope Token") | |
with gr.Row(): | |
repo_type = gr.Textbox(label="仓库类型", value="dataset") | |
hf_repo = gr.Textbox(label="HuggingFace仓库") | |
ms_repo = gr.Textbox(label="ModelScope仓库") | |
with gr.Row(): | |
submit = gr.Button("开始迁移", variant="primary") | |
clear = gr.Button("清除") | |
def handle_submit(hf_token, ms_token, repo_type, hf_repo, ms_repo): | |
config = { | |
'hf_token': hf_token, | |
'ms_token': ms_token, | |
'username': "thomas", | |
'email': "[email protected]", | |
} | |
converter = HFToMSConverter(config) | |
files = converter.get_hf_files(hf_repo, repo_type) | |
converter.process_files(hf_repo, ms_repo, files) | |
submit.click( | |
handle_submit, | |
inputs=[hf_token, ms_token, repo_type, hf_repo, ms_repo], | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_ui() | |
demo.queue(max_size=1) | |
demo.launch(share=False, max_threads=1) |