File size: 5,844 Bytes
c3dc5d8
9cb371b
c3dc5d8
 
90914fa
9cb371b
90914fa
 
c3dc5d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90914fa
c3dc5d8
 
 
90914fa
 
c3dc5d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0db36e3
b340237
c3dc5d8
b340237
 
c3dc5d8
 
 
0db36e3
c3dc5d8
 
 
 
 
0db36e3
 
c3dc5d8
b340237
 
 
 
 
0db36e3
 
 
b340237
0db36e3
 
 
 
 
b340237
c3dc5d8
0db36e3
c3dc5d8
b340237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3dc5d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0db36e3
c3dc5d8
 
 
 
 
90914fa
c3dc5d8
90914fa
 
c3dc5d8
90914fa
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import logging
import os
from pathlib import Path
from typing import List, Optional

os.system('pip install modelscope -U')
import gradio as gr
from huggingface_hub import HfApi
from modelscope.hub.api import HubApi

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

class HFToMSConverter:
    def __init__(self, config: dict):
        self.config = config
        self.cache_dir = config.get('cache_dir', "hf2ms_cache")
        self.local_dir = config.get('local_dir', "hf2ms_local")
        self.hf_api = HfApi(token=config['hf_token'])
        self.ms_api = HubApi()
        self.ms_api.login(config['ms_token'])
        
        for dir_path in [self.local_dir, self.cache_dir]:
            Path(dir_path).mkdir(exist_ok=True)

    def get_hf_files(self, repo_id: str, repo_type: str = "dataset") -> List[str]:
        """获取HuggingFace仓库文件列表"""
        return self.hf_api.list_repo_files(repo_id=repo_id, repo_type=repo_type)

    def download_file(self, repo_id: str, filename: str) -> Optional[str]:
        """从HuggingFace下载文件"""
        save_path = Path(self.local_dir) / filename
        if save_path.exists():
            logger.warning(f"文件已存在: {filename}")
            return None

        try:
            self.hf_api.hf_hub_download(
                repo_id=repo_id,
                repo_type="dataset",
                filename=filename,
                cache_dir=self.cache_dir,
                local_dir=self.local_dir,
                local_dir_use_symlinks=False
            )
            logger.info(f"成功下载文件: {filename}")
            return str(save_path)
        except Exception as e:
            logger.error(f"下载失败 {filename}: {e}")
            return None

    def handle_file_operation(self, operation_type: str, *args) -> bool:
        """统一处理文件操作的异常"""
        try:
            if operation_type == "move":
                src, dst = args
                dst.parent.mkdir(parents=True, exist_ok=True)
                src.rename(dst)
                logger.info(f"移动文件成功: {src.name}")
            elif operation_type == "push":
                ms_repo_id, clone_dir = args
                logger.info(f"开始推送文件夹: {clone_dir}")
                self.ms_api.upload_folder(
                    repo_id=f"{ms_repo_id}",
                    folder_path=str(clone_dir),  # 确保路径是字符串
                    commit_message='upload dataset folder',
                    repo_type='dataset'
                )
                logger.info(f"推送文件夹成功: {clone_dir}")
            return True
        except Exception as e:
            logger.error(f"{operation_type}操作失败: {e}")
            return False

    def process_files(self, hf_repo: str, ms_repo: str, files: List[str]) -> bool:
        """处理所有文件的完整流程"""
        try:
            # 获取绝对路径并创建目录
            clone_dir = Path(os.path.abspath('.')) / ms_repo.split("/")[-1]
            clone_dir.mkdir(parents=True, exist_ok=True)
            
            # 下载并移动所有文件
            for filename in files:
                if not all([
                    self.download_file(hf_repo, filename),
                    self.move_file(filename, str(clone_dir))
                ]):
                    return False
            
            # 统一推送整个文件夹
            return self.push_to_ms(ms_repo, clone_dir)
            
        except Exception as e:
            logger.error(f"处理文件失败: {e}")
            return False
        finally:
            # 清理临时文件
            if clone_dir.exists():
                import shutil
                shutil.rmtree(clone_dir)

    def move_file(self, filename: str, clone_dir: str) -> bool:
        """移动文件到目标目录"""
        return self.handle_file_operation(
            "move",
            Path(self.local_dir) / filename,
            Path(clone_dir) / filename
        )

    def push_to_ms(self, ms_repo_id: str, clone_dir: str) -> bool:
        """推送到ModelScope"""
        return self.handle_file_operation("push", ms_repo_id, clone_dir)

def create_ui() -> gr.Blocks:
    """创建Gradio界面"""
    with gr.Blocks() as demo:
        gr.Markdown(
            """
            # HuggingFace to ModelScope 数据迁移工具
            请确保您拥有相应仓库的权限。
            """
        )
        
        with gr.Row():
            hf_token = gr.Textbox(label="HuggingFace Token")
            ms_token = gr.Textbox(label="ModelScope Token")
            
        with gr.Row():
            repo_type = gr.Textbox(label="仓库类型", value="dataset")
            hf_repo = gr.Textbox(label="HuggingFace仓库")
            ms_repo = gr.Textbox(label="ModelScope仓库")
            
        with gr.Row():
            submit = gr.Button("开始迁移", variant="primary")
            clear = gr.Button("清除")

        def handle_submit(hf_token, ms_token, repo_type, hf_repo, ms_repo):
            config = {
                'hf_token': hf_token,
                'ms_token': ms_token,
                'username': "thomas",
                'email': "[email protected]",
            }
            
            converter = HFToMSConverter(config)
            files = converter.get_hf_files(hf_repo, repo_type)
            converter.process_files(hf_repo, ms_repo, files)

        submit.click(
            handle_submit,
            inputs=[hf_token, ms_token, repo_type, hf_repo, ms_repo],
        )

    return demo

if __name__ == "__main__":
    demo = create_ui()
    demo.queue(max_size=1)
    demo.launch(share=False, max_threads=1)