BryanBradfo commited on
Commit
1960e32
·
1 Parent(s): f9ee089

fix mask issue

Browse files
Files changed (1) hide show
  1. app.py +35 -54
app.py CHANGED
@@ -131,74 +131,55 @@ def generate_text(prompt, max_new_tokens=300, temperature=0.7):
131
  with st.spinner("Loading model... (this may take a minute on first run)"):
132
  tokenizer, model = load_model()
133
 
 
134
  # Format the prompt according to Gemma's expected format
135
  formatted_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
136
 
137
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
138
-
139
- # Create the progress bar
140
  progress_bar = st.progress(0)
141
  status_text = st.empty()
142
  output_area = st.empty()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- streamer_output = ""
 
145
 
146
- status_text.text("Generating response...")
 
 
147
 
148
- with torch.no_grad():
149
- # Generate text step by step to show progress
150
- # Start with 1 token generation
151
- input_ids = inputs["input_ids"]
152
- generated_ids = None
 
 
153
 
154
- for i in range(max_new_tokens):
155
- if i == 0:
156
- # First token generation
157
- outputs = model.generate(
158
- input_ids=input_ids,
159
- max_new_tokens=1,
160
- do_sample=True,
161
- temperature=temperature,
162
- pad_token_id=tokenizer.eos_token_id
163
- )
164
- # Extract only the newly generated token(s)
165
- generated_ids = outputs[0][input_ids.shape[1]:].unsqueeze(0)
166
- else:
167
- # For subsequent tokens, concatenate previous results
168
- current_input_ids = torch.cat([input_ids, generated_ids], dim=1)
169
- outputs = model.generate(
170
- input_ids=current_input_ids,
171
- max_new_tokens=1,
172
- do_sample=True,
173
- temperature=temperature,
174
- pad_token_id=tokenizer.eos_token_id
175
- )
176
- # Extract only the newly generated token
177
- new_token = outputs[0][-1].unsqueeze(0).unsqueeze(0)
178
- generated_ids = torch.cat([generated_ids, new_token], dim=1)
179
-
180
- # Decode the current state
181
- current_output = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
182
- streamer_output = current_output
183
-
184
- # Update progress and output
185
- progress = min(1.0, (i + 1) / max_new_tokens)
186
- progress_bar.progress(progress)
187
-
188
- # Update display
189
- output_area.markdown(f"**Generated Response:**\n\n{streamer_output}")
190
-
191
- # Check if we've reached an EOS token in the latest output
192
- if outputs[0][-1].item() == tokenizer.eos_token_id:
193
- break
194
-
195
- # Add a small delay to simulate typing
196
- time.sleep(0.01)
197
 
198
  status_text.text("Generation complete!")
199
  progress_bar.progress(1.0)
200
 
201
- return streamer_output
202
 
203
  except Exception as e:
204
  st.session_state.error_message = str(e)
 
131
  with st.spinner("Loading model... (this may take a minute on first run)"):
132
  tokenizer, model = load_model()
133
 
134
+ # Simpler approach: use the model's built-in text generation capabilities
135
  # Format the prompt according to Gemma's expected format
136
  formatted_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
137
 
138
+ # Create the progress bar and status indicators
 
 
139
  progress_bar = st.progress(0)
140
  status_text = st.empty()
141
  output_area = st.empty()
142
+ status_text.text("Generating response...")
143
+
144
+ # Tokenize the input with attention mask explicitly set
145
+ encoding = tokenizer(formatted_prompt, return_tensors="pt")
146
+ input_ids = encoding["input_ids"].to(model.device)
147
+
148
+ # Create an attention mask of ones (attend to all tokens)
149
+ attention_mask = torch.ones_like(input_ids)
150
+
151
+ # Generate the full text at once (simpler and more reliable)
152
+ generated_ids = model.generate(
153
+ input_ids=input_ids,
154
+ attention_mask=attention_mask,
155
+ max_new_tokens=max_new_tokens,
156
+ do_sample=True,
157
+ temperature=temperature,
158
+ pad_token_id=tokenizer.eos_token_id,
159
+ )
160
 
161
+ # Get only the newly generated tokens (exclude input prompt)
162
+ generated_text = tokenizer.decode(generated_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
163
 
164
+ # Simulate token-by-token generation for visual effect
165
+ words = generated_text.split()
166
+ displayed_text = ""
167
 
168
+ for i, word in enumerate(words):
169
+ displayed_text += word + " "
170
+
171
+ # Update progress and display
172
+ progress = min(1.0, (i + 1) / len(words))
173
+ progress_bar.progress(progress)
174
+ output_area.markdown(f"**Generated Response:**\n\n{displayed_text}")
175
 
176
+ # Small delay for visual effect
177
+ time.sleep(0.05)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
  status_text.text("Generation complete!")
180
  progress_bar.progress(1.0)
181
 
182
+ return generated_text
183
 
184
  except Exception as e:
185
  st.session_state.error_message = str(e)