Spaces:
Running
Running
File size: 4,630 Bytes
c2d8087 98ae1f9 777ca73 98ae1f9 78cc306 c21bd1c c2d8087 7e807e4 5455493 98ae1f9 777ca73 7e807e4 777ca73 d7b97aa 7e807e4 777ca73 7e807e4 c2d8087 98ae1f9 777ca73 78cc306 98ae1f9 777ca73 c2d8087 7e807e4 78cc306 7e807e4 78cc306 7e807e4 78cc306 7e807e4 98ae1f9 78cc306 7e807e4 78cc306 c2d8087 7e807e4 98ae1f9 7e807e4 98ae1f9 7e807e4 98ae1f9 7e807e4 78cc306 98ae1f9 7e807e4 98ae1f9 7e807e4 c2d8087 7e807e4 c2d8087 7e807e4 c2d8087 98ae1f9 7e807e4 c2d8087 78cc306 7e807e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import streamlit as st
import torch
import requests
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
@st.cache_resource
def load_fingpt_lora():
base_model_id = "meta-llama/Llama-2-7b-hf"
lora_adapter_id = "FinGPT/fingpt-mt_llama2-7b_lora"
tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_auth_token=HF_TOKEN)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
use_auth_token=HF_TOKEN
)
model = PeftModel.from_pretrained(base_model, lora_adapter_id, use_auth_token=HF_TOKEN)
return model, tokenizer
# Load token from Hugging Face Space secrets
HF_TOKEN = os.getenv("Allie", None)
if HF_TOKEN:
login(HF_TOKEN)
# === Available Models for Selection ===
model_map = {
"FinGPT LoRA" : {"id": "FinGPT/fingpt-mt_llama2-7b_lora", "local": True, "custom_loader": load_fingpt_lora},
"InvestLM (AWQ)": {"id": "yixuantt/InvestLM-mistral-AWQ", "local": False},
"FinLLaMA (LLaMA3.1-8B)": {"id": "us4/fin-llama3.1-8b", "local": False},
"FinanceConnect (13B)": {"id": "ceadar-ie/FinanceConnect-13B", "local": True},
"Sujet-Finance (8B)": {"id": "sujet-ai/Sujet-Finance-8B-v0.1", "local": True}
}
# === Load local models with caching ===
@st.cache_resource
def load_local_model(model_id):
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
use_auth_token=HF_TOKEN
)
return model, tokenizer
# === Build system prompt for discursive answers ===
def build_prompt(user_question):
return (
"You are FinGPT, a helpful and knowledgeable financial assistant. "
"You explain finance, controlling, and tax topics clearly, with examples when useful.\n\n"
f"User: {user_question.strip()}\n"
"FinGPT:"
)
# === Clean repeated/extra outputs ===
def clean_output(output_text):
parts = output_text.split("FinGPT:")
return parts[-1].strip() if len(parts) > 1 else output_text.strip()
# === Generate with local model ===
def query_local_model(model_id, prompt):
model, tokenizer = load_local_model(model_id)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=300,
temperature=0.7,
top_k=50,
top_p=0.95,
repetition_penalty=1.2,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
return clean_output(raw_output)
# === Generate with remote HF API ===
def query_remote_model(model_id, prompt):
headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
payload = {"inputs": prompt, "parameters": {"max_new_tokens": 300}}
response = requests.post(
f"https://api-inference.huggingface.co/models/{model_id}",
headers=headers,
json=payload
)
if response.status_code == 200:
result = response.json()
return result[0]["generated_text"] if isinstance(result, list) else result.get("generated_text", "No output")
else:
raise RuntimeError(f"API Error {response.status_code}: {response.text}")
# === Unified model query handler ===
def query_model(model_entry, user_question):
prompt = build_prompt(user_question)
if model_entry["local"]:
return query_local_model(model_entry["id"], prompt)
else:
return clean_output(query_remote_model(model_entry["id"], prompt))
# === Streamlit UI Layout ===
st.set_page_config(page_title="Finance LLM Comparison", layout="centered")
st.title("💼 Financial LLM Evaluation Interface")
model_choice = st.selectbox("Select a Financial Model", list(model_map.keys()))
user_question = st.text_area("Enter your financial question:", "What is EBIT vs EBITDA?", height=150)
if st.button("Get Response"):
with st.spinner("Thinking like a CFO..."):
try:
model_entry = model_map[model_choice]
answer = query_model(model_entry, user_question)
st.text_area("💬 Response:", value=answer, height=300, disabled=True)
except Exception as e:
st.error(f"❌ Error: {e}")
|