Spaces:
Running
Running
from fireworks import LLM | |
from pydantic import BaseModel | |
import asyncio | |
import json | |
import time | |
from typing import Dict, Any, List | |
from gradio import ChatMessage | |
MODELS = { | |
"small": "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507", | |
"large": "accounts/fireworks/models/kimi-k2-instruct" | |
} | |
TODAY = time.strftime("%Y-%m-%d") | |
semaphore = asyncio.Semaphore(10) | |
def get_llm(model: str, api_key: str) -> LLM: | |
return LLM(model=MODELS[model], api_key=api_key, deployment_type="serverless") | |
async def get_llm_completion(llm: LLM, prompt_text: str, output_class: BaseModel = None) -> str: | |
if output_class: | |
return llm.chat.completions.create( | |
messages=[ | |
{ | |
"role": "user", | |
"content": prompt_text | |
}, | |
], | |
temperature=0.1, | |
response_format={ | |
"type": "json_object", | |
"schema": output_class.model_json_schema(), | |
}, | |
) | |
return llm.chat.completions.create( | |
messages=[ | |
{ | |
"role": "user", | |
"content": prompt_text | |
}, | |
], | |
temperature=0.1 | |
) | |
async def get_streaming_completion(llm: LLM, prompt_text: str, system_prompt: str = None): | |
""" | |
Get streaming completion from LLM for real-time responses | |
:param llm: The LLM instance | |
:param prompt_text: The user's input message | |
:param system_prompt: Optional system prompt for context | |
:return: Generator yielding response chunks | |
""" | |
messages = [] | |
if system_prompt: | |
messages.append({ | |
"role": "system", | |
"content": system_prompt | |
}) | |
messages.append({ | |
"role": "user", | |
"content": prompt_text | |
}) | |
try: | |
response = llm.chat.completions.create( | |
messages=messages, | |
temperature=0.2, | |
stream=True, | |
max_tokens=1000 | |
) | |
for chunk in response: | |
if chunk.choices[0].delta.content: | |
yield chunk.choices[0].delta.content | |
except Exception as e: | |
yield f"Error generating response: {str(e)}" | |
async def run_multi_llm_completions(llm: LLM, prompts: list[str], output_class: BaseModel) -> list[str]: | |
""" | |
Run multiple LLM completions in parallel | |
:param llm: | |
:param prompts: | |
:param output_class: | |
:return: | |
""" | |
async with semaphore: | |
if output_class: | |
print(f"Running LLM with structured outputs") | |
tasks = [ | |
asyncio.create_task( | |
get_llm_completion(llm=llm, prompt_text=prompt, output_class=output_class) | |
) for prompt in prompts | |
] | |
else: | |
print(f"Running LLM with non-structured outputs") | |
tasks = [ | |
asyncio.create_task( | |
get_llm_completion(llm=llm, prompt_text=prompt) | |
) for prompt in prompts | |
] | |
return await asyncio.gather(*tasks) | |
def get_orchestrator_decision(user_query: str, api_key: str, prompt_library: Dict[str, str]) -> Dict[str, Any]: | |
"""Use orchestrator LLM to decide which tools to use""" | |
try: | |
orchestrator_prompt = prompt_library.get('fed_orchestrator', '') | |
formatted_prompt = orchestrator_prompt.format(user_query=user_query, date=TODAY) | |
llm = get_llm("large", api_key) | |
response = llm.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": "You are a tool orchestrator. Always respond with valid JSON."}, | |
{"role": "user", "content": formatted_prompt} | |
], | |
temperature=0.1, | |
max_tokens=500 | |
) | |
# Parse JSON response | |
result = json.loads(response.choices[0].message.content) | |
return {"success": True, "decision": result} | |
except Exception as e: | |
print(f"Error in orchestrator: {e}") | |
# Fallback to simple logic | |
return { | |
"success": False, | |
"decision": { | |
"tools_needed": [{"function": "get_latest_meeting", "parameters": {}, "reasoning": "Fallback to latest meeting"}], | |
"query_analysis": f"Error occurred, using fallback for: {user_query}" | |
} | |
} | |
def execute_fed_tools(tools_decision: Dict[str, Any], fed_tools: Dict[str, callable]) -> List[Dict[str, Any]]: | |
"""Execute the tools determined by the orchestrator""" | |
results = [] | |
for tool in tools_decision.get("tools_needed", []): | |
function_name = tool.get("function", "") | |
parameters = tool.get("parameters", {}) | |
reasoning = tool.get("reasoning", "") | |
start_time = time.time() | |
try: | |
# Execute the appropriate function | |
if function_name in fed_tools: | |
tool_func = fed_tools[function_name] | |
result = tool_func(**parameters) | |
else: | |
result = {"success": False, "error": f"Unknown function: {function_name}"} | |
execution_time = time.time() - start_time | |
results.append({ | |
"function": function_name, | |
"parameters": parameters, | |
"reasoning": reasoning, | |
"result": result, | |
"execution_time": execution_time, | |
"success": result.get("success", False) | |
}) | |
except Exception as e: | |
execution_time = time.time() - start_time | |
results.append({ | |
"function": function_name, | |
"parameters": parameters, | |
"reasoning": reasoning, | |
"result": {"success": False, "error": str(e)}, | |
"execution_time": execution_time, | |
"success": False | |
}) | |
return results | |
def stream_fed_agent_response( | |
message: str, | |
api_key: str, | |
prompt_library: Dict[str, str], | |
fed_tools: Dict[str, callable] | |
): | |
"""Main orchestrator function that coordinates tools and generates responses with ChatMessage objects""" | |
if not message.strip(): | |
yield [ChatMessage(role="assistant", content="Please enter a question about Federal Reserve policy or FOMC meetings.")] | |
return | |
if not api_key.strip(): | |
yield [ChatMessage(role="assistant", content="β Please set your FIREWORKS_API_KEY environment variable.")] | |
return | |
messages = [] | |
try: | |
# Step 1: Use orchestrator to determine tools needed | |
messages.append(ChatMessage( | |
role="assistant", | |
content="Analyzing your query...", | |
metadata={"title": "π§ Planning", "status": "pending"} | |
)) | |
yield messages | |
orchestrator_result = get_orchestrator_decision(message, api_key, prompt_library) | |
tools_decision = orchestrator_result["decision"] | |
# Update planning message | |
messages[0] = ChatMessage( | |
role="assistant", | |
content=f"Query Analysis: {tools_decision.get('query_analysis', 'Analyzing Fed data requirements')}\n\nTools needed: {len(tools_decision.get('tools_needed', []))}", | |
metadata={"title": "π§ Planning", "status": "done"} | |
) | |
yield messages | |
# Step 2: Execute the determined tools | |
if tools_decision.get("tools_needed"): | |
for i, tool in enumerate(tools_decision["tools_needed"]): | |
tool_msg = ChatMessage( | |
role="assistant", | |
content=f"Executing: {tool['function']}({', '.join([f'{k}={v}' for k, v in tool['parameters'].items()])})\n\nReasoning: {tool['reasoning']}", | |
metadata={"title": f"π§ Tool {i+1}: {tool['function']}", "status": "pending"} | |
) | |
messages.append(tool_msg) | |
yield messages | |
# Execute all tools | |
tool_results = execute_fed_tools(tools_decision, fed_tools) | |
# Update tool messages with results | |
for i, (tool_result, tool_msg) in enumerate(zip(tool_results, messages[1:])): | |
execution_time = tool_result["execution_time"] | |
success_status = "β " if tool_result["success"] else "β" | |
messages[i+1] = ChatMessage( | |
role="assistant", | |
content=f"{success_status} {tool_result['function']} completed\n\nExecution time: {execution_time:.2f}s\n\nResult summary: {str(tool_result['result'])[:200]}...", | |
metadata={"title": f"π§ Tool {i+1}: {tool_result['function']}", "status": "done", "duration": execution_time} | |
) | |
yield messages | |
# Step 3: Use results to generate final response | |
combined_context = "" | |
for result in tool_results: | |
if result["success"]: | |
combined_context += f"\n\nFrom {result['function']}: {json.dumps(result['result'], indent=2)}" | |
# Generate Fed Savant response using tool results | |
system_prompt_template = prompt_library.get('fed_savant_chat', '') | |
system_prompt = system_prompt_template.format( | |
fed_data_context=combined_context, | |
user_question=message, | |
date=TODAY | |
) | |
# Initialize LLM and get streaming response | |
llm = get_llm("large", api_key) | |
final_response = "" | |
for chunk in llm.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": message} | |
], | |
temperature=0.2, | |
stream=True, | |
max_tokens=1000 | |
): | |
if chunk.choices[0].delta.content: | |
final_response += chunk.choices[0].delta.content | |
# Update messages list with current response | |
if len(messages) > len(tool_results): | |
messages[-1] = ChatMessage(role="assistant", content=final_response) | |
else: | |
messages.append(ChatMessage(role="assistant", content=final_response)) | |
yield messages | |
else: | |
# No tools needed, direct response | |
messages.append(ChatMessage(role="assistant", content="No specific tools required. Providing general Fed information.")) | |
yield messages | |
except Exception as e: | |
messages.append(ChatMessage(role="assistant", content=f"Error generating response: {str(e)}")) | |
yield messages | |