Spaces:
Sleeping
Sleeping
# EUDR ORCHESTRATOR | |
import gradio as gr | |
from fastapi import FastAPI, UploadFile, File, Form | |
from langserve import add_routes | |
from langgraph.graph import StateGraph, START, END | |
from typing import Optional, Dict, Any, List | |
from typing_extensions import TypedDict | |
from pydantic import BaseModel | |
from gradio_client import Client, file | |
import uvicorn | |
import os | |
from datetime import datetime | |
import logging | |
from contextlib import asynccontextmanager | |
import threading | |
from langchain_core.runnables import RunnableLambda | |
import tempfile | |
from utils import getconfig | |
config = getconfig("params.cfg") | |
RETRIEVER = config.get("retriever", "RETRIEVER") | |
GENERATOR = config.get("generator", "GENERATOR") | |
INGESTOR = config.get("ingestor", "INGESTOR") | |
MAX_CONTEXT_CHARS = int(config.get("general", "MAX_CONTEXT_CHARS", fallback="8000")) | |
COLLECTION_NAME = config.get("retriever", "COLLECTION_NAME") | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Models | |
class GraphState(TypedDict): | |
query: str | |
context: str | |
ingestor_context: str | |
result: str | |
country: str | |
file_content: Optional[bytes] | |
filename: Optional[str] | |
metadata: Optional[Dict[str, Any]] | |
class ChatFedInput(TypedDict, total=False): | |
query: str | |
country: Optional[str] | |
session_id: Optional[str] | |
user_id: Optional[str] | |
file_content: Optional[bytes] | |
filename: Optional[str] | |
class ChatFedOutput(TypedDict): | |
result: str | |
metadata: Dict[str, Any] | |
class ChatUIInput(BaseModel): | |
text: str | |
# Module functions | |
def ingest_node(state: GraphState) -> GraphState: | |
"""Process file through ingestor if file is provided""" | |
start_time = datetime.now() | |
# If no file provided, skip this step | |
if not state.get("file_content") or not state.get("filename"): | |
logger.info("No file provided, skipping ingestion") | |
return {"ingestor_context": "", "metadata": state.get("metadata", {})} | |
logger.info(f"Ingesting file: {state['filename']}") | |
try: | |
client = Client(INGESTOR) | |
# Create a temporary file to upload | |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file: | |
tmp_file.write(state["file_content"]) | |
tmp_file_path = tmp_file.name | |
try: | |
# Call the ingestor's ingest endpoint | |
ingestor_context = client.predict( | |
file(tmp_file_path), | |
api_name="/ingest" | |
) | |
logger.info(f"Ingest result length: {len(ingestor_context) if ingestor_context else 0}") | |
# Handle error cases | |
if isinstance(ingestor_context, str) and ingestor_context.startswith("Error:"): | |
raise Exception(ingestor_context) | |
finally: | |
# Clean up temporary file | |
os.unlink(tmp_file_path) | |
duration = (datetime.now() - start_time).total_seconds() | |
metadata = state.get("metadata", {}) | |
metadata.update({ | |
"ingestion_duration": duration, | |
"ingestor_context_length": len(ingestor_context) if ingestor_context else 0, | |
"ingestion_success": True, | |
"analysis_type": "whisp_geojson" | |
}) | |
return { | |
"ingestor_context": ingestor_context, | |
"metadata": metadata | |
} | |
except Exception as e: | |
duration = (datetime.now() - start_time).total_seconds() | |
logger.error(f"Ingestion failed: {str(e)}") | |
metadata = state.get("metadata", {}) | |
metadata.update({ | |
"ingestion_duration": duration, | |
"ingestion_success": False, | |
"ingestion_error": str(e) | |
}) | |
return {"ingestor_context": "", "metadata": metadata} | |
def retrieve_node(state: GraphState) -> GraphState: | |
start_time = datetime.now() | |
logger.info(f"Retrieval: {state['query'][:50]}... Country: {state.get('country', 'All')}") | |
try: | |
client = Client(RETRIEVER) | |
# Create metadata filter for country if specified | |
country = state.get("country", "").strip() | |
filter_metadata = {'country': country} if country else None | |
context = client.predict( | |
query=state["query"], | |
collection_name=COLLECTION_NAME, # Use hardcoded value instead of COLLECTION_NAME variable | |
filter_metadata=filter_metadata, | |
api_name="/retrieve" | |
) | |
duration = (datetime.now() - start_time).total_seconds() | |
metadata = state.get("metadata", {}) | |
metadata.update({ | |
"retrieval_duration": duration, | |
"context_length": len(context) if context else 0, | |
"retrieval_success": True, | |
"country_filter": state.get("country", "All") | |
}) | |
return {"context": context, "metadata": metadata} | |
except Exception as e: | |
duration = (datetime.now() - start_time).total_seconds() | |
logger.error(f"Retrieval failed: {str(e)}") | |
metadata = state.get("metadata", {}) | |
metadata.update({ | |
"retrieval_duration": duration, | |
"retrieval_success": False, | |
"retrieval_error": str(e) | |
}) | |
return {"context": "", "metadata": metadata} | |
def generate_node(state: GraphState) -> GraphState: | |
start_time = datetime.now() | |
logger.info(f"Generation: {state['query'][:50]}...") | |
try: | |
# Combine retriever context with ingestor context | |
retrieved_context = state.get("context", "") | |
ingestor_context = state.get("ingestor_context", "") | |
# Limit context size to prevent token overflow | |
combined_context = "" | |
if ingestor_context and retrieved_context: | |
# Prioritize ingestor context, truncate if needed | |
ingestor_truncated = ingestor_context[:MAX_CONTEXT_CHARS//2] if len(ingestor_context) > MAX_CONTEXT_CHARS//2 else ingestor_context | |
retrieved_truncated = retrieved_context[:MAX_CONTEXT_CHARS//2] if len(retrieved_context) > MAX_CONTEXT_CHARS//2 else retrieved_context | |
combined_context = f"=== UPLOADED DOCUMENT CONTEXT ===\n{ingestor_truncated}\n\n=== RETRIEVED CONTEXT ===\n{retrieved_truncated}" | |
elif ingestor_context: | |
ingestor_truncated = ingestor_context[:MAX_CONTEXT_CHARS] if len(ingestor_context) > MAX_CONTEXT_CHARS else ingestor_context | |
combined_context = f"=== UPLOADED DOCUMENT CONTEXT ===\n{ingestor_truncated}" | |
elif retrieved_context: | |
combined_context = retrieved_context[:MAX_CONTEXT_CHARS] if len(retrieved_context) > MAX_CONTEXT_CHARS else retrieved_context | |
client = Client(GENERATOR) | |
result = client.predict( | |
query=state["query"], | |
context=combined_context, | |
api_name="/generate" | |
) | |
duration = (datetime.now() - start_time).total_seconds() | |
metadata = state.get("metadata", {}) | |
metadata.update({ | |
"generation_duration": duration, | |
"result_length": len(result) if result else 0, | |
"combined_context_length": len(combined_context), | |
"generation_success": True | |
}) | |
return {"result": result, "metadata": metadata} | |
except Exception as e: | |
duration = (datetime.now() - start_time).total_seconds() | |
logger.error(f"Generation failed: {str(e)}") | |
metadata = state.get("metadata", {}) | |
metadata.update({ | |
"generation_duration": duration, | |
"generation_success": False, | |
"generation_error": str(e) | |
}) | |
return {"result": f"Error: {str(e)}", "metadata": metadata} | |
def file_only_node(state: GraphState) -> GraphState: | |
"""Return ingestor result directly without calling generator""" | |
logger.info("File-only processing: returning ingestor result directly") | |
ingestor_context = state.get("ingestor_context", "") | |
metadata = state.get("metadata", {}) | |
metadata.update({ | |
"processing_type": "file_only", | |
"result_source": "ingestor" | |
}) | |
return { | |
"result": ingestor_context, | |
"metadata": metadata | |
} | |
# Create separate workflows for different processing types | |
def create_file_workflow(): | |
"""Workflow for file uploads: ingest -> file_only (skip retrieve and generate)""" | |
workflow = StateGraph(GraphState) | |
workflow.add_node("ingest", ingest_node) | |
workflow.add_node("file_only", file_only_node) | |
workflow.add_edge(START, "ingest") | |
workflow.add_edge("ingest", "file_only") | |
workflow.add_edge("file_only", END) | |
return workflow.compile() | |
def create_query_workflow(): | |
"""Workflow for queries: retrieve -> generate (skip ingest)""" | |
workflow = StateGraph(GraphState) | |
workflow.add_node("retrieve", retrieve_node) | |
workflow.add_node("generate", generate_node) | |
workflow.add_edge(START, "retrieve") | |
workflow.add_edge("retrieve", "generate") | |
workflow.add_edge("generate", END) | |
return workflow.compile() | |
# Compile workflows | |
file_workflow = create_file_workflow() | |
query_workflow = create_query_workflow() | |
def process_query_core( | |
query: str, | |
country: str = "", | |
session_id: Optional[str] = None, | |
user_id: Optional[str] = None, | |
file_content: Optional[bytes] = None, | |
filename: Optional[str] = None, | |
return_metadata: bool = False | |
): | |
start_time = datetime.now() | |
if not session_id: | |
session_id = f"session_{start_time.strftime('%Y%m%d_%H%M%S')}" | |
try: | |
initial_state = { | |
"query": query, | |
"context": "", | |
"ingestor_context": "", | |
"result": "", | |
"country": country or "", | |
"file_content": file_content, | |
"filename": filename, | |
"metadata": { | |
"session_id": session_id, | |
"user_id": user_id, | |
"start_time": start_time.isoformat(), | |
"has_geojson_attachment": file_content is not None, | |
"country_filter": country or "All" | |
} | |
} | |
# Choose workflow based on whether file is provided | |
if file_content and filename: | |
logger.info("File provided - using file workflow (ingest -> file_only)") | |
final_state = file_workflow.invoke(initial_state) | |
else: | |
logger.info("No file provided - using query workflow (retrieve -> generate)") | |
final_state = query_workflow.invoke(initial_state) | |
total_duration = (datetime.now() - start_time).total_seconds() | |
final_metadata = final_state.get("metadata", {}) | |
final_metadata.update({ | |
"total_duration": total_duration, | |
"end_time": datetime.now().isoformat(), | |
"pipeline_success": True | |
}) | |
if return_metadata: | |
return {"result": final_state["result"], "metadata": final_metadata} | |
else: | |
return final_state["result"] | |
except Exception as e: | |
total_duration = (datetime.now() - start_time).total_seconds() | |
logger.error(f"Pipeline failed: {str(e)}") | |
if return_metadata: | |
error_metadata = { | |
"session_id": session_id, | |
"total_duration": total_duration, | |
"pipeline_success": False, | |
"error": str(e) | |
} | |
return {"result": f"Error: {str(e)}", "metadata": error_metadata} | |
else: | |
return f"Error: {str(e)}" | |
def process_query_gradio(query: str, file_upload, country: str = "") -> str: | |
"""Gradio interface function with GeoJSON file upload support""" | |
file_content = None | |
filename = None | |
if file_upload is not None: | |
try: | |
with open(file_upload.name, 'rb') as f: | |
file_content = f.read() | |
filename = os.path.basename(file_upload.name) | |
logger.info(f"File uploaded: {filename}, size: {len(file_content)} bytes") | |
except Exception as e: | |
logger.error(f"Error reading uploaded file: {str(e)}") | |
return f"Error reading file: {str(e)}" | |
return process_query_core( | |
query=query, | |
country=country, | |
file_content=file_content, | |
filename=filename, | |
session_id=f"gradio_{datetime.now().strftime('%Y%m%d_%H%M%S')}", | |
return_metadata=False | |
) | |
def chatui_adapter(data) -> str: | |
try: | |
if hasattr(data, 'text'): | |
text = data.text | |
elif isinstance(data, dict) and 'text' in data: | |
text = data['text'] | |
else: | |
logger.error(f"Unexpected input structure: {data}") | |
return "Error: Invalid input format. Expected 'text' field." | |
result = process_query_core( | |
query=text, | |
session_id=f"chatui_{datetime.now().strftime('%Y%m%d_%H%M%S')}", | |
return_metadata=False | |
) | |
return result | |
except Exception as e: | |
logger.error(f"ChatUI error: {str(e)}") | |
return f"Error: {str(e)}" | |
def process_query_langserve(input_data: ChatFedInput) -> ChatFedOutput: | |
result = process_query_core( | |
query=input_data["query"], | |
country=input_data.get("country", ""), | |
session_id=input_data.get("session_id"), | |
user_id=input_data.get("user_id"), | |
file_content=input_data.get("file_content"), | |
filename=input_data.get("filename"), | |
return_metadata=True | |
) | |
return ChatFedOutput(result=result["result"], metadata=result["metadata"]) | |
def create_gradio_interface(): | |
with gr.Blocks(title="EUDR Orchestrator") as demo: | |
gr.Markdown("# EUDR Orchestrator") | |
gr.Markdown("Upload GeoJSON files for WHISP API analysis alongside EUDR compliance queries. MCP endpoints available at `/gradio_api/mcp/sse`") | |
with gr.Row(): | |
with gr.Column(): | |
query_input = gr.Textbox( | |
label="Query", | |
lines=2, | |
placeholder="Ask about EUDR compliance or upload GeoJSON for deforestation analysis...", | |
info="Enter your EUDR-related question" | |
) | |
file_input = gr.File( | |
label="Upload GeoJSON", | |
file_types=[".geojson", ".json"], | |
info="Upload GeoJSON file for geographic deforestation analysis" | |
) | |
country_input = gr.Dropdown( | |
choices=["", "Ecuador", "Guatemala"], | |
label="Country Filter (Optional)", | |
value="", | |
info="Filter EUDR document retrieval by country" | |
) | |
submit_btn = gr.Button("Submit", variant="primary") | |
with gr.Column(): | |
output = gr.Textbox(label="Response", lines=15, show_copy_button=True) | |
submit_btn.click( | |
fn=process_query_gradio, | |
inputs=[query_input, file_input, country_input], | |
outputs=output | |
) | |
return demo | |
async def lifespan(app: FastAPI): | |
logger.info("ChatFed Orchestrator starting up...") | |
yield | |
logger.info("Orchestrator shutting down...") | |
app = FastAPI( | |
title="ChatFed Orchestrator", | |
version="1.0.0", | |
lifespan=lifespan, | |
docs_url=None, | |
redoc_url=None | |
) | |
async def health_check(): | |
return {"status": "healthy"} | |
async def root(): | |
return { | |
"message": "ChatFed Orchestrator API", | |
"endpoints": { | |
"health": "/health", | |
"chatfed": "/chatfed", | |
"chatfed-ui-stream": "/chatfed-ui-stream", | |
"chatfed-with-file": "/chatfed-with-file" | |
} | |
} | |
async def chatfed_with_file( | |
query: str = Form(...), | |
file: Optional[UploadFile] = File(None), | |
country: Optional[str] = Form(""), | |
session_id: Optional[str] = Form(None), | |
user_id: Optional[str] = Form(None) | |
): | |
"""Endpoint for queries with optional file attachments""" | |
file_content = None | |
filename = None | |
if file: | |
file_content = await file.read() | |
filename = file.filename | |
result = process_query_core( | |
query=query, | |
country=country, | |
file_content=file_content, | |
filename=filename, | |
session_id=session_id, | |
user_id=user_id, | |
return_metadata=True | |
) | |
return ChatFedOutput(result=result["result"], metadata=result["metadata"]) | |
# LangServe routes | |
add_routes( | |
app, | |
RunnableLambda(process_query_langserve), | |
path="/chatfed", | |
input_type=ChatFedInput, | |
output_type=ChatFedOutput | |
) | |
add_routes( | |
app, | |
RunnableLambda(chatui_adapter), | |
path="/chatfed-ui-stream", | |
input_type=ChatUIInput, | |
output_type=str, | |
enable_feedback_endpoint=True, | |
enable_public_trace_link_endpoint=True, | |
) | |
def run_gradio_server(): | |
demo = create_gradio_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7861, | |
mcp_server=True, | |
show_error=True, | |
share=False, | |
quiet=True | |
) | |
if __name__ == "__main__": | |
gradio_thread = threading.Thread(target=run_gradio_server, daemon=True) | |
gradio_thread.start() | |
logger.info("Gradio MCP server started on port 7861") | |
host = os.getenv("HOST", "0.0.0.0") | |
port = int(os.getenv("PORT", "7860")) | |
logger.info(f"Starting FastAPI server on {host}:{port}") | |
uvicorn.run(app, host=host, port=port, log_level="info", access_log=True) |