Spaces:
Running
Running
| import os | |
| import json | |
| import gradio as gr | |
| import huggingface_hub | |
| import numpy as np | |
| import onnxruntime as rt | |
| import pandas as pd | |
| from PIL import Image | |
| from huggingface_hub import whoami, HfApi | |
| from translator import translate_texts | |
| # ------------------------------------------------------------------ | |
| # Model Configuration | |
| # ------------------------------------------------------------------ | |
| MODEL_REPO = "SmilingWolf/wd-eva02-large-tagger-v3" | |
| MODEL_FILENAME = "model.onnx" | |
| LABEL_FILENAME = "selected_tags.csv" | |
| # It's recommended to manage the token within the HF Spaces secrets | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # A more robust way to get the space owner | |
| SPACE_ID = os.environ.get("SPACE_ID") | |
| SPACE_OWNER = SPACE_ID.split('/')[0] if SPACE_ID else None | |
| # ------------------------------------------------------------------ | |
| # Tagger Class (Global Instance) | |
| # ------------------------------------------------------------------ | |
| class Tagger: | |
| def __init__(self): | |
| self.hf_token = HF_TOKEN | |
| self.tag_names = [] | |
| self.categories = {} | |
| self.model = None | |
| self.input_size = 0 | |
| self._load_model_and_labels() | |
| def _load_model_and_labels(self): | |
| try: | |
| label_path = huggingface_hub.hf_hub_download( | |
| MODEL_REPO, LABEL_FILENAME, token=self.hf_token, resume_download=True | |
| ) | |
| model_path = huggingface_hub.hf_hub_download( | |
| MODEL_REPO, MODEL_FILENAME, token=self.hf_token, resume_download=True | |
| ) | |
| tags_df = pd.read_csv(label_path) | |
| self.tag_names = tags_df["name"].tolist() | |
| self.categories = { | |
| "rating": np.where(tags_df["category"] == 9)[0], | |
| "general": np.where(tags_df["category"] == 0)[0], | |
| "character": np.where(tags_df["category"] == 4)[0], | |
| } | |
| self.model = rt.InferenceSession(model_path) | |
| self.input_size = self.model.get_inputs()[0].shape[1] | |
| print("✅ Model and labels loaded successfully.") | |
| except Exception as e: | |
| print(f"❌ Failed to load model or labels: {e}") | |
| raise RuntimeError(f"Model initialization failed: {e}") | |
| # ------------------------- preprocess ------------------------- | |
| def _preprocess(self, img: Image.Image) -> np.ndarray: | |
| if img is None: raise ValueError("Input image cannot be None.") | |
| if img.mode != "RGB": img = img.convert("RGB") | |
| size = max(img.size) | |
| canvas = Image.new("RGB", (size, size), (255, 255, 255)) | |
| canvas.paste(img, ((size - img.width) // 2, (size - img.height) // 2)) | |
| if size != self.input_size: | |
| canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC) | |
| return np.array(canvas)[:, :, ::-1].astype(np.float32) | |
| # --------------------------- predict -------------------------- | |
| def predict(self, img: Image.Image, gen_th: float = 0.35, char_th: float = 0.85): | |
| if self.model is None: raise RuntimeError("Model not loaded, cannot predict.") | |
| inp_name = self.model.get_inputs()[0].name | |
| outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0] | |
| res = {"ratings": {}, "general": {}, "characters": {}} | |
| tag_categories_for_translation = {"ratings": [], "general": [], "characters": []} | |
| for cat_key, cat_indices in self.categories.items(): | |
| sub_res = {} | |
| if cat_key == "rating": | |
| for idx in cat_indices: | |
| tag_name = self.tag_names[idx].replace("_", " ") | |
| sub_res[tag_name] = float(outputs[idx]) | |
| else: | |
| threshold = char_th if cat_key == "character" else gen_th | |
| for idx in cat_indices: | |
| if outputs[idx] > threshold: | |
| tag_name = self.tag_names[idx].replace("_", " ") | |
| sub_res[tag_name] = float(outputs[idx]) | |
| res_key = "characters" if cat_key == "character" else cat_key | |
| res[res_key] = dict(sorted(sub_res.items(), key=lambda kv: kv[1], reverse=True)) | |
| tag_categories_for_translation[res_key] = list(res[res_key].keys()) | |
| return res, tag_categories_for_translation | |
| # Global Tagger instance | |
| try: | |
| tagger_instance = Tagger() | |
| except RuntimeError as e: | |
| print(f"Tagger initialization failed on app startup: {e}") | |
| tagger_instance = None | |
| # ------------------------------------------------------------------ | |
| # Gradio UI | |
| # ------------------------------------------------------------------ | |
| custom_css = """ | |
| .label-container { max-height: 300px; overflow-y: auto; border: 1px solid #ddd; padding: 10px; border-radius: 5px; background-color: #f9f9f9; } | |
| .tag-item { display: flex; justify-content: space-between; align-items: center; margin: 2px 0; padding: 2px 5px; border-radius: 3px; background-color: #fff; transition: background-color 0.2s; } | |
| .tag-item:hover { background-color: #f0f0f0; } | |
| .tag-en { font-weight: bold; color: #333; cursor: pointer; } | |
| .tag-zh { color: #666; margin-left: 10px; } | |
| .tag-score { color: #999; font-size: 0.9em; } | |
| .btn-analyze-container { margin-top: 15px; margin-bottom: 15px; } | |
| """ | |
| _js_functions = """ | |
| function copyToClipboard(text) { | |
| if (typeof text === 'undefined' || text === null) { | |
| console.warn('copyToClipboard was called with undefined or null text.'); | |
| return; | |
| } | |
| navigator.clipboard.writeText(text).then(() => { | |
| const feedback = document.createElement('div'); | |
| let displayText = String(text).substring(0, 30) + (String(text).length > 30 ? '...' : ''); | |
| feedback.textContent = '已复制: ' + displayText; | |
| Object.assign(feedback.style, { | |
| position: 'fixed', bottom: '20px', left: '50%', transform: 'translateX(-50%)', | |
| backgroundColor: '#4CAF50', color: 'white', padding: '10px 20px', | |
| borderRadius: '5px', zIndex: '10000', transition: 'opacity 0.5s ease-out' | |
| }); | |
| document.body.appendChild(feedback); | |
| setTimeout(() => { | |
| feedback.style.opacity = '0'; | |
| setTimeout(() => { if (document.body.contains(feedback)) document.body.removeChild(feedback); }, 500); | |
| }, 1500); | |
| }).catch(err => { | |
| console.error('Failed to copy tag. Error:', err, 'Attempted to copy text:', text); | |
| }); | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css, js=_js_functions) as demo: | |
| gr.Markdown("# 🖼️ AI 图像标签分析器") | |
| gr.Markdown("上传图片自动识别标签,支持中英文显示和一键复制。[NovelAI在线绘画](https://nai.idlecloud.cc/)") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| login_button = gr.LoginButton(value="🤗 通过 Hugging Face 登录") | |
| user_status_md = gr.Markdown("ℹ️ 正在检查登录状态...") | |
| state_res = gr.State({}) | |
| state_translations_dict = gr.State({}) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| img_in = gr.Image(type="pil", label="上传图片", height=300) | |
| btn = gr.Button("🚀 开始分析", variant="primary", elem_classes=["btn-analyze-container"]) | |
| with gr.Accordion("⚙️ 高级设置", open=False): | |
| gen_slider = gr.Slider(0, 1, value=0.35, step=0.01, label="通用标签阈值") | |
| char_slider = gr.Slider(0, 1, value=0.85, step=0.01, label="角色标签阈值") | |
| show_tag_scores = gr.Checkbox(True, label="在列表中显示标签置信度") | |
| with gr.Accordion("🔑 自定义翻译密钥 (可选)", open=False, visible=False) as api_key_accordion: | |
| gr.Markdown("如果你不是空间所有者,需要在这里提供自己的API密钥才能使用翻译功能。") | |
| tencent_id_in = gr.Textbox(label="腾讯云 Secret ID", lines=1) | |
| tencent_key_in = gr.Textbox(label="腾讯云 Secret Key", lines=1, type="password") | |
| baidu_json_in = gr.Textbox(label="百度翻译凭证 (JSON 格式)", lines=3, placeholder='[{"app_id": "...", "secret_key": "..."}]') | |
| with gr.Accordion("📊 标签汇总设置", open=True): | |
| sum_cats = gr.CheckboxGroup(["通用标签", "角色标签", "评分标签"], value=["通用标签", "角色标签"], label="汇总类别") | |
| sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="标签分隔符") | |
| sum_show_zh = gr.Checkbox(False, label="在汇总中显示中文翻译") | |
| processing_info = gr.Markdown("", visible=False) | |
| with gr.Column(scale=2): | |
| with gr.Tabs(): | |
| with gr.TabItem("🏷️ 通用标签"): out_general = gr.HTML(label="General Tags") | |
| with gr.TabItem("👤 角色标签"): out_char = gr.HTML(label="Character Tags") | |
| with gr.TabItem("⭐ 评分标签"): out_rating = gr.HTML(label="Rating Tags") | |
| gr.Markdown("### 标签汇总结果") | |
| out_summary = gr.Textbox(label="标签汇总", lines=5, show_copy_button=True) | |
| def get_token_from_request(request: gr.Request) -> str | None: | |
| auth_header = request.headers.get("authorization") | |
| if auth_header and auth_header.startswith("Bearer "): | |
| return auth_header.split(" ")[1] | |
| return None | |
| def is_user_space_owner(user_info: dict | None) -> bool: | |
| """ | |
| Robustly checks if the user is the owner of the space by parsing SPACE_ID. | |
| """ | |
| if not user_info or not SPACE_OWNER: | |
| if not SPACE_OWNER: | |
| print("⚠️ Warning: SPACE_ID environment variable not found.") | |
| return False | |
| user_name = user_info.get("name") | |
| user_orgs = [org.get("name") for org in user_info.get("orgs", [])] | |
| print(f"ℹ️ [Auth Check] Space Owner: '{SPACE_OWNER}', User: '{user_name}', User Orgs: {user_orgs}") | |
| is_owner = (user_name == SPACE_OWNER) or (SPACE_OWNER in user_orgs) | |
| return is_owner | |
| def check_user_status(request: gr.Request): | |
| token = get_token_from_request(request) | |
| if token: | |
| try: | |
| user_info = whoami(token=token) | |
| if is_user_space_owner(user_info): | |
| return f"✅ 以所有者 **{user_info.get('fullname', user_info.get('name'))}** 身份登录,将使用空间配置的密钥。", gr.update(visible=False) | |
| else: | |
| return f"👋 你好, **{user_info.get('fullname', '用户')}**!请在下方提供你自己的翻译 API 密钥。", gr.update(visible=True, open=True) | |
| except Exception as e: | |
| print(f"Error getting user info: {e}") | |
| return "⚠️ 无法验证您的登录状态。请提供 API 密钥。", gr.update(visible=True, open=True) | |
| return "ℹ️ **访客模式**。如需使用翻译功能,请<a href='/login?redirect=/'>登录</a>或提供 API 密钥。", gr.update(visible=True, open=True) | |
| def format_tags_html(tags_dict, translations_list, show_scores): | |
| if not tags_dict: return "<p>暂无标签</p>" | |
| html = '<div class="label-container">' | |
| for i, (tag, score) in enumerate(tags_dict.items()): | |
| escaped_tag = tag.replace("'", "\\'") | |
| html += '<div class="tag-item">' | |
| tag_display_html = f'<span class="tag-en" onclick="copyToClipboard(\'{escaped_tag}\')">{tag}</span>' | |
| if i < len(translations_list) and translations_list[i]: | |
| tag_display_html += f'<span class="tag-zh">({translations_list[i]})</span>' | |
| html += f'<div>{tag_display_html}</div>' | |
| if show_scores: html += f'<span class="tag-score">{score:.3f}</span>' | |
| html += '</div>' | |
| return html + '</div>' | |
| def generate_summary_text_content(current_res, translations, sum_cats, sep_type, show_zh): | |
| if not current_res: return "请先分析图像。" | |
| parts, sep = [], {"逗号": ", ", "换行": "\n", "空格": " "}.get(sep_type, ", ") | |
| cat_map = {"通用标签": "general", "角色标签": "characters", "评分标签": "ratings"} | |
| for cat_name in sum_cats: | |
| cat_key = cat_map.get(cat_name) | |
| if cat_key and current_res.get(cat_key): | |
| tags_en, trans = list(current_res[cat_key].keys()), translations.get(cat_key, []) | |
| tags_to_join = [f"{en}({zh})" if show_zh and i < len(trans) and trans[i] else en for i, en in enumerate(tags_en)] | |
| if tags_to_join: parts.append(sep.join(tags_to_join)) | |
| return "\n".join(parts) if parts else "选定的类别中没有找到标签。" | |
| def process_image_and_generate_outputs( | |
| img, g_th, c_th, s_scores, | |
| user_tencent_id, user_tencent_key, user_baidu_json, | |
| sum_cats, s_sep, s_zh_in_sum, | |
| request: gr.Request | |
| ): | |
| if img is None: | |
| raise gr.Error("请先上传图片。") | |
| if tagger_instance is None: | |
| raise gr.Error("分析器未成功初始化,请检查后台错误。") | |
| yield gr.update(interactive=False, value="🔄 处理中..."), gr.update(visible=True, value="🔄 正在分析..."), *["<p>分析中...</p>"]*3, "分析中...", {}, {} | |
| token = get_token_from_request(request) | |
| is_owner = False | |
| if token: | |
| try: | |
| user_info = whoami(token=token) | |
| if is_user_space_owner(user_info): | |
| is_owner = True | |
| except Exception: pass | |
| final_tencent_id, final_tencent_key, baidu_json_str = ( | |
| (os.environ.get("TENCENT_SECRET_ID"), os.environ.get("TENCENT_SECRET_KEY"), os.environ.get("BAIDU_CREDENTIALS_JSON", "[]")) | |
| if is_owner else (user_tencent_id, user_tencent_key, user_baidu_json) | |
| ) | |
| final_baidu_creds_list = [] | |
| if baidu_json_str and baidu_json_str.strip(): | |
| try: | |
| parsed_data = json.loads(baidu_json_str) | |
| if isinstance(parsed_data, list): final_baidu_creds_list = parsed_data | |
| except json.JSONDecodeError: print("提供的百度凭证JSON无效。") | |
| try: | |
| res, tag_cats_original = tagger_instance.predict(img, g_th, c_th) | |
| all_tags = [tag for cat in tag_cats_original.values() for tag in cat] | |
| translations_flat = translate_texts( | |
| all_tags, | |
| tencent_secret_id=final_tencent_id, | |
| tencent_secret_key=final_tencent_key, | |
| baidu_credentials_list=final_baidu_creds_list | |
| ) if all_tags else [] | |
| translations, offset = {}, 0 | |
| for cat_key, tags in tag_cats_original.items(): | |
| translations[cat_key] = translations_flat[offset : offset + len(tags)] | |
| offset += len(tags) | |
| outputs_html = {k: format_tags_html(res.get(k, {}), translations.get(k, []), s_scores) for k in ["general", "characters", "ratings"]} | |
| summary = generate_summary_text_content(res, translations, sum_cats, s_sep, s_zh_in_sum) | |
| yield gr.update(interactive=True, value="🚀 开始分析"), gr.update(visible=True, value="✅ 分析完成!"), outputs_html["general"], outputs_html["characters"], outputs_html["ratings"], summary, res, translations | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| raise gr.Error(f"处理时发生错误: {e}") | |
| demo.load(fn=check_user_status, inputs=None, outputs=[user_status_md, api_key_accordion], queue=False) | |
| btn.click( | |
| process_image_and_generate_outputs, | |
| inputs=[ | |
| img_in, gen_slider, char_slider, show_tag_scores, | |
| tencent_id_in, tencent_key_in, baidu_json_in, | |
| sum_cats, sum_sep, sum_show_zh | |
| ], | |
| outputs=[ | |
| btn, processing_info, | |
| out_general, out_char, out_rating, | |
| out_summary, | |
| state_res, state_translations_dict | |
| ], | |
| ) | |
| summary_controls = [sum_cats, sum_sep, sum_show_zh] | |
| for ctrl in summary_controls: | |
| ctrl.change( | |
| fn=lambda r, t, c, s, z: generate_summary_text_content(r, t, c, s, z), | |
| inputs=[state_res, state_translations_dict] + summary_controls, | |
| outputs=[out_summary], | |
| ) | |
| if __name__ == "__main__": | |
| if tagger_instance is None: | |
| print("CRITICAL: Tagger failed to initialize. App functionality will be limited.") | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |