BryanBradfo commited on
Commit
f6f2d18
·
1 Parent(s): f871f1a

handling error

Browse files
Files changed (2) hide show
  1. app.py +154 -90
  2. requirements.txt +2 -1
app.py CHANGED
@@ -23,6 +23,21 @@ This app demonstrates the text generation capabilities of Google's Gemma 2-2B-IT
23
  Enter a prompt below and see the model generate text in real-time!
24
  """)
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # Sidebar with information
27
  with st.sidebar:
28
  st.header("About Gemma")
@@ -66,6 +81,8 @@ if 'generation_complete' not in st.session_state:
66
  st.session_state.generation_complete = False
67
  if 'generated_text' not in st.session_state:
68
  st.session_state.generated_text = ""
 
 
69
 
70
  # Model parameters
71
  col1, col2 = st.columns(2)
@@ -83,110 +100,157 @@ user_input = st.text_area("Enter your prompt:",
83
  placeholder="e.g., Write a short story about a robot discovering emotions")
84
 
85
  # Function to load model and generate text
86
- @st.cache_resource
87
  def load_model():
88
- # Get API Token
89
- huggingface_token = os.getenv("HF_TOKEN")
90
- if not huggingface_token:
91
- st.warning("No Hugging Face API token found. Some models may not be accessible.")
92
-
93
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", token=huggingface_token)
94
- model = AutoModelForCausalLM.from_pretrained(
95
- "google/gemma-2-2b-it",
96
- token=huggingface_token,
97
- torch_dtype=torch.float16,
98
- device_map="auto"
99
- )
100
- return tokenizer, model
 
 
 
 
 
 
 
 
 
 
101
 
102
  def generate_text(prompt, max_new_tokens=300, temperature=0.7):
103
- tokenizer, model = load_model()
104
-
105
- # Format the prompt according to Gemma's expected format
106
- formatted_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
107
-
108
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
109
-
110
- # Create the progress bar
111
- progress_bar = st.progress(0)
112
- status_text = st.empty()
113
- output_area = st.empty()
114
-
115
- tokens_generated = 0
116
- generated_text = ""
117
-
118
- # Generate with streaming
119
- streamer_output = ""
120
-
121
- # Generate with step-by-step tracking for the progress bar
122
- generate_kwargs = dict(
123
- inputs=inputs["input_ids"],
124
- max_new_tokens=max_new_tokens,
125
- temperature=temperature,
126
- do_sample=True,
127
- pad_token_id=tokenizer.eos_token_id
128
- )
129
-
130
- status_text.text("Generating response...")
131
-
132
- with torch.no_grad():
133
- # Generate text step by step
134
- for i in range(max_new_tokens):
135
- if i == 0:
136
- outputs = model.generate(
137
- **generate_kwargs,
138
- max_new_tokens=1,
139
- )
140
- generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
141
- else:
142
- input_ids = torch.cat([inputs["input_ids"], generated_ids], dim=1)
143
- outputs = model.generate(
144
- input_ids=input_ids,
145
- max_new_tokens=1,
146
- do_sample=True,
147
- temperature=temperature,
148
- pad_token_id=tokenizer.eos_token_id
149
- )
150
- new_token = outputs[0][-1].unsqueeze(0)
151
- generated_ids = torch.cat([generated_ids, new_token], dim=0)
152
-
153
- # Decode text
154
- current_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
155
-
156
- # Update streaming output
157
- streamer_output = current_text
158
-
159
- # Update progress and output
160
- progress = min(1.0, (i + 1) / max_new_tokens)
161
- progress_bar.progress(progress)
162
-
163
- # Update display
164
- output_area.markdown(f"**Generated Response:**\n\n{streamer_output}")
165
-
166
- # Check if we've reached an end token
167
- if generated_ids[-1].item() == tokenizer.eos_token_id:
168
- break
169
 
170
- # Add a small delay to simulate typing
171
- time.sleep(0.01)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- status_text.text("Generation complete!")
174
- progress_bar.progress(1.0)
 
 
 
 
 
175
 
176
- return streamer_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  # Generate button
179
  if st.button("Generate Text"):
180
- if user_input:
 
 
 
 
 
181
  st.session_state.user_prompt = user_input
182
- with st.spinner("Generating text..."):
183
- st.session_state.generated_text = generate_text(user_input, max_length, temperature)
 
184
  st.session_state.generation_complete = True
185
  else:
186
  st.error("Please enter a prompt first!")
187
 
188
  # Display results
189
- if st.session_state.generation_complete:
190
  st.markdown("### Generated Text")
191
  st.markdown(st.session_state.generated_text)
192
 
@@ -207,6 +271,6 @@ st.markdown("---")
207
  st.markdown("""
208
  <div style="text-align: center">
209
  <p>Created with ❤️ | Powered by Gemma 2-2B-IT and Hugging Face</p>
210
- <p>Code available on <a href="https://huggingface.co/spaces/your-username/GemmaTextAppeal">Hugging Face Spaces</a></p>
211
  </div>
212
  """, unsafe_allow_html=True)
 
23
  Enter a prompt below and see the model generate text in real-time!
24
  """)
25
 
26
+ # Check for Hugging Face Token
27
+ huggingface_token = os.getenv("HF_TOKEN")
28
+ if not huggingface_token:
29
+ st.warning("""
30
+ ⚠️ **No Hugging Face API token detected**
31
+
32
+ The Gemma models require accepting a license and authentication to use.
33
+
34
+ To make this app work:
35
+ 1. Create a Hugging Face account
36
+ 2. Accept the model license at: https://huggingface.co/google/gemma-2-2b-it
37
+ 3. Create a HF token at: https://huggingface.co/settings/tokens
38
+ 4. Add your token as a secret named 'HF_TOKEN' in your Space settings
39
+ """)
40
+
41
  # Sidebar with information
42
  with st.sidebar:
43
  st.header("About Gemma")
 
81
  st.session_state.generation_complete = False
82
  if 'generated_text' not in st.session_state:
83
  st.session_state.generated_text = ""
84
+ if 'error_message' not in st.session_state:
85
+ st.session_state.error_message = None
86
 
87
  # Model parameters
88
  col1, col2 = st.columns(2)
 
100
  placeholder="e.g., Write a short story about a robot discovering emotions")
101
 
102
  # Function to load model and generate text
103
+ @st.cache_resource(show_spinner=False)
104
  def load_model():
105
+ try:
106
+ # Get API Token
107
+ huggingface_token = os.getenv("HF_TOKEN")
108
+ if not huggingface_token:
109
+ raise ValueError("No Hugging Face API token found. Please add your token as a secret named 'HF_TOKEN'.")
110
+
111
+ # Attempt to download model with explicit token
112
+ tokenizer = AutoTokenizer.from_pretrained(
113
+ "google/gemma-2-2b-it",
114
+ token=huggingface_token,
115
+ use_fast=True
116
+ )
117
+
118
+ model = AutoModelForCausalLM.from_pretrained(
119
+ "google/gemma-2-2b-it",
120
+ token=huggingface_token,
121
+ torch_dtype=torch.float16,
122
+ device_map="auto"
123
+ )
124
+ return tokenizer, model
125
+ except Exception as e:
126
+ # Re-raise the exception to be handled in the calling function
127
+ raise e
128
 
129
  def generate_text(prompt, max_new_tokens=300, temperature=0.7):
130
+ try:
131
+ with st.spinner("Loading model... (this may take a minute on first run)"):
132
+ tokenizer, model = load_model()
133
+
134
+ # Format the prompt according to Gemma's expected format
135
+ formatted_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
136
+
137
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
138
+
139
+ # Create the progress bar
140
+ progress_bar = st.progress(0)
141
+ status_text = st.empty()
142
+ output_area = st.empty()
143
+
144
+ tokens_generated = 0
145
+ generated_text = ""
146
+
147
+ # Generate with streaming
148
+ streamer_output = ""
149
+
150
+ # Generate with step-by-step tracking for the progress bar
151
+ generate_kwargs = dict(
152
+ inputs=inputs["input_ids"],
153
+ max_new_tokens=max_new_tokens,
154
+ temperature=temperature,
155
+ do_sample=True,
156
+ pad_token_id=tokenizer.eos_token_id
157
+ )
158
+
159
+ status_text.text("Generating response...")
160
+
161
+ with torch.no_grad():
162
+ # Generate text step by step
163
+ for i in range(max_new_tokens):
164
+ if i == 0:
165
+ outputs = model.generate(
166
+ **generate_kwargs,
167
+ max_new_tokens=1,
168
+ )
169
+ generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
170
+ else:
171
+ input_ids = torch.cat([inputs["input_ids"], generated_ids], dim=1)
172
+ outputs = model.generate(
173
+ input_ids=input_ids,
174
+ max_new_tokens=1,
175
+ do_sample=True,
176
+ temperature=temperature,
177
+ pad_token_id=tokenizer.eos_token_id
178
+ )
179
+ new_token = outputs[0][-1].unsqueeze(0)
180
+ generated_ids = torch.cat([generated_ids, new_token], dim=0)
181
+
182
+ # Decode text
183
+ current_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
184
+
185
+ # Update streaming output
186
+ streamer_output = current_text
 
 
 
 
 
 
 
 
 
187
 
188
+ # Update progress and output
189
+ progress = min(1.0, (i + 1) / max_new_tokens)
190
+ progress_bar.progress(progress)
191
+
192
+ # Update display
193
+ output_area.markdown(f"**Generated Response:**\n\n{streamer_output}")
194
+
195
+ # Check if we've reached an end token
196
+ if generated_ids[-1].item() == tokenizer.eos_token_id:
197
+ break
198
+
199
+ # Add a small delay to simulate typing
200
+ time.sleep(0.01)
201
+
202
+ status_text.text("Generation complete!")
203
+ progress_bar.progress(1.0)
204
+
205
+ return streamer_output
206
 
207
+ except Exception as e:
208
+ st.session_state.error_message = str(e)
209
+ return None
210
+
211
+ # Show any existing error
212
+ if st.session_state.error_message:
213
+ st.error(f"Error: {st.session_state.error_message}")
214
 
215
+ # Add troubleshooting information
216
+ with st.expander("Troubleshooting Information"):
217
+ st.markdown("""
218
+ ### Common Issues:
219
+
220
+ 1. **Missing Hugging Face Token**: The Gemma model requires authentication. Add your token as a secret named 'HF_TOKEN' in the Space settings.
221
+
222
+ 2. **License Acceptance**: You need to accept the model license on the [Gemma model page](https://huggingface.co/google/gemma-2-2b-it).
223
+
224
+ 3. **Internet Connection**: The model needs to be downloaded the first time the app runs. Ensure your Space has internet access.
225
+
226
+ 4. **Resource Constraints**: The Gemma model requires significant resources. Consider upgrading your Space's hardware if you're encountering memory issues.
227
+
228
+ ### How to Fix:
229
+
230
+ 1. Create a [Hugging Face account](https://huggingface.co/join)
231
+ 2. Visit the [Gemma model page](https://huggingface.co/google/gemma-2-2b-it) and accept the license
232
+ 3. Create a token at https://huggingface.co/settings/tokens
233
+ 4. Add your token to the Space: Settings → Secrets → New Secret (HF_TOKEN)
234
+ """)
235
 
236
  # Generate button
237
  if st.button("Generate Text"):
238
+ # Reset any previous errors
239
+ st.session_state.error_message = None
240
+
241
+ if not huggingface_token:
242
+ st.error("Hugging Face token is required! Please add your token as described above.")
243
+ elif user_input:
244
  st.session_state.user_prompt = user_input
245
+ result = generate_text(user_input, max_length, temperature)
246
+ if result is not None: # Only set if no error occurred
247
+ st.session_state.generated_text = result
248
  st.session_state.generation_complete = True
249
  else:
250
  st.error("Please enter a prompt first!")
251
 
252
  # Display results
253
+ if st.session_state.generation_complete and not st.session_state.error_message:
254
  st.markdown("### Generated Text")
255
  st.markdown(st.session_state.generated_text)
256
 
 
271
  st.markdown("""
272
  <div style="text-align: center">
273
  <p>Created with ❤️ | Powered by Gemma 2-2B-IT and Hugging Face</p>
274
+ <p>Code available on <a href="https://huggingface.co/spaces" target="_blank">Hugging Face Spaces</a></p>
275
  </div>
276
  """, unsafe_allow_html=True)
requirements.txt CHANGED
@@ -1,5 +1,6 @@
 
1
  streamlit==1.24.0
2
  torch>=2.0.0
3
- transformers>=4.31.0
4
  python-dotenv==1.0.0
5
  accelerate>=0.20.0
 
1
+
2
  streamlit==1.24.0
3
  torch>=2.0.0
4
+ transformers>=4.34.0
5
  python-dotenv==1.0.0
6
  accelerate>=0.20.0