Spaces:
Running
Running
import asyncio | |
import os | |
import json | |
from typing import List, Dict, Any, Union | |
from contextlib import AsyncExitStack | |
import gradio as gr | |
from gradio.components.chatbot import ChatMessage | |
from mcp import ClientSession, StdioServerParameters | |
from mcp.client.stdio import stdio_client | |
from mcp.client.sse import sse_client | |
from anthropic import Anthropic | |
from datasets import load_dataset | |
import pandas as pd | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
class MCPClientWrapper: | |
def __init__(self): | |
self.session = None | |
self.exit_stack = None | |
self.anthropic = None | |
self.tools = [] | |
self.dataset = None | |
self.validation_results = [] | |
def set_api_key(self, api_key: str) -> str: | |
"""Set the Anthropic API key and initialize the client""" | |
if not api_key or not api_key.strip(): | |
return "Please enter a valid Anthropic API key" | |
try: | |
self.anthropic = Anthropic(api_key=api_key.strip()) | |
return "API key set successfully β " | |
except Exception as e: | |
return f"Failed to set API key: {str(e)}" | |
def connect(self, server_input: str) -> str: | |
if not self.anthropic: | |
return "Please set your Anthropic API key first" | |
return loop.run_until_complete(self._connect(server_input)) | |
async def _connect(self, server_input: str) -> str: | |
if self.exit_stack: | |
await self.exit_stack.aclose() | |
self.exit_stack = AsyncExitStack() | |
try: | |
# Check if input is a URL (starts with http:// or https://) | |
if server_input.startswith(('http://', 'https://')): | |
# Connect via SSE | |
read, write = await self.exit_stack.enter_async_context( | |
sse_client(server_input) | |
) | |
connection_type = "SSE URL" | |
else: | |
# Connect via stdio (local file) | |
is_python = server_input.endswith('.py') | |
command = "python" if is_python else "node" | |
server_params = StdioServerParameters( | |
command=command, | |
args=[server_input], | |
env={"PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"} | |
) | |
read, write = await self.exit_stack.enter_async_context( | |
stdio_client(server_params) | |
) | |
connection_type = "Local script" | |
self.session = await self.exit_stack.enter_async_context( | |
ClientSession(read, write) | |
) | |
await self.session.initialize() | |
response = await self.session.list_tools() | |
self.tools = [{ | |
"name": tool.name, | |
"description": tool.description, | |
"input_schema": tool.inputSchema | |
} for tool in response.tools] | |
tool_names = [tool["name"] for tool in self.tools] | |
return f"Connected to MCP server via {connection_type}. Available tools: {', '.join(tool_names)}" | |
except Exception as e: | |
return f"Connection failed: {str(e)}" | |
def load_dataset(self) -> tuple: | |
"""Load the TAAIC Phase1 validation dataset""" | |
try: | |
self.dataset = load_dataset("aitxchallenge/Phase1_Model_Validator", split="train") | |
dataset_info = f"Dataset loaded successfully! {len(self.dataset)} validation cases available." | |
# Create a preview of the dataset | |
df = pd.DataFrame(self.dataset) | |
preview = df.head().to_string() | |
return ( | |
dataset_info, | |
gr.Button("π Validate", interactive=True), | |
gr.Textbox(value=f"Dataset Preview:\n{preview}", visible=True) | |
) | |
except Exception as e: | |
return ( | |
f"Failed to load dataset: {str(e)}", | |
gr.Button("π₯ Load Dataset", interactive=True), | |
gr.Textbox(visible=False) | |
) | |
def validate_tools(self) -> str: | |
"""Run validation on all dataset cases""" | |
if not self.anthropic: | |
return "Please set your Anthropic API key first." | |
if not self.dataset: | |
return "Please load the dataset first." | |
if not self.session: | |
return "Please connect to an MCP server first." | |
return loop.run_until_complete(self._run_validation()) | |
async def _run_validation(self) -> str: | |
"""Async validation runner""" | |
self.validation_results = [] | |
total_cases = len(self.dataset) | |
passed = 0 | |
failed = 0 | |
for i, case in enumerate(self.dataset): | |
try: | |
# Extract test case information | |
query = case.get('query', case.get('question', '')) | |
expected_output = case.get('expected_output', case.get('expected', '')) | |
test_id = case.get('id', f'test_{i}') | |
# Run the query through the MCP tools | |
result = await self._validate_single_case(query, expected_output, test_id) | |
self.validation_results.append(result) | |
if result['passed']: | |
passed += 1 | |
else: | |
failed += 1 | |
except Exception as e: | |
failed += 1 | |
self.validation_results.append({ | |
'test_id': test_id, | |
'query': query, | |
'error': str(e), | |
'passed': False | |
}) | |
# Generate validation report | |
report = f""" | |
VALIDATION COMPLETE | |
================== | |
Total Cases: {total_cases} | |
Passed: {passed} | |
Failed: {failed} | |
Success Rate: {(passed/total_cases)*100:.1f}% | |
DETAILED RESULTS: | |
""" | |
for result in self.validation_results: | |
status = "β PASS" if result['passed'] else "β FAIL" | |
report += f"\n{status} [{result['test_id']}] {result['query'][:50]}..." | |
if not result['passed'] and 'error' in result: | |
report += f"\n Error: {result['error']}" | |
return report | |
async def _validate_single_case(self, query: str, expected_output: str, test_id: str) -> Dict[str, Any]: | |
"""Validate a single test case""" | |
try: | |
# Send query to Claude with MCP tools | |
claude_messages = [{"role": "user", "content": query}] | |
response = self.anthropic.messages.create( | |
model="claude-3-5-sonnet-20241022", | |
max_tokens=1000, | |
messages=claude_messages, | |
tools=self.tools | |
) | |
# Process tool calls if any | |
actual_output = "" | |
for content in response.content: | |
if content.type == 'text': | |
actual_output += content.text | |
elif content.type == 'tool_use': | |
tool_result = await self.session.call_tool(content.name, content.input) | |
actual_output += str(tool_result.content) | |
# Simple validation logic - you may want to customize this | |
passed = self._validate_output(actual_output, expected_output) | |
return { | |
'test_id': test_id, | |
'query': query, | |
'expected': expected_output, | |
'actual': actual_output, | |
'passed': passed | |
} | |
except Exception as e: | |
return { | |
'test_id': test_id, | |
'query': query, | |
'error': str(e), | |
'passed': False | |
} | |
def _validate_output(self, actual: str, expected: str) -> bool: | |
"""Basic output validation - customize based on your needs""" | |
# This is a simple implementation - you may want more sophisticated validation | |
if not expected: | |
return True # If no expected output specified, consider it passed | |
# You can implement more sophisticated matching here | |
# For now, using simple substring matching | |
return expected.lower() in actual.lower() | |
def process_message(self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]) -> tuple: | |
if not self.anthropic: | |
return history + [ | |
{"role": "user", "content": message}, | |
{"role": "assistant", "content": "Please set your Anthropic API key first."} | |
], gr.Textbox(value="") | |
if not self.session: | |
return history + [ | |
{"role": "user", "content": message}, | |
{"role": "assistant", "content": "Please connect to an MCP server first."} | |
], gr.Textbox(value="") | |
new_messages = loop.run_until_complete(self._process_query(message, history)) | |
return history + [{"role": "user", "content": message}] + new_messages, gr.Textbox(value="") | |
async def _process_query(self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]): | |
claude_messages = [] | |
for msg in history: | |
if isinstance(msg, ChatMessage): | |
role, content = msg.role, msg.content | |
else: | |
role, content = msg.get("role"), msg.get("content") | |
if role in ["user", "assistant", "system"]: | |
claude_messages.append({"role": role, "content": content}) | |
claude_messages.append({"role": "user", "content": message}) | |
response = self.anthropic.messages.create( | |
model="claude-3-5-sonnet-20241022", | |
max_tokens=1000, | |
messages=claude_messages, | |
tools=self.tools | |
) | |
result_messages = [] | |
for content in response.content: | |
if content.type == 'text': | |
result_messages.append({ | |
"role": "assistant", | |
"content": content.text | |
}) | |
elif content.type == 'tool_use': | |
tool_name = content.name | |
tool_args = content.input | |
result_messages.append({ | |
"role": "assistant", | |
"content": f"I'll only use the {tool_name} tool to help answer your question.", | |
"metadata": { | |
"title": f"Using tool: {tool_name}", | |
"log": f"Parameters: {json.dumps(tool_args, ensure_ascii=True)}", | |
"status": "pending", | |
"id": f"tool_call_{tool_name}" | |
} | |
}) | |
result_messages.append({ | |
"role": "assistant", | |
"content": "```json\n" + json.dumps(tool_args, indent=2, ensure_ascii=True) + "\n```", | |
"metadata": { | |
"parent_id": f"tool_call_{tool_name}", | |
"id": f"params_{tool_name}", | |
"title": "Tool Parameters" | |
} | |
}) | |
try: | |
result = await self.session.call_tool(tool_name, tool_args) | |
if result_messages and "metadata" in result_messages[-2]: | |
result_messages[-2]["metadata"]["status"] = "done" | |
result_messages.append({ | |
"role": "assistant", | |
"content": "Here are the results from the tool:", | |
"metadata": { | |
"title": f"Tool Result for {tool_name}", | |
"status": "done", | |
"id": f"result_{tool_name}" | |
} | |
}) | |
result_content = result.content | |
if isinstance(result_content, list): | |
result_content = "\n".join(str(item) for item in result_content) | |
try: | |
result_json = json.loads(result_content) | |
if isinstance(result_json, dict) and "type" in result_json: | |
if result_json["type"] == "image" and "url" in result_json: | |
result_messages.append({ | |
"role": "assistant", | |
"content": {"path": result_json["url"], "alt_text": result_json.get("message", "Generated image")}, | |
"metadata": { | |
"parent_id": f"result_{tool_name}", | |
"id": f"image_{tool_name}", | |
"title": "Generated Image" | |
} | |
}) | |
else: | |
result_messages.append({ | |
"role": "assistant", | |
"content": "```\n" + result_content + "\n```", | |
"metadata": { | |
"parent_id": f"result_{tool_name}", | |
"id": f"raw_result_{tool_name}", | |
"title": "Raw Output" | |
} | |
}) | |
except: | |
result_messages.append({ | |
"role": "assistant", | |
"content": "```\n" + result_content + "\n```", | |
"metadata": { | |
"parent_id": f"result_{tool_name}", | |
"id": f"raw_result_{tool_name}", | |
"title": "Raw Output" | |
} | |
}) | |
claude_messages.append({"role": "user", "content": f"Tool result for {tool_name}: {result_content}"}) | |
next_response = self.anthropic.messages.create( | |
model="claude-3-5-sonnet-20241022", | |
max_tokens=1000, | |
messages=claude_messages, | |
) | |
if next_response.content and next_response.content[0].type == 'text': | |
result_messages.append({ | |
"role": "assistant", | |
"content": next_response.content[0].text | |
}) | |
except Exception as e: | |
result_messages.append({ | |
"role": "assistant", | |
"content": f"Error calling tool {tool_name}: {str(e)}", | |
"metadata": { | |
"title": f"Error - {tool_name}", | |
"status": "error", | |
"id": f"error_{tool_name}" | |
} | |
}) | |
return result_messages | |
client = MCPClientWrapper() | |
def gradio_interface(): | |
with gr.Blocks(title="TAAIC Tool Validation") as demo: | |
gr.Markdown("# TAAIC Tool Validation") | |
gr.Markdown("Connect your Gradio MCP Tool for validation for the TAAIC challenge.") | |
# API Key input section | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=4): | |
api_key_input = gr.Textbox( | |
label="Anthropic API Key", | |
placeholder="Enter your Anthropic API key (sk-ant-...)", | |
type="password" | |
) | |
with gr.Column(scale=1): | |
api_key_btn = gr.Button("Set API Key") | |
api_key_status = gr.Textbox(label="API Key Status", interactive=False) | |
# MCP Server connection section | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=4): | |
server_input = gr.Textbox( | |
label="MCP Server URL or Script Path", | |
placeholder="Enter URL (e.g., https://cyrilzakka-clinical-trials.hf.space/gradio_api/mcp/sse) or local script path (e.g., weather.py)", | |
value="https://cyrilzakka-clinical-trials.hf.space/gradio_api/mcp/sse" | |
) | |
with gr.Column(scale=1): | |
connect_btn = gr.Button("Connect") | |
status = gr.Textbox(label="Connection Status", interactive=False) | |
# Dataset loading section | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=3): | |
dataset_status = gr.Textbox( | |
label="Dataset Status", | |
value="Click 'Load Dataset' to load validation cases", | |
interactive=False | |
) | |
with gr.Column(scale=1): | |
dataset_btn = gr.Button("π₯ Load Dataset", interactive=True) | |
dataset_preview = gr.Textbox( | |
label="Dataset Preview", | |
visible=False, | |
interactive=False, | |
max_lines=10 | |
) | |
# Validation results | |
validation_results = gr.Textbox( | |
label="Validation Results", | |
visible=False, | |
interactive=False, | |
max_lines=20 | |
) | |
# Event handlers | |
api_key_btn.click(client.set_api_key, inputs=api_key_input, outputs=api_key_status) | |
connect_btn.click(client.connect, inputs=server_input, outputs=status) | |
dataset_btn.click( | |
client.load_dataset, | |
outputs=[dataset_status, dataset_btn, dataset_preview] | |
) | |
def run_validation(): | |
results = client.validate_tools() | |
return gr.Textbox(value=results, visible=True) | |
dataset_btn.click( | |
lambda: client.validate_tools() if client.dataset else "Please load dataset first.", | |
outputs=validation_results, | |
show_progress=True | |
).then( | |
lambda: gr.Textbox(visible=True), | |
outputs=validation_results | |
) | |
# msg.submit(client.process_message, [msg, chatbot], [chatbot, msg]) | |
# clear_btn.click(lambda: [], None, chatbot) | |
return demo | |
if __name__ == "__main__": | |
interface = gradio_interface() | |
interface.launch(debug=True) |