mtyrrell's picture
deterministic query/file handling
1a69cf5
# EUDR ORCHESTRATOR
import gradio as gr
from fastapi import FastAPI, UploadFile, File, Form
from langserve import add_routes
from langgraph.graph import StateGraph, START, END
from typing import Optional, Dict, Any, List
from typing_extensions import TypedDict
from pydantic import BaseModel
from gradio_client import Client, file
import uvicorn
import os
from datetime import datetime
import logging
from contextlib import asynccontextmanager
import threading
from langchain_core.runnables import RunnableLambda
import tempfile
from utils import getconfig
config = getconfig("params.cfg")
RETRIEVER = config.get("retriever", "RETRIEVER")
GENERATOR = config.get("generator", "GENERATOR")
INGESTOR = config.get("ingestor", "INGESTOR")
MAX_CONTEXT_CHARS = int(config.get("general", "MAX_CONTEXT_CHARS", fallback="8000"))
COLLECTION_NAME = config.get("retriever", "COLLECTION_NAME")
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Models
class GraphState(TypedDict):
query: str
context: str
ingestor_context: str
result: str
country: str
file_content: Optional[bytes]
filename: Optional[str]
metadata: Optional[Dict[str, Any]]
class ChatFedInput(TypedDict, total=False):
query: str
country: Optional[str]
session_id: Optional[str]
user_id: Optional[str]
file_content: Optional[bytes]
filename: Optional[str]
class ChatFedOutput(TypedDict):
result: str
metadata: Dict[str, Any]
class ChatUIInput(BaseModel):
text: str
# Module functions
def ingest_node(state: GraphState) -> GraphState:
"""Process file through ingestor if file is provided"""
start_time = datetime.now()
# If no file provided, skip this step
if not state.get("file_content") or not state.get("filename"):
logger.info("No file provided, skipping ingestion")
return {"ingestor_context": "", "metadata": state.get("metadata", {})}
logger.info(f"Ingesting file: {state['filename']}")
try:
client = Client(INGESTOR)
# Create a temporary file to upload
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file:
tmp_file.write(state["file_content"])
tmp_file_path = tmp_file.name
try:
# Call the ingestor's ingest endpoint
ingestor_context = client.predict(
file(tmp_file_path),
api_name="/ingest"
)
logger.info(f"Ingest result length: {len(ingestor_context) if ingestor_context else 0}")
# Handle error cases
if isinstance(ingestor_context, str) and ingestor_context.startswith("Error:"):
raise Exception(ingestor_context)
finally:
# Clean up temporary file
os.unlink(tmp_file_path)
duration = (datetime.now() - start_time).total_seconds()
metadata = state.get("metadata", {})
metadata.update({
"ingestion_duration": duration,
"ingestor_context_length": len(ingestor_context) if ingestor_context else 0,
"ingestion_success": True,
"analysis_type": "whisp_geojson"
})
return {
"ingestor_context": ingestor_context,
"metadata": metadata
}
except Exception as e:
duration = (datetime.now() - start_time).total_seconds()
logger.error(f"Ingestion failed: {str(e)}")
metadata = state.get("metadata", {})
metadata.update({
"ingestion_duration": duration,
"ingestion_success": False,
"ingestion_error": str(e)
})
return {"ingestor_context": "", "metadata": metadata}
def retrieve_node(state: GraphState) -> GraphState:
start_time = datetime.now()
logger.info(f"Retrieval: {state['query'][:50]}... Country: {state.get('country', 'All')}")
try:
client = Client(RETRIEVER)
# Create metadata filter for country if specified
country = state.get("country", "").strip()
filter_metadata = {'country': country} if country else None
context = client.predict(
query=state["query"],
collection_name=COLLECTION_NAME, # Use hardcoded value instead of COLLECTION_NAME variable
filter_metadata=filter_metadata,
api_name="/retrieve"
)
duration = (datetime.now() - start_time).total_seconds()
metadata = state.get("metadata", {})
metadata.update({
"retrieval_duration": duration,
"context_length": len(context) if context else 0,
"retrieval_success": True,
"country_filter": state.get("country", "All")
})
return {"context": context, "metadata": metadata}
except Exception as e:
duration = (datetime.now() - start_time).total_seconds()
logger.error(f"Retrieval failed: {str(e)}")
metadata = state.get("metadata", {})
metadata.update({
"retrieval_duration": duration,
"retrieval_success": False,
"retrieval_error": str(e)
})
return {"context": "", "metadata": metadata}
def generate_node(state: GraphState) -> GraphState:
start_time = datetime.now()
logger.info(f"Generation: {state['query'][:50]}...")
try:
# Combine retriever context with ingestor context
retrieved_context = state.get("context", "")
ingestor_context = state.get("ingestor_context", "")
# Limit context size to prevent token overflow
combined_context = ""
if ingestor_context and retrieved_context:
# Prioritize ingestor context, truncate if needed
ingestor_truncated = ingestor_context[:MAX_CONTEXT_CHARS//2] if len(ingestor_context) > MAX_CONTEXT_CHARS//2 else ingestor_context
retrieved_truncated = retrieved_context[:MAX_CONTEXT_CHARS//2] if len(retrieved_context) > MAX_CONTEXT_CHARS//2 else retrieved_context
combined_context = f"=== UPLOADED DOCUMENT CONTEXT ===\n{ingestor_truncated}\n\n=== RETRIEVED CONTEXT ===\n{retrieved_truncated}"
elif ingestor_context:
ingestor_truncated = ingestor_context[:MAX_CONTEXT_CHARS] if len(ingestor_context) > MAX_CONTEXT_CHARS else ingestor_context
combined_context = f"=== UPLOADED DOCUMENT CONTEXT ===\n{ingestor_truncated}"
elif retrieved_context:
combined_context = retrieved_context[:MAX_CONTEXT_CHARS] if len(retrieved_context) > MAX_CONTEXT_CHARS else retrieved_context
client = Client(GENERATOR)
result = client.predict(
query=state["query"],
context=combined_context,
api_name="/generate"
)
duration = (datetime.now() - start_time).total_seconds()
metadata = state.get("metadata", {})
metadata.update({
"generation_duration": duration,
"result_length": len(result) if result else 0,
"combined_context_length": len(combined_context),
"generation_success": True
})
return {"result": result, "metadata": metadata}
except Exception as e:
duration = (datetime.now() - start_time).total_seconds()
logger.error(f"Generation failed: {str(e)}")
metadata = state.get("metadata", {})
metadata.update({
"generation_duration": duration,
"generation_success": False,
"generation_error": str(e)
})
return {"result": f"Error: {str(e)}", "metadata": metadata}
def file_only_node(state: GraphState) -> GraphState:
"""Return ingestor result directly without calling generator"""
logger.info("File-only processing: returning ingestor result directly")
ingestor_context = state.get("ingestor_context", "")
metadata = state.get("metadata", {})
metadata.update({
"processing_type": "file_only",
"result_source": "ingestor"
})
return {
"result": ingestor_context,
"metadata": metadata
}
# Create separate workflows for different processing types
def create_file_workflow():
"""Workflow for file uploads: ingest -> file_only (skip retrieve and generate)"""
workflow = StateGraph(GraphState)
workflow.add_node("ingest", ingest_node)
workflow.add_node("file_only", file_only_node)
workflow.add_edge(START, "ingest")
workflow.add_edge("ingest", "file_only")
workflow.add_edge("file_only", END)
return workflow.compile()
def create_query_workflow():
"""Workflow for queries: retrieve -> generate (skip ingest)"""
workflow = StateGraph(GraphState)
workflow.add_node("retrieve", retrieve_node)
workflow.add_node("generate", generate_node)
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)
return workflow.compile()
# Compile workflows
file_workflow = create_file_workflow()
query_workflow = create_query_workflow()
def process_query_core(
query: str,
country: str = "",
session_id: Optional[str] = None,
user_id: Optional[str] = None,
file_content: Optional[bytes] = None,
filename: Optional[str] = None,
return_metadata: bool = False
):
start_time = datetime.now()
if not session_id:
session_id = f"session_{start_time.strftime('%Y%m%d_%H%M%S')}"
try:
initial_state = {
"query": query,
"context": "",
"ingestor_context": "",
"result": "",
"country": country or "",
"file_content": file_content,
"filename": filename,
"metadata": {
"session_id": session_id,
"user_id": user_id,
"start_time": start_time.isoformat(),
"has_geojson_attachment": file_content is not None,
"country_filter": country or "All"
}
}
# Choose workflow based on whether file is provided
if file_content and filename:
logger.info("File provided - using file workflow (ingest -> file_only)")
final_state = file_workflow.invoke(initial_state)
else:
logger.info("No file provided - using query workflow (retrieve -> generate)")
final_state = query_workflow.invoke(initial_state)
total_duration = (datetime.now() - start_time).total_seconds()
final_metadata = final_state.get("metadata", {})
final_metadata.update({
"total_duration": total_duration,
"end_time": datetime.now().isoformat(),
"pipeline_success": True
})
if return_metadata:
return {"result": final_state["result"], "metadata": final_metadata}
else:
return final_state["result"]
except Exception as e:
total_duration = (datetime.now() - start_time).total_seconds()
logger.error(f"Pipeline failed: {str(e)}")
if return_metadata:
error_metadata = {
"session_id": session_id,
"total_duration": total_duration,
"pipeline_success": False,
"error": str(e)
}
return {"result": f"Error: {str(e)}", "metadata": error_metadata}
else:
return f"Error: {str(e)}"
def process_query_gradio(query: str, file_upload, country: str = "") -> str:
"""Gradio interface function with GeoJSON file upload support"""
file_content = None
filename = None
if file_upload is not None:
try:
with open(file_upload.name, 'rb') as f:
file_content = f.read()
filename = os.path.basename(file_upload.name)
logger.info(f"File uploaded: {filename}, size: {len(file_content)} bytes")
except Exception as e:
logger.error(f"Error reading uploaded file: {str(e)}")
return f"Error reading file: {str(e)}"
return process_query_core(
query=query,
country=country,
file_content=file_content,
filename=filename,
session_id=f"gradio_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
return_metadata=False
)
def chatui_adapter(data) -> str:
try:
if hasattr(data, 'text'):
text = data.text
elif isinstance(data, dict) and 'text' in data:
text = data['text']
else:
logger.error(f"Unexpected input structure: {data}")
return "Error: Invalid input format. Expected 'text' field."
result = process_query_core(
query=text,
session_id=f"chatui_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
return_metadata=False
)
return result
except Exception as e:
logger.error(f"ChatUI error: {str(e)}")
return f"Error: {str(e)}"
def process_query_langserve(input_data: ChatFedInput) -> ChatFedOutput:
result = process_query_core(
query=input_data["query"],
country=input_data.get("country", ""),
session_id=input_data.get("session_id"),
user_id=input_data.get("user_id"),
file_content=input_data.get("file_content"),
filename=input_data.get("filename"),
return_metadata=True
)
return ChatFedOutput(result=result["result"], metadata=result["metadata"])
def create_gradio_interface():
with gr.Blocks(title="EUDR Orchestrator") as demo:
gr.Markdown("# EUDR Orchestrator")
gr.Markdown("Upload GeoJSON files for WHISP API analysis alongside EUDR compliance queries. MCP endpoints available at `/gradio_api/mcp/sse`")
with gr.Row():
with gr.Column():
query_input = gr.Textbox(
label="Query",
lines=2,
placeholder="Ask about EUDR compliance or upload GeoJSON for deforestation analysis...",
info="Enter your EUDR-related question"
)
file_input = gr.File(
label="Upload GeoJSON",
file_types=[".geojson", ".json"],
info="Upload GeoJSON file for geographic deforestation analysis"
)
country_input = gr.Dropdown(
choices=["", "Ecuador", "Guatemala"],
label="Country Filter (Optional)",
value="",
info="Filter EUDR document retrieval by country"
)
submit_btn = gr.Button("Submit", variant="primary")
with gr.Column():
output = gr.Textbox(label="Response", lines=15, show_copy_button=True)
submit_btn.click(
fn=process_query_gradio,
inputs=[query_input, file_input, country_input],
outputs=output
)
return demo
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("ChatFed Orchestrator starting up...")
yield
logger.info("Orchestrator shutting down...")
app = FastAPI(
title="ChatFed Orchestrator",
version="1.0.0",
lifespan=lifespan,
docs_url=None,
redoc_url=None
)
@app.get("/health")
async def health_check():
return {"status": "healthy"}
@app.get("/")
async def root():
return {
"message": "ChatFed Orchestrator API",
"endpoints": {
"health": "/health",
"chatfed": "/chatfed",
"chatfed-ui-stream": "/chatfed-ui-stream",
"chatfed-with-file": "/chatfed-with-file"
}
}
@app.post("/chatfed-with-file")
async def chatfed_with_file(
query: str = Form(...),
file: Optional[UploadFile] = File(None),
country: Optional[str] = Form(""),
session_id: Optional[str] = Form(None),
user_id: Optional[str] = Form(None)
):
"""Endpoint for queries with optional file attachments"""
file_content = None
filename = None
if file:
file_content = await file.read()
filename = file.filename
result = process_query_core(
query=query,
country=country,
file_content=file_content,
filename=filename,
session_id=session_id,
user_id=user_id,
return_metadata=True
)
return ChatFedOutput(result=result["result"], metadata=result["metadata"])
# LangServe routes
add_routes(
app,
RunnableLambda(process_query_langserve),
path="/chatfed",
input_type=ChatFedInput,
output_type=ChatFedOutput
)
add_routes(
app,
RunnableLambda(chatui_adapter),
path="/chatfed-ui-stream",
input_type=ChatUIInput,
output_type=str,
enable_feedback_endpoint=True,
enable_public_trace_link_endpoint=True,
)
def run_gradio_server():
demo = create_gradio_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7861,
mcp_server=True,
show_error=True,
share=False,
quiet=True
)
if __name__ == "__main__":
gradio_thread = threading.Thread(target=run_gradio_server, daemon=True)
gradio_thread.start()
logger.info("Gradio MCP server started on port 7861")
host = os.getenv("HOST", "0.0.0.0")
port = int(os.getenv("PORT", "7860"))
logger.info(f"Starting FastAPI server on {host}:{port}")
uvicorn.run(app, host=host, port=port, log_level="info", access_log=True)