rohansampath commited on
Commit
614dffd
·
verified ·
1 Parent(s): 9775899

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -23
app.py CHANGED
@@ -4,12 +4,13 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import evaluate
5
  import re
6
  import matplotlib
7
- matplotlib.use('Agg') # for non-interactive envs
8
  import matplotlib.pyplot as plt
9
  import io
10
  import base64
11
  import os
12
  from huggingface_hub import login
 
13
 
14
  # Read token and login
15
  hf_token = os.getenv("HF_TOKEN_READ_WRITE")
@@ -18,28 +19,26 @@ if hf_token:
18
  else:
19
  print("⚠️ No HF_TOKEN_READ_WRITE found in environment")
20
 
21
- # Check GPU availability
22
- if torch.cuda.is_available():
23
- print("✅ GPU is available")
24
- print("GPU Name:", torch.cuda.get_device_name(0))
25
- else:
26
- print("❌ No GPU available")
27
-
28
  # ---------------------------------------------------------------------------
29
- # 1. Define model name and load model/tokenizer
30
  # ---------------------------------------------------------------------------
31
  model_name = "mistralai/Mistral-7B-Instruct-v0.3"
 
 
32
 
33
- tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
34
- device = "cuda" if torch.cuda.is_available() else "cpu"
35
- model = AutoModelForCausalLM.from_pretrained(
36
- model_name,
37
- token=hf_token,
38
- torch_dtype=torch.float16,
39
- device_map="auto"
40
- )
41
-
42
- print(f"✅ Model loaded on {device}")
 
 
 
43
 
44
  # ---------------------------------------------------------------------------
45
  # 2. Test dataset
@@ -58,14 +57,17 @@ accuracy_metric = evaluate.load("accuracy")
58
  # ---------------------------------------------------------------------------
59
  # 4. Inference helper functions
60
  # ---------------------------------------------------------------------------
 
61
  def generate_answer(question):
62
  """
63
  Generates an answer using Mistral's instruction format.
64
  """
 
 
65
  # Mistral instruction format
66
  prompt = f"""<s>[INST] {question} [/INST]"""
67
 
68
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
69
  with torch.no_grad():
70
  outputs = model.generate(
71
  **inputs,
@@ -91,6 +93,7 @@ def parse_answer(model_output):
91
  # ---------------------------------------------------------------------------
92
  # 5. Evaluation routine
93
  # ---------------------------------------------------------------------------
 
94
  def run_evaluation():
95
  predictions = []
96
  references = []
@@ -125,10 +128,10 @@ def run_evaluation():
125
  accuracy = results["accuracy"]
126
 
127
  # Create visualization
 
128
  correct_count = sum(p == r for p, r in zip(norm_preds, norm_refs))
129
  incorrect_count = len(test_data) - correct_count
130
 
131
- fig, ax = plt.subplots(figsize=(8, 6))
132
  bars = ax.bar(["Correct", "Incorrect"],
133
  [correct_count, incorrect_count],
134
  color=["#2ecc71", "#e74c3c"])
@@ -142,7 +145,7 @@ def run_evaluation():
142
 
143
  ax.set_title("Evaluation Results")
144
  ax.set_ylabel("Count")
145
- ax.set_ylim([0, len(test_data) + 0.5]) # Add some padding at top
146
 
147
  # Convert plot to base64
148
  buf = io.BytesIO()
@@ -176,7 +179,6 @@ def run_evaluation():
176
 
177
  details_html += "</table></div>"
178
 
179
- # Combine plot and details
180
  full_html = f"""
181
  <div>
182
  <img src="data:image/png;base64,{data}" style="width:100%; max-width:600px;">
 
4
  import evaluate
5
  import re
6
  import matplotlib
7
+ matplotlib.use('Agg')
8
  import matplotlib.pyplot as plt
9
  import io
10
  import base64
11
  import os
12
  from huggingface_hub import login
13
+ import spaces
14
 
15
  # Read token and login
16
  hf_token = os.getenv("HF_TOKEN_READ_WRITE")
 
19
  else:
20
  print("⚠️ No HF_TOKEN_READ_WRITE found in environment")
21
 
 
 
 
 
 
 
 
22
  # ---------------------------------------------------------------------------
23
+ # 1. Model and tokenizer setup
24
  # ---------------------------------------------------------------------------
25
  model_name = "mistralai/Mistral-7B-Instruct-v0.3"
26
+ tokenizer = None
27
+ model = None
28
 
29
+ @spaces.GPU
30
+ def load_model():
31
+ global tokenizer, model
32
+ if tokenizer is None:
33
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
34
+ if model is None:
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ model_name,
37
+ token=hf_token,
38
+ torch_dtype=torch.float16
39
+ )
40
+ model.to('cuda')
41
+ return model, tokenizer
42
 
43
  # ---------------------------------------------------------------------------
44
  # 2. Test dataset
 
57
  # ---------------------------------------------------------------------------
58
  # 4. Inference helper functions
59
  # ---------------------------------------------------------------------------
60
+ @spaces.GPU
61
  def generate_answer(question):
62
  """
63
  Generates an answer using Mistral's instruction format.
64
  """
65
+ model, tokenizer = load_model()
66
+
67
  # Mistral instruction format
68
  prompt = f"""<s>[INST] {question} [/INST]"""
69
 
70
+ inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
71
  with torch.no_grad():
72
  outputs = model.generate(
73
  **inputs,
 
93
  # ---------------------------------------------------------------------------
94
  # 5. Evaluation routine
95
  # ---------------------------------------------------------------------------
96
+ @spaces.GPU(duration=120) # Allow up to 2 minutes for full evaluation
97
  def run_evaluation():
98
  predictions = []
99
  references = []
 
128
  accuracy = results["accuracy"]
129
 
130
  # Create visualization
131
+ fig, ax = plt.subplots(figsize=(8, 6))
132
  correct_count = sum(p == r for p, r in zip(norm_preds, norm_refs))
133
  incorrect_count = len(test_data) - correct_count
134
 
 
135
  bars = ax.bar(["Correct", "Incorrect"],
136
  [correct_count, incorrect_count],
137
  color=["#2ecc71", "#e74c3c"])
 
145
 
146
  ax.set_title("Evaluation Results")
147
  ax.set_ylabel("Count")
148
+ ax.set_ylim([0, len(test_data) + 0.5])
149
 
150
  # Convert plot to base64
151
  buf = io.BytesIO()
 
179
 
180
  details_html += "</table></div>"
181
 
 
182
  full_html = f"""
183
  <div>
184
  <img src="data:image/png;base64,{data}" style="width:100%; max-width:600px;">