Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from typing import Dict, List, Optional, Union, Any | |
| from pydantic import BaseModel, Field | |
| from datetime import datetime | |
| import logging | |
| import json | |
| import os | |
| from dotenv import load_dotenv | |
| from dify_client_python.dify_client import models | |
| from sse_starlette.sse import EventSourceResponse | |
| import httpx | |
| from json_parser import SSEParser, MessageState | |
| from logger_config import setup_logger | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.responses import JSONResponse | |
| from response_formatter import ResponseFormatter | |
| import traceback | |
| # Load environment variables | |
| load_dotenv() | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class AgentOutput(BaseModel): | |
| """Structured output from agent processing""" | |
| thought_content: str | |
| observation: Optional[str] | |
| tool_outputs: List[Dict] | |
| citations: List[Dict] | |
| metadata: Dict | |
| raw_response: str | |
| class AgentRequest(BaseModel): | |
| """Enhanced request model with additional parameters""" | |
| query: str | |
| conversation_id: Optional[str] = None | |
| stream: bool = True | |
| inputs: Dict = {} | |
| files: List = [] | |
| user: str = "default_user" | |
| response_mode: str = "streaming" | |
| class AgentProcessor: | |
| def __init__(self, api_key: str): | |
| self.api_key = api_key | |
| # Update API base to use environment variable with fallback | |
| self.api_base = os.getenv( | |
| "API_BASE_URL", | |
| "https://severian.a.pinggy.link/v1" | |
| ) | |
| self.formatter = ResponseFormatter() | |
| self.client = httpx.AsyncClient(timeout=60.0) | |
| self.logger = setup_logger("agent_processor") | |
| # Initialize headers | |
| self.headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json", | |
| "Accept": "text/event-stream" | |
| } | |
| def prepare_request(self, request: AgentRequest) -> Dict: | |
| """Prepare request payload for API""" | |
| return { | |
| "query": request.query, | |
| "inputs": request.inputs, | |
| "response_mode": "streaming" if request.stream else "blocking", | |
| "user": request.user, | |
| "conversation_id": request.conversation_id, | |
| "files": request.files | |
| } | |
| async def log_request_details( | |
| self, | |
| request: AgentRequest, | |
| start_time: datetime | |
| ) -> None: | |
| """Log detailed request information""" | |
| self.logger.debug( | |
| "Request details: \n" | |
| f"Query: {request.query}\n" | |
| f"User: {request.user}\n" | |
| f"Conversation ID: {request.conversation_id}\n" | |
| f"Stream mode: {request.stream}\n" | |
| f"Start time: {start_time}\n" | |
| f"Inputs: {request.inputs}\n" | |
| f"Files: {len(request.files)} files attached" | |
| ) | |
| async def log_error( | |
| self, | |
| error: Exception, | |
| context: Optional[Dict] = None | |
| ) -> None: | |
| """Log detailed error information""" | |
| error_msg = ( | |
| f"Error type: {type(error).__name__}\n" | |
| f"Error message: {str(error)}\n" | |
| f"Stack trace:\n{traceback.format_exc()}\n" | |
| ) | |
| if context: | |
| error_msg += f"Context:\n{json.dumps(context, indent=2)}" | |
| self.logger.error(error_msg) | |
| async def cleanup(self): | |
| """Cleanup method to properly close client""" | |
| await self.client.aclose() | |
| async def process_stream(self, request: AgentRequest): | |
| """Process streaming request and format for frontend""" | |
| parser = SSEParser() | |
| formatter = ResponseFormatter() | |
| async def event_generator(): | |
| try: | |
| async with self.client.stream( | |
| "POST", | |
| f"{self.api_base}/chat-messages", | |
| headers=self.headers, | |
| json=self.prepare_request(request) | |
| ) as response: | |
| async for line in response.aiter_lines(): | |
| if not line.strip(): | |
| continue | |
| # Parse the event | |
| parsed = parser.parse_sse_event(line) | |
| if not parsed: | |
| continue | |
| event_type = parsed.get("type") | |
| # Format based on type | |
| if event_type == "message": | |
| formatted = formatter.format_message( | |
| message=parsed["content"], | |
| message_id=parsed["message_id"] | |
| ) | |
| elif event_type == "thought": | |
| formatted = formatter.format_thought( | |
| thought=parsed["content"]["thought"], | |
| observation=parsed["content"]["observation"], | |
| message_id=parsed["message_id"] | |
| ) | |
| elif event_type == "tool_output": | |
| # Special handling for tool outputs | |
| formatted = formatter.format_thought( | |
| thought="", | |
| observation="", | |
| tool_outputs=[{ | |
| "type": parsed["tool"], | |
| "content": parsed["content"] | |
| }], | |
| message_id=parsed["message_id"] | |
| ) | |
| else: | |
| continue | |
| if formatted: | |
| _, xml_output = formatted | |
| yield f"data: {xml_output}\n\n" | |
| except Exception as e: | |
| self.logger.error(f"Stream error: {str(e)}") | |
| error_msg = formatter.format_error(str(e)) | |
| if error_msg: | |
| _, xml_output = error_msg | |
| yield f"data: {xml_output}\n\n" | |
| return EventSourceResponse(event_generator()) | |
| def format_terminal_output( | |
| self, | |
| response: Dict, | |
| citations: List[Dict] = None, | |
| metadata: Dict = None, | |
| tool_outputs: List[Dict] = None | |
| ) -> Optional[str]: | |
| """Format response for terminal output""" | |
| event_type = response.get("event") | |
| if event_type == "agent_thought": | |
| thought = response.get("thought", "") | |
| observation = response.get("observation", "") | |
| terminal_output, _ = self.formatter.format_thought( | |
| thought, | |
| observation, | |
| citations=citations, | |
| metadata=metadata, | |
| tool_outputs=tool_outputs | |
| ) | |
| return terminal_output | |
| elif event_type == "agent_message": | |
| message = response.get("answer", "") | |
| terminal_output, _ = self.formatter.format_message(message) | |
| return terminal_output | |
| elif event_type == "error": | |
| error = response.get("error", "Unknown error") | |
| terminal_output, _ = self.formatter.format_error(error) | |
| return terminal_output | |
| return None | |
| def clean_response(self, response: Dict) -> Optional[Dict]: | |
| """Clean and transform the response for frontend consumption""" | |
| try: | |
| event_type = response.get("event") | |
| if not event_type: | |
| return None | |
| # Handle different event types | |
| if event_type == "agent_thought": | |
| thought = response.get("thought", "") | |
| observation = response.get("observation", "") | |
| tool = response.get("tool", "") | |
| # Handle mermaid diagram observations | |
| if tool == "mermaid_diagram" and observation: | |
| try: | |
| # First check if observation is error message | |
| if isinstance(observation, str): | |
| obs_data = json.loads(observation) | |
| if "mermaid_diagram" in obs_data: | |
| if obs_data["mermaid_diagram"].startswith("tool invoke error"): | |
| self.logger.warning( | |
| f"Mermaid diagram tool error: {obs_data['mermaid_diagram']}" | |
| ) | |
| return None | |
| # Handle successful mermaid diagram | |
| if isinstance(observation, dict): | |
| mermaid_data = observation.get("mermaid_diagram", "") | |
| else: | |
| obs_data = json.loads(observation) | |
| mermaid_data = obs_data.get("mermaid_diagram", "") | |
| if mermaid_data: | |
| # Handle nested JSON structure | |
| if isinstance(mermaid_data, str): | |
| mermaid_data = json.loads(mermaid_data) | |
| # Extract diagram from either format | |
| if isinstance(mermaid_data, dict): | |
| diagram = mermaid_data.get("mermaid_diagram", "") | |
| else: | |
| diagram = mermaid_data | |
| # Clean up the diagram code | |
| if isinstance(diagram, str): | |
| if "tool response:" in diagram: | |
| diagram = diagram.split("tool response:")[0] | |
| if diagram.startswith('{"mermaid_diagram": "'): | |
| diagram = json.loads(diagram)["mermaid_diagram"] | |
| if diagram.startswith("```mermaid\n"): | |
| diagram = diagram[10:] | |
| if diagram.endswith("\n```"): | |
| diagram = diagram[:-4] | |
| return { | |
| "type": "mermaid_diagram", | |
| "content": diagram.strip() | |
| } | |
| except (json.JSONDecodeError, KeyError) as e: | |
| self.logger.error(f"Failed to parse mermaid diagram data: {str(e)}") | |
| self.logger.debug(f"Raw observation: {observation}") | |
| return None | |
| # Handle regular thought | |
| _, xml_output = self.formatter.format_thought(thought, observation) | |
| return { | |
| "type": "thought", | |
| "content": xml_output | |
| } | |
| elif event_type == "agent_message": | |
| message = response.get("answer", "") | |
| _, xml_output = self.formatter.format_message(message) | |
| return { | |
| "type": "message", | |
| "content": xml_output | |
| } | |
| elif event_type == "error": | |
| error = response.get("error", "Unknown error") | |
| _, xml_output = self.formatter.format_error(error) | |
| return { | |
| "type": "error", | |
| "content": xml_output | |
| } | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error cleaning response: {str(e)}") | |
| return None | |
| # Initialize FastAPI app | |
| app = FastAPI() | |
| agent_processor = None | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def startup_event(): | |
| global agent_processor | |
| api_key = os.getenv("DIFY_API_KEY") | |
| agent_processor = AgentProcessor(api_key=api_key) | |
| async def shutdown_event(): | |
| global agent_processor | |
| if agent_processor: | |
| await agent_processor.cleanup() | |
| async def process_agent_request(request: AgentRequest): | |
| try: | |
| logger.info(f"Processing agent request: {request.query}") | |
| return await agent_processor.process_stream(request) | |
| except Exception as e: | |
| logger.error(f"Error in agent request processing: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def error_handling_middleware(request: Request, call_next): | |
| try: | |
| response = await call_next(request) | |
| return response | |
| except Exception as e: | |
| logger.error(f"Unhandled error: {str(e)}", exc_info=True) | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": "Internal server error occurred"} | |
| ) | |
| # Add host and port parameters to the launch | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", 7860)) | |
| uvicorn.run( | |
| "api:app", | |
| host="0.0.0.0", | |
| port=port, | |
| reload=True | |
| ) |