4lli39421 commited on
Commit
7e807e4
·
verified ·
1 Parent(s): ee61e04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -26
app.py CHANGED
@@ -5,21 +5,21 @@ import os
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  from huggingface_hub import login
7
 
8
- # Load Hugging Face token from secrets
9
  HF_TOKEN = os.getenv("Allie", None)
10
  if HF_TOKEN:
11
  login(HF_TOKEN)
12
 
13
- # All available models
14
  model_map = {
15
- "FinGPT": {"id": "OpenFinAL/GPT2_FINGPT_QA", "local": True},
16
- "InvestLM": {"id": "yixuantt/InvestLM-mistral-AWQ", "local": False},
17
- "FinLLaMA": {"id": "us4/fin-llama3.1-8b", "local": False},
18
- "FinanceConnect": {"id": "ceadar-ie/FinanceConnect-13B", "local": True},
19
- "Sujet-Finance": {"id": "sujet-ai/Sujet-Finance-8B-v0.1", "local": True}
20
  }
21
 
22
- # Load local model
23
  @st.cache_resource
24
  def load_local_model(model_id):
25
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
@@ -31,27 +31,27 @@ def load_local_model(model_id):
31
  )
32
  return model, tokenizer
33
 
34
- # Build discursive prompt
35
  def build_prompt(user_question):
36
  return (
37
- "You are a helpful and knowledgeable financial assistant named FinGPT. "
38
- "You explain financial terms and concepts clearly, with examples when useful.\n\n"
39
  f"User: {user_question.strip()}\n"
40
  "FinGPT:"
41
  )
42
 
43
- # Clean up repeated parts
44
  def clean_output(output_text):
45
  parts = output_text.split("FinGPT:")
46
  return parts[-1].strip() if len(parts) > 1 else output_text.strip()
47
 
48
- # Local inference
49
  def query_local_model(model_id, prompt):
50
  model, tokenizer = load_local_model(model_id)
51
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
52
  outputs = model.generate(
53
  **inputs,
54
- max_new_tokens=200,
55
  temperature=0.7,
56
  top_k=50,
57
  top_p=0.95,
@@ -63,10 +63,10 @@ def query_local_model(model_id, prompt):
63
  raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
64
  return clean_output(raw_output)
65
 
66
- # Remote inference
67
  def query_remote_model(model_id, prompt):
68
  headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
69
- payload = {"inputs": prompt, "parameters": {"max_new_tokens": 200}}
70
  response = requests.post(
71
  f"https://api-inference.huggingface.co/models/{model_id}",
72
  headers=headers,
@@ -76,31 +76,30 @@ def query_remote_model(model_id, prompt):
76
  result = response.json()
77
  return result[0]["generated_text"] if isinstance(result, list) else result.get("generated_text", "No output")
78
  else:
79
- raise RuntimeError(f"API Error: {response.status_code} {response.text}")
80
 
81
- # Unified query handler
82
  def query_model(model_entry, user_question):
83
  prompt = build_prompt(user_question)
84
  if model_entry["local"]:
85
  return query_local_model(model_entry["id"], prompt)
86
  else:
87
- return query_remote_model(model_entry["id"], prompt)
88
 
89
- # Streamlit UI
90
- st.set_page_config(page_title="Financial LLM Interface", layout="centered")
91
  st.title("💼 Financial LLM Evaluation Interface")
92
 
93
  model_choice = st.selectbox("Select a Financial Model", list(model_map.keys()))
94
- user_question = st.text_area("Enter your financial question:", "What is CAP in finance?")
95
 
96
  if st.button("Get Response"):
97
- with st.spinner("Generating discursive response..."):
98
  try:
99
  model_entry = model_map[model_choice]
100
  answer = query_model(model_entry, user_question)
101
- st.markdown("### 🧠 Response:")
102
- st.text_area("💬 Response from FinGPT:", value=answer, height=200, disabled=True)
103
-
104
  except Exception as e:
105
  st.error(f"❌ Error: {e}")
106
 
 
 
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  from huggingface_hub import login
7
 
8
+ # Load token from Hugging Face Space secrets
9
  HF_TOKEN = os.getenv("Allie", None)
10
  if HF_TOKEN:
11
  login(HF_TOKEN)
12
 
13
+ # === Available Models for Selection ===
14
  model_map = {
15
+ "FinGPT (GPT2)": {"id": "OpenFinAL/GPT2_FINGPT_QA", "local": True},
16
+ "InvestLM (AWQ)": {"id": "yixuantt/InvestLM-mistral-AWQ", "local": False},
17
+ "FinLLaMA (LLaMA3.1-8B)": {"id": "us4/fin-llama3.1-8b", "local": False},
18
+ "FinanceConnect (13B)": {"id": "ceadar-ie/FinanceConnect-13B", "local": True},
19
+ "Sujet-Finance (8B)": {"id": "sujet-ai/Sujet-Finance-8B-v0.1", "local": True}
20
  }
21
 
22
+ # === Load local models with caching ===
23
  @st.cache_resource
24
  def load_local_model(model_id):
25
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
 
31
  )
32
  return model, tokenizer
33
 
34
+ # === Build system prompt for discursive answers ===
35
  def build_prompt(user_question):
36
  return (
37
+ "You are FinGPT, a helpful and knowledgeable financial assistant. "
38
+ "You explain finance, controlling, and tax topics clearly, with examples when useful.\n\n"
39
  f"User: {user_question.strip()}\n"
40
  "FinGPT:"
41
  )
42
 
43
+ # === Clean repeated/extra outputs ===
44
  def clean_output(output_text):
45
  parts = output_text.split("FinGPT:")
46
  return parts[-1].strip() if len(parts) > 1 else output_text.strip()
47
 
48
+ # === Generate with local model ===
49
  def query_local_model(model_id, prompt):
50
  model, tokenizer = load_local_model(model_id)
51
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
52
  outputs = model.generate(
53
  **inputs,
54
+ max_new_tokens=300,
55
  temperature=0.7,
56
  top_k=50,
57
  top_p=0.95,
 
63
  raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
64
  return clean_output(raw_output)
65
 
66
+ # === Generate with remote HF API ===
67
  def query_remote_model(model_id, prompt):
68
  headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
69
+ payload = {"inputs": prompt, "parameters": {"max_new_tokens": 300}}
70
  response = requests.post(
71
  f"https://api-inference.huggingface.co/models/{model_id}",
72
  headers=headers,
 
76
  result = response.json()
77
  return result[0]["generated_text"] if isinstance(result, list) else result.get("generated_text", "No output")
78
  else:
79
+ raise RuntimeError(f"API Error {response.status_code}: {response.text}")
80
 
81
+ # === Unified model query handler ===
82
  def query_model(model_entry, user_question):
83
  prompt = build_prompt(user_question)
84
  if model_entry["local"]:
85
  return query_local_model(model_entry["id"], prompt)
86
  else:
87
+ return clean_output(query_remote_model(model_entry["id"], prompt))
88
 
89
+ # === Streamlit UI Layout ===
90
+ st.set_page_config(page_title="Finance LLM Comparison", layout="centered")
91
  st.title("💼 Financial LLM Evaluation Interface")
92
 
93
  model_choice = st.selectbox("Select a Financial Model", list(model_map.keys()))
94
+ user_question = st.text_area("Enter your financial question:", "What is EBIT vs EBITDA?", height=150)
95
 
96
  if st.button("Get Response"):
97
+ with st.spinner("Thinking like a CFO..."):
98
  try:
99
  model_entry = model_map[model_choice]
100
  answer = query_model(model_entry, user_question)
101
+ st.text_area("💬 Response:", value=answer, height=300, disabled=True)
 
 
102
  except Exception as e:
103
  st.error(f"❌ Error: {e}")
104
 
105
+