4lli39421 commited on
Commit
78cc306
·
verified ·
1 Parent(s): a0ec0f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -21
app.py CHANGED
@@ -3,16 +3,14 @@ import torch
3
  import requests
4
  import os
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
- from huggingface_hub import login, HfApi
7
 
 
8
  HF_TOKEN = os.getenv("Allie", None)
9
-
10
  if HF_TOKEN:
11
- from huggingface_hub import login
12
  login(HF_TOKEN)
13
 
14
-
15
- # Define model map with access type
16
  model_map = {
17
  "FinGPT": {"id": "OpenFinAL/GPT2_FINGPT_QA", "local": True},
18
  "InvestLM": {"id": "yixuantt/InvestLM-mistral-AWQ", "local": False},
@@ -21,29 +19,54 @@ model_map = {
21
  "Sujet-Finance": {"id": "sujet-ai/Sujet-Finance-8B-v0.1", "local": True}
22
  }
23
 
24
- # Cache local models
25
  @st.cache_resource
26
  def load_local_model(model_id):
27
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
28
  model = AutoModelForCausalLM.from_pretrained(
29
  model_id,
30
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
31
  device_map="auto" if torch.cuda.is_available() else None,
32
  use_auth_token=HF_TOKEN
33
  )
34
  return model, tokenizer
35
 
36
- # Local model querying
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def query_local_model(model_id, prompt):
38
  model, tokenizer = load_local_model(model_id)
39
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
40
- outputs = model.generate(**inputs, max_new_tokens=150)
41
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # Remote model querying (via Inference API)
44
  def query_remote_model(model_id, prompt):
45
  headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
46
- payload = {"inputs": prompt, "parameters": {"max_new_tokens": 150}}
47
  response = requests.post(
48
  f"https://api-inference.huggingface.co/models/{model_id}",
49
  headers=headers,
@@ -53,27 +76,30 @@ def query_remote_model(model_id, prompt):
53
  result = response.json()
54
  return result[0]["generated_text"] if isinstance(result, list) else result.get("generated_text", "No output")
55
  else:
56
- raise RuntimeError(f"Failed to call remote model: {response.text}")
57
 
58
- # Unified query dispatcher
59
- def query_model(model_entry, prompt):
 
60
  if model_entry["local"]:
61
  return query_local_model(model_entry["id"], prompt)
62
  else:
63
  return query_remote_model(model_entry["id"], prompt)
64
 
65
- # --- Streamlit UI ---
 
66
  st.title("💼 Financial LLM Evaluation Interface")
67
 
68
  model_choice = st.selectbox("Select a Financial Model", list(model_map.keys()))
69
- user_question = st.text_area("Enter your financial question:", "What is EBITDA?")
70
 
71
  if st.button("Get Response"):
72
- with st.spinner("Generating response..."):
73
  try:
74
  model_entry = model_map[model_choice]
75
  answer = query_model(model_entry, user_question)
76
- st.subheader(f"Response from {model_choice}:")
77
- st.write(answer)
78
  except Exception as e:
79
- st.error(f"Something went wrong: {e}")
 
 
3
  import requests
4
  import os
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from huggingface_hub import login
7
 
8
+ # Load Hugging Face token from secrets
9
  HF_TOKEN = os.getenv("Allie", None)
 
10
  if HF_TOKEN:
 
11
  login(HF_TOKEN)
12
 
13
+ # All available models
 
14
  model_map = {
15
  "FinGPT": {"id": "OpenFinAL/GPT2_FINGPT_QA", "local": True},
16
  "InvestLM": {"id": "yixuantt/InvestLM-mistral-AWQ", "local": False},
 
19
  "Sujet-Finance": {"id": "sujet-ai/Sujet-Finance-8B-v0.1", "local": True}
20
  }
21
 
22
+ # Load local model
23
  @st.cache_resource
24
  def load_local_model(model_id):
25
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
26
  model = AutoModelForCausalLM.from_pretrained(
27
  model_id,
28
+ torch_dtype=torch.float32,
29
  device_map="auto" if torch.cuda.is_available() else None,
30
  use_auth_token=HF_TOKEN
31
  )
32
  return model, tokenizer
33
 
34
+ # Build discursive prompt
35
+ def build_prompt(user_question):
36
+ return (
37
+ "You are a helpful and knowledgeable financial assistant named FinGPT. "
38
+ "You explain financial terms and concepts clearly, with examples when useful.\n\n"
39
+ f"User: {user_question.strip()}\n"
40
+ "FinGPT:"
41
+ )
42
+
43
+ # Clean up repeated parts
44
+ def clean_output(output_text):
45
+ parts = output_text.split("FinGPT:")
46
+ return parts[-1].strip() if len(parts) > 1 else output_text.strip()
47
+
48
+ # Local inference
49
  def query_local_model(model_id, prompt):
50
  model, tokenizer = load_local_model(model_id)
51
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
52
+ outputs = model.generate(
53
+ **inputs,
54
+ max_new_tokens=200,
55
+ temperature=0.7,
56
+ top_k=50,
57
+ top_p=0.95,
58
+ repetition_penalty=1.2,
59
+ do_sample=True,
60
+ pad_token_id=tokenizer.eos_token_id,
61
+ eos_token_id=tokenizer.eos_token_id
62
+ )
63
+ raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
64
+ return clean_output(raw_output)
65
 
66
+ # Remote inference
67
  def query_remote_model(model_id, prompt):
68
  headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
69
+ payload = {"inputs": prompt, "parameters": {"max_new_tokens": 200}}
70
  response = requests.post(
71
  f"https://api-inference.huggingface.co/models/{model_id}",
72
  headers=headers,
 
76
  result = response.json()
77
  return result[0]["generated_text"] if isinstance(result, list) else result.get("generated_text", "No output")
78
  else:
79
+ raise RuntimeError(f"API Error: {response.status_code} {response.text}")
80
 
81
+ # Unified query handler
82
+ def query_model(model_entry, user_question):
83
+ prompt = build_prompt(user_question)
84
  if model_entry["local"]:
85
  return query_local_model(model_entry["id"], prompt)
86
  else:
87
  return query_remote_model(model_entry["id"], prompt)
88
 
89
+ # Streamlit UI
90
+ st.set_page_config(page_title="Financial LLM Interface", layout="centered")
91
  st.title("💼 Financial LLM Evaluation Interface")
92
 
93
  model_choice = st.selectbox("Select a Financial Model", list(model_map.keys()))
94
+ user_question = st.text_area("Enter your financial question:", "What is CAP in finance?")
95
 
96
  if st.button("Get Response"):
97
+ with st.spinner("Generating discursive response..."):
98
  try:
99
  model_entry = model_map[model_choice]
100
  answer = query_model(model_entry, user_question)
101
+ st.markdown("### 🧠 Response:")
102
+ st.markdown(f"```text\n{answer}\n```")
103
  except Exception as e:
104
+ st.error(f" Error: {e}")
105
+