hf-tagging-bot / app.py
asmaa105's picture
Update app.py
a20a584 verified
raw
history blame
14.4 kB
# app.py - HF Spaces compatible version (Fixed)
import os
import re
import json
from datetime import datetime
from typing import List, Dict, Any
import gradio as gr
# Configuration - Use HF Spaces secrets
WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET")
HF_TOKEN = os.getenv("HF_TOKEN")
# Simple storage for processed tag operations
tag_operations_store: List[Dict[str, Any]] = []
# Common ML tags that we recognize for auto-tagging
RECOGNIZED_TAGS = {
"pytorch", "tensorflow", "jax", "transformers", "diffusers",
"text-generation", "text-classification", "question-answering",
"text-to-image", "image-classification", "object-detection",
"fill-mask", "token-classification", "translation", "summarization",
"feature-extraction", "sentence-similarity", "zero-shot-classification",
"image-to-text", "automatic-speech-recognition", "audio-classification",
"voice-activity-detection", "depth-estimation", "image-segmentation",
"video-classification", "reinforcement-learning", "tabular-classification",
"tabular-regression", "time-series-forecasting", "graph-ml", "robotics",
"computer-vision", "nlp", "cv", "multimodal", "gguf", "safetensors",
"llamacpp", "onnx", "mlx"
}
def extract_tags_from_text(text: str) -> List[str]:
"""Extract potential tags from discussion text"""
text_lower = text.lower()
explicit_tags = []
# Pattern 1: "tag: something" or "tags: something"
tag_pattern = r"tags?:\s*([a-zA-Z0-9-_,\s]+)"
matches = re.findall(tag_pattern, text_lower)
for match in matches:
tags = [tag.strip() for tag in match.split(",")]
explicit_tags.extend(tags)
# Pattern 2: "#hashtag" style
hashtag_pattern = r"#([a-zA-Z0-9-_]+)"
hashtag_matches = re.findall(hashtag_pattern, text_lower)
explicit_tags.extend(hashtag_matches)
# Pattern 3: Look for recognized tags mentioned in natural text
mentioned_tags = []
for tag in RECOGNIZED_TAGS:
if tag in text_lower:
mentioned_tags.append(tag)
# Combine and deduplicate
all_tags = list(set(explicit_tags + mentioned_tags))
# Filter to only include recognized tags or explicitly mentioned ones
valid_tags = []
for tag in all_tags:
if tag in RECOGNIZED_TAGS or tag in explicit_tags:
valid_tags.append(tag)
return valid_tags
async def process_tags_directly(all_tags: List[str], repo_name: str) -> List[str]:
"""Process tags using direct HuggingFace Hub API calls"""
print("πŸ”§ Using direct HuggingFace Hub API approach...")
result_messages = []
if not HF_TOKEN:
error_msg = "No HF_TOKEN configured"
print(f"❌ {error_msg}")
return [error_msg]
try:
from huggingface_hub import HfApi, model_info, dataset_info, space_info, ModelCard, ModelCardData
from huggingface_hub.utils import HfHubHTTPError
from huggingface_hub import CommitOperationAdd
hf_api = HfApi(token=HF_TOKEN)
# Determine repository type
repo_type = None
repo_info = None
for repo_type_to_try in ["model", "dataset", "space"]:
try:
print(f"πŸ” Trying to access {repo_name} as {repo_type_to_try}...")
if repo_type_to_try == "model":
repo_info = model_info(repo_id=repo_name, token=HF_TOKEN)
elif repo_type_to_try == "dataset":
repo_info = dataset_info(repo_id=repo_name, token=HF_TOKEN)
elif repo_type_to_try == "space":
repo_info = space_info(repo_id=repo_name, token=HF_TOKEN)
repo_type = repo_type_to_try
print(f"βœ… Found repository as {repo_type}")
break
except HfHubHTTPError as e:
if "404" in str(e):
continue
else:
print(f"❌ Error accessing as {repo_type_to_try}: {e}")
continue
except Exception as e:
print(f"❌ Unexpected error for {repo_type_to_try}: {e}")
continue
if not repo_type or not repo_info:
error_msg = f"Repository '{repo_name}' not found"
return [f"Error: {error_msg}"]
current_tags = repo_info.tags if repo_info.tags else []
print(f"🏷️ Current tags: {current_tags}")
# Process each tag
for tag in all_tags:
try:
if tag in current_tags:
msg = f"Tag '{tag}': Already exists"
result_messages.append(msg)
continue
# Add the new tag
updated_tags = current_tags + [tag]
# Create/update model card
try:
card = ModelCard.load(repo_name, token=HF_TOKEN, repo_type=repo_type)
if not hasattr(card, "data") or card.data is None:
card.data = ModelCardData()
except HfHubHTTPError:
card = ModelCard("")
card.data = ModelCardData()
# Update tags
card_dict = card.data.to_dict()
card_dict["tags"] = updated_tags
card.data = ModelCardData(**card_dict)
pr_title = f"Add '{tag}' tag"
pr_description = f"""
## Add tag: {tag}
This PR adds the `{tag}` tag to the {repo_type} repository.
**Changes:**
- Added `{tag}` to {repo_type} tags
- Updated from {len(current_tags)} to {len(updated_tags)} tags
**Current tags:** {", ".join(current_tags) if current_tags else "None"}
**New tags:** {", ".join(updated_tags)}
*This PR was created automatically by the HF Tagging Bot.*
"""
commit_info = hf_api.create_commit(
repo_id=repo_name,
repo_type=repo_type,
operations=[
CommitOperationAdd(
path_in_repo="README.md",
path_or_fileobj=str(card).encode("utf-8")
)
],
commit_message=pr_title,
commit_description=pr_description,
token=HF_TOKEN,
create_pr=True,
)
pr_url = getattr(commit_info, 'pr_url', str(commit_info))
msg = f"Tag '{tag}': [PR created]({pr_url})"
result_messages.append(msg)
except Exception as tag_error:
error_msg = f"Tag '{tag}': Error - {str(tag_error)}"
result_messages.append(error_msg)
return result_messages
except Exception as e:
error_msg = f"Processing failed: {str(e)}"
return [error_msg]
async def process_webhook_comment(webhook_data: Dict[str, Any]):
"""Process webhook to detect and add tags"""
try:
comment_content = webhook_data["comment"]["content"]
discussion_title = webhook_data["discussion"]["title"]
repo_name = webhook_data["repo"]["name"]
discussion_num = webhook_data["discussion"]["num"]
comment_author = webhook_data["comment"]["author"].get("id", "unknown")
print(f"πŸ“ Processing comment from {comment_author} on {repo_name}")
print(f"πŸ“ Comment: {comment_content}")
# Extract tags
comment_tags = extract_tags_from_text(comment_content)
title_tags = extract_tags_from_text(discussion_title)
all_tags = list(set(comment_tags + title_tags))
if not all_tags:
return "No recognizable tags found"
print(f"🏷️ Found tags: {all_tags}")
result_messages = await process_tags_directly(all_tags, repo_name)
# Store interaction
interaction = {
"timestamp": datetime.now().isoformat(),
"repo": repo_name,
"discussion_title": discussion_title,
"discussion_num": discussion_num,
"comment_author": comment_author,
"detected_tags": all_tags,
"results": result_messages,
}
tag_operations_store.append(interaction)
# Keep only last 100 operations
if len(tag_operations_store) > 100:
tag_operations_store.pop(0)
return " | ".join(result_messages)
except Exception as e:
error_msg = f"Error processing webhook: {str(e)}"
print(f"❌ {error_msg}")
return error_msg
def create_gradio_interface():
"""Create Gradio interface for monitoring"""
with gr.Blocks(title="HF Tagging Bot", theme=gr.themes.Soft()) as interface:
gr.Markdown("# 🏷️ HuggingFace Tagging Bot")
gr.Markdown("*Automatically adds tags to repositories when mentioned in discussions*")
with gr.Tab("🏠 Status"):
gr.Markdown(f"""
## Bot Configuration
- πŸ”‘ **HF Token**: {'βœ… Configured' if HF_TOKEN else '❌ Missing'}
- πŸ” **Webhook Secret**: {'βœ… Configured' if WEBHOOK_SECRET else '❌ Missing'}
- πŸ“Š **Operations Processed**: {len(tag_operations_store)}
## Setup Instructions
1. **Add webhook to your repository**:
- Go to repository Settings β†’ Webhooks
- Add webhook URL: `https://your-space-name.hf.space/webhook`
- Select "Discussion comments" events
- Add your webhook secret (optional)
2. **In discussions, mention tags**:
- "Please add tags: pytorch, transformers"
- "This needs #pytorch and #text-generation"
- "tag: computer-vision"
## Webhook Endpoint
`POST https://your-space-name.hf.space/webhook`
## Health Check
Visit: `https://your-space-name.hf.space/health`
""")
with gr.Tab("πŸ“ Operations Log"):
def get_recent_operations():
if not tag_operations_store:
return "No operations yet. Configure webhooks and post comments with tags to see activity here."
recent = tag_operations_store[-10:]
output = []
for op in reversed(recent):
output.append(f"""
**{op['repo']}** - {op['timestamp'][:19]}
- πŸ‘€ Author: {op['comment_author']}
- 🏷️ Tags: {', '.join(op['detected_tags']) if op['detected_tags'] else 'None'}
- πŸ“‹ Results: {' | '.join(op['results'][:2])}{'...' if len(op['results']) > 2 else ''}
---""")
return "\n".join(output)
operations_display = gr.Textbox(
label="Recent Operations",
value=get_recent_operations(),
lines=15,
interactive=False
)
refresh_btn = gr.Button("πŸ”„ Refresh Log")
refresh_btn.click(fn=get_recent_operations, outputs=operations_display)
with gr.Tab("🏷️ Tags & Testing"):
gr.Markdown(f"""
## Supported Tags ({len(RECOGNIZED_TAGS)} total)
{', '.join(sorted(RECOGNIZED_TAGS))}
## Tag Detection Examples
- **Explicit**: `tag: pytorch` or `tags: pytorch, transformers`
- **Hashtag**: `#pytorch #transformers`
- **Natural**: "This model uses pytorch and transformers"
""")
gr.Markdown("### Test Tag Detection")
test_input = gr.Textbox(
label="Test Comment",
placeholder="Enter a comment to test tag detection...",
lines=3,
value="This model should have tags: pytorch, text-generation"
)
test_output = gr.Textbox(
label="Detected Tags",
lines=2,
interactive=False
)
test_btn = gr.Button("πŸ” Test Detection")
def test_tag_detection(text):
if not text:
return "Enter some text to test"
tags = extract_tags_from_text(text)
if tags:
return f"Found {len(tags)} tags: {', '.join(tags)}"
else:
return "No tags detected in this text"
test_btn.click(fn=test_tag_detection, inputs=test_input, outputs=test_output)
return interface
# Create the Gradio interface
demo = create_gradio_interface()
# Add webhook handling using Gradio's built-in FastAPI integration
from fastapi import Request
import asyncio
@demo.app.post("/webhook")
async def webhook_handler(request: Request):
"""Handle HF Hub webhooks"""
# Verify webhook secret if configured
if WEBHOOK_SECRET:
webhook_secret = request.headers.get("X-Webhook-Secret")
if webhook_secret != WEBHOOK_SECRET:
print("❌ Invalid webhook secret")
return {"error": "Invalid webhook secret"}
try:
payload = await request.json()
print(f"πŸ“₯ Received webhook: {payload.get('event', {})}")
event = payload.get("event", {})
scope = event.get("scope")
action = event.get("action")
# Only process discussion comment creation (not PRs)
if (scope == "discussion" and
action == "create" and
not payload.get("discussion", {}).get("isPullRequest", False)):
# Process webhook in background
asyncio.create_task(process_webhook_comment(payload))
return {"status": "processing"}
return {"status": "ignored"}
except Exception as e:
print(f"❌ Webhook error: {e}")
return {"error": str(e)}
@demo.app.get("/health")
async def health_check():
return {
"status": "healthy",
"hf_token_configured": bool(HF_TOKEN),
"webhook_secret_configured": bool(WEBHOOK_SECRET),
"operations_processed": len(tag_operations_store)
}
# Launch the interface
if __name__ == "__main__":
print("πŸš€ HF Tagging Bot - Gradio interface ready")
print(f"πŸ”‘ HF_TOKEN: {'βœ… Configured' if HF_TOKEN else '❌ Missing'}")
print(f"πŸ” Webhook Secret: {'βœ… Configured' if WEBHOOK_SECRET else '❌ Missing'}")
demo.launch(server_name="0.0.0.0", server_port=7860)