Spaces:
Running
Running
File size: 2,920 Bytes
c2d8087 98ae1f9 777ca73 98ae1f9 c2d8087 5455493 98ae1f9 5455493 98ae1f9 777ca73 5455493 98ae1f9 777ca73 a0ec0f2 941234a 98ae1f9 777ca73 98ae1f9 c2d8087 98ae1f9 777ca73 98ae1f9 777ca73 c2d8087 98ae1f9 777ca73 c2d8087 98ae1f9 c2d8087 777ca73 c2d8087 98ae1f9 c2d8087 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import streamlit as st
import torch
import requests
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login, HfApi
HF_TOKEN = os.getenv("Allie", None)
if HF_TOKEN:
from huggingface_hub import login
login(HF_TOKEN)
# Define model map with access type
model_map = {
"FinGPT": {"id": "OpenFinAL/GPT2_FINGPT_QA", "local": True},
"InvestLM": {"id": "yixuantt/InvestLM-mistral-AWQ", "local": False},
"FinLLaMA": {"id": "us4/fin-llama3.1-8b", "local": False},
"FinanceConnect": {"id": "ceadar-ie/FinanceConnect-13B", "local": True},
"Sujet-Finance": {"id": "sujet-ai/Sujet-Finance-8B-v0.1", "local": True}
}
# Cache local models
@st.cache_resource
def load_local_model(model_id):
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
use_auth_token=HF_TOKEN
)
return model, tokenizer
# Local model querying
def query_local_model(model_id, prompt):
model, tokenizer = load_local_model(model_id)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=150)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remote model querying (via Inference API)
def query_remote_model(model_id, prompt):
headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
payload = {"inputs": prompt, "parameters": {"max_new_tokens": 150}}
response = requests.post(
f"https://api-inference.huggingface.co/models/{model_id}",
headers=headers,
json=payload
)
if response.status_code == 200:
result = response.json()
return result[0]["generated_text"] if isinstance(result, list) else result.get("generated_text", "No output")
else:
raise RuntimeError(f"Failed to call remote model: {response.text}")
# Unified query dispatcher
def query_model(model_entry, prompt):
if model_entry["local"]:
return query_local_model(model_entry["id"], prompt)
else:
return query_remote_model(model_entry["id"], prompt)
# --- Streamlit UI ---
st.title("💼 Financial LLM Evaluation Interface")
model_choice = st.selectbox("Select a Financial Model", list(model_map.keys()))
user_question = st.text_area("Enter your financial question:", "What is EBITDA?")
if st.button("Get Response"):
with st.spinner("Generating response..."):
try:
model_entry = model_map[model_choice]
answer = query_model(model_entry, user_question)
st.subheader(f"Response from {model_choice}:")
st.write(answer)
except Exception as e:
st.error(f"Something went wrong: {e}")
|