Spaces:
Running
Running
import os, json, time, random | |
from collections import defaultdict | |
from datetime import date, datetime, timedelta | |
from dotenv import load_dotenv | |
import gradio as gr | |
import pandas as pd | |
import finnhub | |
from openai import OpenAI | |
from io import StringIO | |
import requests | |
# Load environment variables from .env file | |
load_dotenv() | |
# ---------- 0 CONFIG --------------------------------------------------------- | |
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") | |
FINNHUB_KEY = os.getenv("FINNHUB_API_KEY") | |
ALPHA_KEY = os.getenv("ALPHAVANTAGE_API_KEY") | |
if not FINNHUB_KEY: | |
raise RuntimeError("FINNHUB_API_KEY not set") | |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
finnhub_client = finnhub.Client(api_key=FINNHUB_KEY) | |
SYSTEM_PROMPT = ( | |
"You are a seasoned stock-market analyst. " | |
"Given recent company news and optional basic financials, " | |
"return:\n" | |
"[Positive Developments] – 2-4 bullets\n" | |
"[Potential Concerns] – 2-4 bullets\n" | |
"[Prediction & Analysis] – a one-week price outlook with rationale." | |
) | |
# ---------- 1 DATE / UTILITY HELPERS ---------------------------------------- | |
def today() -> str: | |
return date.today().strftime("%Y-%m-%d") | |
def n_weeks_before(date_string: str, n: int) -> str: | |
return (datetime.strptime(date_string, "%Y-%m-%d") - | |
timedelta(days=7 * n)).strftime("%Y-%m-%d") | |
# ---------- 2 DATA FETCHING -------------------------------------------------- | |
def get_stock_data(symbol: str, steps: list[str]) -> pd.DataFrame: | |
if not ALPHA_KEY: | |
raise RuntimeError("ALPHAVANTAGE_API_KEY is Missing") | |
# 免费端点:TIME_SERIES_DAILY :contentReference[oaicite:8]{index=8} | |
url = ( | |
"https://www.alphavantage.co/query" | |
"?function=TIME_SERIES_DAILY" | |
f"&symbol={symbol}" | |
f"&apikey={ALPHA_KEY}" | |
"&datatype=csv" | |
"&outputsize=full" | |
) | |
# 重试 3 次 | |
text = None | |
for attempt in range(3): | |
resp = requests.get(url, timeout=10) | |
if not resp.ok: | |
time.sleep(1) | |
continue | |
text = resp.text.strip() | |
if text.startswith("{"): | |
info = resp.json() | |
msg = info.get("Note") or info.get("Error Message") or str(info) | |
raise RuntimeError(f"Alpha Vantage Return Error:{msg}") | |
break | |
if not text: | |
raise RuntimeError(f"Alpha Vantage Connection Error:{url}") | |
df = pd.read_csv(StringIO(text)) | |
date_col = "timestamp" if "timestamp" in df.columns else df.columns[0] | |
df[date_col] = pd.to_datetime(df[date_col]) | |
df = df.sort_values(date_col).set_index(date_col) | |
data = {"Start Date": [], "End Date": [], "Start Price": [], "End Price": []} | |
for i in range(len(steps) - 1): | |
s_date = pd.to_datetime(steps[i]) | |
e_date = pd.to_datetime(steps[i+1]) | |
seg = df.loc[s_date:e_date] | |
if seg.empty: | |
raise RuntimeError( | |
f"Alpha Vantage 无法获取 {symbol} 在 {steps[i]} – {steps[i+1]} 的数据" | |
) | |
data["Start Date"].append(seg.index[0]) | |
data["Start Price"].append(seg["close"].iloc[0]) | |
data["End Date"].append(seg.index[-1]) | |
data["End Price"].append(seg["close"].iloc[-1]) | |
# Limits:5 times/min | |
time.sleep(12) | |
return pd.DataFrame(data) | |
def current_basics(symbol: str, curday: str) -> dict: | |
raw = finnhub_client.company_basic_financials(symbol, "all") | |
if not raw["series"]: | |
return {} | |
merged = defaultdict(dict) | |
for metric, vals in raw["series"]["quarterly"].items(): | |
for v in vals: | |
merged[v["period"]][metric] = v["v"] | |
latest = max((p for p in merged if p <= curday), default=None) | |
if latest is None: | |
return {} | |
d = dict(merged[latest]) | |
d["period"] = latest | |
return d | |
def attach_news(symbol: str, df: pd.DataFrame) -> pd.DataFrame: | |
news_col = [] | |
for _, row in df.iterrows(): | |
start = row["Start Date"].strftime("%Y-%m-%d") | |
end = row["End Date"].strftime("%Y-%m-%d") | |
time.sleep(1) # Finnhub QPM guard | |
weekly = finnhub_client.company_news(symbol, _from=start, to=end) | |
weekly_fmt = [ | |
{ | |
"date" : datetime.fromtimestamp(n["datetime"]).strftime("%Y%m%d%H%M%S"), | |
"headline": n["headline"], | |
"summary" : n["summary"], | |
} | |
for n in weekly | |
] | |
weekly_fmt.sort(key=lambda x: x["date"]) | |
news_col.append(json.dumps(weekly_fmt)) | |
df["News"] = news_col | |
return df | |
# ---------- 3 PROMPT CONSTRUCTION ------------------------------------------- | |
def sample_news(news: list[str], k: int = 5) -> list[str]: | |
if len(news) <= k: return news | |
return [news[i] for i in sorted(random.sample(range(len(news)), k))] | |
def make_prompt(symbol: str, df: pd.DataFrame, curday: str, use_basics=False) -> str: | |
# Company profile | |
prof = finnhub_client.company_profile2(symbol=symbol) | |
company_blurb = ( | |
f"[Company Introduction]:\n{prof['name']} operates in the " | |
f"{prof['finnhubIndustry']} sector ({prof['country']}). " | |
f"Founded {prof['ipo']}, market cap {prof['marketCapitalization']:.1f} " | |
f"{prof['currency']}; ticker {symbol} on {prof['exchange']}.\n" | |
) | |
# Past weeks block | |
past_block = "" | |
for _, row in df.iterrows(): | |
term = "increased" if row["End Price"] > row["Start Price"] else "decreased" | |
head = (f"From {row['Start Date']:%Y-%m-%d} to {row['End Date']:%Y-%m-%d}, " | |
f"{symbol}'s stock price {term} from " | |
f"{row['Start Price']:.2f} to {row['End Price']:.2f}.") | |
news_items = json.loads(row["News"]) | |
summaries = [ | |
f"[Headline] {n['headline']}\n[Summary] {n['summary']}\n" | |
for n in news_items | |
if not n["summary"].startswith("Looking for stock market analysis") | |
] | |
past_block += "\n" + head + "\n" + "".join(sample_news(summaries, 5)) | |
# Optional basic financials | |
if use_basics: | |
basics = current_basics(symbol, curday) | |
if basics: | |
basics_txt = "\n".join(f"{k}: {v}" for k, v in basics.items() if k != "period") | |
basics_block = (f"\n[Basic Financials] (reported {basics['period']}):\n{basics_txt}\n") | |
else: | |
basics_block = "\n[Basic Financials]: not available\n" | |
else: | |
basics_block = "\n[Basic Financials]: not requested\n" | |
horizon = f"{curday} to {n_weeks_before(curday, -1)}" | |
final_user_msg = ( | |
company_blurb | |
+ past_block | |
+ basics_block | |
+ f"\nBased on all information before {curday}, analyse positive " | |
"developments and potential concerns for {symbol}, then predict its " | |
f"price movement for next week ({horizon})." | |
) | |
return final_user_msg | |
# ---------- 4 LLM CALL ------------------------------------------------------- | |
def chat_completion(prompt: str, | |
model: str = OPENAI_MODEL, | |
temperature: float = 0.3, | |
stream: bool = False) -> str: | |
response = client.chat.completions.create( | |
model=model, | |
temperature=temperature, | |
stream=stream, | |
messages=[ | |
{"role": "system", "content": SYSTEM_PROMPT}, | |
{"role": "user", "content": prompt} | |
], | |
) | |
if stream: | |
collected = [] | |
for chunk in response: | |
delta = chunk.choices[0].delta.content or "" | |
print(delta, end="", flush=True) | |
collected.append(delta) | |
print() | |
return "".join(collected) | |
# without stream | |
return response.choices[0].message.content | |
# ---------- 5 MAIN ENTRY (CLI test) ----------------------------------------- | |
def predict(symbol: str = "AAPL", | |
curday: str = today(), | |
n_weeks: int = 3, | |
use_basics: bool = False, | |
stream: bool = False) -> tuple[str, str]: | |
steps = [n_weeks_before(curday, n) for n in range(n_weeks + 1)][::-1] | |
df = get_stock_data(symbol, steps) | |
df = attach_news(symbol, df) | |
prompt_info = make_prompt(symbol, df, curday, use_basics) | |
answer = chat_completion(prompt_info, stream=stream) | |
return prompt_info, answer | |
# ---------- 6 SETUP HF ----------------------------------------- | |
def hf_predict(symbol, n_weeks, use_basics): | |
# 1. get curday | |
curday = date.today().strftime("%Y-%m-%d") | |
# 2. call predict | |
prompt, answer = predict( | |
symbol=symbol.upper(), | |
curday=curday, | |
n_weeks=int(n_weeks), | |
use_basics=bool(use_basics), | |
stream=False | |
) | |
return prompt, answer | |
with gr.Blocks() as demo: | |
gr.Markdown("FinRobot_Forecaster") | |
with gr.Row(): | |
symbol = gr.Textbox(label="Ticker(eg. AAPL)", value="AAPL") | |
n_weeks = gr.Slider(1, 6, value=3, step=1, label="Trace Back Weeks") | |
use_basics = gr.Checkbox(label="Add Basic Financials", value=False) | |
output_prompt = gr.Textbox(label="Model Prompt", lines=8) | |
output_answer = gr.Textbox(label="Model Output", lines=12) | |
btn = gr.Button("Run Forecaster") | |
btn.click(fn=hf_predict, | |
inputs=[symbol, n_weeks, use_basics], | |
outputs=[output_prompt, output_answer]) | |
if __name__ == "__main__": | |
demo.launch() |