4lli39421 commited on
Commit
98ae1f9
·
verified ·
1 Parent(s): 777ca73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -16
app.py CHANGED
@@ -1,35 +1,66 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- from huggingface_hub import login
4
  import torch
 
5
  import os
 
 
6
 
 
 
 
 
7
 
8
- # Set model map
9
  model_map = {
10
- "FinGPT": "AI4Finance/FinGPT",
11
- "FinanceConnect": "ceadar-ie/FinanceConnect-13B",
12
- "Sujet-Finance": "sujet-ai/Sujet-Finance-8B-v0.1"
 
 
13
  }
14
 
15
- # Cache model loading for performance
16
  @st.cache_resource
17
- def load_model_and_tokenizer(model_id):
18
- tokenizer = AutoTokenizer.from_pretrained(model_id)
19
  model = AutoModelForCausalLM.from_pretrained(
20
  model_id,
21
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
 
 
22
  )
23
  return model, tokenizer
24
 
25
- # Query model
26
- def query_model(model_id, question):
27
- model, tokenizer = load_model_and_tokenizer(model_id)
28
- inputs = tokenizer(question, return_tensors="pt")
29
  outputs = model.generate(**inputs, max_new_tokens=150)
30
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
31
 
32
- # Streamlit app layout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  st.title("💼 Financial LLM Evaluation Interface")
34
 
35
  model_choice = st.selectbox("Select a Financial Model", list(model_map.keys()))
@@ -38,7 +69,8 @@ user_question = st.text_area("Enter your financial question:", "What is EBITDA?"
38
  if st.button("Get Response"):
39
  with st.spinner("Generating response..."):
40
  try:
41
- answer = query_model(model_map[model_choice], user_question)
 
42
  st.subheader(f"Response from {model_choice}:")
43
  st.write(answer)
44
  except Exception as e:
 
1
  import streamlit as st
 
 
2
  import torch
3
+ import requests
4
  import os
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from huggingface_hub import login, HfApi
7
 
8
+ # Optional: Login if you want access to gated/private models
9
+ HF_TOKEN = os.getenv("HF_TOKEN", None)
10
+ if HF_TOKEN:
11
+ login(HF_TOKEN)
12
 
13
+ # Define model map with access type
14
  model_map = {
15
+ "FinGPT": {"id": "AI4Finance/FinGPT", "local": True},
16
+ "InvestLM": {"id": "mrm8488/investLM-7B", "local": False}, # example ID, update if needed
17
+ "FinLLaMA": {"id": "HuggingFaceH4/fin-llama", "local": False},
18
+ "FinanceConnect": {"id": "ceadar-ie/FinanceConnect-13B", "local": True},
19
+ "Sujet-Finance": {"id": "sujet-ai/Sujet-Finance-8B-v0.1", "local": True}
20
  }
21
 
22
+ # Cache local models
23
  @st.cache_resource
24
+ def load_local_model(model_id):
25
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
26
  model = AutoModelForCausalLM.from_pretrained(
27
  model_id,
28
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
29
+ device_map="auto" if torch.cuda.is_available() else None,
30
+ use_auth_token=HF_TOKEN
31
  )
32
  return model, tokenizer
33
 
34
+ # Local model querying
35
+ def query_local_model(model_id, prompt):
36
+ model, tokenizer = load_local_model(model_id)
37
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
38
  outputs = model.generate(**inputs, max_new_tokens=150)
39
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
40
 
41
+ # Remote model querying (via Inference API)
42
+ def query_remote_model(model_id, prompt):
43
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
44
+ payload = {"inputs": prompt, "parameters": {"max_new_tokens": 150}}
45
+ response = requests.post(
46
+ f"https://api-inference.huggingface.co/models/{model_id}",
47
+ headers=headers,
48
+ json=payload
49
+ )
50
+ if response.status_code == 200:
51
+ result = response.json()
52
+ return result[0]["generated_text"] if isinstance(result, list) else result.get("generated_text", "No output")
53
+ else:
54
+ raise RuntimeError(f"Failed to call remote model: {response.text}")
55
+
56
+ # Unified query dispatcher
57
+ def query_model(model_entry, prompt):
58
+ if model_entry["local"]:
59
+ return query_local_model(model_entry["id"], prompt)
60
+ else:
61
+ return query_remote_model(model_entry["id"], prompt)
62
+
63
+ # --- Streamlit UI ---
64
  st.title("💼 Financial LLM Evaluation Interface")
65
 
66
  model_choice = st.selectbox("Select a Financial Model", list(model_map.keys()))
 
69
  if st.button("Get Response"):
70
  with st.spinner("Generating response..."):
71
  try:
72
+ model_entry = model_map[model_choice]
73
+ answer = query_model(model_entry, user_question)
74
  st.subheader(f"Response from {model_choice}:")
75
  st.write(answer)
76
  except Exception as e: