Spaces:
Sleeping
Sleeping
Commit
·
9b002fb
1
Parent(s):
995f0f7
generating for too long
Browse files
app.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
-
import time
|
5 |
import os
|
6 |
from dotenv import load_dotenv
|
7 |
|
@@ -20,7 +19,7 @@ st.title("✨ GemmaTextAppeal")
|
|
20 |
st.markdown("""
|
21 |
### Interactive Demo of Google's Gemma 2-2B-IT Model
|
22 |
This app demonstrates the text generation capabilities of Google's Gemma 2-2B-IT model.
|
23 |
-
Enter a prompt below and see the model generate text
|
24 |
""")
|
25 |
|
26 |
# Function to load model
|
@@ -32,31 +31,24 @@ def load_model():
|
|
32 |
if not huggingface_token:
|
33 |
return None, None, "No Hugging Face API token found. Please add your token as a secret named 'HF_TOKEN'."
|
34 |
|
35 |
-
#
|
36 |
tokenizer = AutoTokenizer.from_pretrained(
|
37 |
"google/gemma-2-2b-it",
|
38 |
token=huggingface_token
|
39 |
)
|
40 |
|
41 |
-
# Load model
|
42 |
model_kwargs = {
|
43 |
"token": huggingface_token,
|
44 |
-
"
|
|
|
45 |
}
|
46 |
|
47 |
-
# Only add device_map if GPU is available
|
48 |
-
if torch.cuda.is_available():
|
49 |
-
model_kwargs["device_map"] = "auto"
|
50 |
-
|
51 |
model = AutoModelForCausalLM.from_pretrained(
|
52 |
"google/gemma-2-2b-it",
|
53 |
**model_kwargs
|
54 |
)
|
55 |
|
56 |
-
# Move model to CPU if no GPU
|
57 |
-
if not torch.cuda.is_available():
|
58 |
-
model = model.to("cpu")
|
59 |
-
|
60 |
return tokenizer, model, None
|
61 |
except Exception as e:
|
62 |
return None, None, str(e)
|
@@ -158,47 +150,35 @@ def generate_text(prompt, max_new_tokens=300, temperature=0.7):
|
|
158 |
# Format the prompt according to Gemma's expected format
|
159 |
formatted_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
|
160 |
|
161 |
-
# Create the
|
162 |
-
progress_bar = st.progress(0)
|
163 |
status_text = st.empty()
|
164 |
output_area = st.empty()
|
165 |
status_text.text("Generating response...")
|
166 |
|
167 |
# Tokenize the input
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
# Display incrementally for visual effect
|
189 |
-
display_text = ""
|
190 |
-
words = generated_text.split()
|
191 |
-
total_words = len(words)
|
192 |
-
|
193 |
-
for i, word in enumerate(words):
|
194 |
-
display_text += word + " "
|
195 |
-
progress = min(1.0, (i + 1) / total_words)
|
196 |
-
progress_bar.progress(progress)
|
197 |
-
output_area.markdown(f"**Generated Response:**\n\n{display_text}")
|
198 |
-
time.sleep(0.05) # Brief delay for visual effect
|
199 |
|
|
|
|
|
200 |
status_text.text("Generation complete!")
|
201 |
-
progress_bar.progress(1.0)
|
202 |
|
203 |
return generated_text
|
204 |
|
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
4 |
import os
|
5 |
from dotenv import load_dotenv
|
6 |
|
|
|
19 |
st.markdown("""
|
20 |
### Interactive Demo of Google's Gemma 2-2B-IT Model
|
21 |
This app demonstrates the text generation capabilities of Google's Gemma 2-2B-IT model.
|
22 |
+
Enter a prompt below and see the model generate text!
|
23 |
""")
|
24 |
|
25 |
# Function to load model
|
|
|
31 |
if not huggingface_token:
|
32 |
return None, None, "No Hugging Face API token found. Please add your token as a secret named 'HF_TOKEN'."
|
33 |
|
34 |
+
# Load tokenizer
|
35 |
tokenizer = AutoTokenizer.from_pretrained(
|
36 |
"google/gemma-2-2b-it",
|
37 |
token=huggingface_token
|
38 |
)
|
39 |
|
40 |
+
# Load model with appropriate configuration
|
41 |
model_kwargs = {
|
42 |
"token": huggingface_token,
|
43 |
+
"device_map": "auto" if torch.cuda.is_available() else None,
|
44 |
+
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32
|
45 |
}
|
46 |
|
|
|
|
|
|
|
|
|
47 |
model = AutoModelForCausalLM.from_pretrained(
|
48 |
"google/gemma-2-2b-it",
|
49 |
**model_kwargs
|
50 |
)
|
51 |
|
|
|
|
|
|
|
|
|
52 |
return tokenizer, model, None
|
53 |
except Exception as e:
|
54 |
return None, None, str(e)
|
|
|
150 |
# Format the prompt according to Gemma's expected format
|
151 |
formatted_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
|
152 |
|
153 |
+
# Create the status indicator and output area
|
|
|
154 |
status_text = st.empty()
|
155 |
output_area = st.empty()
|
156 |
status_text.text("Generating response...")
|
157 |
|
158 |
# Tokenize the input
|
159 |
+
with torch.no_grad():
|
160 |
+
encoding = tokenizer(formatted_prompt, return_tensors="pt")
|
161 |
+
|
162 |
+
# Move to the appropriate device
|
163 |
+
if torch.cuda.is_available():
|
164 |
+
encoding = {k: v.to("cuda") for k, v in encoding.items()}
|
165 |
+
|
166 |
+
# Generate the text - streamlined version
|
167 |
+
output_ids = model.generate(
|
168 |
+
**encoding,
|
169 |
+
max_new_tokens=max_new_tokens,
|
170 |
+
do_sample=True,
|
171 |
+
temperature=temperature,
|
172 |
+
pad_token_id=tokenizer.eos_token_id
|
173 |
+
)
|
174 |
+
|
175 |
+
# Get only the generated part (exclude the prompt)
|
176 |
+
new_tokens = output_ids[0][encoding["input_ids"].shape[1]:]
|
177 |
+
generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
+
# Display the result
|
180 |
+
output_area.markdown(f"**Generated Response:**\n\n{generated_text}")
|
181 |
status_text.text("Generation complete!")
|
|
|
182 |
|
183 |
return generated_text
|
184 |
|