cyrilzakka HF Staff commited on
Commit
106aab5
·
verified ·
1 Parent(s): d9ba15b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +468 -0
app.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import json
4
+ from typing import List, Dict, Any, Union
5
+ from contextlib import AsyncExitStack
6
+
7
+ import gradio as gr
8
+ from gradio.components.chatbot import ChatMessage
9
+ from mcp import ClientSession, StdioServerParameters
10
+ from mcp.client.stdio import stdio_client
11
+ from mcp.client.sse import sse_client
12
+ from anthropic import Anthropic
13
+ from datasets import load_dataset
14
+ import pandas as pd
15
+
16
+ loop = asyncio.new_event_loop()
17
+ asyncio.set_event_loop(loop)
18
+
19
+ class MCPClientWrapper:
20
+ def __init__(self):
21
+ self.session = None
22
+ self.exit_stack = None
23
+ self.anthropic = None
24
+ self.tools = []
25
+ self.dataset = None
26
+ self.validation_results = []
27
+
28
+ def set_api_key(self, api_key: str) -> str:
29
+ """Set the Anthropic API key and initialize the client"""
30
+ if not api_key or not api_key.strip():
31
+ return "Please enter a valid Anthropic API key"
32
+
33
+ try:
34
+ self.anthropic = Anthropic(api_key=api_key.strip())
35
+ return "API key set successfully ✅"
36
+ except Exception as e:
37
+ return f"Failed to set API key: {str(e)}"
38
+
39
+ def connect(self, server_input: str) -> str:
40
+ if not self.anthropic:
41
+ return "Please set your Anthropic API key first"
42
+ return loop.run_until_complete(self._connect(server_input))
43
+
44
+ async def _connect(self, server_input: str) -> str:
45
+ if self.exit_stack:
46
+ await self.exit_stack.aclose()
47
+
48
+ self.exit_stack = AsyncExitStack()
49
+
50
+ try:
51
+ # Check if input is a URL (starts with http:// or https://)
52
+ if server_input.startswith(('http://', 'https://')):
53
+ # Connect via SSE
54
+ read, write = await self.exit_stack.enter_async_context(
55
+ sse_client(server_input)
56
+ )
57
+ connection_type = "SSE URL"
58
+ else:
59
+ # Connect via stdio (local file)
60
+ is_python = server_input.endswith('.py')
61
+ command = "python" if is_python else "node"
62
+
63
+ server_params = StdioServerParameters(
64
+ command=command,
65
+ args=[server_input],
66
+ env={"PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"}
67
+ )
68
+
69
+ read, write = await self.exit_stack.enter_async_context(
70
+ stdio_client(server_params)
71
+ )
72
+ connection_type = "Local script"
73
+
74
+ self.session = await self.exit_stack.enter_async_context(
75
+ ClientSession(read, write)
76
+ )
77
+ await self.session.initialize()
78
+
79
+ response = await self.session.list_tools()
80
+ self.tools = [{
81
+ "name": tool.name,
82
+ "description": tool.description,
83
+ "input_schema": tool.inputSchema
84
+ } for tool in response.tools]
85
+
86
+ tool_names = [tool["name"] for tool in self.tools]
87
+ return f"Connected to MCP server via {connection_type}. Available tools: {', '.join(tool_names)}"
88
+
89
+ except Exception as e:
90
+ return f"Connection failed: {str(e)}"
91
+
92
+ def load_dataset(self) -> tuple:
93
+ """Load the TAAIC Phase1 validation dataset"""
94
+ try:
95
+ self.dataset = load_dataset("aitxchallenge/Phase1_Model_Validator", split="train")
96
+ dataset_info = f"Dataset loaded successfully! {len(self.dataset)} validation cases available."
97
+
98
+ # Create a preview of the dataset
99
+ df = pd.DataFrame(self.dataset)
100
+ preview = df.head().to_string()
101
+
102
+ return (
103
+ dataset_info,
104
+ gr.Button("🔍 Validate", interactive=True),
105
+ gr.Textbox(value=f"Dataset Preview:\n{preview}", visible=True)
106
+ )
107
+ except Exception as e:
108
+ return (
109
+ f"Failed to load dataset: {str(e)}",
110
+ gr.Button("📥 Load Dataset", interactive=True),
111
+ gr.Textbox(visible=False)
112
+ )
113
+
114
+ def validate_tools(self) -> str:
115
+ """Run validation on all dataset cases"""
116
+ if not self.anthropic:
117
+ return "Please set your Anthropic API key first."
118
+
119
+ if not self.dataset:
120
+ return "Please load the dataset first."
121
+
122
+ if not self.session:
123
+ return "Please connect to an MCP server first."
124
+
125
+ return loop.run_until_complete(self._run_validation())
126
+
127
+ async def _run_validation(self) -> str:
128
+ """Async validation runner"""
129
+ self.validation_results = []
130
+ total_cases = len(self.dataset)
131
+ passed = 0
132
+ failed = 0
133
+
134
+ for i, case in enumerate(self.dataset):
135
+ try:
136
+ # Extract test case information
137
+ query = case.get('query', case.get('question', ''))
138
+ expected_output = case.get('expected_output', case.get('expected', ''))
139
+ test_id = case.get('id', f'test_{i}')
140
+
141
+ # Run the query through the MCP tools
142
+ result = await self._validate_single_case(query, expected_output, test_id)
143
+ self.validation_results.append(result)
144
+
145
+ if result['passed']:
146
+ passed += 1
147
+ else:
148
+ failed += 1
149
+
150
+ except Exception as e:
151
+ failed += 1
152
+ self.validation_results.append({
153
+ 'test_id': test_id,
154
+ 'query': query,
155
+ 'error': str(e),
156
+ 'passed': False
157
+ })
158
+
159
+ # Generate validation report
160
+ report = f"""
161
+ VALIDATION COMPLETE
162
+ ==================
163
+ Total Cases: {total_cases}
164
+ Passed: {passed}
165
+ Failed: {failed}
166
+ Success Rate: {(passed/total_cases)*100:.1f}%
167
+
168
+ DETAILED RESULTS:
169
+ """
170
+
171
+ for result in self.validation_results:
172
+ status = "✅ PASS" if result['passed'] else "❌ FAIL"
173
+ report += f"\n{status} [{result['test_id']}] {result['query'][:50]}..."
174
+ if not result['passed'] and 'error' in result:
175
+ report += f"\n Error: {result['error']}"
176
+
177
+ return report
178
+
179
+ async def _validate_single_case(self, query: str, expected_output: str, test_id: str) -> Dict[str, Any]:
180
+ """Validate a single test case"""
181
+ try:
182
+ # Send query to Claude with MCP tools
183
+ claude_messages = [{"role": "user", "content": query}]
184
+
185
+ response = self.anthropic.messages.create(
186
+ model="claude-3-5-sonnet-20241022",
187
+ max_tokens=1000,
188
+ messages=claude_messages,
189
+ tools=self.tools
190
+ )
191
+
192
+ # Process tool calls if any
193
+ actual_output = ""
194
+ for content in response.content:
195
+ if content.type == 'text':
196
+ actual_output += content.text
197
+ elif content.type == 'tool_use':
198
+ tool_result = await self.session.call_tool(content.name, content.input)
199
+ actual_output += str(tool_result.content)
200
+
201
+ # Simple validation logic - you may want to customize this
202
+ passed = self._validate_output(actual_output, expected_output)
203
+
204
+ return {
205
+ 'test_id': test_id,
206
+ 'query': query,
207
+ 'expected': expected_output,
208
+ 'actual': actual_output,
209
+ 'passed': passed
210
+ }
211
+
212
+ except Exception as e:
213
+ return {
214
+ 'test_id': test_id,
215
+ 'query': query,
216
+ 'error': str(e),
217
+ 'passed': False
218
+ }
219
+
220
+ def _validate_output(self, actual: str, expected: str) -> bool:
221
+ """Basic output validation - customize based on your needs"""
222
+ # This is a simple implementation - you may want more sophisticated validation
223
+ if not expected:
224
+ return True # If no expected output specified, consider it passed
225
+
226
+ # You can implement more sophisticated matching here
227
+ # For now, using simple substring matching
228
+ return expected.lower() in actual.lower()
229
+
230
+ def process_message(self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]) -> tuple:
231
+ if not self.anthropic:
232
+ return history + [
233
+ {"role": "user", "content": message},
234
+ {"role": "assistant", "content": "Please set your Anthropic API key first."}
235
+ ], gr.Textbox(value="")
236
+
237
+ if not self.session:
238
+ return history + [
239
+ {"role": "user", "content": message},
240
+ {"role": "assistant", "content": "Please connect to an MCP server first."}
241
+ ], gr.Textbox(value="")
242
+
243
+ new_messages = loop.run_until_complete(self._process_query(message, history))
244
+ return history + [{"role": "user", "content": message}] + new_messages, gr.Textbox(value="")
245
+
246
+ async def _process_query(self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]):
247
+ claude_messages = []
248
+ for msg in history:
249
+ if isinstance(msg, ChatMessage):
250
+ role, content = msg.role, msg.content
251
+ else:
252
+ role, content = msg.get("role"), msg.get("content")
253
+
254
+ if role in ["user", "assistant", "system"]:
255
+ claude_messages.append({"role": role, "content": content})
256
+
257
+ claude_messages.append({"role": "user", "content": message})
258
+
259
+ response = self.anthropic.messages.create(
260
+ model="claude-3-5-sonnet-20241022",
261
+ max_tokens=1000,
262
+ messages=claude_messages,
263
+ tools=self.tools
264
+ )
265
+
266
+ result_messages = []
267
+
268
+ for content in response.content:
269
+ if content.type == 'text':
270
+ result_messages.append({
271
+ "role": "assistant",
272
+ "content": content.text
273
+ })
274
+
275
+ elif content.type == 'tool_use':
276
+ tool_name = content.name
277
+ tool_args = content.input
278
+
279
+ result_messages.append({
280
+ "role": "assistant",
281
+ "content": f"I'll only use the {tool_name} tool to help answer your question.",
282
+ "metadata": {
283
+ "title": f"Using tool: {tool_name}",
284
+ "log": f"Parameters: {json.dumps(tool_args, ensure_ascii=True)}",
285
+ "status": "pending",
286
+ "id": f"tool_call_{tool_name}"
287
+ }
288
+ })
289
+
290
+ result_messages.append({
291
+ "role": "assistant",
292
+ "content": "```json\n" + json.dumps(tool_args, indent=2, ensure_ascii=True) + "\n```",
293
+ "metadata": {
294
+ "parent_id": f"tool_call_{tool_name}",
295
+ "id": f"params_{tool_name}",
296
+ "title": "Tool Parameters"
297
+ }
298
+ })
299
+
300
+ try:
301
+ result = await self.session.call_tool(tool_name, tool_args)
302
+
303
+ if result_messages and "metadata" in result_messages[-2]:
304
+ result_messages[-2]["metadata"]["status"] = "done"
305
+
306
+ result_messages.append({
307
+ "role": "assistant",
308
+ "content": "Here are the results from the tool:",
309
+ "metadata": {
310
+ "title": f"Tool Result for {tool_name}",
311
+ "status": "done",
312
+ "id": f"result_{tool_name}"
313
+ }
314
+ })
315
+
316
+ result_content = result.content
317
+ if isinstance(result_content, list):
318
+ result_content = "\n".join(str(item) for item in result_content)
319
+
320
+ try:
321
+ result_json = json.loads(result_content)
322
+ if isinstance(result_json, dict) and "type" in result_json:
323
+ if result_json["type"] == "image" and "url" in result_json:
324
+ result_messages.append({
325
+ "role": "assistant",
326
+ "content": {"path": result_json["url"], "alt_text": result_json.get("message", "Generated image")},
327
+ "metadata": {
328
+ "parent_id": f"result_{tool_name}",
329
+ "id": f"image_{tool_name}",
330
+ "title": "Generated Image"
331
+ }
332
+ })
333
+ else:
334
+ result_messages.append({
335
+ "role": "assistant",
336
+ "content": "```\n" + result_content + "\n```",
337
+ "metadata": {
338
+ "parent_id": f"result_{tool_name}",
339
+ "id": f"raw_result_{tool_name}",
340
+ "title": "Raw Output"
341
+ }
342
+ })
343
+ except:
344
+ result_messages.append({
345
+ "role": "assistant",
346
+ "content": "```\n" + result_content + "\n```",
347
+ "metadata": {
348
+ "parent_id": f"result_{tool_name}",
349
+ "id": f"raw_result_{tool_name}",
350
+ "title": "Raw Output"
351
+ }
352
+ })
353
+
354
+ claude_messages.append({"role": "user", "content": f"Tool result for {tool_name}: {result_content}"})
355
+ next_response = self.anthropic.messages.create(
356
+ model="claude-3-5-sonnet-20241022",
357
+ max_tokens=1000,
358
+ messages=claude_messages,
359
+ )
360
+
361
+ if next_response.content and next_response.content[0].type == 'text':
362
+ result_messages.append({
363
+ "role": "assistant",
364
+ "content": next_response.content[0].text
365
+ })
366
+
367
+ except Exception as e:
368
+ result_messages.append({
369
+ "role": "assistant",
370
+ "content": f"Error calling tool {tool_name}: {str(e)}",
371
+ "metadata": {
372
+ "title": f"Error - {tool_name}",
373
+ "status": "error",
374
+ "id": f"error_{tool_name}"
375
+ }
376
+ })
377
+
378
+ return result_messages
379
+
380
+ client = MCPClientWrapper()
381
+
382
+ def gradio_interface():
383
+ with gr.Blocks(title="TAAIC Tool Validation") as demo:
384
+ gr.Markdown("# TAAIC Tool Validation")
385
+ gr.Markdown("Connect your Gradio MCP Tool for validation for the TAAIC challenge.")
386
+
387
+ # API Key input section
388
+ with gr.Row(equal_height=True):
389
+ with gr.Column(scale=4):
390
+ api_key_input = gr.Textbox(
391
+ label="Anthropic API Key",
392
+ placeholder="Enter your Anthropic API key (sk-ant-...)",
393
+ type="password"
394
+ )
395
+ with gr.Column(scale=1):
396
+ api_key_btn = gr.Button("Set API Key")
397
+
398
+ api_key_status = gr.Textbox(label="API Key Status", interactive=False)
399
+
400
+ # MCP Server connection section
401
+ with gr.Row(equal_height=True):
402
+ with gr.Column(scale=4):
403
+ server_input = gr.Textbox(
404
+ label="MCP Server URL or Script Path",
405
+ placeholder="Enter URL (e.g., https://cyrilzakka-clinical-trials.hf.space/gradio_api/mcp/sse) or local script path (e.g., weather.py)",
406
+ value="https://cyrilzakka-clinical-trials.hf.space/gradio_api/mcp/sse"
407
+ )
408
+ with gr.Column(scale=1):
409
+ connect_btn = gr.Button("Connect")
410
+
411
+ status = gr.Textbox(label="Connection Status", interactive=False)
412
+
413
+ # Dataset loading section
414
+ with gr.Row(equal_height=True):
415
+ with gr.Column(scale=3):
416
+ dataset_status = gr.Textbox(
417
+ label="Dataset Status",
418
+ value="Click 'Load Dataset' to load validation cases",
419
+ interactive=False
420
+ )
421
+ with gr.Column(scale=1):
422
+ dataset_btn = gr.Button("📥 Load Dataset", interactive=True)
423
+
424
+ dataset_preview = gr.Textbox(
425
+ label="Dataset Preview",
426
+ visible=False,
427
+ interactive=False,
428
+ max_lines=10
429
+ )
430
+
431
+ # Validation results
432
+ validation_results = gr.Textbox(
433
+ label="Validation Results",
434
+ visible=False,
435
+ interactive=False,
436
+ max_lines=20
437
+ )
438
+
439
+ # Event handlers
440
+ api_key_btn.click(client.set_api_key, inputs=api_key_input, outputs=api_key_status)
441
+ connect_btn.click(client.connect, inputs=server_input, outputs=status)
442
+
443
+ dataset_btn.click(
444
+ client.load_dataset,
445
+ outputs=[dataset_status, dataset_btn, dataset_preview]
446
+ )
447
+
448
+ def run_validation():
449
+ results = client.validate_tools()
450
+ return gr.Textbox(value=results, visible=True)
451
+
452
+ dataset_btn.click(
453
+ lambda: client.validate_tools() if client.dataset else "Please load dataset first.",
454
+ outputs=validation_results,
455
+ show_progress=True
456
+ ).then(
457
+ lambda: gr.Textbox(visible=True),
458
+ outputs=validation_results
459
+ )
460
+
461
+ # msg.submit(client.process_message, [msg, chatbot], [chatbot, msg])
462
+ # clear_btn.click(lambda: [], None, chatbot)
463
+
464
+ return demo
465
+
466
+ if __name__ == "__main__":
467
+ interface = gradio_interface()
468
+ interface.launch(debug=True)