asmaa105 commited on
Commit
2d073e8
Β·
verified Β·
1 Parent(s): de10c5f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +466 -0
app.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ from datetime import datetime
5
+ from typing import List, Dict, Any, Optional, Literal
6
+
7
+ from fastapi import FastAPI, Request, BackgroundTasks
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ import gradio as gr
10
+ import uvicorn
11
+ from pydantic import BaseModel
12
+ from dotenv import load_dotenv
13
+
14
+ load_dotenv()
15
+
16
+ # Configuration
17
+ WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET", "716f77a91d0415cd0e3ed9dc8d188fc9ee53b11a8661e161a86f669f598a8016")
18
+ HF_TOKEN = os.getenv("HF_TOKEN")
19
+
20
+ # Simple storage for processed tag operations
21
+ tag_operations_store: List[Dict[str, Any]] = []
22
+
23
+ # Common ML tags that we recognize for auto-tagging
24
+ RECOGNIZED_TAGS = {
25
+ "pytorch",
26
+ "tensorflow",
27
+ "jax",
28
+ "transformers",
29
+ "diffusers",
30
+ "text-generation",
31
+ "text-classification",
32
+ "question-answering",
33
+ "text-to-image",
34
+ "image-classification",
35
+ "object-detection",
36
+ "fill-mask",
37
+ "token-classification",
38
+ "translation",
39
+ "summarization",
40
+ "feature-extraction",
41
+ "sentence-similarity",
42
+ "zero-shot-classification",
43
+ "image-to-text",
44
+ "automatic-speech-recognition",
45
+ "audio-classification",
46
+ "voice-activity-detection",
47
+ "depth-estimation",
48
+ "image-segmentation",
49
+ "video-classification",
50
+ "reinforcement-learning",
51
+ "tabular-classification",
52
+ "tabular-regression",
53
+ "time-series-forecasting",
54
+ "graph-ml",
55
+ "robotics",
56
+ "computer-vision",
57
+ "nlp",
58
+ "cv",
59
+ "multimodal",
60
+ }
61
+
62
+
63
+ class WebhookEvent(BaseModel):
64
+ event: Dict[str, str]
65
+ comment: Dict[str, Any]
66
+ discussion: Dict[str, Any]
67
+ repo: Dict[str, str]
68
+
69
+
70
+ app = FastAPI(title="HF Tagging Bot")
71
+ app.add_middleware(CORSMiddleware, allow_origins=["*"])
72
+
73
+
74
+ def extract_tags_from_text(text: str) -> List[str]:
75
+ """Extract potential tags from discussion text"""
76
+ text_lower = text.lower()
77
+
78
+ # Look for explicit tag mentions like "tag: pytorch" or "#pytorch"
79
+ explicit_tags = []
80
+
81
+ # Pattern 1: "tag: something" or "tags: something"
82
+ tag_pattern = r"tags?:\s*([a-zA-Z0-9-_,\s]+)"
83
+ matches = re.findall(tag_pattern, text_lower)
84
+ for match in matches:
85
+ # Split by comma and clean up
86
+ tags = [tag.strip() for tag in match.split(",")]
87
+ explicit_tags.extend(tags)
88
+
89
+ # Pattern 2: "#hashtag" style
90
+ hashtag_pattern = r"#([a-zA-Z0-9-_]+)"
91
+ hashtag_matches = re.findall(hashtag_pattern, text_lower)
92
+ explicit_tags.extend(hashtag_matches)
93
+
94
+ # Pattern 3: Look for recognized tags mentioned in natural text
95
+ mentioned_tags = []
96
+ for tag in RECOGNIZED_TAGS:
97
+ if tag in text_lower:
98
+ mentioned_tags.append(tag)
99
+
100
+ # Combine and deduplicate
101
+ all_tags = list(set(explicit_tags + mentioned_tags))
102
+
103
+ # Filter to only include recognized tags or explicitly mentioned ones
104
+ valid_tags = []
105
+ for tag in all_tags:
106
+ if tag in RECOGNIZED_TAGS or tag in explicit_tags:
107
+ valid_tags.append(tag)
108
+
109
+ return valid_tags
110
+
111
+
112
+ async def process_tags_directly(all_tags: List[str], repo_name: str) -> List[str]:
113
+ """Process tags using direct HuggingFace Hub API calls"""
114
+ print("πŸ”§ Using direct HuggingFace Hub API approach...")
115
+ result_messages = []
116
+
117
+ if not HF_TOKEN:
118
+ error_msg = "No HF_TOKEN configured"
119
+ print(f"❌ {error_msg}")
120
+ return [error_msg]
121
+
122
+ try:
123
+ from huggingface_hub import HfApi, model_info, dataset_info, space_info, ModelCard, ModelCardData
124
+ from huggingface_hub.utils import HfHubHTTPError
125
+ from huggingface_hub import CommitOperationAdd
126
+
127
+ hf_api = HfApi(token=HF_TOKEN)
128
+
129
+ # First, let's determine what type of repository this is
130
+ repo_type = None
131
+ repo_info = None
132
+
133
+ # Try different repository types
134
+ for repo_type_to_try in ["model", "dataset", "space"]:
135
+ try:
136
+ print(f"πŸ” Trying to access {repo_name} as {repo_type_to_try}...")
137
+ if repo_type_to_try == "model":
138
+ repo_info = model_info(repo_id=repo_name, token=HF_TOKEN)
139
+ elif repo_type_to_try == "dataset":
140
+ repo_info = dataset_info(repo_id=repo_name, token=HF_TOKEN)
141
+ elif repo_type_to_try == "space":
142
+ repo_info = space_info(repo_id=repo_name, token=HF_TOKEN)
143
+
144
+ repo_type = repo_type_to_try
145
+ print(f"βœ… Found repository as {repo_type}")
146
+ break
147
+
148
+ except HfHubHTTPError as e:
149
+ if "404" in str(e):
150
+ print(f"⚠️ Repository not found as {repo_type_to_try}")
151
+ continue
152
+ else:
153
+ print(f"❌ Error accessing as {repo_type_to_try}: {e}")
154
+ continue
155
+ except Exception as e:
156
+ print(f"❌ Unexpected error for {repo_type_to_try}: {e}")
157
+ continue
158
+
159
+ if not repo_type or not repo_info:
160
+ error_msg = f"Repository '{repo_name}' not found as model, dataset, or space"
161
+ print(f"❌ {error_msg}")
162
+ return [f"Error: {error_msg}"]
163
+
164
+ print(f"πŸ“‹ Repository type: {repo_type}")
165
+ current_tags = repo_info.tags if repo_info.tags else []
166
+ print(f"🏷️ Current tags: {current_tags}")
167
+
168
+ # Process each tag
169
+ for tag in all_tags:
170
+ try:
171
+ # Check if tag already exists
172
+ if tag in current_tags:
173
+ msg = f"Tag '{tag}': Already exists"
174
+ print(f"βœ… {msg}")
175
+ result_messages.append(msg)
176
+ continue
177
+
178
+ # Add the new tag
179
+ print(f"πŸ”§ Adding tag '{tag}' to {repo_type} '{repo_name}'")
180
+ updated_tags = current_tags + [tag]
181
+
182
+ # Create model card content with updated tags
183
+ try:
184
+ # Load existing model card
185
+ print(f"πŸ“„ Loading existing model card...")
186
+ card = ModelCard.load(repo_name, token=HF_TOKEN, repo_type=repo_type)
187
+ if not hasattr(card, "data") or card.data is None:
188
+ card.data = ModelCardData()
189
+ except HfHubHTTPError:
190
+ # Create new model card if none exists
191
+ print(f"πŸ“„ Creating new model card (none exists)")
192
+ card = ModelCard("")
193
+ card.data = ModelCardData()
194
+
195
+ # Update tags
196
+ card_dict = card.data.to_dict()
197
+ card_dict["tags"] = updated_tags
198
+ card.data = ModelCardData(**card_dict)
199
+
200
+ # Create a pull request with the updated model card
201
+ pr_title = f"Add '{tag}' tag"
202
+ pr_description = f"""
203
+ ## Add tag: {tag}
204
+
205
+ This PR adds the `{tag}` tag to the {repo_type} repository.
206
+
207
+ **Changes:**
208
+ - Added `{tag}` to {repo_type} tags
209
+ - Updated from {len(current_tags)} to {len(updated_tags)} tags
210
+
211
+ **Current tags:** {", ".join(current_tags) if current_tags else "None"}
212
+ **New tags:** {", ".join(updated_tags)}
213
+ """
214
+
215
+ print(f"πŸš€ Creating PR with title: {pr_title}")
216
+
217
+ # Create commit with updated model card
218
+ commit_info = hf_api.create_commit(
219
+ repo_id=repo_name,
220
+ repo_type=repo_type,
221
+ operations=[
222
+ CommitOperationAdd(
223
+ path_in_repo="README.md",
224
+ path_or_fileobj=str(card).encode("utf-8")
225
+ )
226
+ ],
227
+ commit_message=pr_title,
228
+ commit_description=pr_description,
229
+ token=HF_TOKEN,
230
+ create_pr=True,
231
+ )
232
+
233
+ # Extract PR URL from commit info
234
+ pr_url = getattr(commit_info, 'pr_url', str(commit_info))
235
+
236
+ print(f"βœ… PR created successfully! URL: {pr_url}")
237
+ msg = f"Tag '{tag}': PR created - {pr_url}"
238
+ result_messages.append(msg)
239
+
240
+ except Exception as tag_error:
241
+ error_msg = f"Tag '{tag}': Error - {str(tag_error)}"
242
+ print(f"❌ {error_msg}")
243
+ result_messages.append(error_msg)
244
+
245
+ return result_messages
246
+
247
+ except Exception as e:
248
+ error_msg = f"Direct API processing failed: {str(e)}"
249
+ print(f"❌ {error_msg}")
250
+ return [error_msg]
251
+
252
+
253
+ async def process_webhook_comment(webhook_data: Dict[str, Any]):
254
+ """Process webhook to detect and add tags"""
255
+ print("🏷️ Starting process_webhook_comment...")
256
+
257
+ try:
258
+ comment_content = webhook_data["comment"]["content"]
259
+ discussion_title = webhook_data["discussion"]["title"]
260
+ repo_name = webhook_data["repo"]["name"]
261
+ discussion_num = webhook_data["discussion"]["num"]
262
+ comment_author = webhook_data["comment"]["author"].get("id", "unknown")
263
+
264
+ print(f"πŸ“ Comment content: {comment_content}")
265
+ print(f"πŸ“° Discussion title: {discussion_title}")
266
+ print(f"πŸ“¦ Repository: {repo_name}")
267
+
268
+ # Extract potential tags from the comment and discussion title
269
+ comment_tags = extract_tags_from_text(comment_content)
270
+ title_tags = extract_tags_from_text(discussion_title)
271
+ all_tags = list(set(comment_tags + title_tags))
272
+
273
+ print(f"πŸ” Comment tags found: {comment_tags}")
274
+ print(f"πŸ” Title tags found: {title_tags}")
275
+ print(f"🏷️ All unique tags: {all_tags}")
276
+
277
+ result_messages = []
278
+
279
+ if not all_tags:
280
+ msg = "No recognizable tags found in the discussion."
281
+ print(f"❌ {msg}")
282
+ result_messages.append(msg)
283
+ else:
284
+ # Skip agent entirely and use direct API approach
285
+ print("πŸ”§ Using direct HuggingFace Hub API processing...")
286
+ result_messages = await process_tags_directly(all_tags, repo_name)
287
+
288
+ # Store the interaction
289
+ base_url = "https://huggingface.co"
290
+ discussion_url = f"{base_url}/{repo_name}/discussions/{discussion_num}"
291
+
292
+ interaction = {
293
+ "timestamp": datetime.now().isoformat(),
294
+ "repo": repo_name,
295
+ "discussion_title": discussion_title,
296
+ "discussion_num": discussion_num,
297
+ "discussion_url": discussion_url,
298
+ "original_comment": comment_content,
299
+ "comment_author": comment_author,
300
+ "detected_tags": all_tags,
301
+ "results": result_messages,
302
+ }
303
+
304
+ tag_operations_store.append(interaction)
305
+ final_result = " | ".join(result_messages)
306
+ print(f"πŸ’Ύ Stored interaction and returning result: {final_result}")
307
+ return final_result
308
+
309
+ except Exception as e:
310
+ error_msg = f"❌ Fatal error in process_webhook_comment: {str(e)}"
311
+ print(error_msg)
312
+ import traceback
313
+ print(f"❌ Traceback: {traceback.format_exc()}")
314
+ return error_msg
315
+
316
+
317
+ @app.post("/webhook")
318
+ async def webhook_handler(request: Request, background_tasks: BackgroundTasks):
319
+ """Handle HF Hub webhooks"""
320
+ webhook_secret = request.headers.get("X-Webhook-Secret")
321
+ if webhook_secret != WEBHOOK_SECRET:
322
+ print("❌ Invalid webhook secret")
323
+ return {"error": "Invalid webhook secret"}
324
+
325
+ payload = await request.json()
326
+ print(f"πŸ“₯ Received webhook payload: {json.dumps(payload, indent=2)}")
327
+
328
+ event = payload.get("event", {})
329
+ scope = event.get("scope")
330
+ action = event.get("action")
331
+
332
+ print(f"πŸ” Event details - scope: {scope}, action: {action}")
333
+
334
+ # Check if this is a discussion comment creation
335
+ scope_check = scope == "discussion"
336
+ action_check = action == "create"
337
+ not_pr = not payload["discussion"]["isPullRequest"]
338
+ scope_check = scope_check and not_pr
339
+ print(f"βœ… not_pr: {not_pr}")
340
+ print(f"βœ… scope_check: {scope_check}")
341
+ print(f"βœ… action_check: {action_check}")
342
+
343
+ if scope_check and action_check:
344
+ # Verify we have the required fields
345
+ required_fields = ["comment", "discussion", "repo"]
346
+ missing_fields = [field for field in required_fields if field not in payload]
347
+
348
+ if missing_fields:
349
+ error_msg = f"Missing required fields: {missing_fields}"
350
+ print(f"❌ {error_msg}")
351
+ return {"error": error_msg}
352
+
353
+ print(f"πŸš€ Processing webhook for repo: {payload['repo']['name']}")
354
+ background_tasks.add_task(process_webhook_comment, payload)
355
+ return {"status": "processing"}
356
+
357
+ print(f"⏭️ Ignoring webhook - scope: {scope}, action: {action}")
358
+ return {"status": "ignored"}
359
+
360
+
361
+ async def simulate_webhook(
362
+ repo_name: str, discussion_title: str, comment_content: str
363
+ ) -> str:
364
+ """Simulate webhook for testing"""
365
+ if not all([repo_name, discussion_title, comment_content]):
366
+ return "Please fill in all fields."
367
+
368
+ mock_payload = {
369
+ "event": {"action": "create", "scope": "discussion"},
370
+ "comment": {
371
+ "content": comment_content,
372
+ "author": {"id": "test-user-id"},
373
+ "id": "mock-comment-id",
374
+ "hidden": False,
375
+ },
376
+ "discussion": {
377
+ "title": discussion_title,
378
+ "num": len(tag_operations_store) + 1,
379
+ "id": "mock-discussion-id",
380
+ "status": "open",
381
+ "isPullRequest": False,
382
+ },
383
+ "repo": {
384
+ "name": repo_name,
385
+ "type": "model",
386
+ "private": False,
387
+ },
388
+ }
389
+
390
+ response = await process_webhook_comment(mock_payload)
391
+ return f"βœ… Processed! Results: {response}"
392
+
393
+
394
+ def create_gradio_app():
395
+ """Create Gradio interface"""
396
+ with gr.Blocks(title="HF Tagging Bot", theme=gr.themes.Soft()) as demo:
397
+ gr.Markdown("# 🏷️ HF Tagging Bot Dashboard")
398
+ gr.Markdown("*Automatically adds tags to models, datasets, and spaces when mentioned in discussions*")
399
+
400
+ gr.Markdown("""
401
+ ## How it works:
402
+ - Monitors HuggingFace Hub discussions
403
+ - Detects tag mentions in comments (e.g., "tag: pytorch", "#transformers")
404
+ - Automatically detects repository type (model/dataset/space)
405
+ - Creates pull requests to add recognized tags to the repository
406
+ - Supports common ML tags like: pytorch, tensorflow, text-generation, etc.
407
+ """)
408
+
409
+ with gr.Column():
410
+ sim_repo = gr.Textbox(
411
+ label="Repository",
412
+ value="burtenshaw/play-mcp-repo-bot",
413
+ placeholder="username/repo-name (can be model, dataset, or space)",
414
+ )
415
+ sim_title = gr.Textbox(
416
+ label="Discussion Title",
417
+ value="Add pytorch tag",
418
+ placeholder="Discussion title",
419
+ )
420
+ sim_comment = gr.Textbox(
421
+ label="Comment",
422
+ lines=3,
423
+ value="This repository should have tags: pytorch, text-generation",
424
+ placeholder="Comment mentioning tags...",
425
+ )
426
+ sim_btn = gr.Button("🏷️ Test Tag Detection")
427
+
428
+ with gr.Column():
429
+ sim_result = gr.Textbox(label="Result", lines=8)
430
+
431
+ sim_btn.click(
432
+ fn=simulate_webhook,
433
+ inputs=[sim_repo, sim_title, sim_comment],
434
+ outputs=sim_result,
435
+ )
436
+
437
+ gr.Markdown(f"""
438
+ ## Recognized Tags:
439
+ {", ".join(sorted(RECOGNIZED_TAGS))}
440
+ """)
441
+
442
+ # Add recent operations section
443
+ if tag_operations_store:
444
+ gr.Markdown("## Recent Operations")
445
+ for op in tag_operations_store[-5:]: # Show last 5 operations
446
+ gr.Markdown(f"""
447
+ **{op['repo']}** - {op['timestamp'][:19]}
448
+ - Tags: {', '.join(op['detected_tags'])}
449
+ - Results: {' | '.join(op['results'][:2])}...
450
+ """)
451
+
452
+ return demo
453
+
454
+
455
+ # Mount Gradio app
456
+ gradio_app = create_gradio_app()
457
+ app = gr.mount_gradio_app(app, gradio_app, path="/gradio")
458
+
459
+
460
+ if __name__ == "__main__":
461
+ print("πŸš€ Starting HF Tagging Bot...")
462
+ print(f"πŸ“Š Dashboard: http://localhost:7860/gradio")
463
+ print(f"πŸ”— Webhook: http://localhost:7860/webhook")
464
+ print(f"πŸ”‘ HF_TOKEN configured: {bool(HF_TOKEN)}")
465
+ print("πŸ”§ Using direct HuggingFace Hub API (Windows compatible)")
466
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)