asmaa105 commited on
Commit
25daca9
Β·
verified Β·
1 Parent(s): 4c1b733

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +522 -0
  2. mcp_server.py +180 -0
  3. requirements.txt +77 -11
app.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ from datetime import datetime
5
+ from typing import List, Dict, Any, Optional, Literal
6
+
7
+ from fastapi import FastAPI, Request, BackgroundTasks
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ import gradio as gr
10
+ import uvicorn
11
+ from pydantic import BaseModel
12
+ from huggingface_hub.inference._mcp.agent import Agent
13
+ from dotenv import load_dotenv
14
+
15
+ load_dotenv()
16
+
17
+ # Configuration
18
+ WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET", "716f77a91d0415cd0e3ed9dc8d188fc9ee53b11a8661e161a86f669f598a8016")
19
+ HF_TOKEN = os.getenv("HF_TOKEN")
20
+ HF_MODEL = os.getenv("HF_MODEL", "microsoft/DialoGPT-medium")
21
+ # Use a valid provider literal from the documentation
22
+ DEFAULT_PROVIDER: Literal["hf-inference"] = "hf-inference"
23
+ HF_PROVIDER = os.getenv("HF_PROVIDER", DEFAULT_PROVIDER)
24
+
25
+ # Simple storage for processed tag operations
26
+ tag_operations_store: List[Dict[str, Any]] = []
27
+
28
+ # Agent instance
29
+ agent_instance: Optional[Agent] = None
30
+
31
+ # Common ML tags that we recognize for auto-tagging
32
+ RECOGNIZED_TAGS = {
33
+ "pytorch",
34
+ "tensorflow",
35
+ "jax",
36
+ "transformers",
37
+ "diffusers",
38
+ "text-generation",
39
+ "text-classification",
40
+ "question-answering",
41
+ "text-to-image",
42
+ "image-classification",
43
+ "object-detection",
44
+ " ",
45
+ "fill-mask",
46
+ "token-classification",
47
+ "translation",
48
+ "summarization",
49
+ "feature-extraction",
50
+ "sentence-similarity",
51
+ "zero-shot-classification",
52
+ "image-to-text",
53
+ "automatic-speech-recognition",
54
+ "audio-classification",
55
+ "voice-activity-detection",
56
+ "depth-estimation",
57
+ "image-segmentation",
58
+ "video-classification",
59
+ "reinforcement-learning",
60
+ "tabular-classification",
61
+ "tabular-regression",
62
+ "time-series-forecasting",
63
+ "graph-ml",
64
+ "robotics",
65
+ "computer-vision",
66
+ "nlp",
67
+ "cv",
68
+ "multimodal",
69
+ }
70
+
71
+
72
+ class WebhookEvent(BaseModel):
73
+ event: Dict[str, str]
74
+ comment: Dict[str, Any]
75
+ discussion: Dict[str, Any]
76
+ repo: Dict[str, str]
77
+
78
+
79
+ app = FastAPI(title="HF Tagging Bot")
80
+ app.add_middleware(CORSMiddleware, allow_origins=["*"])
81
+
82
+
83
+ async def get_agent():
84
+ """Get or create Agent instance"""
85
+ print("πŸ€– get_agent() called...")
86
+ global agent_instance
87
+ if agent_instance is None and HF_TOKEN:
88
+ print("πŸ”§ Creating new Agent instance...")
89
+ print(f"πŸ”‘ HF_TOKEN present: {bool(HF_TOKEN)}")
90
+ print(f"πŸ€– Model: {HF_MODEL}")
91
+ print(f"πŸ”— Provider: {DEFAULT_PROVIDER}")
92
+
93
+ try:
94
+ agent_instance = Agent(
95
+ model=HF_MODEL,
96
+ provider=DEFAULT_PROVIDER,
97
+ api_key=HF_TOKEN,
98
+ servers=[
99
+ {
100
+ "type": "stdio",
101
+ "config": {
102
+ "command": "python",
103
+ "args": ["mcp_server.py"],
104
+ "cwd": ".", # Ensure correct working directory
105
+ "env": {"HF_TOKEN": HF_TOKEN} if HF_TOKEN else {},
106
+ },
107
+ }
108
+ ],
109
+ )
110
+ print("βœ… Agent instance created successfully")
111
+ print("πŸ”§ Loading tools...")
112
+ await agent_instance.load_tools()
113
+ print("βœ… Tools loaded successfully")
114
+ except Exception as e:
115
+ print(f"❌ Error creating/loading agent: {str(e)}")
116
+ agent_instance = None
117
+ elif agent_instance is None:
118
+ print("❌ No HF_TOKEN available, cannot create agent")
119
+ else:
120
+ print("βœ… Using existing agent instance")
121
+
122
+ return agent_instance
123
+
124
+
125
+ def extract_tags_from_text(text: str) -> List[str]:
126
+ """Extract potential tags from discussion text"""
127
+ text_lower = text.lower()
128
+
129
+ # Look for explicit tag mentions like "tag: pytorch" or "#pytorch"
130
+ explicit_tags = []
131
+
132
+ # Pattern 1: "tag: something" or "tags: something"
133
+ tag_pattern = r"tags?:\s*([a-zA-Z0-9-_,\s]+)"
134
+ matches = re.findall(tag_pattern, text_lower)
135
+ for match in matches:
136
+ # Split by comma and clean up
137
+ tags = [tag.strip() for tag in match.split(",")]
138
+ explicit_tags.extend(tags)
139
+
140
+ # Pattern 2: "#hashtag" style
141
+ hashtag_pattern = r"#([a-zA-Z0-9-_]+)"
142
+ hashtag_matches = re.findall(hashtag_pattern, text_lower)
143
+ explicit_tags.extend(hashtag_matches)
144
+
145
+ # Pattern 3: Look for recognized tags mentioned in natural text
146
+ mentioned_tags = []
147
+ for tag in RECOGNIZED_TAGS:
148
+ if tag in text_lower:
149
+ mentioned_tags.append(tag)
150
+
151
+ # Combine and deduplicate
152
+ all_tags = list(set(explicit_tags + mentioned_tags))
153
+
154
+ # Filter to only include recognized tags or explicitly mentioned ones
155
+ valid_tags = []
156
+ for tag in all_tags:
157
+ if tag in RECOGNIZED_TAGS or tag in explicit_tags:
158
+ valid_tags.append(tag)
159
+
160
+ return valid_tags
161
+
162
+
163
+ async def process_webhook_comment(webhook_data: Dict[str, Any]):
164
+ """Process webhook to detect and add tags"""
165
+ print("🏷️ Starting process_webhook_comment...")
166
+
167
+ try:
168
+ comment_content = webhook_data["comment"]["content"]
169
+ discussion_title = webhook_data["discussion"]["title"]
170
+ repo_name = webhook_data["repo"]["name"]
171
+ discussion_num = webhook_data["discussion"]["num"]
172
+ # Author is an object with "id" field
173
+ comment_author = webhook_data["comment"]["author"].get("id", "unknown")
174
+
175
+ print(f"πŸ“ Comment content: {comment_content}")
176
+ print(f"πŸ“° Discussion title: {discussion_title}")
177
+ print(f"πŸ“¦ Repository: {repo_name}")
178
+
179
+ # Extract potential tags from the comment and discussion title
180
+ comment_tags = extract_tags_from_text(comment_content)
181
+ title_tags = extract_tags_from_text(discussion_title)
182
+ all_tags = list(set(comment_tags + title_tags))
183
+
184
+ print(f"πŸ” Comment tags found: {comment_tags}")
185
+ print(f"πŸ” Title tags found: {title_tags}")
186
+ print(f"🏷️ All unique tags: {all_tags}")
187
+
188
+ result_messages = []
189
+
190
+ if not all_tags:
191
+ msg = "No recognizable tags found in the discussion."
192
+ print(f"❌ {msg}")
193
+ result_messages.append(msg)
194
+ else:
195
+ print("πŸ€– Getting agent instance...")
196
+ agent = await get_agent()
197
+ if not agent:
198
+ msg = "Error: Agent not configured (missing HF_TOKEN)"
199
+ print(f"❌ {msg}")
200
+ result_messages.append(msg)
201
+ else:
202
+ print("βœ… Agent instance obtained successfully")
203
+
204
+ # Process all tags in a single conversation with the agent
205
+ try:
206
+ # Create a comprehensive prompt for the agent
207
+ user_prompt = f"""
208
+ I need to add the following tags to the repository '{repo_name}': {", ".join(all_tags)}
209
+ For each tag, please:
210
+ 1. Check if the tag already exists on the repository using get_current_tags
211
+ 2. If the tag doesn't exist, add it using add_new_tag
212
+ 3. Provide a summary of what was done for each tag
213
+ Please process all {len(all_tags)} tags: {", ".join(all_tags)}
214
+ """
215
+
216
+ print("πŸ’¬ Sending comprehensive prompt to agent...")
217
+ print(f"πŸ“ Prompt: {user_prompt}")
218
+
219
+ # Let the agent handle the entire conversation
220
+ conversation_result = []
221
+
222
+ try:
223
+ async for item in agent.run(user_prompt):
224
+ # The agent yields different types of items
225
+ item_str = str(item)
226
+ conversation_result.append(item_str)
227
+
228
+ # Log important events
229
+ if (
230
+ "tool_call" in item_str.lower()
231
+ or "function" in item_str.lower()
232
+ ):
233
+ print(f"πŸ”§ Agent using tools: {item_str[:200]}...")
234
+ elif "content" in item_str and len(item_str) < 500:
235
+ print(f"πŸ’­ Agent response: {item_str}")
236
+
237
+ # Extract the final response from the conversation
238
+ full_response = " ".join(conversation_result)
239
+ print(f"πŸ“‹ Agent conversation completed successfully")
240
+
241
+ # Try to extract meaningful results for each tag
242
+ for tag in all_tags:
243
+ tag_mentioned = tag.lower() in full_response.lower()
244
+
245
+ if (
246
+ "already exists" in full_response.lower()
247
+ and tag_mentioned
248
+ ):
249
+ msg = f"Tag '{tag}': Already exists"
250
+ elif (
251
+ "pr" in full_response.lower()
252
+ or "pull request" in full_response.lower()
253
+ ):
254
+ if tag_mentioned:
255
+ msg = f"Tag '{tag}': PR created successfully"
256
+ else:
257
+ msg = (
258
+ f"Tag '{tag}': Processed "
259
+ "(PR may have been created)"
260
+ )
261
+ elif "success" in full_response.lower() and tag_mentioned:
262
+ msg = f"Tag '{tag}': Successfully processed"
263
+ elif "error" in full_response.lower() and tag_mentioned:
264
+ msg = f"Tag '{tag}': Error during processing"
265
+ else:
266
+ msg = f"Tag '{tag}': Processed by agent"
267
+
268
+ print(f"βœ… Result for tag '{tag}': {msg}")
269
+ result_messages.append(msg)
270
+
271
+ except Exception as agent_error:
272
+ print(f"⚠️ Agent streaming failed: {str(agent_error)}")
273
+ print("πŸ”„ Falling back to direct MCP tool calls...")
274
+
275
+ # Import the MCP server functions directly as fallback
276
+ try:
277
+ import sys
278
+ import importlib.util
279
+
280
+ # Load the MCP server module
281
+ spec = importlib.util.spec_from_file_location(
282
+ "mcp_server", "./mcp_server.py"
283
+ )
284
+ mcp_module = importlib.util.module_from_spec(spec)
285
+ spec.loader.exec_module(mcp_module)
286
+
287
+ # Use the MCP tools directly for each tag
288
+ for tag in all_tags:
289
+ try:
290
+ print(
291
+ f"πŸ”§ Directly calling get_current_tags for '{tag}'"
292
+ )
293
+ current_tags_result = mcp_module.get_current_tags(
294
+ repo_name
295
+ )
296
+ print(
297
+ f"πŸ“„ Current tags result: {current_tags_result}"
298
+ )
299
+
300
+ # Parse the JSON result
301
+ import json
302
+
303
+ tags_data = json.loads(current_tags_result)
304
+
305
+ if tags_data.get("status") == "success":
306
+ current_tags = tags_data.get("current_tags", [])
307
+ if tag in current_tags:
308
+ msg = f"Tag '{tag}': Already exists"
309
+ print(f"βœ… {msg}")
310
+ else:
311
+ print(
312
+ f"πŸ”§ Directly calling add_new_tag for '{tag}'"
313
+ )
314
+ add_result = mcp_module.add_new_tag(
315
+ repo_name, tag
316
+ )
317
+ print(f"πŸ“„ Add tag result: {add_result}")
318
+
319
+ add_data = json.loads(add_result)
320
+ if add_data.get("status") == "success":
321
+ pr_url = add_data.get("pr_url", "")
322
+ msg = f"Tag '{tag}': PR created - {pr_url}"
323
+ elif (
324
+ add_data.get("status")
325
+ == "already_exists"
326
+ ):
327
+ msg = f"Tag '{tag}': Already exists"
328
+ else:
329
+ msg = f"Tag '{tag}': {add_data.get('message', 'Processed')}"
330
+ print(f"βœ… {msg}")
331
+ else:
332
+ error_msg = tags_data.get(
333
+ "error", "Unknown error"
334
+ )
335
+ msg = f"Tag '{tag}': Error - {error_msg}"
336
+ print(f"❌ {msg}")
337
+
338
+ result_messages.append(msg)
339
+
340
+ except Exception as direct_error:
341
+ error_msg = f"Tag '{tag}': Direct call error - {str(direct_error)}"
342
+ print(f"❌ {error_msg}")
343
+ result_messages.append(error_msg)
344
+
345
+ except Exception as fallback_error:
346
+ error_msg = (
347
+ f"Fallback approach failed: {str(fallback_error)}"
348
+ )
349
+ print(f"❌ {error_msg}")
350
+ result_messages.append(error_msg)
351
+
352
+ except Exception as e:
353
+ error_msg = f"Error during agent processing: {str(e)}"
354
+ print(f"❌ {error_msg}")
355
+ result_messages.append(error_msg)
356
+
357
+ # Store the interaction
358
+ base_url = "https://huggingface.co"
359
+ discussion_url = f"{base_url}/{repo_name}/discussions/{discussion_num}"
360
+
361
+ interaction = {
362
+ "timestamp": datetime.now().isoformat(),
363
+ "repo": repo_name,
364
+ "discussion_title": discussion_title,
365
+ "discussion_num": discussion_num,
366
+ "discussion_url": discussion_url,
367
+ "original_comment": comment_content,
368
+ "comment_author": comment_author,
369
+ "detected_tags": all_tags,
370
+ "results": result_messages,
371
+ }
372
+
373
+ tag_operations_store.append(interaction)
374
+ final_result = " | ".join(result_messages)
375
+ print(f"πŸ’Ύ Stored interaction and returning result: {final_result}")
376
+ return final_result
377
+
378
+ except Exception as e:
379
+ error_msg = f"❌ Fatal error in process_webhook_comment: {str(e)}"
380
+ print(error_msg)
381
+ return error_msg
382
+
383
+
384
+ @app.post("/webhook")
385
+ async def webhook_handler(request: Request, background_tasks: BackgroundTasks):
386
+ """Handle HF Hub webhooks"""
387
+ webhook_secret = request.headers.get("X-Webhook-Secret")
388
+ if webhook_secret != WEBHOOK_SECRET:
389
+ print("❌ Invalid webhook secret")
390
+ return {"error": "Invalid webhook secret"}
391
+
392
+ payload = await request.json()
393
+ print(f"πŸ“₯ Received webhook payload: {json.dumps(payload, indent=2)}")
394
+
395
+ event = payload.get("event", {})
396
+ scope = event.get("scope")
397
+ action = event.get("action")
398
+
399
+ print(f"πŸ” Event details - scope: {scope}, action: {action}")
400
+
401
+ # Check if this is a discussion comment creation
402
+ scope_check = scope == "discussion"
403
+ action_check = action == "create"
404
+ not_pr = not payload["discussion"]["isPullRequest"]
405
+ scope_check = scope_check and not_pr
406
+ print(f"βœ… not_pr: {not_pr}")
407
+ print(f"βœ… scope_check: {scope_check}")
408
+ print(f"βœ… action_check: {action_check}")
409
+
410
+ if scope_check and action_check:
411
+ # Verify we have the required fields
412
+ required_fields = ["comment", "discussion", "repo"]
413
+ missing_fields = [field for field in required_fields if field not in payload]
414
+
415
+ if missing_fields:
416
+ error_msg = f"Missing required fields: {missing_fields}"
417
+ print(f"❌ {error_msg}")
418
+ return {"error": error_msg}
419
+
420
+ print(f"πŸš€ Processing webhook for repo: {payload['repo']['name']}")
421
+ background_tasks.add_task(process_webhook_comment, payload)
422
+ return {"status": "processing"}
423
+
424
+ print(f"⏭️ Ignoring webhook - scope: {scope}, action: {action}")
425
+ return {"status": "ignored"}
426
+
427
+
428
+ async def simulate_webhook(
429
+ repo_name: str, discussion_title: str, comment_content: str
430
+ ) -> str:
431
+ """Simulate webhook for testing"""
432
+ if not all([repo_name, discussion_title, comment_content]):
433
+ return "Please fill in all fields."
434
+
435
+ mock_payload = {
436
+ "event": {"action": "create", "scope": "discussion"},
437
+ "comment": {
438
+ "content": comment_content,
439
+ "author": {"id": "test-user-id"},
440
+ "id": "mock-comment-id",
441
+ "hidden": False,
442
+ },
443
+ "discussion": {
444
+ "title": discussion_title,
445
+ "num": len(tag_operations_store) + 1,
446
+ "id": "mock-discussion-id",
447
+ "status": "open",
448
+ "isPullRequest": False,
449
+ },
450
+ "repo": {
451
+ "name": repo_name,
452
+ "type": "model",
453
+ "private": False,
454
+ },
455
+ }
456
+
457
+ response = await process_webhook_comment(mock_payload)
458
+ return f"βœ… Processed! Results: {response}"
459
+
460
+
461
+ def create_gradio_app():
462
+ """Create Gradio interface"""
463
+ with gr.Blocks(title="HF Tagging Bot", theme=gr.themes.Soft()) as demo:
464
+ gr.Markdown("# 🏷️ HF Tagging Bot Dashboard")
465
+ gr.Markdown("*Automatically adds tags to models when mentioned in discussions*")
466
+
467
+ gr.Markdown("""
468
+ ## How it works:
469
+ - Monitors HuggingFace Hub discussions
470
+ - Detects tag mentions in comments (e.g., "tag: pytorch",
471
+ "#transformers")
472
+ - Automatically adds recognized tags to the model repository
473
+ - Supports common ML tags like: pytorch, tensorflow,
474
+ text-generation, etc.
475
+ """)
476
+
477
+ with gr.Column():
478
+ sim_repo = gr.Textbox(
479
+ label="Repository",
480
+ value="burtenshaw/play-mcp-repo-bot",
481
+ placeholder="username/model-name",
482
+ )
483
+ sim_title = gr.Textbox(
484
+ label="Discussion Title",
485
+ value="Add pytorch tag",
486
+ placeholder="Discussion title",
487
+ )
488
+ sim_comment = gr.Textbox(
489
+ label="Comment",
490
+ lines=3,
491
+ value="This model should have tags: pytorch, text-generation",
492
+ placeholder="Comment mentioning tags...",
493
+ )
494
+ sim_btn = gr.Button("🏷️ Test Tag Detection")
495
+
496
+ with gr.Column():
497
+ sim_result = gr.Textbox(label="Result", lines=8)
498
+
499
+ sim_btn.click(
500
+ fn=simulate_webhook,
501
+ inputs=[sim_repo, sim_title, sim_comment],
502
+ outputs=sim_result,
503
+ )
504
+
505
+ gr.Markdown(f"""
506
+ ## Recognized Tags:
507
+ {", ".join(sorted(RECOGNIZED_TAGS))}
508
+ """)
509
+
510
+ return demo
511
+
512
+
513
+ # Mount Gradio app
514
+ gradio_app = create_gradio_app()
515
+ app = gr.mount_gradio_app(app, gradio_app, path="/gradio")
516
+
517
+
518
+ if __name__ == "__main__":
519
+ print("πŸš€ Starting HF Tagging Bot...")
520
+ print("πŸ“Š Dashboard: http://localhost:7860/gradio")
521
+ print("πŸ”— Webhook: http://localhost:7860/webhook")
522
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
mcp_server.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simplified MCP Server for HuggingFace Hub Tagging Operations using FastMCP
3
+ """
4
+
5
+ import os
6
+ import json
7
+ from fastmcp import FastMCP
8
+ from huggingface_hub import HfApi, model_info, ModelCard, ModelCardData
9
+ from huggingface_hub.utils import HfHubHTTPError
10
+ from dotenv import load_dotenv
11
+
12
+ load_dotenv()
13
+
14
+ # Configuration
15
+ HF_TOKEN = os.getenv("HF_TOKEN")
16
+
17
+ # Initialize HF API client
18
+ hf_api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
19
+
20
+ # Create the FastMCP server
21
+ mcp = FastMCP("hf-tagging-bot")
22
+
23
+
24
+ @mcp.tool()
25
+ def get_current_tags(repo_id: str) -> str:
26
+ """Get current tags from a HuggingFace model repository"""
27
+ print(f"πŸ”§ get_current_tags called with repo_id: {repo_id}")
28
+
29
+ if not hf_api:
30
+ error_result = {"error": "HF token not configured"}
31
+ json_str = json.dumps(error_result)
32
+ print(f"❌ No HF API token - returning: {json_str}")
33
+ return json_str
34
+
35
+ try:
36
+ print(f"πŸ“‘ Fetching model info for: {repo_id}")
37
+ info = model_info(repo_id=repo_id, token=HF_TOKEN)
38
+ current_tags = info.tags if info.tags else []
39
+ print(f"🏷️ Found {len(current_tags)} tags: {current_tags}")
40
+
41
+ result = {
42
+ "status": "success",
43
+ "repo_id": repo_id,
44
+ "current_tags": current_tags,
45
+ "count": len(current_tags),
46
+ }
47
+ json_str = json.dumps(result)
48
+ print(f"βœ… get_current_tags returning: {json_str}")
49
+ return json_str
50
+
51
+ except Exception as e:
52
+ print(f"❌ Error in get_current_tags: {str(e)}")
53
+ error_result = {"status": "error", "repo_id": repo_id, "error": str(e)}
54
+ json_str = json.dumps(error_result)
55
+ print(f"❌ get_current_tags error returning: {json_str}")
56
+ return json_str
57
+
58
+
59
+ @mcp.tool()
60
+ def add_new_tag(repo_id: str, new_tag: str) -> str:
61
+ """Add a new tag to a HuggingFace model repository via PR"""
62
+ print(f"πŸ”§ add_new_tag called with repo_id: {repo_id}, new_tag: {new_tag}")
63
+
64
+ if not hf_api:
65
+ error_result = {"error": "HF token not configured"}
66
+ json_str = json.dumps(error_result)
67
+ print(f"❌ No HF API token - returning: {json_str}")
68
+ return json_str
69
+
70
+ try:
71
+ # Get current model info and tags
72
+ print(f"πŸ“‘ Fetching current model info for: {repo_id}")
73
+ info = model_info(repo_id=repo_id, token=HF_TOKEN)
74
+ current_tags = info.tags if info.tags else []
75
+ print(f"🏷️ Current tags: {current_tags}")
76
+
77
+ # Check if tag already exists
78
+ if new_tag in current_tags:
79
+ print(f"⚠️ Tag '{new_tag}' already exists in {current_tags}")
80
+ result = {
81
+ "status": "already_exists",
82
+ "repo_id": repo_id,
83
+ "tag": new_tag,
84
+ "message": f"Tag '{new_tag}' already exists",
85
+ }
86
+ json_str = json.dumps(result)
87
+ print(f"🏷️ add_new_tag (already exists) returning: {json_str}")
88
+ return json_str
89
+
90
+ # Add the new tag to existing tags
91
+ updated_tags = current_tags + [new_tag]
92
+ print(f"πŸ†• Will update tags from {current_tags} to {updated_tags}")
93
+
94
+ # Create model card content with updated tags
95
+ try:
96
+ # Load existing model card
97
+ print(f"πŸ“„ Loading existing model card...")
98
+ card = ModelCard.load(repo_id, token=HF_TOKEN)
99
+ if not hasattr(card, "data") or card.data is None:
100
+ card.data = ModelCardData()
101
+ except HfHubHTTPError:
102
+ # Create new model card if none exists
103
+ print(f"πŸ“„ Creating new model card (none exists)")
104
+ card = ModelCard("")
105
+ card.data = ModelCardData()
106
+
107
+ # Update tags - create new ModelCardData with updated tags
108
+ card_dict = card.data.to_dict()
109
+ card_dict["tags"] = updated_tags
110
+ card.data = ModelCardData(**card_dict)
111
+
112
+ # Create a pull request with the updated model card
113
+ pr_title = f"Add '{new_tag}' tag"
114
+ pr_description = f"""
115
+ ## Add tag: {new_tag}
116
+ This PR adds the `{new_tag}` tag to the model repository.
117
+ **Changes:**
118
+ - Added `{new_tag}` to model tags
119
+ - Updated from {len(current_tags)} to {len(updated_tags)} tags
120
+ **Current tags:** {", ".join(current_tags) if current_tags else "None"}
121
+ **New tags:** {", ".join(updated_tags)}
122
+ """
123
+
124
+ print(f"πŸš€ Creating PR with title: {pr_title}")
125
+
126
+ # Create commit with updated model card using CommitOperationAdd
127
+ from huggingface_hub import CommitOperationAdd
128
+
129
+ commit_info = hf_api.create_commit(
130
+ repo_id=repo_id,
131
+ operations=[
132
+ CommitOperationAdd(
133
+ path_in_repo="README.md", path_or_fileobj=str(card).encode("utf-8")
134
+ )
135
+ ],
136
+ commit_message=pr_title,
137
+ commit_description=pr_description,
138
+ token=HF_TOKEN,
139
+ create_pr=True,
140
+ )
141
+
142
+ # Extract PR URL from commit info
143
+ pr_url_attr = commit_info.pr_url
144
+ pr_url = pr_url_attr if hasattr(commit_info, "pr_url") else str(commit_info)
145
+
146
+ print(f"βœ… PR created successfully! URL: {pr_url}")
147
+
148
+ result = {
149
+ "status": "success",
150
+ "repo_id": repo_id,
151
+ "tag": new_tag,
152
+ "pr_url": pr_url,
153
+ "previous_tags": current_tags,
154
+ "new_tags": updated_tags,
155
+ "message": f"Created PR to add tag '{new_tag}'",
156
+ }
157
+ json_str = json.dumps(result)
158
+ print(f"βœ… add_new_tag success returning: {json_str}")
159
+ return json_str
160
+
161
+ except Exception as e:
162
+ print(f"❌ Error in add_new_tag: {str(e)}")
163
+ print(f"❌ Error type: {type(e)}")
164
+ import traceback
165
+
166
+ print(f"❌ Traceback: {traceback.format_exc()}")
167
+
168
+ error_result = {
169
+ "status": "error",
170
+ "repo_id": repo_id,
171
+ "tag": new_tag,
172
+ "error": str(e),
173
+ }
174
+ json_str = json.dumps(error_result)
175
+ print(f"❌ add_new_tag error returning: {json_str}")
176
+ return json_str
177
+
178
+
179
+ if __name__ == "__main__":
180
+ mcp.run()
requirements.txt CHANGED
@@ -1,11 +1,77 @@
1
- Flask==3.0.3
2
- ipython==8.12.3
3
- numpy==2.0.0
4
- opencv_python==4.9.0.80
5
- opencv_python_headless==4.10.0.84
6
- pandas==2.2.2
7
- Pillow==10.4.0
8
- Requests==2.32.3
9
- roboflow==1.1.34
10
- streamlit==1.36.0
11
- ultralytics==8.0.196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv export --format requirements-txt --no-hashes
3
+ aiofiles==24.1.0
4
+ aiohappyeyeballs==2.6.1
5
+ aiohttp==3.12.2
6
+ aiosignal==1.3.2
7
+ annotated-types==0.7.0
8
+ anyio==4.9.0
9
+ attrs==25.3.0
10
+ audioop-lts==0.2.1 ; python_full_version >= '3.13'
11
+ certifi==2025.4.26
12
+ charset-normalizer==3.4.2
13
+ click==8.2.1
14
+ colorama==0.4.6 ; sys_platform == 'win32' or platform_system == 'Windows'
15
+ exceptiongroup==1.3.0
16
+ fastapi==0.115.12
17
+ fastmcp==2.5.1
18
+ ffmpy==0.5.0
19
+ filelock==3.18.0
20
+ frozenlist==1.6.0
21
+ fsspec==2025.5.1
22
+ gradio==5.31.0
23
+ gradio-client==1.10.1
24
+ groovy==0.1.2
25
+ h11==0.16.0
26
+ hf-xet==1.1.2 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
27
+ httpcore==1.0.9
28
+ httptools==0.6.4
29
+ httpx==0.28.1
30
+ httpx-sse==0.4.0
31
+ huggingface-hub==0.32.2
32
+ idna==3.10
33
+ jinja2==3.1.6
34
+ markdown-it-py==3.0.0
35
+ markupsafe==3.0.2
36
+ mcp==1.9.1
37
+ mdurl==0.1.2
38
+ multidict==6.4.4
39
+ numpy==2.2.6
40
+ openapi-pydantic==0.5.1
41
+ orjson==3.10.18
42
+ packaging==25.0
43
+ pandas==2.2.3
44
+ pillow==11.2.1
45
+ propcache==0.3.1
46
+ pydantic==2.11.5
47
+ pydantic-core==2.33.2
48
+ pydantic-settings==2.9.1
49
+ pydub==0.25.1
50
+ pygments==2.19.1
51
+ python-dateutil==2.9.0.post0
52
+ python-dotenv==1.1.0
53
+ python-multipart==0.0.20
54
+ pytz==2025.2
55
+ pyyaml==6.0.2
56
+ requests==2.32.3
57
+ rich==14.0.0
58
+ ruff==0.11.11 ; sys_platform != 'emscripten'
59
+ safehttpx==0.1.6
60
+ semantic-version==2.10.0
61
+ shellingham==1.5.4
62
+ six==1.17.0
63
+ sniffio==1.3.1
64
+ sse-starlette==2.3.5
65
+ starlette==0.46.2
66
+ tomlkit==0.13.2
67
+ tqdm==4.67.1
68
+ typer==0.16.0
69
+ typing-extensions==4.13.2
70
+ typing-inspection==0.4.1
71
+ tzdata==2025.2
72
+ urllib3==2.4.0
73
+ uvicorn==0.34.2
74
+ uvloop==0.21.0 ; platform_python_implementation != 'PyPy' and sys_platform != 'cygwin' and sys_platform != 'win32'
75
+ watchfiles==1.0.5
76
+ websockets==15.0.1
77
+ yarl==1.20.0