Spaces:
Running
Running
| import os | |
| import streamlit as st | |
| import yfinance as yf | |
| import pandas as pd | |
| from langchain.agents import create_csv_agent, AgentType | |
| from langchain.chat_models import ChatOpenAI | |
| from htmlTemplates import css, user_template, bot_template | |
| # Set OpenAI API Key | |
| os.environ['OPENAI_API_KEY'] = os.environ.get('OPENAI_API_KEY') | |
| # Initialize LangChain ChatOpenAI agent | |
| llm = ChatOpenAI( | |
| model='gpt-3.5-turbo', | |
| max_tokens=500, | |
| temperature=0.7, | |
| ) | |
| def init_ses_states(): | |
| st.session_state.setdefault('chat_history', []) | |
| def relative_returns(df): | |
| rel = df.pct_change() | |
| cumret = ((1 + rel).cumprod() - 1).fillna(0) | |
| return cumret | |
| def display_convo(): | |
| with st.container(): | |
| for i, message in enumerate(reversed(st.session_state.chat_history)): | |
| if i % 2 == 0: | |
| st.markdown(bot_template.replace("{{MSG}}", message), unsafe_allow_html=True) | |
| else: | |
| st.markdown(user_template.replace("{{MSG}}", message), unsafe_allow_html=True) | |
| def main(): | |
| st.set_page_config(page_title="Stock Price AI Bot", page_icon=":chart:") | |
| st.write(css, unsafe_allow_html=True) | |
| init_ses_states() | |
| st.title("Stock Price AI Bot") | |
| st.caption("Visualizations and OpenAI Chatbot for Multiple Stocks Over A Specified Period") | |
| with st.sidebar: | |
| asset_tickers = sorted(['DOW', 'NVDA', 'TSL', 'GOOGL', 'AMZN', 'AI', 'NIO', 'LCID', 'F', 'LYFY', 'AAPL', 'MSFT', 'BTC-USD', 'ETH-USD']) | |
| asset_dropdown = st.multiselect('Pick Assets:', asset_tickers) | |
| metric_tickers = ['Adj. Close', 'Relative Returns'] | |
| metric_dropdown = st.selectbox("Metric", metric_tickers) | |
| viz_tickers = ['Line Chart', 'Area Chart'] | |
| viz_dropdown = st.multiselect("Pick Charts:", viz_tickers) | |
| start = st.date_input('Start', value=pd.to_datetime('2023-01-01')) | |
| end = st.date_input('End', value=pd.to_datetime('today')) | |
| if len(asset_dropdown) > 0: | |
| df = yf.download(asset_dropdown, start, end)['Adj Close'] | |
| if metric_dropdown == 'Relative Returns': | |
| df = relative_returns(df) | |
| if len(viz_dropdown) > 0: | |
| with st.expander("Data Visualizations", expanded=True): | |
| if "Line Chart" in viz_dropdown: | |
| st.line_chart(df) | |
| if "Area Chart" in viz_dropdown: | |
| st.area_chart(df) | |
| st.header("Chat with your Data") | |
| query = st.text_input("Enter a query:") | |
| chat_prompt = f''' | |
| You are an AI ChatBot intended to help with user stock data. | |
| \nDATA MODE: {metric_dropdown} | |
| \nSTOCKS: {asset_dropdown} | |
| \nTIME PERIOD: {start} to {end} | |
| \nCHAT HISTORY: {st.session_state.chat_history} | |
| \nUSER MESSAGE: {query} | |
| \nAI RESPONSE HERE: | |
| ''' | |
| if st.button("Execute") and query: | |
| with st.spinner('Generating response...'): | |
| try: | |
| DF = pd.DataFrame(df) | |
| DF.to_csv('data.csv') | |
| agent = create_csv_agent( | |
| llm, | |
| 'data.csv', | |
| verbose=True, | |
| agent_type=AgentType.OPENAI_FUNCTIONS, | |
| ) | |
| answer = agent.run(chat_prompt) | |
| st.session_state.chat_history.append(f"USER: {query}\n") | |
| st.session_state.chat_history.append(f"AI: {answer}\n") | |
| display_convo() | |
| except Exception as e: | |
| st.error(f"An error occurred: {str(e)}") | |
| if __name__ == '__main__': | |
| main() | |