Fed-AI-Savant / src /modules /llm_completions.py
RobertoBarrosoLuque
Fix tool use frontend implementation
1b4db50
from fireworks import LLM
from pydantic import BaseModel
import asyncio
import json
import time
from typing import Dict, Any, List
from gradio import ChatMessage
from src.modules.fed_tools import TOOLS
MODELS = {
"small": "accounts/fireworks/models/gpt-oss-20b",
"large": "accounts/fireworks/models/gpt-oss-120b"
}
TODAY = time.strftime("%Y-%m-%d")
semaphore = asyncio.Semaphore(10)
def get_llm(model: str, api_key: str) -> LLM:
return LLM(model=MODELS[model], api_key=api_key, deployment_type="serverless")
async def get_llm_completion(llm: LLM, prompt_text: str, output_class: BaseModel = None) -> str:
if output_class:
return llm.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt_text
},
],
temperature=0.1,
response_format={
"type": "json_object",
"schema": output_class.model_json_schema(),
},
)
return llm.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt_text
},
],
temperature=0.1
)
async def run_multi_llm_completions(llm: LLM, prompts: list[str], output_class: BaseModel) -> list[str]:
"""
Run multiple LLM completions in parallel
:param llm:
:param prompts:
:param output_class:
:return:
"""
async with semaphore:
if output_class:
print(f"Running LLM with structured outputs")
tasks = [
asyncio.create_task(
get_llm_completion(llm=llm, prompt_text=prompt, output_class=output_class)
) for prompt in prompts
]
else:
print(f"Running LLM with non-structured outputs")
tasks = [
asyncio.create_task(
get_llm_completion(llm=llm, prompt_text=prompt)
) for prompt in prompts
]
return await asyncio.gather(*tasks)
def get_orchestrator_decision(
*, user_query: str, history: str, api_key: str, prompt_library: Dict[str, str]
) -> Dict[str, Any]:
"""Use orchestrator LLM to decide which tools to use"""
orchestrator_prompt = prompt_library.get('fed_orchestrator')
formatted_prompt = orchestrator_prompt.format(user_query=user_query, date=TODAY, conversation_context=history)
print("Running function orchestrator")
llm = get_llm("large", api_key)
response = llm.chat.completions.create(
messages=[
{"role": "system",
"content": "You are a Federal Reserve tool orchestrator. Always call exactly one function based on the user query analysis."},
{"role": "user", "content": formatted_prompt}
],
tools=TOOLS,
temperature=0.1
)
# Extract the response message
message = response.choices[0].message
return {
"success": True,
"message": message,
"has_tool_calls": bool(message.tool_calls),
"tool_calls": message.tool_calls or []
}
def execute_tool_calls(tool_calls: List[Any], fed_tools: Dict[str, callable]) -> List[Dict[str, Any]]:
"""Execute the tool calls from Fireworks function calling"""
results = []
for tool_call in tool_calls:
function_name = tool_call.function.name
# Parse the arguments JSON string
try:
parameters = json.loads(tool_call.function.arguments)
except json.JSONDecodeError:
parameters = {}
start_time = time.time()
try:
# Execute the appropriate function
if function_name in fed_tools:
tool_func = fed_tools[function_name]
result = tool_func(**parameters)
else:
result = {"success": False, "error": f"Unknown function: {function_name}"}
execution_time = time.time() - start_time
results.append({
"tool_call_id": tool_call.id,
"function": function_name,
"parameters": parameters,
"result": result,
"execution_time": execution_time,
"success": result.get("success", False)
})
except Exception as e:
execution_time = time.time() - start_time
results.append({
"tool_call_id": tool_call.id,
"function": function_name,
"parameters": parameters,
"result": {"success": False, "error": str(e)},
"execution_time": execution_time,
"success": False
})
return results
def extract_citations_from_tool_results(tool_results: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Extract unique citations from tool results"""
citations = []
for result in tool_results:
if result["success"] and result["result"].get("success"):
# Check if result has meeting data with URLs
meeting_data = result["result"].get("meeting")
if meeting_data:
# Handle both single meeting object and list of meetings
meetings_to_process = meeting_data if isinstance(meeting_data, list) else [meeting_data]
for meeting in meetings_to_process:
if isinstance(meeting, dict) and meeting.get("url"):
citations.append({
"date": meeting.get("date", "Unknown date"),
"url": meeting["url"],
"title": meeting.get("title", f"FOMC Meeting {meeting.get('date', '')}")
})
# Handle search results
elif "results" in result["result"]:
for meeting in result["result"]["results"]:
if isinstance(meeting, dict) and meeting.get("url"):
citations.append({
"date": meeting.get("date", "Unknown date"),
"url": meeting["url"],
"title": meeting.get("title", f"FOMC Meeting {meeting.get('date', '')}")
})
# Remove duplicate citations
unique_citations = []
seen_urls = set()
for citation in citations:
if citation["url"] not in seen_urls:
unique_citations.append(citation)
seen_urls.add(citation["url"])
return unique_citations
def format_response_with_citations(response: str, citations: List[Dict[str, str]]) -> str:
"""Format response with citations appended"""
if citations:
response += "\n\n**πŸ“š Sources:**\n"
for citation in citations:
response += f"β€’ [{citation['title']}]({citation['url']})\n"
return response
def update_tool_messages_with_results(tool_results: List[Dict[str, Any]]) -> List[ChatMessage]:
"""Update tool messages with execution results"""
updated_messages = []
for i, tool_result in enumerate(tool_results):
execution_time = tool_result["execution_time"]
success_status = "βœ…" if tool_result["success"] else "❌"
updated_msg = ChatMessage(
role="assistant",
content=f"{success_status} {tool_result['function']} completed\n\nExecution time: {execution_time:.2f}s\n\nResult summary: {str(tool_result['result'])[:200]}...",
metadata={"title": f"πŸ”§ Tool {i+1}: {tool_result['function']}", "status": "done", "duration": execution_time}
)
updated_messages.append(updated_msg)
return updated_messages
def build_context_from_tool_results(tool_results: List[Dict[str, Any]]) -> str:
"""Build combined context from successful tool results"""
combined_context = ""
for result in tool_results:
if result["success"]:
combined_context += f"\n\nFrom {result['function']}: {json.dumps(result['result'], indent=2)}"
return combined_context
def stream_final_response(message: str, system_prompt: str, api_key: str, citations: List[Dict[str, str]]):
"""Stream the final Fed Savant response with citations"""
llm = get_llm("large", api_key)
final_response = ""
for chunk in llm.chat.completions.create(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": message}
],
temperature=0.1,
stream=True
):
if chunk.choices[0].delta.content:
final_response += chunk.choices[0].delta.content
yield format_response_with_citations(final_response, citations)
def stream_fed_agent_response(
*,
message: str,
api_key: str,
prompt_library: Dict[str, str],
fed_tools: Dict[str, callable],
history: str = "",
):
"""Main orchestrator function that coordinates tools and generates responses with ChatMessage objects"""
if not message.strip():
yield [ChatMessage(role="assistant",
content="Please enter a question about Federal Reserve policy or FOMC meetings.")]
return
if not api_key.strip():
yield [ChatMessage(role="assistant", content="❌ Please set your FIREWORKS_API_KEY environment variable.")]
return
messages = []
try:
print("Getting orchestrator decision...")
orchestrator_result = get_orchestrator_decision(
user_query=message, api_key=api_key, history=history, prompt_library=prompt_library
)
if not orchestrator_result["success"]:
yield [ChatMessage(role="assistant", content="❌ Error in planning phase")]
return
orchestrator_message = orchestrator_result["message"]
# Execute tools if any were called
if orchestrator_result["has_tool_calls"]:
tool_names = [tc.function.name for tc in orchestrator_result["tool_calls"]]
# Show initial tools execution with pending status
tools_summary = f"Executing tools: {', '.join(tool_names)}"
messages.append(ChatMessage(
role="assistant",
content=tools_summary,
metadata={"title": "πŸ”§ Tools Used", "status": "pending"}
))
yield messages
print(f"Executing the following tools {tool_names}")
tool_results = execute_tool_calls(orchestrator_result["tool_calls"], fed_tools)
successful_tools = sum(1 for tr in tool_results if tr["success"])
total_time = sum(tr["execution_time"] for tr in tool_results)
updated_summary = f"Executed {len(tool_names)} tools: {', '.join(tool_names)} βœ… ({successful_tools}/{len(tool_results)} successful)"
messages[0] = ChatMessage(
role="assistant",
content=updated_summary,
metadata={"title": "πŸ”§ Tools Used", "status": "done", "duration": total_time}
)
yield messages
combined_context = build_context_from_tool_results(tool_results)
citations = extract_citations_from_tool_results(tool_results)
# Add thinking indicator with pending status
messages.append(ChatMessage(
role="assistant",
content="Processing Fed data and formulating response...",
metadata={"title": "Pondering ....", "status": "pending"}
))
yield messages
system_prompt_template = prompt_library.get('fed_savant_chat', '')
system_prompt = system_prompt_template.format(
fed_data_context=combined_context,
user_question=message,
date=TODAY
)
# Mark thinking as complete and start response
messages[-1] = ChatMessage(
role="assistant",
content="Analysis complete, generating response...",
metadata={"title": "Pondering ....", "status": "done"}
)
yield messages
messages.append(ChatMessage(role="assistant", content=""))
for response_chunk in stream_final_response(message, system_prompt, api_key, citations):
messages[-1] = ChatMessage(role="assistant", content=response_chunk)
yield messages
else:
if orchestrator_message.content:
messages.append(ChatMessage(role="assistant", content=orchestrator_message.content))
yield messages
else:
system_prompt_template = prompt_library.get('fed_savant_chat', '')
system_prompt = system_prompt_template.format(
fed_data_context="No specific tool data available.",
user_question=message,
date=TODAY
)
messages.append(ChatMessage(role="assistant", content=""))
for response_chunk in stream_final_response(message, system_prompt, api_key, []):
messages[-1] = ChatMessage(role="assistant", content=response_chunk)
yield messages
except Exception as e:
print(f"Error in stream_fed_agent_response: {str(e)}")
error_message = ChatMessage(
role="assistant",
content=f"❌ Error generating response: {str(e)}"
)
if messages:
messages.append(error_message)
else:
messages = [error_message]
yield messages