File size: 2,920 Bytes
c2d8087
 
98ae1f9
777ca73
98ae1f9
 
c2d8087
5455493
 
98ae1f9
5455493
98ae1f9
777ca73
5455493
98ae1f9
777ca73
a0ec0f2
941234a
 
98ae1f9
 
777ca73
 
98ae1f9
c2d8087
98ae1f9
 
777ca73
 
98ae1f9
 
 
777ca73
c2d8087
 
98ae1f9
 
 
 
777ca73
c2d8087
 
98ae1f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2d8087
 
 
777ca73
c2d8087
 
 
 
98ae1f9
 
c2d8087
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import streamlit as st
import torch
import requests
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login, HfApi

HF_TOKEN = os.getenv("Allie", None)

if HF_TOKEN:
    from huggingface_hub import login
    login(HF_TOKEN)


# Define model map with access type
model_map = {
    "FinGPT": {"id": "OpenFinAL/GPT2_FINGPT_QA", "local": True},
    "InvestLM": {"id": "yixuantt/InvestLM-mistral-AWQ", "local": False},
    "FinLLaMA": {"id": "us4/fin-llama3.1-8b", "local": False},
    "FinanceConnect": {"id": "ceadar-ie/FinanceConnect-13B", "local": True},
    "Sujet-Finance": {"id": "sujet-ai/Sujet-Finance-8B-v0.1", "local": True}
}

# Cache local models
@st.cache_resource
def load_local_model(model_id):
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map="auto" if torch.cuda.is_available() else None,
        use_auth_token=HF_TOKEN
    )
    return model, tokenizer

# Local model querying
def query_local_model(model_id, prompt):
    model, tokenizer = load_local_model(model_id)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs, max_new_tokens=150)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Remote model querying (via Inference API)
def query_remote_model(model_id, prompt):
    headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
    payload = {"inputs": prompt, "parameters": {"max_new_tokens": 150}}
    response = requests.post(
        f"https://api-inference.huggingface.co/models/{model_id}",
        headers=headers,
        json=payload
    )
    if response.status_code == 200:
        result = response.json()
        return result[0]["generated_text"] if isinstance(result, list) else result.get("generated_text", "No output")
    else:
        raise RuntimeError(f"Failed to call remote model: {response.text}")

# Unified query dispatcher
def query_model(model_entry, prompt):
    if model_entry["local"]:
        return query_local_model(model_entry["id"], prompt)
    else:
        return query_remote_model(model_entry["id"], prompt)

# --- Streamlit UI ---
st.title("💼 Financial LLM Evaluation Interface")

model_choice = st.selectbox("Select a Financial Model", list(model_map.keys()))
user_question = st.text_area("Enter your financial question:", "What is EBITDA?")

if st.button("Get Response"):
    with st.spinner("Generating response..."):
        try:
            model_entry = model_map[model_choice]
            answer = query_model(model_entry, user_question)
            st.subheader(f"Response from {model_choice}:")
            st.write(answer)
        except Exception as e:
            st.error(f"Something went wrong: {e}")