File size: 4,268 Bytes
af978d7
c4c99cc
 
 
 
 
 
 
 
e09f725
ee1c031
20d39a9
 
e09f725
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20d39a9
 
 
 
e09f725
20d39a9
 
 
 
 
 
 
 
 
 
 
 
 
e09f725
20d39a9
e09f725
20d39a9
 
 
 
e09f725
c4c99cc
20d39a9
 
 
 
 
 
 
 
 
c4c99cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e09f725
c4c99cc
 
 
20d39a9
c4c99cc
 
 
 
 
 
 
 
 
 
 
 
 
 
e09f725
c4c99cc
e09f725
c4c99cc
 
 
 
 
20d39a9
1b149d3
20d39a9
c4c99cc
 
e09f725
c4c99cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import streamlit as st
import yfinance as yf
import requests
import pandas as pd
from langchain.agents import initialize_agent, AgentType
from langchain.tools import Tool
from langchain_huggingface import HuggingFacePipeline
import os
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
from langchain.prompts import PromptTemplate

# Load environment variables
load_dotenv()
NEWSAPI_KEY = os.getenv("NEWSAPI_KEY")
access_token = os.getenv("API_KEY")

# Initialize model and pipeline
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=access_token)
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    torch_dtype=torch.bfloat16,
    token=access_token
)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Define improved prompt template
prompt_template = PromptTemplate(
    input_variables=["input"],
    template="""Answer the following question as best you can. You have access to the following tools:

Stock Data Fetcher(ticker) - Fetch recent stock data for a valid stock ticker symbol (e.g., AAPL for Apple).
Stock News Fetcher(ticker) - Fetch recent news articles about a stock ticker.
Moving Average Calculator(ticker, window=5) - Calculate the moving average of a stock over a 5-day window.

Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [Stock Data Fetcher, Stock News Fetcher, Moving Average Calculator]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

Strictly follow this format. Do not provide a Final Answer until all Observations are collected.

Begin!
Question: {input}
"""
)

# Helper functions
def fetch_stock_data(ticker):
    try:
        stock = yf.Ticker(ticker)
        hist = stock.history(period="1mo")
        if hist.empty:
            return {"error": f"No data found for ticker {ticker}"}
        return hist.tail(5).to_dict()
    except Exception as e:
        return {"error": str(e)}

def fetch_stock_news(ticker, NEWSAPI_KEY):
    api_url = f"https://newsapi.org/v2/everything?q={ticker}&apiKey={NEWSAPI_KEY}"
    response = requests.get(api_url)
    if response.status_code == 200:
        articles = response.json().get('articles', [])
        return [{"title": article['title'], "description": article['description']} for article in articles[:5]]
    else:
        return [{"error": "Unable to fetch news."}]

def calculate_moving_average(ticker, window=5):
    stock = yf.Ticker(ticker)
    hist = stock.history(period="1mo")
    hist[f"{window}-day MA"] = hist["Close"].rolling(window=window).mean()
    return hist[["Close", f"{window}-day MA"]].tail(5)

# Tools
stock_data_tool = Tool(
    name="Stock Data Fetcher",
    func=fetch_stock_data,
    description="Fetch recent stock data for a valid stock ticker symbol (e.g., AAPL for Apple)."
)

stock_news_tool = Tool(
    name="Stock News Fetcher",
    func=lambda ticker: fetch_stock_news(ticker, NEWSAPI_KEY),
    description="Fetch recent news articles about a stock ticker."
)

moving_average_tool = Tool(
    name="Moving Average Calculator",
    func=calculate_moving_average,
    description="Calculate the moving average of a stock over a 5-day window."
)

# Initialize agent
tools = [stock_data_tool, stock_news_tool, moving_average_tool]
llm = HuggingFacePipeline(pipeline=pipe)

agent = initialize_agent(
    tools=tools,
    llm=llm,
    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    prompt=prompt_template,
    verbose=True,
    handle_parsing_errors=True
)

# Streamlit interface
st.title("Trading Helper Agent")

query = st.text_input("Enter your query:")

if st.button("Submit"):
    if query:
        with st.spinner("Processing..."):
            try:
                response = agent.run(query)
                st.success("Response:")
                st.write(response)
            except Exception as e:
                st.error(f"An error occurred: {e}")
    else:
        st.warning("Please enter a query.")