Spaces:
Running
Running
File size: 4,268 Bytes
af978d7 c4c99cc e09f725 ee1c031 20d39a9 e09f725 20d39a9 e09f725 20d39a9 e09f725 20d39a9 e09f725 20d39a9 e09f725 c4c99cc 20d39a9 c4c99cc e09f725 c4c99cc 20d39a9 c4c99cc e09f725 c4c99cc e09f725 c4c99cc 20d39a9 1b149d3 20d39a9 c4c99cc e09f725 c4c99cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import streamlit as st
import yfinance as yf
import requests
import pandas as pd
from langchain.agents import initialize_agent, AgentType
from langchain.tools import Tool
from langchain_huggingface import HuggingFacePipeline
import os
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
from langchain.prompts import PromptTemplate
# Load environment variables
load_dotenv()
NEWSAPI_KEY = os.getenv("NEWSAPI_KEY")
access_token = os.getenv("API_KEY")
# Initialize model and pipeline
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=access_token)
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b-it",
torch_dtype=torch.bfloat16,
token=access_token
)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
# Define improved prompt template
prompt_template = PromptTemplate(
input_variables=["input"],
template="""Answer the following question as best you can. You have access to the following tools:
Stock Data Fetcher(ticker) - Fetch recent stock data for a valid stock ticker symbol (e.g., AAPL for Apple).
Stock News Fetcher(ticker) - Fetch recent news articles about a stock ticker.
Moving Average Calculator(ticker, window=5) - Calculate the moving average of a stock over a 5-day window.
Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [Stock Data Fetcher, Stock News Fetcher, Moving Average Calculator]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Strictly follow this format. Do not provide a Final Answer until all Observations are collected.
Begin!
Question: {input}
"""
)
# Helper functions
def fetch_stock_data(ticker):
try:
stock = yf.Ticker(ticker)
hist = stock.history(period="1mo")
if hist.empty:
return {"error": f"No data found for ticker {ticker}"}
return hist.tail(5).to_dict()
except Exception as e:
return {"error": str(e)}
def fetch_stock_news(ticker, NEWSAPI_KEY):
api_url = f"https://newsapi.org/v2/everything?q={ticker}&apiKey={NEWSAPI_KEY}"
response = requests.get(api_url)
if response.status_code == 200:
articles = response.json().get('articles', [])
return [{"title": article['title'], "description": article['description']} for article in articles[:5]]
else:
return [{"error": "Unable to fetch news."}]
def calculate_moving_average(ticker, window=5):
stock = yf.Ticker(ticker)
hist = stock.history(period="1mo")
hist[f"{window}-day MA"] = hist["Close"].rolling(window=window).mean()
return hist[["Close", f"{window}-day MA"]].tail(5)
# Tools
stock_data_tool = Tool(
name="Stock Data Fetcher",
func=fetch_stock_data,
description="Fetch recent stock data for a valid stock ticker symbol (e.g., AAPL for Apple)."
)
stock_news_tool = Tool(
name="Stock News Fetcher",
func=lambda ticker: fetch_stock_news(ticker, NEWSAPI_KEY),
description="Fetch recent news articles about a stock ticker."
)
moving_average_tool = Tool(
name="Moving Average Calculator",
func=calculate_moving_average,
description="Calculate the moving average of a stock over a 5-day window."
)
# Initialize agent
tools = [stock_data_tool, stock_news_tool, moving_average_tool]
llm = HuggingFacePipeline(pipeline=pipe)
agent = initialize_agent(
tools=tools,
llm=llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
prompt=prompt_template,
verbose=True,
handle_parsing_errors=True
)
# Streamlit interface
st.title("Trading Helper Agent")
query = st.text_input("Enter your query:")
if st.button("Submit"):
if query:
with st.spinner("Processing..."):
try:
response = agent.run(query)
st.success("Response:")
st.write(response)
except Exception as e:
st.error(f"An error occurred: {e}")
else:
st.warning("Please enter a query.")
|