BryanBradfo commited on
Commit
9b002fb
·
1 Parent(s): 995f0f7

generating for too long

Browse files
Files changed (1) hide show
  1. app.py +27 -47
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import streamlit as st
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import time
5
  import os
6
  from dotenv import load_dotenv
7
 
@@ -20,7 +19,7 @@ st.title("✨ GemmaTextAppeal")
20
  st.markdown("""
21
  ### Interactive Demo of Google's Gemma 2-2B-IT Model
22
  This app demonstrates the text generation capabilities of Google's Gemma 2-2B-IT model.
23
- Enter a prompt below and see the model generate text in real-time!
24
  """)
25
 
26
  # Function to load model
@@ -32,31 +31,24 @@ def load_model():
32
  if not huggingface_token:
33
  return None, None, "No Hugging Face API token found. Please add your token as a secret named 'HF_TOKEN'."
34
 
35
- # Attempt to download model with explicit token
36
  tokenizer = AutoTokenizer.from_pretrained(
37
  "google/gemma-2-2b-it",
38
  token=huggingface_token
39
  )
40
 
41
- # Load model - use CPU configuration if no GPU available
42
  model_kwargs = {
43
  "token": huggingface_token,
44
- "torch_dtype": torch.float16
 
45
  }
46
 
47
- # Only add device_map if GPU is available
48
- if torch.cuda.is_available():
49
- model_kwargs["device_map"] = "auto"
50
-
51
  model = AutoModelForCausalLM.from_pretrained(
52
  "google/gemma-2-2b-it",
53
  **model_kwargs
54
  )
55
 
56
- # Move model to CPU if no GPU
57
- if not torch.cuda.is_available():
58
- model = model.to("cpu")
59
-
60
  return tokenizer, model, None
61
  except Exception as e:
62
  return None, None, str(e)
@@ -158,47 +150,35 @@ def generate_text(prompt, max_new_tokens=300, temperature=0.7):
158
  # Format the prompt according to Gemma's expected format
159
  formatted_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
160
 
161
- # Create the progress bar and status indicators
162
- progress_bar = st.progress(0)
163
  status_text = st.empty()
164
  output_area = st.empty()
165
  status_text.text("Generating response...")
166
 
167
  # Tokenize the input
168
- encoding = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
169
- input_ids = encoding["input_ids"]
170
-
171
- # Ensure we have a proper attention mask
172
- attention_mask = torch.ones_like(input_ids)
173
-
174
- # Simple approach - generate all at once
175
- output_ids = model.generate(
176
- input_ids=input_ids,
177
- attention_mask=attention_mask,
178
- max_new_tokens=max_new_tokens,
179
- do_sample=True,
180
- temperature=temperature,
181
- pad_token_id=tokenizer.eos_token_id
182
- )
183
-
184
- # Get only the generated part (exclude the prompt)
185
- new_tokens = output_ids[0][input_ids.shape[1]:]
186
- generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
187
-
188
- # Display incrementally for visual effect
189
- display_text = ""
190
- words = generated_text.split()
191
- total_words = len(words)
192
-
193
- for i, word in enumerate(words):
194
- display_text += word + " "
195
- progress = min(1.0, (i + 1) / total_words)
196
- progress_bar.progress(progress)
197
- output_area.markdown(f"**Generated Response:**\n\n{display_text}")
198
- time.sleep(0.05) # Brief delay for visual effect
199
 
 
 
200
  status_text.text("Generation complete!")
201
- progress_bar.progress(1.0)
202
 
203
  return generated_text
204
 
 
1
  import streamlit as st
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
4
  import os
5
  from dotenv import load_dotenv
6
 
 
19
  st.markdown("""
20
  ### Interactive Demo of Google's Gemma 2-2B-IT Model
21
  This app demonstrates the text generation capabilities of Google's Gemma 2-2B-IT model.
22
+ Enter a prompt below and see the model generate text!
23
  """)
24
 
25
  # Function to load model
 
31
  if not huggingface_token:
32
  return None, None, "No Hugging Face API token found. Please add your token as a secret named 'HF_TOKEN'."
33
 
34
+ # Load tokenizer
35
  tokenizer = AutoTokenizer.from_pretrained(
36
  "google/gemma-2-2b-it",
37
  token=huggingface_token
38
  )
39
 
40
+ # Load model with appropriate configuration
41
  model_kwargs = {
42
  "token": huggingface_token,
43
+ "device_map": "auto" if torch.cuda.is_available() else None,
44
+ "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32
45
  }
46
 
 
 
 
 
47
  model = AutoModelForCausalLM.from_pretrained(
48
  "google/gemma-2-2b-it",
49
  **model_kwargs
50
  )
51
 
 
 
 
 
52
  return tokenizer, model, None
53
  except Exception as e:
54
  return None, None, str(e)
 
150
  # Format the prompt according to Gemma's expected format
151
  formatted_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
152
 
153
+ # Create the status indicator and output area
 
154
  status_text = st.empty()
155
  output_area = st.empty()
156
  status_text.text("Generating response...")
157
 
158
  # Tokenize the input
159
+ with torch.no_grad():
160
+ encoding = tokenizer(formatted_prompt, return_tensors="pt")
161
+
162
+ # Move to the appropriate device
163
+ if torch.cuda.is_available():
164
+ encoding = {k: v.to("cuda") for k, v in encoding.items()}
165
+
166
+ # Generate the text - streamlined version
167
+ output_ids = model.generate(
168
+ **encoding,
169
+ max_new_tokens=max_new_tokens,
170
+ do_sample=True,
171
+ temperature=temperature,
172
+ pad_token_id=tokenizer.eos_token_id
173
+ )
174
+
175
+ # Get only the generated part (exclude the prompt)
176
+ new_tokens = output_ids[0][encoding["input_ids"].shape[1]:]
177
+ generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
+ # Display the result
180
+ output_area.markdown(f"**Generated Response:**\n\n{generated_text}")
181
  status_text.text("Generation complete!")
 
182
 
183
  return generated_text
184