Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import logging | |
from .source_manager import SourceManager | |
from .action_manager import ActionManager | |
from .exporter_manager import ExporterManager | |
from .config_manager import ConfigManager | |
from .file_handler import FileHandler | |
class Interface: | |
def __init__(self, source_manager: SourceManager, action_manager: ActionManager, exporter_manager: ExporterManager, config_manager: ConfigManager): | |
self.source_manager = source_manager | |
self.action_manager = action_manager | |
self.exporter_manager = exporter_manager | |
self.config_manager = config_manager | |
self.params = source_manager.create_param_inputs() | |
self.param_components = [param_input for source_params in self.params.values() for param_input in source_params.values()] | |
self.language_dict = { | |
"zh": { | |
"title": "Waifuc 数据收集工具", | |
"data_collection": "数据收集", | |
"config": "配置", | |
"logs": "日志", | |
"select_source": "选择数据源", | |
"select_actions": "选择动作", | |
"dataset_name": "数据集名称", | |
"select_exporter": "选择导出器", | |
"start_collection": "开始收集", | |
"status": "状态", | |
"download_data": "下载收集的数据", | |
"config_management": "配置管理", | |
"pixiv_token": "Pixiv刷新令牌", | |
"output_dir": "输出目录", | |
"num_items": "默认图片数量", | |
"resize_size": "默认调整大小", | |
"language": "语言", | |
"save_config": "保存配置", | |
"export_config": "导出配置", | |
"import_config": "导入配置", | |
"config_status": "配置状态", | |
"view_logs": "查看日志", | |
"download_log": "下载日志", | |
"local_source_info": "上传包含图片的zip文件用于LocalSource", | |
"pixiv_word": "搜索关键词(Pixiv)", | |
"missing_params": "请填写所有必填参数(如搜索关键词)", | |
"invalid_source": "无效的数据源配置", | |
"cleanup_session": "清理会话", | |
"cleanup_status": "清理状态" | |
}, | |
"en": { | |
"title": "Waifuc Data Collection Tool", | |
"data_collection": "Data Collection", | |
"config": "Configuration", | |
"logs": "Logs", | |
"select_source": "Select Data Source", | |
"select_actions": "Select Actions", | |
"dataset_name": "Dataset Name", | |
"select_exporter": "Select Exporter", | |
"start_collection": "Start Collection", | |
"status": "Status", | |
"download_data": "Download Collected Data", | |
"config_management": "Configuration Management", | |
"pixiv_token": "Pixiv Refresh Token", | |
"output_dir": "Output Directory", | |
"num_items": "Default Number of Items", | |
"resize_size": "Default Resize Size", | |
"language": "Language", | |
"save_config": "Save Configuration", | |
"export_config": "Export Configuration", | |
"import_config": "Import Configuration", | |
"config_status": "Configuration Status", | |
"view_logs": "View Logs", | |
"download_log": "Download Log", | |
"local_source_info": "Upload a ZIP file containing images for LocalSource", | |
"pixiv_word": "Search Keyword (Pixiv)", | |
"missing_params": "Please fill in all required parameters (e.g., search keyword)", | |
"invalid_source": "Invalid data source configuration", | |
"cleanup_session": "Cleanup Session", | |
"cleanup_status": "Cleanup Status" | |
} | |
} | |
def get_text(self, key): | |
return self.language_dict[self.config_manager.get_config("language")][key] | |
def update_params_visibility(self, selected_source): | |
updates = [] | |
for source, source_params in self.params.items(): | |
for param_name, param_input in source_params.items(): | |
updates.append(gr.update(visible=(source == selected_source))) | |
local_source_info_visible = selected_source == "LocalSource" | |
return updates + [gr.update(visible=local_source_info_visible)] | |
def collect_params(self, selected_source, *param_values): | |
collected = {} | |
source_params = self.params.get(selected_source, {}) | |
for param_name, param_value in zip(source_params.keys(), param_values): | |
collected[param_name] = param_value | |
return collected | |
def update_action_params(self, selected_actions): | |
action_params = self.action_manager.create_action_param_inputs(selected_actions) | |
components = [] | |
for action, params in action_params.items(): | |
for param_name, param_input in params.items(): | |
components.append(param_input) | |
return components, action_params | |
def start_collection( | |
self, selected_source, params, selected_actions, action_params, dataset_name, selected_exporter, | |
source_manager, action_manager, exporter_manager, file_handler | |
): | |
if not selected_source or not dataset_name or not selected_exporter: | |
return ( | |
self.get_text("missing_params"), | |
None, "" | |
) | |
if selected_source == "PixivSearchSource" and (not params.get("word") or not params.get("refresh_token")): | |
return ( | |
self.get_text("missing_params"), | |
None, "" | |
) | |
output_dir = f"/tmp/user_{source_manager.config_manager.session_id}/{dataset_name}" | |
try: | |
logger = logging.getLogger("waifuc_gui") | |
logger.info(f"Starting collection for {selected_source} with dataset {dataset_name}") | |
os.makedirs(output_dir, exist_ok=True) | |
source = source_manager.instantiate_source(selected_source, params, file_handler) | |
actions = action_manager.instantiate_actions(selected_actions, action_params) | |
exporter = exporter_manager.instantiate_exporter(selected_exporter, dataset_name) | |
logger.info(f"Attaching actions: {selected_actions}") | |
os.chdir(f"/tmp/user_{source_manager.config_manager.session_id}") | |
source.attach(*actions).export(exporter) | |
logger.info(f"Export completed, creating ZIP") | |
zip_path = file_handler.create_zip(dataset_name) | |
logger.info(f"Collection completed: {zip_path}") | |
return ( | |
f"Data collection completed, output file: {dataset_name}.zip" if source_manager.config_manager.get_config("language") == "en" else | |
f"数据收集完成,输出文件:{dataset_name}.zip" | |
), zip_path, "\n".join(logger.handlers[0].log_stream) | |
except Exception as e: | |
logger.error(f"Collection failed: {str(e)}") | |
return ( | |
f"Data collection failed: {str(e)}" if source_manager.config_manager.get_config("language") == "en" else | |
f"数据收集失败:{str(e)}" | |
), None, "\n".join(logger.handlers[0].log_stream) | |
def download_log(self, file_handler): | |
logger = logging.getLogger("waifuc_gui") | |
log_content = "\n".join(logger.handlers[0].log_stream) | |
return file_handler.save_log(log_content) | |
def cleanup_session(self, config_manager): | |
config_manager.cleanup() | |
return self.get_text("cleanup_status") + ": " + ("Session cleaned" if config_manager.get_config("language") == "en" else "会话已清理") | |
def build(self): | |
with gr.Blocks(title=self.get_text("title")) as demo: | |
language_dropdown = gr.Dropdown( | |
choices=["zh", "en"], | |
label=self.get_text("language"), | |
value=self.config_manager.get_config("language") | |
) | |
with gr.Tab(self.get_text("data_collection")): | |
gr.Markdown("### " + self.get_text("data_collection")) | |
source_dropdown = gr.Dropdown( | |
choices=self.source_manager.source_names, | |
label=self.get_text("select_source"), | |
value=self.source_manager.source_names[0] if self.source_manager.source_names else None, | |
info=self.get_text("select_source_info") | |
) | |
local_source_info = gr.Markdown( | |
self.get_text("local_source_info"), | |
visible=False | |
) | |
param_components = self.param_components | |
action_checkboxes = gr.CheckboxGroup( | |
choices=self.action_manager.action_names, | |
label=self.get_text("select_actions"), | |
value=[], | |
info=self.get_text("select_actions_info") | |
) | |
action_param_components = gr.State(value={}) | |
action_param_inputs = [] | |
dataset_name_input = gr.Textbox( | |
label=self.get_text("dataset_name"), | |
value="waifuc_dataset", | |
info=self.get_text("dataset_name_info") | |
) | |
exporter_dropdown = gr.Dropdown( | |
choices=self.exporter_manager.exporter_names, | |
label=self.get_text("select_exporter"), | |
value=self.exporter_manager.exporter_names[0] if self.exporter_manager.exporter_names else None, | |
info=self.get_text("select_exporter_info") | |
) | |
start_btn = gr.Button(self.get_text("start_collection")) | |
status = gr.Textbox(label=self.get_text("status"), interactive=False) | |
output_file = gr.File(label=self.get_text("download_data")) | |
with gr.Tab(self.get_text("config")): | |
gr.Markdown("### " + self.get_text("config_management")) | |
pixiv_token_input = gr.Textbox( | |
label=self.get_text("pixiv_token"), | |
value=self.config_manager.get_config("pixiv_refresh_token"), | |
type="password", | |
info=self.get_text("pixiv_token_info") | |
) | |
output_dir_input = gr.Textbox( | |
label=self.get_text("output_dir"), | |
value=self.config_manager.get_config("output_dir"), | |
info=self.get_text("output_dir_info") | |
) | |
num_items_input = gr.Number( | |
label=self.get_text("num_items"), | |
value=self.config_manager.get_config("default_num_items"), | |
info=self.get_text("num_items_info") | |
) | |
resize_size_input = gr.Number( | |
label=self.get_text("resize_size"), | |
value=self.config_manager.get_config("default_resize_size"), | |
info=self.get_text("resize_size_info") | |
) | |
config_export_btn = gr.Button(self.get_text("export_config")) | |
config_import_file = gr.File(label=self.get_text("import_config")) | |
save_config_btn = gr.Button(self.get_text("save_config")) | |
config_status = gr.Textbox(label=self.get_text("config_status"), interactive=False) | |
with gr.Tab(self.get_text("logs")): | |
gr.Markdown("### " + self.get_text("view_logs")) | |
log_output = gr.Textbox(label=self.get_text("view_logs"), interactive=False, lines=10) | |
log_download_btn = gr.Button(self.get_text("download_log")) | |
log_download_file = gr.File(label=self.get_text("download_log")) | |
def update_language(language): | |
self.config_manager.set_config("language", language) | |
updates = { | |
language_dropdown: gr.update(label=self.get_text("language")), | |
source_dropdown: gr.update(label=self.get_text("select_source"), info="Choose a data source like DanbooruSource or LocalSource" if self.config_manager.get_config("language") == "en" else "选择数据源,如DanbooruSource或LocalSource"), | |
action_checkboxes: gr.update(label=self.get_text("select_actions"), info="Select actions to process images, e.g., NoMonochromeAction" if self.config_manager.get_config("language") == "en" else "选择处理图片的动作,例如NoMonochromeAction"), | |
dataset_name_input: gr.update(label=self.get_text("dataset_name"), info="Name for the output dataset" if self.config_manager.get_config("language") == "en" else "输出数据集的名称"), | |
exporter_dropdown: gr.update(label=self.get_text("select_exporter"), info="Choose an exporter like TextualInversionExporter" if self.config_manager.get_config("language") == "en" else "选择导出器,如TextualInversionExporter"), | |
start_btn: gr.update(value=self.get_text("start_collection")), | |
status: gr.update(label=self.get_text("status")), | |
output_file: gr.update(label=self.get_text("download_data")), | |
pixiv_token_input: gr.update(label=self.get_text("pixiv_token"), info="Required for PixivSearchSource, stored in session only" if self.config_manager.get_config("language") == "en" else "PixivSearchSource所需,仅存储在会话中"), | |
output_dir_input: gr.update(label=self.get_text("output_dir"), info="Directory for output data, session-specific" if self.config_manager.get_config("language") == "en" else "输出数据的目录,会话专用"), | |
num_items_input: gr.update(label=self.get_text("num_items"), info="Default number of images to collect" if self.config_manager.get_config("language") == "en" else "默认收集的图片数量"), | |
resize_size_input: gr.update(label=self.get_text("resize_size"), info="Default size for ResizeAction" if self.config_manager.get_config("language") == "en" else "ResizeAction的默认尺寸"), | |
config_export_btn: gr.update(value=self.get_text("export_config")), | |
config_import_file: gr.update(label=self.get_text("import_config")), | |
save_config_btn: gr.update(value=self.get_text("save_config")), | |
config_status: gr.update(label=self.get_text("config_status")), | |
log_output: gr.update(label=self.get_text("view_logs")), | |
log_download_btn: gr.update(value=self.get_text("download_log")), | |
log_download_file: gr.update(label=self.get_text("download_log")), | |
local_source_info: gr.update(value=self.get_text("local_source_info"), visible=source_dropdown.value == "LocalSource"), | |
cleanup_btn: gr.update(value=self.get_text("cleanup_session")), | |
cleanup_status: gr.update(label=self.get_text("cleanup_status")) | |
} | |
return updates | |
language_dropdown.change( | |
fn=update_language, | |
inputs=language_dropdown, | |
outputs=[ | |
language_dropdown, source_dropdown, action_checkboxes, | |
dataset_name_input, exporter_dropdown, start_btn, status, | |
output_file, pixiv_token_input, output_dir_input, num_items_input, | |
resize_size_input, config_export_btn, config_import_file, | |
save_config_btn, config_status, log_output, log_download_btn, | |
log_download_file, local_source_info, cleanup_btn, cleanup_status | |
] | |
) | |
source_dropdown.change( | |
fn=self.update_params_visibility, | |
inputs=source_dropdown, | |
outputs=param_components + [local_source_info] | |
) | |
action_checkboxes.change( | |
fn=self.update_action_params, | |
inputs=action_checkboxes, | |
outputs=[gr.State(), action_param_components] | |
) | |
def save_configs(pixiv_token, output_dir, num_items, resize_size): | |
self.config_manager.set_config("pixiv_refresh_token", pixiv_token) | |
self.config_manager.set_config("output_dir", output_dir) | |
self.config_manager.set_config("default_num_items", int(num_items)) | |
self.config_manager.set_config("default_resize_size", int(resize_size)) | |
return self.get_text("config_status") + ": " + ("Configuration saved" if self.config_manager.get_config("language") == "en" else "配置已保存") | |
save_config_btn.click( | |
fn=save_configs, | |
inputs=[pixiv_token_input, output_dir_input, num_items_input, resize_size_input], | |
outputs=config_status | |
) | |
config_export_btn.click( | |
fn=self.config_manager.export_config, | |
inputs=[], | |
outputs=gr.File(label=self.get_text("export_config")) | |
) | |
config_import_file.upload( | |
fn=self.config_manager.import_config, | |
inputs=config_import_file, | |
outputs=config_status | |
) | |
start_btn.click( | |
fn=self.start_collection, | |
inputs=[ | |
source_dropdown, | |
gr.State(lambda: self.collect_params( | |
source_dropdown.value, | |
*[p.value for p in param_components] | |
)), | |
action_checkboxes, | |
action_param_components, | |
dataset_name_input, | |
exporter_dropdown, | |
gr.State(value=self.source_manager), | |
gr.State(value=self.action_manager), | |
gr.State(value=self.exporter_manager), | |
gr.State(value=FileHandler()) | |
], | |
outputs=[status, output_file, log_output], | |
api_name="start_collection" | |
) | |
log_download_btn.click( | |
fn=self.download_log, | |
inputs=[gr.State(value=FileHandler())], | |
outputs=log_download_file | |
) | |
cleanup_btn.click( | |
fn=self.cleanup_session, | |
inputs=[gr.State(value=self.config_manager)], | |
outputs=cleanup_status | |
) | |
return { | |
"demo": demo, | |
"source_dropdown": source_dropdown, | |
"param_components": param_components, | |
"action_checkboxes": action_checkboxes, | |
"action_param_components": action_param_components, | |
"action_param_inputs": action_param_inputs, | |
"dataset_name_input": dataset_name_input, | |
"exporter_dropdown": exporter_dropdown, | |
"start_btn": start_btn, | |
"status": status, | |
"output_file": output_file, | |
"log_output": log_output, | |
"log_download_btn": log_download_btn, | |
"log_download_file": log_download_file, | |
"cleanup_btn": cleanup_btn, | |
"cleanup_status": cleanup_status | |
} |