|
import os |
|
import io |
|
import sys |
|
import re |
|
import traceback |
|
import subprocess |
|
import warnings |
|
import gradio as gr |
|
import pandas as pd |
|
from dotenv import load_dotenv |
|
from crewai import Crew, Agent, Task, Process, LLM |
|
from crewai_tools import FileReadTool |
|
from pydantic import BaseModel, Field |
|
|
|
|
|
warnings.filterwarnings('ignore', category=FutureWarning, module='yfinance.*') |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') |
|
if not OPENAI_API_KEY: |
|
raise ValueError("OPENAI_API_KEY environment variable not set") |
|
|
|
llm = LLM( |
|
model="openai/gpt-4o", |
|
api_key=OPENAI_API_KEY, |
|
temperature=0.7 |
|
) |
|
|
|
|
|
query_parser_agent = Agent( |
|
role="Stock Data Analyst", |
|
goal="Extract stock details and fetch required data from this user query: {query}.", |
|
backstory="You are a financial analyst specializing in stock market data retrieval.", |
|
llm=llm, |
|
verbose=True, |
|
memory=True, |
|
) |
|
|
|
|
|
class QueryAnalysisOutput(BaseModel): |
|
"""Structured output for the query analysis task.""" |
|
symbols: list[str] = Field( |
|
..., |
|
json_schema_extra={"description": "List of stock ticker symbols (e.g., ['TSLA', 'AAPL'])."} |
|
) |
|
timeframe: str = Field( |
|
..., |
|
json_schema_extra={"description": "Time period (e.g., '1d', '1mo', '1y')."} |
|
) |
|
action: str = Field( |
|
..., |
|
json_schema_extra={"description": "Action to be performed (e.g., 'fetch', 'plot')."} |
|
) |
|
|
|
|
|
query_parsing_task = Task( |
|
description="Analyze the user query and extract stock details.", |
|
expected_output="A dictionary with keys: 'symbol', 'timeframe', 'action'.", |
|
output_pydantic=QueryAnalysisOutput, |
|
agent=query_parser_agent, |
|
) |
|
|
|
|
|
code_writer_agent = Agent( |
|
role="Senior Python Developer", |
|
goal="Write Python code to visualize stock data.", |
|
backstory="""You are a Senior Python developer specializing in stock market data visualization. |
|
You are also a Pandas, Matplotlib and yfinance library expert. |
|
You are skilled at writing production-ready Python code. |
|
Ensure the code handles potential variations in the DataFrame structure returned by yfinance, |
|
especially for different timeframes or delisted stocks. |
|
Crucially, ensure the generated script saves any generated plot as 'plot.png' using `plt.savefig('plot.png')` before the script ends.""", |
|
llm=llm, |
|
verbose=True, |
|
) |
|
|
|
code_writer_task = Task( |
|
description="""Write Python code to visualize stock data based on the inputs from the stock analyst |
|
where you would find stock symbol, timeframe and action.""", |
|
expected_output="A clean and executable Python script file (.py) for stock visualization.", |
|
agent=code_writer_agent, |
|
) |
|
|
|
|
|
code_output_agent = Agent( |
|
role="Python Code Presenter", |
|
goal="Present the generated Python code for stock visualization.", |
|
backstory="You are an expert in presenting Python code in a clear and readable format.", |
|
allow_delegation=False, |
|
llm=llm, |
|
verbose=True, |
|
) |
|
|
|
code_output_task = Task( |
|
description="""Receive the Python code for stock visualization from the code writer agent and present it.""", |
|
expected_output="The complete Python script for stock visualization.", |
|
agent=code_output_agent, |
|
) |
|
|
|
crew = Crew( |
|
agents=[query_parser_agent, code_writer_agent, code_output_agent], |
|
tasks=[query_parsing_task, code_writer_task, code_output_task], |
|
process=Process.sequential |
|
) |
|
|
|
|
|
def run_crewai_process(user_query, model, temperature): |
|
""" |
|
Runs the CrewAI process, captures agent thoughts, gets generated code, |
|
executes the code, and returns results, including plot. |
|
|
|
Args: |
|
user_query (str): The user's query for the CrewAI process. |
|
model (str): The model to use for the LLM. |
|
temperature (float): The temperature to use for the LLM. |
|
|
|
Yields: |
|
tuple: A tuple containing the agent thoughts (str), the final answer (list of dicts), |
|
the generated code (str), the execution output (str), and plot file path (str or None). |
|
""" |
|
|
|
output_buffer = io.StringIO() |
|
original_stdout = sys.stdout |
|
sys.stdout = output_buffer |
|
agent_thoughts = "" |
|
generated_code = "" |
|
execution_output = "" |
|
generated_plot_path = None |
|
final_answer_chat = [{"role": "user", "content": user_query}] |
|
|
|
try: |
|
|
|
initial_message = {"role": "assistant", "content": "Starting CrewAI process..."} |
|
final_answer_chat = [{"role": "user", "content": str(user_query)}, initial_message] |
|
yield final_answer_chat, agent_thoughts, generated_code, execution_output, None, None |
|
|
|
|
|
final_result = crew.kickoff(inputs={"query": user_query}) |
|
|
|
|
|
agent_thoughts = output_buffer.getvalue() |
|
|
|
|
|
processing_message = {"role": "assistant", "content": "Processing complete. Generating code..."} |
|
final_answer_chat = [{"role": "user", "content": str(user_query)}, processing_message] |
|
yield final_answer_chat, agent_thoughts, generated_code, execution_output, None, None |
|
|
|
|
|
generated_code_raw = str(final_result).strip() |
|
|
|
|
|
code_match = re.search(r"```python\n(.*?)\n```", generated_code_raw, re.DOTALL) |
|
if code_match: |
|
generated_code = code_match.group(1).strip() |
|
else: |
|
|
|
generated_code = generated_code_raw |
|
if not generated_code.strip(): |
|
execution_output = "CrewAI process completed, but no code was generated." |
|
final_answer_chat.append({"role": "assistant", "content": execution_output}) |
|
yield agent_thoughts, final_answer_chat, generated_code, execution_output, generated_plot_path |
|
return |
|
|
|
|
|
code_gen_message = {"role": "assistant", "content": "Code generation complete. See the 'Generated Code' box. Attempting to execute code..."} |
|
final_answer_chat = [{"role": "user", "content": str(user_query)}, code_gen_message] |
|
yield final_answer_chat, agent_thoughts, generated_code, execution_output, None, None |
|
|
|
|
|
|
|
|
|
symbol_plot_patterns = ['META_plot.png', 'AAPL_plot.png', 'MSFT_plot.png', 'GOOG_plot.png', 'TSLA_plot.png'] |
|
|
|
|
|
generic_plot_patterns = ['plot.png', 'output.png', 'figure.png'] |
|
|
|
|
|
plot_file_paths = symbol_plot_patterns + generic_plot_patterns |
|
|
|
if generated_code: |
|
try: |
|
|
|
temp_script_path = "generated_script.py" |
|
with open(temp_script_path, "w") as f: |
|
f.write(generated_code) |
|
|
|
|
|
with open(temp_script_path, 'r') as f: |
|
script_content = f.read() |
|
|
|
|
|
def add_auto_adjust(match): |
|
|
|
args = match.group(1).strip() |
|
if 'auto_adjust' not in args: |
|
|
|
if args.endswith(','): |
|
return f'yf.download({args} auto_adjust=True)' |
|
elif args: |
|
return f'yf.download({args}, auto_adjust=True)' |
|
else: |
|
return 'yf.download(auto_adjust=True)' |
|
return match.group(0) |
|
|
|
|
|
script_content = re.sub( |
|
r'yf\.download\(([^)]*)\)', |
|
add_auto_adjust, |
|
script_content |
|
) |
|
|
|
|
|
helpers = """ |
|
# Standard plot filename to use |
|
PLOT_FILENAME = 'generated_plot.png' |
|
|
|
# Helper functions for data processing |
|
def safe_get_column(df, column): |
|
# Handle case where column is a tuple (e.g., from multi-index) |
|
if isinstance(column, tuple): |
|
column = column[0] # Take the first element of the tuple |
|
|
|
# Convert column to string in case it's not |
|
column = str(column) |
|
|
|
# Try exact match first |
|
if column in df.columns: |
|
return df[column] |
|
|
|
# Try case-insensitive match |
|
try: |
|
col_lower = column.lower() |
|
for col in df.columns: |
|
if str(col).lower() == col_lower: |
|
return df[col] |
|
except (AttributeError, TypeError): |
|
pass # Skip case-insensitive matching if not applicable |
|
|
|
# If not found, try common variations |
|
variations = { |
|
'close': ['Close', 'Adj Close', 'close', 'adj close', 'CLOSE', 'Adj. Close'], |
|
'adj close': ['Adj Close', 'adj close', 'ADJ CLOSE', 'Close', 'close', 'CLOSE', 'Adj. Close'] |
|
} |
|
|
|
for var_list in variations.values(): |
|
for var in var_list: |
|
if var in df.columns: |
|
return df[var] |
|
|
|
# If still not found, try to find any column containing 'close' |
|
for col in df.columns: |
|
if 'close' in str(col).lower(): |
|
return df[col] |
|
|
|
# If still not found, raise a helpful error |
|
raise KeyError(f"Column '{column}' not found in DataFrame. Available columns: {list(df.columns)}") |
|
|
|
def show_plot(plt): |
|
try: |
|
# Use a non-interactive backend to avoid GUI issues |
|
import matplotlib |
|
matplotlib.use('Agg') |
|
|
|
# Create a temporary file to store the plot |
|
import io |
|
import base64 |
|
|
|
# Save plot to a bytes buffer |
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', bbox_inches='tight', dpi=100) |
|
plt.close() |
|
|
|
# Convert to base64 for display |
|
buf.seek(0) |
|
img_str = base64.b64encode(buf.read()).decode('utf-8') |
|
buf.close() |
|
|
|
# Return HTML to display the image |
|
return f'<img src="data:image/png;base64,{img_str}" />' |
|
except Exception as e: |
|
print(f"[ERROR] Failed to display plot: {str(e)}") |
|
return None |
|
|
|
# Monkey patch DataFrame to add safe column access |
|
import pandas as pd |
|
pd.DataFrame.safe_get = safe_get_column |
|
""" |
|
|
|
if 'import ' in script_content: |
|
|
|
last_import = script_content.rfind('import ') |
|
insert_pos = script_content.find('\n', last_import) + 1 |
|
script_content = script_content[:insert_pos] + '\n' + helpers + script_content[insert_pos:] |
|
else: |
|
|
|
script_content = helpers + '\n' + script_content |
|
|
|
|
|
script_content = script_content.replace("['Adj Close']", ".safe_get('close')") |
|
script_content = script_content.replace("['Close']", ".safe_get('close')") |
|
script_content = script_content.replace("['close']", ".safe_get('close')") |
|
|
|
|
|
script_content = re.sub( |
|
r'plt\\.show\\(\s*\\)', |
|
r'print(show_plot(plt))', |
|
script_content |
|
) |
|
|
|
|
|
if 'plt.show()' not in script_content: |
|
script_content += "\n# Display the plot if any figures exist\nif 'plt' in locals() and len(plt.get_fignums()) > 0:\n print(show_plot(plt))\n" |
|
|
|
|
|
with open(temp_script_path, 'w') as f: |
|
f.write(script_content) |
|
|
|
|
|
|
|
process = subprocess.run( |
|
["python3", temp_script_path], |
|
capture_output=True, |
|
text=True, |
|
check=False |
|
) |
|
execution_output = process.stdout + process.stderr |
|
|
|
|
|
if "KeyError" in execution_output: |
|
execution_output += "\n\nPotential Issue: The generated script encountered a KeyError. This might mean the script tried to access a column or data point that wasn't available for the specified stock or timeframe. Please try a different query or timeframe." |
|
elif "FileNotFoundError: [Errno 2] No such file or directory: 'plot.png'" in execution_output and "Figure(" in execution_output: |
|
execution_output += "\n\nPlot Generation Issue: The script seems to have created a plot but did not save it to 'plot.png'. Please ensure the generated code includes `plt.savefig('plot.png')`." |
|
elif "FileNotFoundError: [Errno 2] No such file or directory: 'plot.png'" in execution_output: |
|
execution_output += "\n\nPlot Generation Issue: The script ran, but the plot file was not created. Ensure the generated code includes commands to save the plot to 'plot.png'." |
|
|
|
|
|
generated_plot_path = None |
|
plot_found = False |
|
|
|
|
|
plot_file_paths = ['generated_plot.png', 'plot.png', 'META_plot.png', 'AAPL_plot.png', 'MSFT_plot.png', 'output.png'] |
|
|
|
|
|
current_dir = os.path.abspath('.') |
|
png_files = [f for f in os.listdir(current_dir) |
|
if f.endswith('.png') and not f.startswith('gradio_')] |
|
|
|
|
|
plot_file_paths.extend(png_files) |
|
|
|
|
|
plot_file_paths = list(dict.fromkeys([os.path.abspath(f) for f in plot_file_paths])) |
|
|
|
print(f"[DEBUG] Looking for plot files in: {plot_file_paths}") |
|
|
|
for plot_file in plot_file_paths: |
|
try: |
|
if os.path.exists(plot_file) and os.path.getsize(plot_file) > 0: |
|
print(f"[DEBUG] Found plot file: {plot_file}") |
|
generated_plot_path = plot_file |
|
plot_found = True |
|
break |
|
except Exception as e: |
|
print(f"[DEBUG] Error checking plot file {plot_file}: {e}") |
|
|
|
|
|
try: |
|
import base64 |
|
with open(plot_abs_path, 'rb') as img_file: |
|
img_str = base64.b64encode(img_file.read()).decode('utf-8') |
|
execution_output += f"\n\n" |
|
except Exception as e: |
|
execution_output += f"\n\nNote: Could not embed plot in output: {str(e)}" |
|
|
|
break |
|
|
|
if not plot_found: |
|
|
|
current_dir = os.path.abspath('.') |
|
png_files = [f for f in os.listdir(current_dir) if f.endswith('.png') and not f.startswith('gradio_')] |
|
if png_files: |
|
|
|
plot_abs_path = os.path.abspath(png_files[0]) |
|
generated_plot_path = plot_abs_path |
|
print(f"Using plot file found at: {plot_abs_path}") |
|
|
|
|
|
try: |
|
import base64 |
|
with open(plot_abs_path, 'rb') as img_file: |
|
img_str = base64.b64encode(img_file.read()).decode('utf-8') |
|
execution_output += f"\n\n" |
|
except Exception as e: |
|
execution_output += f"\n\nNote: Could not embed plot in output: {str(e)}" |
|
else: |
|
print(f"No plot file found in {current_dir}") |
|
execution_output += "\nNo plot file was found after execution.\n\nMake sure the generated code includes:\n1. `plt.savefig('plot.png')` to save the plot\n2. `plt.close()` to close the figure after saving" |
|
|
|
except Exception as e: |
|
traceback_str = traceback.format_exc() |
|
execution_output = f"An error occurred during code execution: {e}\n{traceback_str}" |
|
|
|
finally: |
|
|
|
if os.path.exists(temp_script_path): |
|
os.remove(temp_script_path) |
|
|
|
else: |
|
execution_output = "No code was generated to execute." |
|
|
|
|
|
execution_complete_msg = "Code execution finished. See 'Execution Output'." |
|
if generated_plot_path: |
|
plot_msg = f"Plot generated successfully at: {generated_plot_path}" |
|
final_answer_chat = [ |
|
{"role": "user", "content": str(user_query)}, |
|
{"role": "assistant", "content": execution_complete_msg}, |
|
{"role": "assistant", "content": plot_msg} |
|
] |
|
|
|
yield final_answer_chat, agent_thoughts, generated_code, execution_output, None, generated_plot_path |
|
else: |
|
no_plot_msg = "No plot was generated. Make sure your query includes a request for a visualization. Check the 'Execution Output' tab for any errors." |
|
final_answer_chat = [ |
|
{"role": "user", "content": str(user_query)}, |
|
{"role": "assistant", "content": execution_complete_msg}, |
|
{"role": "assistant", "content": no_plot_msg} |
|
] |
|
|
|
yield final_answer_chat, agent_thoughts, generated_code, execution_output, None, None |
|
|
|
yield agent_thoughts, final_answer_chat, generated_code, execution_output, generated_plot_path |
|
|
|
except Exception as e: |
|
|
|
traceback_str = traceback.format_exc() |
|
agent_thoughts += f"\nAn error occurred during CrewAI process: {e}\n{traceback_str}" |
|
error_message = f"An error occurred during CrewAI process: {e}" |
|
final_answer_chat = [ |
|
{"role": "user", "content": str(user_query)}, |
|
{"role": "assistant", "content": error_message} |
|
] |
|
yield final_answer_chat, agent_thoughts, generated_code, execution_output, None, None |
|
|
|
finally: |
|
|
|
sys.stdout = original_stdout |
|
|
|
|
|
def create_interface(): |
|
"""Create and return the Gradio interface.""" |
|
with gr.Blocks(title="Financial Analytics Agent", theme=gr.themes.Soft()) as interface: |
|
gr.Markdown("# 📊 Financial Analytics Agent") |
|
gr.Markdown("Enter your financial query to analyze stock data and generate visualizations.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
user_query_input = gr.Textbox( |
|
label="Enter your financial query", |
|
placeholder="e.g., Show me the stock performance of AAPL and MSFT for the last year", |
|
lines=3 |
|
) |
|
submit_btn = gr.Button("Analyze", variant="primary") |
|
|
|
with gr.Accordion("Advanced Options", open=False): |
|
gr.Markdown("### Model Settings") |
|
model_dropdown = gr.Dropdown( |
|
["gpt-4o", "gpt-4-turbo", "gpt-3.5-turbo"], |
|
value="gpt-4o", |
|
label="Model" |
|
) |
|
temperature = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.7, |
|
step=0.1, |
|
label="Creativity (Temperature)" |
|
) |
|
|
|
with gr.Column(scale=3): |
|
with gr.Tabs(): |
|
with gr.TabItem("Analysis"): |
|
final_answer_chat = gr.Chatbot( |
|
label="Analysis Results", |
|
height=300, |
|
show_copy_button=True, |
|
type="messages" |
|
) |
|
|
|
with gr.TabItem("Agent Thoughts"): |
|
agent_thoughts = gr.Textbox( |
|
label="Agent Thinking Process", |
|
interactive=False, |
|
lines=15, |
|
max_lines=30, |
|
show_copy_button=True |
|
) |
|
|
|
with gr.TabItem("Generated Code"): |
|
generated_code = gr.Code( |
|
label="Generated Python Code", |
|
language="python", |
|
interactive=False, |
|
lines=15 |
|
) |
|
|
|
with gr.TabItem("Execution Output"): |
|
execution_output = gr.Textbox( |
|
label="Code Execution Output", |
|
interactive=False, |
|
lines=10, |
|
show_copy_button=True |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
plot_output = gr.Plot( |
|
label="Generated Visualization", |
|
visible=False |
|
) |
|
image_output = gr.Image( |
|
label="Generated Plot", |
|
type="filepath", |
|
visible=False |
|
) |
|
|
|
|
|
inputs = [user_query_input, model_dropdown, temperature] |
|
outputs = [ |
|
final_answer_chat, |
|
agent_thoughts, |
|
generated_code, |
|
execution_output, |
|
plot_output, |
|
image_output |
|
] |
|
|
|
def process_results(chat, thoughts, code, output, plot_path): |
|
|
|
|
|
return [ |
|
chat, |
|
thoughts, |
|
code, |
|
output, |
|
gr.update(visible=plot_path is not None and os.path.exists(plot_path)), |
|
gr.update(value=plot_path if (plot_path and os.path.exists(plot_path)) else None, |
|
visible=plot_path is not None and os.path.exists(plot_path)) |
|
] |
|
|
|
|
|
click_event = submit_btn.click( |
|
fn=run_crewai_process, |
|
inputs=inputs, |
|
outputs=outputs, |
|
api_name="analyze" |
|
) |
|
|
|
|
|
click_event.then( |
|
fn=process_results, |
|
inputs=[final_answer_chat, agent_thoughts, generated_code, execution_output, image_output], |
|
outputs=outputs |
|
) |
|
|
|
return interface |
|
|
|
|
|
def main(): |
|
"""Run the Gradio interface.""" |
|
interface = create_interface() |
|
interface.launch(share=False, server_name="0.0.0.0", server_port=7860) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |