waifuc-gui / waifuc_gui /source_manager.py
LittleApple-fp16's picture
Update waifuc_gui/source_manager.py
8ffbcc8 verified
import inspect
import waifuc.source
import gradio as gr
from .config_manager import ConfigManager
class SourceManager:
def __init__(self, config_manager: ConfigManager):
self.config_manager = config_manager
self.source_classes = [
cls for name, cls in inspect.getmembers(waifuc.source, inspect.isclass)
if name.endswith("Source") and not inspect.isabstract(cls)
]
self.source_names = [cls.__name__ for cls in self.source_classes]
def get_source_params(self, selected_source):
source_cls = next(cls for cls in self.source_classes if cls.__name__ == selected_source)
sig = inspect.signature(source_cls.__init__)
return [p for p in sig.parameters.values() if p.name != 'self']
def create_param_inputs(self):
params = {}
for source in self.source_names:
if source == "LocalSource":
params[source] = {
"file": gr.File(
label="Upload ZIP file" if self.config_manager.get_config("language") == "en" else "上传zip文件",
visible=False
)
}
else:
source_cls = next(cls for cls in self.source_classes if cls.__name__ == source)
sig = inspect.signature(source_cls.__init__)
for param in sig.parameters.values():
if param.name != 'self':
if param.name == "refresh_token":
params.setdefault(source, {})[param.name] = gr.Textbox(
label="Pixiv Refresh Token" if self.config_manager.get_config("language") == "en" else "Pixiv刷新令牌",
type="password",
visible=False,
value=self.config_manager.get_config("pixiv_refresh_token"),
info="Set in Config tab or here" if self.config_manager.get_config("language") == "en" else "在配置选项卡或此处设置"
)
elif param.name == "num_items":
params.setdefault(source, {})[param.name] = gr.Number(
label="Number of Items" if self.config_manager.get_config("language") == "en" else "图片数量",
visible=False,
value=self.config_manager.get_config("default_num_items"),
info="Number of images to collect" if self.config_manager.get_config("language") == "en" else "要收集的图片数量"
)
elif param.annotation == list:
params.setdefault(source, {})[param.name] = gr.Textbox(
label=param.name,
placeholder="Comma-separated tags, e.g., tag1,tag2" if self.config_manager.get_config("language") == "en" else "逗号分隔的标签,例如:tag1,tag2",
visible=False
)
elif param.annotation == int:
params.setdefault(source, {})[param.name] = gr.Number(
label=param.name,
visible=False
)
else:
params.setdefault(source, {})[param.name] = gr.Textbox(
label=param.name,
visible=False
)
return params
def instantiate_source(self, selected_source, params, file_handler):
if selected_source == "LocalSource":
uploaded_file = params.get("file", None)
if not uploaded_file:
raise ValueError(
"Please upload a ZIP file for LocalSource" if self.config_manager.get_config("language") == "en" else
"请上传LocalSource的zip文件"
)
extract_dir = file_handler.extract_zip(uploaded_file)
return waifuc.source.LocalSource(extract_dir)
else:
source_cls = next(cls for cls in self.source_classes if cls.__name__ == selected_source)
source_params = {k: v for k, v in params.items() if k in [p.name for p in self.get_source_params(selected_source)]}
if "tags" in source_params:
source_params["tags"] = [tag.strip() for tag in source_params["tags"].split(',')]
if "refresh_token" in source_params and not source_params["refresh_token"]:
source_params["refresh_token"] = self.config_manager.get_config("pixiv_refresh_token")
return source_cls(**source_params)