Spaces:
Sleeping
Sleeping
from fireworks import LLM | |
from pydantic import BaseModel | |
import asyncio | |
import json | |
import time | |
from typing import Dict, Any, List | |
from gradio import ChatMessage | |
from src.modules.fed_tools import TOOLS | |
MODELS = { | |
"small": "accounts/fireworks/models/gpt-oss-20b", | |
"large": "accounts/fireworks/models/gpt-oss-120b" | |
} | |
TODAY = time.strftime("%Y-%m-%d") | |
semaphore = asyncio.Semaphore(10) | |
def get_llm(model: str, api_key: str) -> LLM: | |
return LLM(model=MODELS[model], api_key=api_key, deployment_type="serverless") | |
async def get_llm_completion(llm: LLM, prompt_text: str, output_class: BaseModel = None) -> str: | |
if output_class: | |
return llm.chat.completions.create( | |
messages=[ | |
{ | |
"role": "user", | |
"content": prompt_text | |
}, | |
], | |
temperature=0.1, | |
response_format={ | |
"type": "json_object", | |
"schema": output_class.model_json_schema(), | |
}, | |
) | |
return llm.chat.completions.create( | |
messages=[ | |
{ | |
"role": "user", | |
"content": prompt_text | |
}, | |
], | |
temperature=0.1 | |
) | |
async def run_multi_llm_completions(llm: LLM, prompts: list[str], output_class: BaseModel) -> list[str]: | |
""" | |
Run multiple LLM completions in parallel | |
:param llm: | |
:param prompts: | |
:param output_class: | |
:return: | |
""" | |
async with semaphore: | |
if output_class: | |
print(f"Running LLM with structured outputs") | |
tasks = [ | |
asyncio.create_task( | |
get_llm_completion(llm=llm, prompt_text=prompt, output_class=output_class) | |
) for prompt in prompts | |
] | |
else: | |
print(f"Running LLM with non-structured outputs") | |
tasks = [ | |
asyncio.create_task( | |
get_llm_completion(llm=llm, prompt_text=prompt) | |
) for prompt in prompts | |
] | |
return await asyncio.gather(*tasks) | |
def get_orchestrator_decision( | |
*, user_query: str, history: str, api_key: str, prompt_library: Dict[str, str] | |
) -> Dict[str, Any]: | |
"""Use orchestrator LLM to decide which tools to use""" | |
orchestrator_prompt = prompt_library.get('fed_orchestrator') | |
formatted_prompt = orchestrator_prompt.format(user_query=user_query, date=TODAY, conversation_context=history) | |
print("Running function orchestrator") | |
llm = get_llm("large", api_key) | |
response = llm.chat.completions.create( | |
messages=[ | |
{"role": "system", | |
"content": "You are a Federal Reserve tool orchestrator. Always call exactly one function based on the user query analysis."}, | |
{"role": "user", "content": formatted_prompt} | |
], | |
tools=TOOLS, | |
temperature=0.1 | |
) | |
# Extract the response message | |
message = response.choices[0].message | |
return { | |
"success": True, | |
"message": message, | |
"has_tool_calls": bool(message.tool_calls), | |
"tool_calls": message.tool_calls or [] | |
} | |
def execute_tool_calls(tool_calls: List[Any], fed_tools: Dict[str, callable]) -> List[Dict[str, Any]]: | |
"""Execute the tool calls from Fireworks function calling""" | |
results = [] | |
for tool_call in tool_calls: | |
function_name = tool_call.function.name | |
# Parse the arguments JSON string | |
try: | |
parameters = json.loads(tool_call.function.arguments) | |
except json.JSONDecodeError: | |
parameters = {} | |
start_time = time.time() | |
try: | |
# Execute the appropriate function | |
if function_name in fed_tools: | |
tool_func = fed_tools[function_name] | |
result = tool_func(**parameters) | |
else: | |
result = {"success": False, "error": f"Unknown function: {function_name}"} | |
execution_time = time.time() - start_time | |
results.append({ | |
"tool_call_id": tool_call.id, | |
"function": function_name, | |
"parameters": parameters, | |
"result": result, | |
"execution_time": execution_time, | |
"success": result.get("success", False) | |
}) | |
except Exception as e: | |
execution_time = time.time() - start_time | |
results.append({ | |
"tool_call_id": tool_call.id, | |
"function": function_name, | |
"parameters": parameters, | |
"result": {"success": False, "error": str(e)}, | |
"execution_time": execution_time, | |
"success": False | |
}) | |
return results | |
def extract_citations_from_tool_results(tool_results: List[Dict[str, Any]]) -> List[Dict[str, str]]: | |
"""Extract unique citations from tool results""" | |
citations = [] | |
for result in tool_results: | |
if result["success"] and result["result"].get("success"): | |
# Check if result has meeting data with URLs | |
meeting_data = result["result"].get("meeting") | |
if meeting_data: | |
# Handle both single meeting object and list of meetings | |
meetings_to_process = meeting_data if isinstance(meeting_data, list) else [meeting_data] | |
for meeting in meetings_to_process: | |
if isinstance(meeting, dict) and meeting.get("url"): | |
citations.append({ | |
"date": meeting.get("date", "Unknown date"), | |
"url": meeting["url"], | |
"title": meeting.get("title", f"FOMC Meeting {meeting.get('date', '')}") | |
}) | |
# Handle search results | |
elif "results" in result["result"]: | |
for meeting in result["result"]["results"]: | |
if isinstance(meeting, dict) and meeting.get("url"): | |
citations.append({ | |
"date": meeting.get("date", "Unknown date"), | |
"url": meeting["url"], | |
"title": meeting.get("title", f"FOMC Meeting {meeting.get('date', '')}") | |
}) | |
# Remove duplicate citations | |
unique_citations = [] | |
seen_urls = set() | |
for citation in citations: | |
if citation["url"] not in seen_urls: | |
unique_citations.append(citation) | |
seen_urls.add(citation["url"]) | |
return unique_citations | |
def format_response_with_citations(response: str, citations: List[Dict[str, str]]) -> str: | |
"""Format response with citations appended""" | |
if citations: | |
response += "\n\n**π Sources:**\n" | |
for citation in citations: | |
response += f"β’ [{citation['title']}]({citation['url']})\n" | |
return response | |
def update_tool_messages_with_results(tool_results: List[Dict[str, Any]]) -> List[ChatMessage]: | |
"""Update tool messages with execution results""" | |
updated_messages = [] | |
for i, tool_result in enumerate(tool_results): | |
execution_time = tool_result["execution_time"] | |
success_status = "β " if tool_result["success"] else "β" | |
updated_msg = ChatMessage( | |
role="assistant", | |
content=f"{success_status} {tool_result['function']} completed\n\nExecution time: {execution_time:.2f}s\n\nResult summary: {str(tool_result['result'])[:200]}...", | |
metadata={"title": f"π§ Tool {i+1}: {tool_result['function']}", "status": "done", "duration": execution_time} | |
) | |
updated_messages.append(updated_msg) | |
return updated_messages | |
def build_context_from_tool_results(tool_results: List[Dict[str, Any]]) -> str: | |
"""Build combined context from successful tool results""" | |
combined_context = "" | |
for result in tool_results: | |
if result["success"]: | |
combined_context += f"\n\nFrom {result['function']}: {json.dumps(result['result'], indent=2)}" | |
return combined_context | |
def stream_final_response(message: str, system_prompt: str, api_key: str, citations: List[Dict[str, str]]): | |
"""Stream the final Fed Savant response with citations""" | |
llm = get_llm("large", api_key) | |
final_response = "" | |
for chunk in llm.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": message} | |
], | |
temperature=0.1, | |
stream=True | |
): | |
if chunk.choices[0].delta.content: | |
final_response += chunk.choices[0].delta.content | |
yield format_response_with_citations(final_response, citations) | |
def stream_fed_agent_response( | |
*, | |
message: str, | |
api_key: str, | |
prompt_library: Dict[str, str], | |
fed_tools: Dict[str, callable], | |
history: str = "", | |
): | |
"""Main orchestrator function that coordinates tools and generates responses with ChatMessage objects""" | |
if not message.strip(): | |
yield [ChatMessage(role="assistant", | |
content="Please enter a question about Federal Reserve policy or FOMC meetings.")] | |
return | |
if not api_key.strip(): | |
yield [ChatMessage(role="assistant", content="β Please set your FIREWORKS_API_KEY environment variable.")] | |
return | |
messages = [] | |
try: | |
print("Getting orchestrator decision...") | |
orchestrator_result = get_orchestrator_decision( | |
user_query=message, api_key=api_key, history=history, prompt_library=prompt_library | |
) | |
if not orchestrator_result["success"]: | |
yield [ChatMessage(role="assistant", content="β Error in planning phase")] | |
return | |
orchestrator_message = orchestrator_result["message"] | |
# Execute tools if any were called | |
if orchestrator_result["has_tool_calls"]: | |
tool_names = [tc.function.name for tc in orchestrator_result["tool_calls"]] | |
# Show initial tools execution with pending status | |
tools_summary = f"Executing tools: {', '.join(tool_names)}" | |
messages.append(ChatMessage( | |
role="assistant", | |
content=tools_summary, | |
metadata={"title": "π§ Tools Used", "status": "pending"} | |
)) | |
yield messages | |
print(f"Executing the following tools {tool_names}") | |
tool_results = execute_tool_calls(orchestrator_result["tool_calls"], fed_tools) | |
successful_tools = sum(1 for tr in tool_results if tr["success"]) | |
total_time = sum(tr["execution_time"] for tr in tool_results) | |
updated_summary = f"Executed {len(tool_names)} tools: {', '.join(tool_names)} β ({successful_tools}/{len(tool_results)} successful)" | |
messages[0] = ChatMessage( | |
role="assistant", | |
content=updated_summary, | |
metadata={"title": "π§ Tools Used", "status": "done", "duration": total_time} | |
) | |
yield messages | |
combined_context = build_context_from_tool_results(tool_results) | |
citations = extract_citations_from_tool_results(tool_results) | |
# Add thinking indicator with pending status | |
messages.append(ChatMessage( | |
role="assistant", | |
content="Processing Fed data and formulating response...", | |
metadata={"title": "Pondering ....", "status": "pending"} | |
)) | |
yield messages | |
system_prompt_template = prompt_library.get('fed_savant_chat', '') | |
system_prompt = system_prompt_template.format( | |
fed_data_context=combined_context, | |
user_question=message, | |
date=TODAY | |
) | |
# Mark thinking as complete and start response | |
messages[-1] = ChatMessage( | |
role="assistant", | |
content="Analysis complete, generating response...", | |
metadata={"title": "Pondering ....", "status": "done"} | |
) | |
yield messages | |
messages.append(ChatMessage(role="assistant", content="")) | |
for response_chunk in stream_final_response(message, system_prompt, api_key, citations): | |
messages[-1] = ChatMessage(role="assistant", content=response_chunk) | |
yield messages | |
else: | |
if orchestrator_message.content: | |
messages.append(ChatMessage(role="assistant", content=orchestrator_message.content)) | |
yield messages | |
else: | |
system_prompt_template = prompt_library.get('fed_savant_chat', '') | |
system_prompt = system_prompt_template.format( | |
fed_data_context="No specific tool data available.", | |
user_question=message, | |
date=TODAY | |
) | |
messages.append(ChatMessage(role="assistant", content="")) | |
for response_chunk in stream_final_response(message, system_prompt, api_key, []): | |
messages[-1] = ChatMessage(role="assistant", content=response_chunk) | |
yield messages | |
except Exception as e: | |
print(f"Error in stream_fed_agent_response: {str(e)}") | |
error_message = ChatMessage( | |
role="assistant", | |
content=f"β Error generating response: {str(e)}" | |
) | |
if messages: | |
messages.append(error_message) | |
else: | |
messages = [error_message] | |
yield messages |