Spaces:
No application file
No application file
import os | |
import re | |
import json | |
from datetime import datetime | |
from typing import List, Dict, Any, Optional, Literal | |
from fastapi import FastAPI, Request, BackgroundTasks | |
from fastapi.middleware.cors import CORSMiddleware | |
import gradio as gr | |
import uvicorn | |
from pydantic import BaseModel | |
from dotenv import load_dotenv | |
load_dotenv() | |
# Configuration | |
WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET", "716f77a91d0415cd0e3ed9dc8d188fc9ee53b11a8661e161a86f669f598a8016") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# Simple storage for processed tag operations | |
tag_operations_store: List[Dict[str, Any]] = [] | |
# Common ML tags that we recognize for auto-tagging | |
RECOGNIZED_TAGS = { | |
"pytorch", | |
"tensorflow", | |
"jax", | |
"transformers", | |
"diffusers", | |
"text-generation", | |
"text-classification", | |
"question-answering", | |
"text-to-image", | |
"image-classification", | |
"object-detection", | |
"fill-mask", | |
"token-classification", | |
"translation", | |
"summarization", | |
"feature-extraction", | |
"sentence-similarity", | |
"zero-shot-classification", | |
"image-to-text", | |
"automatic-speech-recognition", | |
"audio-classification", | |
"voice-activity-detection", | |
"depth-estimation", | |
"image-segmentation", | |
"video-classification", | |
"reinforcement-learning", | |
"tabular-classification", | |
"tabular-regression", | |
"time-series-forecasting", | |
"graph-ml", | |
"robotics", | |
"computer-vision", | |
"nlp", | |
"cv", | |
"multimodal", | |
} | |
class WebhookEvent(BaseModel): | |
event: Dict[str, str] | |
comment: Dict[str, Any] | |
discussion: Dict[str, Any] | |
repo: Dict[str, str] | |
app = FastAPI(title="HF Tagging Bot") | |
app.add_middleware(CORSMiddleware, allow_origins=["*"]) | |
def extract_tags_from_text(text: str) -> List[str]: | |
"""Extract potential tags from discussion text""" | |
text_lower = text.lower() | |
# Look for explicit tag mentions like "tag: pytorch" or "#pytorch" | |
explicit_tags = [] | |
# Pattern 1: "tag: something" or "tags: something" | |
tag_pattern = r"tags?:\s*([a-zA-Z0-9-_,\s]+)" | |
matches = re.findall(tag_pattern, text_lower) | |
for match in matches: | |
# Split by comma and clean up | |
tags = [tag.strip() for tag in match.split(",")] | |
explicit_tags.extend(tags) | |
# Pattern 2: "#hashtag" style | |
hashtag_pattern = r"#([a-zA-Z0-9-_]+)" | |
hashtag_matches = re.findall(hashtag_pattern, text_lower) | |
explicit_tags.extend(hashtag_matches) | |
# Pattern 3: Look for recognized tags mentioned in natural text | |
mentioned_tags = [] | |
for tag in RECOGNIZED_TAGS: | |
if tag in text_lower: | |
mentioned_tags.append(tag) | |
# Combine and deduplicate | |
all_tags = list(set(explicit_tags + mentioned_tags)) | |
# Filter to only include recognized tags or explicitly mentioned ones | |
valid_tags = [] | |
for tag in all_tags: | |
if tag in RECOGNIZED_TAGS or tag in explicit_tags: | |
valid_tags.append(tag) | |
return valid_tags | |
async def process_tags_directly(all_tags: List[str], repo_name: str) -> List[str]: | |
"""Process tags using direct HuggingFace Hub API calls""" | |
print("π§ Using direct HuggingFace Hub API approach...") | |
result_messages = [] | |
if not HF_TOKEN: | |
error_msg = "No HF_TOKEN configured" | |
print(f"β {error_msg}") | |
return [error_msg] | |
try: | |
from huggingface_hub import HfApi, model_info, dataset_info, space_info, ModelCard, ModelCardData | |
from huggingface_hub.utils import HfHubHTTPError | |
from huggingface_hub import CommitOperationAdd | |
hf_api = HfApi(token=HF_TOKEN) | |
# First, let's determine what type of repository this is | |
repo_type = None | |
repo_info = None | |
# Try different repository types | |
for repo_type_to_try in ["model", "dataset", "space"]: | |
try: | |
print(f"π Trying to access {repo_name} as {repo_type_to_try}...") | |
if repo_type_to_try == "model": | |
repo_info = model_info(repo_id=repo_name, token=HF_TOKEN) | |
elif repo_type_to_try == "dataset": | |
repo_info = dataset_info(repo_id=repo_name, token=HF_TOKEN) | |
elif repo_type_to_try == "space": | |
repo_info = space_info(repo_id=repo_name, token=HF_TOKEN) | |
repo_type = repo_type_to_try | |
print(f"β Found repository as {repo_type}") | |
break | |
except HfHubHTTPError as e: | |
if "404" in str(e): | |
print(f"β οΈ Repository not found as {repo_type_to_try}") | |
continue | |
else: | |
print(f"β Error accessing as {repo_type_to_try}: {e}") | |
continue | |
except Exception as e: | |
print(f"β Unexpected error for {repo_type_to_try}: {e}") | |
continue | |
if not repo_type or not repo_info: | |
error_msg = f"Repository '{repo_name}' not found as model, dataset, or space" | |
print(f"β {error_msg}") | |
return [f"Error: {error_msg}"] | |
print(f"π Repository type: {repo_type}") | |
current_tags = repo_info.tags if repo_info.tags else [] | |
print(f"π·οΈ Current tags: {current_tags}") | |
# Process each tag | |
for tag in all_tags: | |
try: | |
# Check if tag already exists | |
if tag in current_tags: | |
msg = f"Tag '{tag}': Already exists" | |
print(f"β {msg}") | |
result_messages.append(msg) | |
continue | |
# Add the new tag | |
print(f"π§ Adding tag '{tag}' to {repo_type} '{repo_name}'") | |
updated_tags = current_tags + [tag] | |
# Create model card content with updated tags | |
try: | |
# Load existing model card | |
print(f"π Loading existing model card...") | |
card = ModelCard.load(repo_name, token=HF_TOKEN, repo_type=repo_type) | |
if not hasattr(card, "data") or card.data is None: | |
card.data = ModelCardData() | |
except HfHubHTTPError: | |
# Create new model card if none exists | |
print(f"π Creating new model card (none exists)") | |
card = ModelCard("") | |
card.data = ModelCardData() | |
# Update tags | |
card_dict = card.data.to_dict() | |
card_dict["tags"] = updated_tags | |
card.data = ModelCardData(**card_dict) | |
# Create a pull request with the updated model card | |
pr_title = f"Add '{tag}' tag" | |
pr_description = f""" | |
## Add tag: {tag} | |
This PR adds the `{tag}` tag to the {repo_type} repository. | |
**Changes:** | |
- Added `{tag}` to {repo_type} tags | |
- Updated from {len(current_tags)} to {len(updated_tags)} tags | |
**Current tags:** {", ".join(current_tags) if current_tags else "None"} | |
**New tags:** {", ".join(updated_tags)} | |
""" | |
print(f"π Creating PR with title: {pr_title}") | |
# Create commit with updated model card | |
commit_info = hf_api.create_commit( | |
repo_id=repo_name, | |
repo_type=repo_type, | |
operations=[ | |
CommitOperationAdd( | |
path_in_repo="README.md", | |
path_or_fileobj=str(card).encode("utf-8") | |
) | |
], | |
commit_message=pr_title, | |
commit_description=pr_description, | |
token=HF_TOKEN, | |
create_pr=True, | |
) | |
# Extract PR URL from commit info | |
pr_url = getattr(commit_info, 'pr_url', str(commit_info)) | |
print(f"β PR created successfully! URL: {pr_url}") | |
msg = f"Tag '{tag}': PR created - {pr_url}" | |
result_messages.append(msg) | |
except Exception as tag_error: | |
error_msg = f"Tag '{tag}': Error - {str(tag_error)}" | |
print(f"β {error_msg}") | |
result_messages.append(error_msg) | |
return result_messages | |
except Exception as e: | |
error_msg = f"Direct API processing failed: {str(e)}" | |
print(f"β {error_msg}") | |
return [error_msg] | |
async def process_webhook_comment(webhook_data: Dict[str, Any]): | |
"""Process webhook to detect and add tags""" | |
print("π·οΈ Starting process_webhook_comment...") | |
try: | |
comment_content = webhook_data["comment"]["content"] | |
discussion_title = webhook_data["discussion"]["title"] | |
repo_name = webhook_data["repo"]["name"] | |
discussion_num = webhook_data["discussion"]["num"] | |
comment_author = webhook_data["comment"]["author"].get("id", "unknown") | |
print(f"π Comment content: {comment_content}") | |
print(f"π° Discussion title: {discussion_title}") | |
print(f"π¦ Repository: {repo_name}") | |
# Extract potential tags from the comment and discussion title | |
comment_tags = extract_tags_from_text(comment_content) | |
title_tags = extract_tags_from_text(discussion_title) | |
all_tags = list(set(comment_tags + title_tags)) | |
print(f"π Comment tags found: {comment_tags}") | |
print(f"π Title tags found: {title_tags}") | |
print(f"π·οΈ All unique tags: {all_tags}") | |
result_messages = [] | |
if not all_tags: | |
msg = "No recognizable tags found in the discussion." | |
print(f"β {msg}") | |
result_messages.append(msg) | |
else: | |
# Skip agent entirely and use direct API approach | |
print("π§ Using direct HuggingFace Hub API processing...") | |
result_messages = await process_tags_directly(all_tags, repo_name) | |
# Store the interaction | |
base_url = "https://huggingface.co" | |
discussion_url = f"{base_url}/{repo_name}/discussions/{discussion_num}" | |
interaction = { | |
"timestamp": datetime.now().isoformat(), | |
"repo": repo_name, | |
"discussion_title": discussion_title, | |
"discussion_num": discussion_num, | |
"discussion_url": discussion_url, | |
"original_comment": comment_content, | |
"comment_author": comment_author, | |
"detected_tags": all_tags, | |
"results": result_messages, | |
} | |
tag_operations_store.append(interaction) | |
final_result = " | ".join(result_messages) | |
print(f"πΎ Stored interaction and returning result: {final_result}") | |
return final_result | |
except Exception as e: | |
error_msg = f"β Fatal error in process_webhook_comment: {str(e)}" | |
print(error_msg) | |
import traceback | |
print(f"β Traceback: {traceback.format_exc()}") | |
return error_msg | |
async def webhook_handler(request: Request, background_tasks: BackgroundTasks): | |
"""Handle HF Hub webhooks""" | |
webhook_secret = request.headers.get("X-Webhook-Secret") | |
if webhook_secret != WEBHOOK_SECRET: | |
print("β Invalid webhook secret") | |
return {"error": "Invalid webhook secret"} | |
payload = await request.json() | |
print(f"π₯ Received webhook payload: {json.dumps(payload, indent=2)}") | |
event = payload.get("event", {}) | |
scope = event.get("scope") | |
action = event.get("action") | |
print(f"π Event details - scope: {scope}, action: {action}") | |
# Check if this is a discussion comment creation | |
scope_check = scope == "discussion" | |
action_check = action == "create" | |
not_pr = not payload["discussion"]["isPullRequest"] | |
scope_check = scope_check and not_pr | |
print(f"β not_pr: {not_pr}") | |
print(f"β scope_check: {scope_check}") | |
print(f"β action_check: {action_check}") | |
if scope_check and action_check: | |
# Verify we have the required fields | |
required_fields = ["comment", "discussion", "repo"] | |
missing_fields = [field for field in required_fields if field not in payload] | |
if missing_fields: | |
error_msg = f"Missing required fields: {missing_fields}" | |
print(f"β {error_msg}") | |
return {"error": error_msg} | |
print(f"π Processing webhook for repo: {payload['repo']['name']}") | |
background_tasks.add_task(process_webhook_comment, payload) | |
return {"status": "processing"} | |
print(f"βοΈ Ignoring webhook - scope: {scope}, action: {action}") | |
return {"status": "ignored"} | |
async def simulate_webhook( | |
repo_name: str, discussion_title: str, comment_content: str | |
) -> str: | |
"""Simulate webhook for testing""" | |
if not all([repo_name, discussion_title, comment_content]): | |
return "Please fill in all fields." | |
mock_payload = { | |
"event": {"action": "create", "scope": "discussion"}, | |
"comment": { | |
"content": comment_content, | |
"author": {"id": "test-user-id"}, | |
"id": "mock-comment-id", | |
"hidden": False, | |
}, | |
"discussion": { | |
"title": discussion_title, | |
"num": len(tag_operations_store) + 1, | |
"id": "mock-discussion-id", | |
"status": "open", | |
"isPullRequest": False, | |
}, | |
"repo": { | |
"name": repo_name, | |
"type": "model", | |
"private": False, | |
}, | |
} | |
response = await process_webhook_comment(mock_payload) | |
return f"β Processed! Results: {response}" | |
def create_gradio_app(): | |
"""Create Gradio interface""" | |
with gr.Blocks(title="HF Tagging Bot", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# π·οΈ HF Tagging Bot Dashboard") | |
gr.Markdown("*Automatically adds tags to models, datasets, and spaces when mentioned in discussions*") | |
gr.Markdown(""" | |
## How it works: | |
- Monitors HuggingFace Hub discussions | |
- Detects tag mentions in comments (e.g., "tag: pytorch", "#transformers") | |
- Automatically detects repository type (model/dataset/space) | |
- Creates pull requests to add recognized tags to the repository | |
- Supports common ML tags like: pytorch, tensorflow, text-generation, etc. | |
""") | |
with gr.Column(): | |
sim_repo = gr.Textbox( | |
label="Repository", | |
value="burtenshaw/play-mcp-repo-bot", | |
placeholder="username/repo-name (can be model, dataset, or space)", | |
) | |
sim_title = gr.Textbox( | |
label="Discussion Title", | |
value="Add pytorch tag", | |
placeholder="Discussion title", | |
) | |
sim_comment = gr.Textbox( | |
label="Comment", | |
lines=3, | |
value="This repository should have tags: pytorch, text-generation", | |
placeholder="Comment mentioning tags...", | |
) | |
sim_btn = gr.Button("π·οΈ Test Tag Detection") | |
with gr.Column(): | |
sim_result = gr.Textbox(label="Result", lines=8) | |
sim_btn.click( | |
fn=simulate_webhook, | |
inputs=[sim_repo, sim_title, sim_comment], | |
outputs=sim_result, | |
) | |
gr.Markdown(f""" | |
## Recognized Tags: | |
{", ".join(sorted(RECOGNIZED_TAGS))} | |
""") | |
# Add recent operations section | |
if tag_operations_store: | |
gr.Markdown("## Recent Operations") | |
for op in tag_operations_store[-5:]: # Show last 5 operations | |
gr.Markdown(f""" | |
**{op['repo']}** - {op['timestamp'][:19]} | |
- Tags: {', '.join(op['detected_tags'])} | |
- Results: {' | '.join(op['results'][:2])}... | |
""") | |
return demo | |
# Mount Gradio app | |
gradio_app = create_gradio_app() | |
app = gr.mount_gradio_app(app, gradio_app, path="/gradio") | |
if __name__ == "__main__": | |
print("π Starting HF Tagging Bot...") | |
print(f"π Dashboard: http://localhost:7860/gradio") | |
print(f"π Webhook: http://localhost:7860/webhook") | |
print(f"π HF_TOKEN configured: {bool(HF_TOKEN)}") | |
print("π§ Using direct HuggingFace Hub API (Windows compatible)") | |
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |