Spaces:
Running
Running
import os | |
import re | |
import json | |
from datetime import datetime | |
from typing import List, Dict, Any, Optional, Literal | |
from fastapi import FastAPI, Request, BackgroundTasks | |
from fastapi.middleware.cors import CORSMiddleware | |
import gradio as gr | |
import uvicorn | |
from pydantic import BaseModel | |
from huggingface_hub.inference._mcp.agent import Agent | |
from dotenv import load_dotenv | |
load_dotenv() | |
# Configuration | |
WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET", "your-webhook-secret") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
HF_MODEL = os.getenv("HF_MODEL", "microsoft/DialoGPT-medium") | |
# Use a valid provider literal from the documentation | |
DEFAULT_PROVIDER: Literal["hf-inference"] = "hf-inference" | |
HF_PROVIDER = os.getenv("HF_PROVIDER", DEFAULT_PROVIDER) | |
# Simple storage for processed tag operations | |
tag_operations_store: List[Dict[str, Any]] = [] | |
# Agent instance | |
agent_instance: Optional[Agent] = None | |
# Common ML tags that we recognize for auto-tagging | |
RECOGNIZED_TAGS = { | |
"pytorch", | |
"tensorflow", | |
"jax", | |
"transformers", | |
"diffusers", | |
"text-generation", | |
"text-classification", | |
"question-answering", | |
"text-to-image", | |
"image-classification", | |
"object-detection", | |
" ", | |
"fill-mask", | |
"token-classification", | |
"translation", | |
"summarization", | |
"feature-extraction", | |
"sentence-similarity", | |
"zero-shot-classification", | |
"image-to-text", | |
"automatic-speech-recognition", | |
"audio-classification", | |
"voice-activity-detection", | |
"depth-estimation", | |
"image-segmentation", | |
"video-classification", | |
"reinforcement-learning", | |
"tabular-classification", | |
"tabular-regression", | |
"time-series-forecasting", | |
"graph-ml", | |
"robotics", | |
"computer-vision", | |
"nlp", | |
"cv", | |
"multimodal", | |
} | |
class WebhookEvent(BaseModel): | |
event: Dict[str, str] # Contains action and scope information | |
comment: Dict[str, Any] # Comment content and metadata | |
discussion: Dict[str, Any] # Discussion information | |
repo: Dict[str, str] # Repository details | |
app = FastAPI(title="HF Tagging Bot") | |
app.add_middleware(CORSMiddleware, allow_origins=["*"]) | |
async def get_agent(): | |
"""Get or create Agent instance""" | |
print("π€ get_agent() called...") | |
global agent_instance | |
if agent_instance is None and HF_TOKEN: | |
print("π§ Creating new Agent instance...") | |
print(f"π HF_TOKEN present: {bool(HF_TOKEN)}") | |
print(f"π€ Model: {HF_MODEL}") | |
print(f"π Provider: {DEFAULT_PROVIDER}") | |
try: | |
agent_instance = Agent( | |
model=HF_MODEL, | |
provider=DEFAULT_PROVIDER, | |
api_key=HF_TOKEN, | |
servers=[ | |
{ | |
"type": "stdio", | |
"config": { | |
"command": "python", | |
"args": ["mcp_server.py"], | |
"cwd": ".", # Ensure correct working directory | |
"env": {"HF_TOKEN": HF_TOKEN} if HF_TOKEN else {}, | |
}, | |
} | |
], | |
) | |
print("β Agent instance created successfully") | |
print("π§ Loading tools...") | |
await agent_instance.load_tools() | |
print("β Tools loaded successfully") | |
except Exception as e: | |
print(f"β Error creating/loading agent: {str(e)}") | |
agent_instance = None | |
elif agent_instance is None: | |
print("β No HF_TOKEN available, cannot create agent") | |
else: | |
print("β Using existing agent instance") | |
return agent_instance | |
def extract_tags_from_text(text: str) -> List[str]: | |
"""Extract potential tags from discussion text""" | |
text_lower = text.lower() | |
# Look for explicit tag mentions like "tag: pytorch" or "#pytorch" | |
explicit_tags = [] | |
# Pattern 1: "tag: something" or "tags: something" | |
tag_pattern = r"tags?:\s*([a-zA-Z0-9-_,\s]+)" | |
matches = re.findall(tag_pattern, text_lower) | |
for match in matches: | |
# Split by comma and clean up | |
tags = [tag.strip() for tag in match.split(",")] | |
explicit_tags.extend(tags) | |
# Pattern 2: "#hashtag" style | |
hashtag_pattern = r"#([a-zA-Z0-9-_]+)" | |
hashtag_matches = re.findall(hashtag_pattern, text_lower) | |
explicit_tags.extend(hashtag_matches) | |
# Pattern 3: Look for recognized tags mentioned in natural text | |
mentioned_tags = [] | |
for tag in RECOGNIZED_TAGS: | |
if tag in text_lower: | |
mentioned_tags.append(tag) | |
# Combine and deduplicate | |
all_tags = list(set(explicit_tags + mentioned_tags)) | |
# Filter to only include recognized tags or explicitly mentioned ones | |
valid_tags = [] | |
for tag in all_tags: | |
if tag in RECOGNIZED_TAGS or tag in explicit_tags: | |
valid_tags.append(tag) | |
return valid_tags | |
# async def process_webhook_comment(webhook_data: Dict[str, Any]): | |
# """Process webhook to detect and add tags""" | |
# print("π·οΈ Starting process_webhook_comment...") | |
# try: | |
# comment_content = webhook_data["comment"]["content"] | |
# discussion_title = webhook_data["discussion"]["title"] | |
# repo_name = webhook_data["repo"]["name"] | |
# discussion_num = webhook_data["discussion"]["num"] | |
# # Author is an object with "id" field | |
# comment_author = webhook_data["comment"]["author"].get("id", "unknown") | |
# print(f"π Comment content: {comment_content}") | |
# print(f"π° Discussion title: {discussion_title}") | |
# print(f"π¦ Repository: {repo_name}") | |
# # Extract potential tags from the comment and discussion title | |
# comment_tags = extract_tags_from_text(comment_content) | |
# title_tags = extract_tags_from_text(discussion_title) | |
# all_tags = list(set(comment_tags + title_tags)) | |
# print(f"π Comment tags found: {comment_tags}") | |
# print(f"π Title tags found: {title_tags}") | |
# print(f"π·οΈ All unique tags: {all_tags}") | |
# result_messages = [] | |
# if not all_tags: | |
# msg = "No recognizable tags found in the discussion." | |
# print(f"β {msg}") | |
# result_messages.append(msg) | |
# else: | |
# print("π€ Getting agent instance...") | |
# agent = await get_agent() | |
# if not agent: | |
# msg = "Error: Agent not configured (missing HF_TOKEN)" | |
# print(f"β {msg}") | |
# result_messages.append(msg) | |
# else: | |
# print("β Agent instance obtained successfully") | |
# # Process all tags in a single conversation with the agent | |
# try: | |
# # Create a comprehensive prompt for the agent | |
# user_prompt = f""" | |
# I need to add the following tags to the repository '{repo_name}': {", ".join(all_tags)} | |
# For each tag, please: | |
# 1. Check if the tag already exists on the repository using get_current_tags | |
# 2. If the tag doesn't exist, add it using add_new_tag | |
# 3. Provide a summary of what was done for each tag | |
# Please process all {len(all_tags)} tags: {", ".join(all_tags)} | |
# """ | |
# print("π¬ Sending comprehensive prompt to agent...") | |
# print(f"π Prompt: {user_prompt}") | |
# # Let the agent handle the entire conversation | |
# conversation_result = [] | |
# try: | |
# async for item in agent.run(user_prompt): | |
# # The agent yields different types of items | |
# item_str = str(item) | |
# conversation_result.append(item_str) | |
# # Log important events | |
# if ( | |
# "tool_call" in item_str.lower() | |
# or "function" in item_str.lower() | |
# ): | |
# print(f"π§ Agent using tools: {item_str[:200]}...") | |
# elif "content" in item_str and len(item_str) < 500: | |
# print(f"π Agent response: {item_str}") | |
# # Extract the final response from the conversation | |
# full_response = " ".join(conversation_result) | |
# print(f"π Agent conversation completed successfully") | |
# # Try to extract meaningful results for each tag | |
# for tag in all_tags: | |
# tag_mentioned = tag.lower() in full_response.lower() | |
# if ( | |
# "already exists" in full_response.lower() | |
# and tag_mentioned | |
# ): | |
# msg = f"Tag '{tag}': Already exists" | |
# elif ( | |
# "pr" in full_response.lower() | |
# or "pull request" in full_response.lower() | |
# ): | |
# if tag_mentioned: | |
# msg = f"Tag '{tag}': PR created successfully" | |
# else: | |
# msg = ( | |
# f"Tag '{tag}': Processed " | |
# "(PR may have been created)" | |
# ) | |
# elif "success" in full_response.lower() and tag_mentioned: | |
# msg = f"Tag '{tag}': Successfully processed" | |
# elif "error" in full_response.lower() and tag_mentioned: | |
# msg = f"Tag '{tag}': Error during processing" | |
# else: | |
# msg = f"Tag '{tag}': Processed by agent" | |
# print(f"β Result for tag '{tag}': {msg}") | |
# result_messages.append(msg) | |
# except Exception as agent_error: | |
# print(f"β οΈ Agent streaming failed: {str(agent_error)}") | |
# print("π Falling back to direct MCP tool calls...") | |
# # Import the MCP server functions directly as fallback | |
# try: | |
# import sys | |
# import importlib.util | |
# # Load the MCP server module | |
# spec = importlib.util.spec_from_file_location( | |
# "mcp_server", "./mcp_server.py" | |
# ) | |
# mcp_module = importlib.util.module_from_spec(spec) # type: ignore | |
# spec.loader.exec_module(mcp_module) # type: ignore | |
# # Use the MCP tools directly for each tag | |
# for tag in all_tags: | |
# try: | |
# print( | |
# f"π§ Directly calling get_current_tags for '{tag}'" | |
# ) | |
# current_tags_result = mcp_module.get_current_tags( | |
# repo_name | |
# ) | |
# print( | |
# f"π Current tags result: {current_tags_result}" | |
# ) | |
# # Parse the JSON result | |
# import json | |
# tags_data = json.loads(current_tags_result) | |
# if tags_data.get("status") == "success": | |
# current_tags = tags_data.get("current_tags", []) | |
# if tag in current_tags: | |
# msg = f"Tag '{tag}': Already exists" | |
# print(f"β {msg}") | |
# else: | |
# print( | |
# f"π§ Directly calling add_new_tag for '{tag}'" | |
# ) | |
# add_result = mcp_module.add_new_tag( | |
# repo_name, tag | |
# ) | |
# print(f"π Add tag result: {add_result}") | |
# add_data = json.loads(add_result) | |
# if add_data.get("status") == "success": | |
# pr_url = add_data.get("pr_url", "") | |
# msg = f"Tag '{tag}': PR created - {pr_url}" | |
# elif ( | |
# add_data.get("status") | |
# == "already_exists" | |
# ): | |
# msg = f"Tag '{tag}': Already exists" | |
# else: | |
# msg = f"Tag '{tag}': {add_data.get('message', 'Processed')}" | |
# print(f"β {msg}") | |
# else: | |
# error_msg = tags_data.get( | |
# "error", "Unknown error" | |
# ) | |
# msg = f"Tag '{tag}': Error - {error_msg}" | |
# print(f"β {msg}") | |
# result_messages.append(msg) | |
# except Exception as direct_error: | |
# error_msg = f"Tag '{tag}': Direct call error - {str(direct_error)}" | |
# print(f"β {error_msg}") | |
# result_messages.append(error_msg) | |
# except Exception as fallback_error: | |
# error_msg = ( | |
# f"Fallback approach failed: {str(fallback_error)}" | |
# ) | |
# print(f"β {error_msg}") | |
# result_messages.append(error_msg) | |
# except Exception as e: | |
# error_msg = f"Error during agent processing: {str(e)}" | |
# print(f"β {error_msg}") | |
# result_messages.append(error_msg) | |
# # Store the interaction | |
# base_url = "https://huggingface.co" | |
# discussion_url = f"{base_url}/{repo_name}/discussions/{discussion_num}" | |
# interaction = { | |
# "timestamp": datetime.now().isoformat(), | |
# "repo": repo_name, | |
# "discussion_title": discussion_title, | |
# "discussion_num": discussion_num, | |
# "discussion_url": discussion_url, | |
# "original_comment": comment_content, | |
# "comment_author": comment_author, | |
# "detected_tags": all_tags, | |
# "results": result_messages, | |
# } | |
# tag_operations_store.append(interaction) | |
# final_result = " | ".join(result_messages) | |
# print(f"πΎ Stored interaction and returning result: {final_result}") | |
# return final_result | |
# except Exception as e: | |
# error_msg = f"β Fatal error in process_webhook_comment: {str(e)}" | |
# print(error_msg) | |
# return error_msg | |
async def webhook_handler(request: Request, background_tasks: BackgroundTasks): | |
""" | |
Handle incoming webhooks from Hugging Face Hub | |
Following the pattern from: https://raw.githubusercontent.com/huggingface/hub-docs/refs/heads/main/docs/hub/webhooks-guide-discussion-bot.md | |
""" | |
print("π Webhook received!") | |
# Step 1: Validate webhook secret (security) | |
webhook_secret = request.headers.get("X-Webhook-Secret") | |
if webhook_secret != WEBHOOK_SECRET: | |
print("β Invalid webhook secret") | |
return {"error": "Invalid webhook secret"}, 400 | |
# Step 2: Parse webhook data | |
try: | |
webhook_data = await request.json() | |
print(f"π₯ Webhook data: {json.dumps(webhook_data, indent=2)}") | |
except Exception as e: | |
print(f"β Error parsing webhook data: {str(e)}") | |
return {"error": "invalid JSON"}, 400 | |
# Step 3: Validate event structure | |
event = webhook_data.get("event", {}) | |
if not event: | |
print("β No event data in webhook") | |
return {"error": "missing event data"}, 400 | |
scope = event.get("scope") | |
action = event.get("action") | |
print(f"π Event details - scope: {scope}, action: {action}") | |
# Step 4: Check if this is a discussion comment creation | |
# Following the webhook guide pattern: | |
if ( | |
action == "create" and | |
scope == "discussion.comment" | |
): | |
print("β Valid discussion comment creation event") | |
# Process in background to return quickly to Hub | |
background_tasks.add_task(process_webhook_comment, webhook_data) | |
return { | |
"status": "accepted", | |
"message": "Comment processing started", | |
"timestamp": datetime.now().isoformat() | |
} | |
else: | |
print(f"βΉοΈ Ignoring event: action={event.get('action')}, scope={event.get('scope')}") | |
return { | |
"status": "ignored", | |
"reason": "Not a discussion comment creation" | |
} | |
async def process_webhook_comment(webhook_data: Dict[str, Any]): | |
""" | |
Process webhook comment to detect and add tags | |
Integrates with our MCP client for Hub interactions | |
""" | |
print("π·οΈ Starting process_webhook_comment...") | |
try: | |
# Extract comment and repository information | |
comment_content = webhook_data["comment"]["content"] | |
discussion_title = webhook_data["discussion"]["title"] | |
repo_name = webhook_data["repo"]["name"] | |
discussion_num = webhook_data["discussion"]["num"] | |
comment_author = webhook_data["comment"]["author"].get("id", "unknown") | |
print(f"π Comment from {comment_author}: {comment_content}") | |
print(f"π° Discussion: {discussion_title}") | |
print(f"π¦ Repository: {repo_name}") | |
except Exception as e: | |
print(f"β Error parsing webhook data: {str(e)}") | |
return {"error": "invalid JSON"}, 400 | |
# Extract potential tags from comment and title | |
comment_tags = extract_tags_from_text(comment_content) | |
title_tags = extract_tags_from_text(discussion_title) | |
all_tags = list(set(comment_tags + title_tags)) | |
print(f"π Found tags: {all_tags}") | |
# Store operation for monitoring | |
operation = { | |
"timestamp": datetime.now().isoformat(), | |
"repo_name": repo_name, | |
"discussion_num": discussion_num, | |
"comment_author": comment_author, | |
"extracted_tags": all_tags, | |
"comment_preview": comment_content[:100] + "..." if len(comment_content) > 100 else comment_content, | |
"status": "processing" | |
} | |
tag_operations_store.append(operation) | |
if not all_tags: | |
operation["status"] = "no_tags" | |
operation["message"] = "No recognizable tags found" | |
print("β No tags found to process") | |
return | |
# Get MCP agent for tag processing | |
agent = await get_agent() | |
if not agent: | |
operation["status"] = "error" | |
operation["message"] = "Agent not configured (missing HF_TOKEN)" | |
print("β No agent available") | |
return | |
# Process each extracted tag | |
operation["results"] = [] | |
for tag in all_tags: | |
try: | |
print(f"π€ Processing tag '{tag}' for repo '{repo_name}'") | |
# Create prompt for agent to handle tag processing | |
prompt = f""" | |
Analyze the repository '{repo_name}' and determine if the tag '{tag}' should be added. | |
First, check the current tags using get_current_tags. | |
If '{tag}' is not already present and it's a valid tag, add it using add_new_tag. | |
Repository: {repo_name} | |
Tag to process: {tag} | |
Provide a clear summary of what was done. | |
""" | |
response = await agent.run(prompt) # type: ignore | |
print(f"π€ Agent response for '{tag}': {response}") | |
# Parse response and store result | |
tag_result = { | |
"tag": tag, | |
"response": response, | |
"timestamp": datetime.now().isoformat() | |
} | |
operation["results"].append(tag_result) | |
except Exception as e: | |
error_msg = f"β Error processing tag '{tag}': {str(e)}" | |
print(error_msg) | |
operation["results"].append({ | |
"tag": tag, | |
"error": str(e), | |
"timestamp": datetime.now().isoformat() | |
}) | |
operation["status"] = "completed" | |
print(f"β Completed processing {len(all_tags)} tags") | |
async def root(): | |
"""Root endpoint with basic information""" | |
return { | |
"name": "HF Tagging Bot", | |
"status": "running", | |
"description": "Webhook listener for automatic model tagging", | |
"endpoints": { | |
"webhook": "/webhook", | |
"health": "/health", | |
"operations": "/operations" | |
} | |
} | |
async def health_check(): | |
"""Health check endpoint for monitoring""" | |
agent = await get_agent() | |
return { | |
"status": "healthy", | |
"timestamp": datetime.now().isoformat(), | |
"components": { | |
"webhook_secret": "configured" if WEBHOOK_SECRET else "missing", | |
"hf_token": "configured" if HF_TOKEN else "missing", | |
"mcp_agent": "ready" if agent else "not_ready" | |
} | |
} | |
async def get_operations(): | |
"""Get recent tag operations for monitoring""" | |
# Return last 50 operations | |
recent_ops = tag_operations_store[-50:] if tag_operations_store else [] | |
return { | |
"total_operations": len(tag_operations_store), | |
"recent_operations": recent_ops | |
} | |
# async def simulate_webhook( | |
# repo_name: str, discussion_title: str, comment_content: str | |
# ) -> str: | |
# """Simulate webhook for testing""" | |
# if not all([repo_name, discussion_title, comment_content]): | |
# return "Please fill in all fields." | |
# | |
# mock_payload = { | |
# "event": {"action": "create", "scope": "discussion"}, | |
# "comment": { | |
# "content": comment_content, | |
# "author": {"id": "test-user-id"}, | |
# "id": "mock-comment-id", | |
# "hidden": False, | |
# }, | |
# "discussion": { | |
# "title": discussion_title, | |
# "num": len(tag_operations_store) + 1, | |
# "id": "mock-discussion-id", | |
# "status": "open", | |
# "isPullRequest": False, | |
# }, | |
# "repo": { | |
# "name": repo_name, | |
# "type": "model", | |
# "private": False, | |
# }, | |
# } | |
# | |
# response = await process_webhook_comment(mock_payload) | |
# return f"β Processed! Results: {response}" | |
def create_gradio_app(): | |
"""Create Gradio interface""" | |
with gr.Blocks( | |
title="HF Tagging Bot", | |
# theme=gr.themes.Soft() | |
) as demo: | |
gr.Markdown("# π·οΈ HF Tagging Bot Dashboard") | |
gr.Markdown("*Automatically adds tags to models when mentioned in discussions*") | |
gr.Markdown(""" | |
## How it works: | |
- Monitors HuggingFace Hub discussions | |
- Detects tag mentions in comments (e.g., "tag: pytorch", | |
"#transformers") | |
- Automatically adds recognized tags to the model repository | |
- Supports common ML tags like: pytorch, tensorflow, | |
text-generation, etc. | |
""") | |
with gr.Column(): | |
sim_repo = gr.Textbox( | |
label="Repository", | |
value="burtenshaw/play-mcp-repo-bot", | |
placeholder="username/model-name", | |
) | |
sim_title = gr.Textbox( | |
label="Discussion Title", | |
value="Add pytorch tag", | |
placeholder="Discussion title", | |
) | |
sim_comment = gr.Textbox( | |
label="Comment", | |
lines=3, | |
value="This model should have tags: pytorch, text-generation", | |
placeholder="Comment mentioning tags...", | |
) | |
sim_btn = gr.Button("π·οΈ Test Tag Detection") | |
with gr.Column(): | |
sim_result = gr.Textbox(label="Result", lines=8) | |
# sim_btn.click( | |
# fn=simulate_webhook, | |
# inputs=[sim_repo, sim_title, sim_comment], | |
# outputs=sim_result, | |
# ) | |
gr.Markdown(f""" | |
## Recognized Tags: | |
{", ".join(sorted(RECOGNIZED_TAGS))} | |
""") | |
return demo | |
async def welcome() -> dict: | |
return { "message": "Hello World"} | |
async def welcome_gradio2() -> dict: | |
return { "message": "Hello gradio2"} | |
# Mount Gradio app | |
gradio_app = create_gradio_app() | |
app = gr.mount_gradio_app(app, gradio_app, path="/gradio") | |
if __name__ == "__main__": | |
print("π Starting HF Tagging Bot...") | |
print("π Dashboard: http://localhost:7860/gradio") | |
print("π Webhook: http://localhost:7860/webhook") | |
# | |
# QUICK-AND-DIRTY TEST WITHOUT uvicorn | |
# | |
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) | |
# | |
# gradio_app = create_gradio_app() | |
# gradio_app.launch() | |
# EOF | |