Spaces:
Runtime error
Runtime error
import inspect | |
import waifuc.source | |
import gradio as gr | |
from .config_manager import ConfigManager | |
class SourceManager: | |
def __init__(self, config_manager: ConfigManager): | |
self.config_manager = config_manager | |
self.source_classes = [ | |
cls for name, cls in inspect.getmembers(waifuc.source, inspect.isclass) | |
if name.endswith("Source") and not inspect.isabstract(cls) | |
] | |
self.source_names = [cls.__name__ for cls in self.source_classes] | |
def get_source_params(self, selected_source): | |
source_cls = next(cls for cls in self.source_classes if cls.__name__ == selected_source) | |
sig = inspect.signature(source_cls.__init__) | |
return [p for p in sig.parameters.values() if p.name != 'self'] | |
def create_param_inputs(self): | |
params = {} | |
for source in self.source_names: | |
if source == "LocalSource": | |
params[source] = { | |
"file": gr.File( | |
label="Upload ZIP file" if self.config_manager.get_config("language") == "en" else "上传zip文件", | |
visible=False | |
) | |
} | |
else: | |
source_cls = next(cls for cls in self.source_classes if cls.__name__ == source) | |
sig = inspect.signature(source_cls.__init__) | |
for param in sig.parameters.values(): | |
if param.name != 'self': | |
if param.name == "refresh_token": | |
params.setdefault(source, {})[param.name] = gr.Textbox( | |
label="Pixiv Refresh Token" if self.config_manager.get_config("language") == "en" else "Pixiv刷新令牌", | |
type="password", | |
visible=False, | |
value=self.config_manager.get_config("pixiv_refresh_token"), | |
info="Set in Config tab or here" if self.config_manager.get_config("language") == "en" else "在配置选项卡或此处设置" | |
) | |
elif param.name == "num_items": | |
params.setdefault(source, {})[param.name] = gr.Number( | |
label="Number of Items" if self.config_manager.get_config("language") == "en" else "图片数量", | |
visible=False, | |
value=self.config_manager.get_config("default_num_items"), | |
info="Number of images to collect" if self.config_manager.get_config("language") == "en" else "要收集的图片数量" | |
) | |
elif param.annotation == list: | |
params.setdefault(source, {})[param.name] = gr.Textbox( | |
label=param.name, | |
placeholder="Comma-separated tags, e.g., tag1,tag2" if self.config_manager.get_config("language") == "en" else "逗号分隔的标签,例如:tag1,tag2", | |
visible=False | |
) | |
elif param.annotation == int: | |
params.setdefault(source, {})[param.name] = gr.Number( | |
label=param.name, | |
visible=False | |
) | |
else: | |
params.setdefault(source, {})[param.name] = gr.Textbox( | |
label=param.name, | |
visible=False | |
) | |
return params | |
def instantiate_source(self, selected_source, params, file_handler): | |
if selected_source == "LocalSource": | |
uploaded_file = params.get("file", None) | |
if not uploaded_file: | |
raise ValueError( | |
"Please upload a ZIP file for LocalSource" if self.config_manager.get_config("language") == "en" else | |
"请上传LocalSource的zip文件" | |
) | |
extract_dir = file_handler.extract_zip(uploaded_file) | |
return waifuc.source.LocalSource(extract_dir) | |
else: | |
source_cls = next(cls for cls in self.source_classes if cls.__name__ == selected_source) | |
source_params = {k: v for k, v in params.items() if k in [p.name for p in self.get_source_params(selected_source)]} | |
if "tags" in source_params: | |
source_params["tags"] = [tag.strip() for tag in source_params["tags"].split(',')] | |
if "refresh_token" in source_params and not source_params["refresh_token"]: | |
source_params["refresh_token"] = self.config_manager.get_config("pixiv_refresh_token") | |
return source_cls(**source_params) | |