asmaa105 commited on
Commit
de10c5f
·
verified ·
1 Parent(s): 25daca9

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -522
app.py DELETED
@@ -1,522 +0,0 @@
1
- import os
2
- import re
3
- import json
4
- from datetime import datetime
5
- from typing import List, Dict, Any, Optional, Literal
6
-
7
- from fastapi import FastAPI, Request, BackgroundTasks
8
- from fastapi.middleware.cors import CORSMiddleware
9
- import gradio as gr
10
- import uvicorn
11
- from pydantic import BaseModel
12
- from huggingface_hub.inference._mcp.agent import Agent
13
- from dotenv import load_dotenv
14
-
15
- load_dotenv()
16
-
17
- # Configuration
18
- WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET", "716f77a91d0415cd0e3ed9dc8d188fc9ee53b11a8661e161a86f669f598a8016")
19
- HF_TOKEN = os.getenv("HF_TOKEN")
20
- HF_MODEL = os.getenv("HF_MODEL", "microsoft/DialoGPT-medium")
21
- # Use a valid provider literal from the documentation
22
- DEFAULT_PROVIDER: Literal["hf-inference"] = "hf-inference"
23
- HF_PROVIDER = os.getenv("HF_PROVIDER", DEFAULT_PROVIDER)
24
-
25
- # Simple storage for processed tag operations
26
- tag_operations_store: List[Dict[str, Any]] = []
27
-
28
- # Agent instance
29
- agent_instance: Optional[Agent] = None
30
-
31
- # Common ML tags that we recognize for auto-tagging
32
- RECOGNIZED_TAGS = {
33
- "pytorch",
34
- "tensorflow",
35
- "jax",
36
- "transformers",
37
- "diffusers",
38
- "text-generation",
39
- "text-classification",
40
- "question-answering",
41
- "text-to-image",
42
- "image-classification",
43
- "object-detection",
44
- " ",
45
- "fill-mask",
46
- "token-classification",
47
- "translation",
48
- "summarization",
49
- "feature-extraction",
50
- "sentence-similarity",
51
- "zero-shot-classification",
52
- "image-to-text",
53
- "automatic-speech-recognition",
54
- "audio-classification",
55
- "voice-activity-detection",
56
- "depth-estimation",
57
- "image-segmentation",
58
- "video-classification",
59
- "reinforcement-learning",
60
- "tabular-classification",
61
- "tabular-regression",
62
- "time-series-forecasting",
63
- "graph-ml",
64
- "robotics",
65
- "computer-vision",
66
- "nlp",
67
- "cv",
68
- "multimodal",
69
- }
70
-
71
-
72
- class WebhookEvent(BaseModel):
73
- event: Dict[str, str]
74
- comment: Dict[str, Any]
75
- discussion: Dict[str, Any]
76
- repo: Dict[str, str]
77
-
78
-
79
- app = FastAPI(title="HF Tagging Bot")
80
- app.add_middleware(CORSMiddleware, allow_origins=["*"])
81
-
82
-
83
- async def get_agent():
84
- """Get or create Agent instance"""
85
- print("🤖 get_agent() called...")
86
- global agent_instance
87
- if agent_instance is None and HF_TOKEN:
88
- print("🔧 Creating new Agent instance...")
89
- print(f"🔑 HF_TOKEN present: {bool(HF_TOKEN)}")
90
- print(f"🤖 Model: {HF_MODEL}")
91
- print(f"🔗 Provider: {DEFAULT_PROVIDER}")
92
-
93
- try:
94
- agent_instance = Agent(
95
- model=HF_MODEL,
96
- provider=DEFAULT_PROVIDER,
97
- api_key=HF_TOKEN,
98
- servers=[
99
- {
100
- "type": "stdio",
101
- "config": {
102
- "command": "python",
103
- "args": ["mcp_server.py"],
104
- "cwd": ".", # Ensure correct working directory
105
- "env": {"HF_TOKEN": HF_TOKEN} if HF_TOKEN else {},
106
- },
107
- }
108
- ],
109
- )
110
- print("✅ Agent instance created successfully")
111
- print("🔧 Loading tools...")
112
- await agent_instance.load_tools()
113
- print("✅ Tools loaded successfully")
114
- except Exception as e:
115
- print(f"❌ Error creating/loading agent: {str(e)}")
116
- agent_instance = None
117
- elif agent_instance is None:
118
- print("❌ No HF_TOKEN available, cannot create agent")
119
- else:
120
- print("✅ Using existing agent instance")
121
-
122
- return agent_instance
123
-
124
-
125
- def extract_tags_from_text(text: str) -> List[str]:
126
- """Extract potential tags from discussion text"""
127
- text_lower = text.lower()
128
-
129
- # Look for explicit tag mentions like "tag: pytorch" or "#pytorch"
130
- explicit_tags = []
131
-
132
- # Pattern 1: "tag: something" or "tags: something"
133
- tag_pattern = r"tags?:\s*([a-zA-Z0-9-_,\s]+)"
134
- matches = re.findall(tag_pattern, text_lower)
135
- for match in matches:
136
- # Split by comma and clean up
137
- tags = [tag.strip() for tag in match.split(",")]
138
- explicit_tags.extend(tags)
139
-
140
- # Pattern 2: "#hashtag" style
141
- hashtag_pattern = r"#([a-zA-Z0-9-_]+)"
142
- hashtag_matches = re.findall(hashtag_pattern, text_lower)
143
- explicit_tags.extend(hashtag_matches)
144
-
145
- # Pattern 3: Look for recognized tags mentioned in natural text
146
- mentioned_tags = []
147
- for tag in RECOGNIZED_TAGS:
148
- if tag in text_lower:
149
- mentioned_tags.append(tag)
150
-
151
- # Combine and deduplicate
152
- all_tags = list(set(explicit_tags + mentioned_tags))
153
-
154
- # Filter to only include recognized tags or explicitly mentioned ones
155
- valid_tags = []
156
- for tag in all_tags:
157
- if tag in RECOGNIZED_TAGS or tag in explicit_tags:
158
- valid_tags.append(tag)
159
-
160
- return valid_tags
161
-
162
-
163
- async def process_webhook_comment(webhook_data: Dict[str, Any]):
164
- """Process webhook to detect and add tags"""
165
- print("🏷️ Starting process_webhook_comment...")
166
-
167
- try:
168
- comment_content = webhook_data["comment"]["content"]
169
- discussion_title = webhook_data["discussion"]["title"]
170
- repo_name = webhook_data["repo"]["name"]
171
- discussion_num = webhook_data["discussion"]["num"]
172
- # Author is an object with "id" field
173
- comment_author = webhook_data["comment"]["author"].get("id", "unknown")
174
-
175
- print(f"📝 Comment content: {comment_content}")
176
- print(f"📰 Discussion title: {discussion_title}")
177
- print(f"📦 Repository: {repo_name}")
178
-
179
- # Extract potential tags from the comment and discussion title
180
- comment_tags = extract_tags_from_text(comment_content)
181
- title_tags = extract_tags_from_text(discussion_title)
182
- all_tags = list(set(comment_tags + title_tags))
183
-
184
- print(f"🔍 Comment tags found: {comment_tags}")
185
- print(f"🔍 Title tags found: {title_tags}")
186
- print(f"🏷️ All unique tags: {all_tags}")
187
-
188
- result_messages = []
189
-
190
- if not all_tags:
191
- msg = "No recognizable tags found in the discussion."
192
- print(f"❌ {msg}")
193
- result_messages.append(msg)
194
- else:
195
- print("🤖 Getting agent instance...")
196
- agent = await get_agent()
197
- if not agent:
198
- msg = "Error: Agent not configured (missing HF_TOKEN)"
199
- print(f"❌ {msg}")
200
- result_messages.append(msg)
201
- else:
202
- print("✅ Agent instance obtained successfully")
203
-
204
- # Process all tags in a single conversation with the agent
205
- try:
206
- # Create a comprehensive prompt for the agent
207
- user_prompt = f"""
208
- I need to add the following tags to the repository '{repo_name}': {", ".join(all_tags)}
209
- For each tag, please:
210
- 1. Check if the tag already exists on the repository using get_current_tags
211
- 2. If the tag doesn't exist, add it using add_new_tag
212
- 3. Provide a summary of what was done for each tag
213
- Please process all {len(all_tags)} tags: {", ".join(all_tags)}
214
- """
215
-
216
- print("💬 Sending comprehensive prompt to agent...")
217
- print(f"📝 Prompt: {user_prompt}")
218
-
219
- # Let the agent handle the entire conversation
220
- conversation_result = []
221
-
222
- try:
223
- async for item in agent.run(user_prompt):
224
- # The agent yields different types of items
225
- item_str = str(item)
226
- conversation_result.append(item_str)
227
-
228
- # Log important events
229
- if (
230
- "tool_call" in item_str.lower()
231
- or "function" in item_str.lower()
232
- ):
233
- print(f"🔧 Agent using tools: {item_str[:200]}...")
234
- elif "content" in item_str and len(item_str) < 500:
235
- print(f"💭 Agent response: {item_str}")
236
-
237
- # Extract the final response from the conversation
238
- full_response = " ".join(conversation_result)
239
- print(f"📋 Agent conversation completed successfully")
240
-
241
- # Try to extract meaningful results for each tag
242
- for tag in all_tags:
243
- tag_mentioned = tag.lower() in full_response.lower()
244
-
245
- if (
246
- "already exists" in full_response.lower()
247
- and tag_mentioned
248
- ):
249
- msg = f"Tag '{tag}': Already exists"
250
- elif (
251
- "pr" in full_response.lower()
252
- or "pull request" in full_response.lower()
253
- ):
254
- if tag_mentioned:
255
- msg = f"Tag '{tag}': PR created successfully"
256
- else:
257
- msg = (
258
- f"Tag '{tag}': Processed "
259
- "(PR may have been created)"
260
- )
261
- elif "success" in full_response.lower() and tag_mentioned:
262
- msg = f"Tag '{tag}': Successfully processed"
263
- elif "error" in full_response.lower() and tag_mentioned:
264
- msg = f"Tag '{tag}': Error during processing"
265
- else:
266
- msg = f"Tag '{tag}': Processed by agent"
267
-
268
- print(f"✅ Result for tag '{tag}': {msg}")
269
- result_messages.append(msg)
270
-
271
- except Exception as agent_error:
272
- print(f"⚠️ Agent streaming failed: {str(agent_error)}")
273
- print("🔄 Falling back to direct MCP tool calls...")
274
-
275
- # Import the MCP server functions directly as fallback
276
- try:
277
- import sys
278
- import importlib.util
279
-
280
- # Load the MCP server module
281
- spec = importlib.util.spec_from_file_location(
282
- "mcp_server", "./mcp_server.py"
283
- )
284
- mcp_module = importlib.util.module_from_spec(spec)
285
- spec.loader.exec_module(mcp_module)
286
-
287
- # Use the MCP tools directly for each tag
288
- for tag in all_tags:
289
- try:
290
- print(
291
- f"🔧 Directly calling get_current_tags for '{tag}'"
292
- )
293
- current_tags_result = mcp_module.get_current_tags(
294
- repo_name
295
- )
296
- print(
297
- f"📄 Current tags result: {current_tags_result}"
298
- )
299
-
300
- # Parse the JSON result
301
- import json
302
-
303
- tags_data = json.loads(current_tags_result)
304
-
305
- if tags_data.get("status") == "success":
306
- current_tags = tags_data.get("current_tags", [])
307
- if tag in current_tags:
308
- msg = f"Tag '{tag}': Already exists"
309
- print(f"✅ {msg}")
310
- else:
311
- print(
312
- f"🔧 Directly calling add_new_tag for '{tag}'"
313
- )
314
- add_result = mcp_module.add_new_tag(
315
- repo_name, tag
316
- )
317
- print(f"📄 Add tag result: {add_result}")
318
-
319
- add_data = json.loads(add_result)
320
- if add_data.get("status") == "success":
321
- pr_url = add_data.get("pr_url", "")
322
- msg = f"Tag '{tag}': PR created - {pr_url}"
323
- elif (
324
- add_data.get("status")
325
- == "already_exists"
326
- ):
327
- msg = f"Tag '{tag}': Already exists"
328
- else:
329
- msg = f"Tag '{tag}': {add_data.get('message', 'Processed')}"
330
- print(f"✅ {msg}")
331
- else:
332
- error_msg = tags_data.get(
333
- "error", "Unknown error"
334
- )
335
- msg = f"Tag '{tag}': Error - {error_msg}"
336
- print(f"❌ {msg}")
337
-
338
- result_messages.append(msg)
339
-
340
- except Exception as direct_error:
341
- error_msg = f"Tag '{tag}': Direct call error - {str(direct_error)}"
342
- print(f"❌ {error_msg}")
343
- result_messages.append(error_msg)
344
-
345
- except Exception as fallback_error:
346
- error_msg = (
347
- f"Fallback approach failed: {str(fallback_error)}"
348
- )
349
- print(f"❌ {error_msg}")
350
- result_messages.append(error_msg)
351
-
352
- except Exception as e:
353
- error_msg = f"Error during agent processing: {str(e)}"
354
- print(f"❌ {error_msg}")
355
- result_messages.append(error_msg)
356
-
357
- # Store the interaction
358
- base_url = "https://huggingface.co"
359
- discussion_url = f"{base_url}/{repo_name}/discussions/{discussion_num}"
360
-
361
- interaction = {
362
- "timestamp": datetime.now().isoformat(),
363
- "repo": repo_name,
364
- "discussion_title": discussion_title,
365
- "discussion_num": discussion_num,
366
- "discussion_url": discussion_url,
367
- "original_comment": comment_content,
368
- "comment_author": comment_author,
369
- "detected_tags": all_tags,
370
- "results": result_messages,
371
- }
372
-
373
- tag_operations_store.append(interaction)
374
- final_result = " | ".join(result_messages)
375
- print(f"💾 Stored interaction and returning result: {final_result}")
376
- return final_result
377
-
378
- except Exception as e:
379
- error_msg = f"❌ Fatal error in process_webhook_comment: {str(e)}"
380
- print(error_msg)
381
- return error_msg
382
-
383
-
384
- @app.post("/webhook")
385
- async def webhook_handler(request: Request, background_tasks: BackgroundTasks):
386
- """Handle HF Hub webhooks"""
387
- webhook_secret = request.headers.get("X-Webhook-Secret")
388
- if webhook_secret != WEBHOOK_SECRET:
389
- print("❌ Invalid webhook secret")
390
- return {"error": "Invalid webhook secret"}
391
-
392
- payload = await request.json()
393
- print(f"📥 Received webhook payload: {json.dumps(payload, indent=2)}")
394
-
395
- event = payload.get("event", {})
396
- scope = event.get("scope")
397
- action = event.get("action")
398
-
399
- print(f"🔍 Event details - scope: {scope}, action: {action}")
400
-
401
- # Check if this is a discussion comment creation
402
- scope_check = scope == "discussion"
403
- action_check = action == "create"
404
- not_pr = not payload["discussion"]["isPullRequest"]
405
- scope_check = scope_check and not_pr
406
- print(f"✅ not_pr: {not_pr}")
407
- print(f"✅ scope_check: {scope_check}")
408
- print(f"✅ action_check: {action_check}")
409
-
410
- if scope_check and action_check:
411
- # Verify we have the required fields
412
- required_fields = ["comment", "discussion", "repo"]
413
- missing_fields = [field for field in required_fields if field not in payload]
414
-
415
- if missing_fields:
416
- error_msg = f"Missing required fields: {missing_fields}"
417
- print(f"❌ {error_msg}")
418
- return {"error": error_msg}
419
-
420
- print(f"🚀 Processing webhook for repo: {payload['repo']['name']}")
421
- background_tasks.add_task(process_webhook_comment, payload)
422
- return {"status": "processing"}
423
-
424
- print(f"⏭️ Ignoring webhook - scope: {scope}, action: {action}")
425
- return {"status": "ignored"}
426
-
427
-
428
- async def simulate_webhook(
429
- repo_name: str, discussion_title: str, comment_content: str
430
- ) -> str:
431
- """Simulate webhook for testing"""
432
- if not all([repo_name, discussion_title, comment_content]):
433
- return "Please fill in all fields."
434
-
435
- mock_payload = {
436
- "event": {"action": "create", "scope": "discussion"},
437
- "comment": {
438
- "content": comment_content,
439
- "author": {"id": "test-user-id"},
440
- "id": "mock-comment-id",
441
- "hidden": False,
442
- },
443
- "discussion": {
444
- "title": discussion_title,
445
- "num": len(tag_operations_store) + 1,
446
- "id": "mock-discussion-id",
447
- "status": "open",
448
- "isPullRequest": False,
449
- },
450
- "repo": {
451
- "name": repo_name,
452
- "type": "model",
453
- "private": False,
454
- },
455
- }
456
-
457
- response = await process_webhook_comment(mock_payload)
458
- return f"✅ Processed! Results: {response}"
459
-
460
-
461
- def create_gradio_app():
462
- """Create Gradio interface"""
463
- with gr.Blocks(title="HF Tagging Bot", theme=gr.themes.Soft()) as demo:
464
- gr.Markdown("# 🏷️ HF Tagging Bot Dashboard")
465
- gr.Markdown("*Automatically adds tags to models when mentioned in discussions*")
466
-
467
- gr.Markdown("""
468
- ## How it works:
469
- - Monitors HuggingFace Hub discussions
470
- - Detects tag mentions in comments (e.g., "tag: pytorch",
471
- "#transformers")
472
- - Automatically adds recognized tags to the model repository
473
- - Supports common ML tags like: pytorch, tensorflow,
474
- text-generation, etc.
475
- """)
476
-
477
- with gr.Column():
478
- sim_repo = gr.Textbox(
479
- label="Repository",
480
- value="burtenshaw/play-mcp-repo-bot",
481
- placeholder="username/model-name",
482
- )
483
- sim_title = gr.Textbox(
484
- label="Discussion Title",
485
- value="Add pytorch tag",
486
- placeholder="Discussion title",
487
- )
488
- sim_comment = gr.Textbox(
489
- label="Comment",
490
- lines=3,
491
- value="This model should have tags: pytorch, text-generation",
492
- placeholder="Comment mentioning tags...",
493
- )
494
- sim_btn = gr.Button("🏷️ Test Tag Detection")
495
-
496
- with gr.Column():
497
- sim_result = gr.Textbox(label="Result", lines=8)
498
-
499
- sim_btn.click(
500
- fn=simulate_webhook,
501
- inputs=[sim_repo, sim_title, sim_comment],
502
- outputs=sim_result,
503
- )
504
-
505
- gr.Markdown(f"""
506
- ## Recognized Tags:
507
- {", ".join(sorted(RECOGNIZED_TAGS))}
508
- """)
509
-
510
- return demo
511
-
512
-
513
- # Mount Gradio app
514
- gradio_app = create_gradio_app()
515
- app = gr.mount_gradio_app(app, gradio_app, path="/gradio")
516
-
517
-
518
- if __name__ == "__main__":
519
- print("🚀 Starting HF Tagging Bot...")
520
- print("📊 Dashboard: http://localhost:7860/gradio")
521
- print("🔗 Webhook: http://localhost:7860/webhook")
522
- uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)