Update app.py
Browse files
app.py
CHANGED
@@ -37,6 +37,7 @@ st.markdown(
|
|
37 |
def load_models():
|
38 |
"""
|
39 |
Lazy-load the required pipelines and store them in session state.
|
|
|
40 |
Pipelines:
|
41 |
1. Captioner: Generates descriptive text from an image using a lighter model.
|
42 |
2. Storyer: Generates a humorous children's story using aspis/gpt2-genre-story-generation.
|
@@ -62,11 +63,10 @@ def load_models():
|
|
62 |
@st.cache_data(show_spinner=False)
|
63 |
def get_caption(image_bytes):
|
64 |
"""
|
65 |
-
Converts image bytes into a lower resolution image (256x256
|
66 |
and generates a caption.
|
67 |
"""
|
68 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
69 |
-
# Resize image to 256x256 maximum for faster processing
|
70 |
image.thumbnail((256, 256))
|
71 |
caption = st.session_state.captioner(image)[0]["generated_text"]
|
72 |
return caption
|
@@ -76,46 +76,51 @@ def get_story(caption):
|
|
76 |
"""
|
77 |
Generates a humorous and engaging children's story based on the caption.
|
78 |
Uses a prompt to instruct the model and limits token generation to 80 tokens.
|
|
|
|
|
79 |
"""
|
80 |
prompt = (
|
81 |
f"Write a funny, warm, and imaginative children's story for ages 3-10, 50-100 words, "
|
82 |
f"{caption}\nStory: in third-person narrative, as if the author is playfully describing the scene in the image."
|
83 |
)
|
84 |
-
|
85 |
prompt,
|
86 |
max_new_tokens=80,
|
87 |
do_sample=True,
|
88 |
temperature=0.7,
|
89 |
top_p=0.9,
|
90 |
return_full_text=False
|
91 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
words = raw_story.split()
|
93 |
-
|
|
|
94 |
|
95 |
@st.cache_data(show_spinner=False)
|
96 |
def get_audio(story):
|
97 |
"""
|
98 |
Converts the generated story text into audio.
|
99 |
-
Splits the text into 300-character chunks
|
100 |
-
|
101 |
"""
|
102 |
chunks = textwrap.wrap(story, width=300)
|
103 |
audio_chunks = []
|
104 |
for chunk in chunks:
|
105 |
try:
|
106 |
output = st.session_state.tts(chunk)
|
107 |
-
# Some pipelines return a list; if so, use the first element.
|
108 |
if isinstance(output, list):
|
109 |
output = output[0]
|
110 |
if "audio" in output:
|
111 |
-
# Ensure the audio is a numpy array and squeeze any extra dimensions.
|
112 |
audio_array = np.array(output["audio"]).squeeze()
|
113 |
audio_chunks.append(audio_array)
|
114 |
-
except Exception
|
115 |
-
# Skip any chunk that raises an error.
|
116 |
continue
|
117 |
|
118 |
-
# If no audio was generated, produce 1 second of silence as a fallback.
|
119 |
if not audio_chunks:
|
120 |
sr = st.session_state.tts.model.config.sampling_rate
|
121 |
audio = np.zeros(sr, dtype=np.float32)
|
@@ -133,7 +138,6 @@ if uploaded_file is not None:
|
|
133 |
try:
|
134 |
load_models() # Ensure models are loaded
|
135 |
image_bytes = uploaded_file.getvalue()
|
136 |
-
# Display the uploaded image
|
137 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
138 |
st.image(image, caption="Your Amazing Picture!", use_column_width=True)
|
139 |
st.markdown("<h3 style='text-align: center;'>Ready for your story?</h3>", unsafe_allow_html=True)
|
@@ -147,7 +151,11 @@ if uploaded_file is not None:
|
|
147 |
with st.spinner("Generating story..."):
|
148 |
story = get_story(caption)
|
149 |
st.markdown("<h3 style='text-align: center;'>Your Story:</h3>", unsafe_allow_html=True)
|
150 |
-
|
|
|
|
|
|
|
|
|
151 |
|
152 |
with st.spinner("Generating audio..."):
|
153 |
audio_buffer = get_audio(story)
|
|
|
37 |
def load_models():
|
38 |
"""
|
39 |
Lazy-load the required pipelines and store them in session state.
|
40 |
+
|
41 |
Pipelines:
|
42 |
1. Captioner: Generates descriptive text from an image using a lighter model.
|
43 |
2. Storyer: Generates a humorous children's story using aspis/gpt2-genre-story-generation.
|
|
|
63 |
@st.cache_data(show_spinner=False)
|
64 |
def get_caption(image_bytes):
|
65 |
"""
|
66 |
+
Converts image bytes into a lower resolution image (maximum 256x256)
|
67 |
and generates a caption.
|
68 |
"""
|
69 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
|
|
70 |
image.thumbnail((256, 256))
|
71 |
caption = st.session_state.captioner(image)[0]["generated_text"]
|
72 |
return caption
|
|
|
76 |
"""
|
77 |
Generates a humorous and engaging children's story based on the caption.
|
78 |
Uses a prompt to instruct the model and limits token generation to 80 tokens.
|
79 |
+
|
80 |
+
If no text is generated, a fallback story is returned.
|
81 |
"""
|
82 |
prompt = (
|
83 |
f"Write a funny, warm, and imaginative children's story for ages 3-10, 50-100 words, "
|
84 |
f"{caption}\nStory: in third-person narrative, as if the author is playfully describing the scene in the image."
|
85 |
)
|
86 |
+
result = st.session_state.storyer(
|
87 |
prompt,
|
88 |
max_new_tokens=80,
|
89 |
do_sample=True,
|
90 |
temperature=0.7,
|
91 |
top_p=0.9,
|
92 |
return_full_text=False
|
93 |
+
)
|
94 |
+
# Log the raw result for debugging (viewable in server logs)
|
95 |
+
print("Story generation raw result:", result)
|
96 |
+
|
97 |
+
raw_story = result[0].get("generated_text", "").strip()
|
98 |
+
if not raw_story:
|
99 |
+
raw_story = "Once upon a time, the park was filled with laughter as children played happily under the bright sun."
|
100 |
words = raw_story.split()
|
101 |
+
story = " ".join(words[:100])
|
102 |
+
return story
|
103 |
|
104 |
@st.cache_data(show_spinner=False)
|
105 |
def get_audio(story):
|
106 |
"""
|
107 |
Converts the generated story text into audio.
|
108 |
+
Splits the text into 300-character chunks, processes each via the TTS pipeline,
|
109 |
+
and concatenates the resulting audio arrays. If no audio is generated, 1 second of silence is used.
|
110 |
"""
|
111 |
chunks = textwrap.wrap(story, width=300)
|
112 |
audio_chunks = []
|
113 |
for chunk in chunks:
|
114 |
try:
|
115 |
output = st.session_state.tts(chunk)
|
|
|
116 |
if isinstance(output, list):
|
117 |
output = output[0]
|
118 |
if "audio" in output:
|
|
|
119 |
audio_array = np.array(output["audio"]).squeeze()
|
120 |
audio_chunks.append(audio_array)
|
121 |
+
except Exception:
|
|
|
122 |
continue
|
123 |
|
|
|
124 |
if not audio_chunks:
|
125 |
sr = st.session_state.tts.model.config.sampling_rate
|
126 |
audio = np.zeros(sr, dtype=np.float32)
|
|
|
138 |
try:
|
139 |
load_models() # Ensure models are loaded
|
140 |
image_bytes = uploaded_file.getvalue()
|
|
|
141 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
142 |
st.image(image, caption="Your Amazing Picture!", use_column_width=True)
|
143 |
st.markdown("<h3 style='text-align: center;'>Ready for your story?</h3>", unsafe_allow_html=True)
|
|
|
151 |
with st.spinner("Generating story..."):
|
152 |
story = get_story(caption)
|
153 |
st.markdown("<h3 style='text-align: center;'>Your Story:</h3>", unsafe_allow_html=True)
|
154 |
+
# If the story is empty (or consists only of whitespace), display a default message.
|
155 |
+
if not story.strip():
|
156 |
+
st.write("No story was generated. Please try again.")
|
157 |
+
else:
|
158 |
+
st.write(story)
|
159 |
|
160 |
with st.spinner("Generating audio..."):
|
161 |
audio_buffer = get_audio(story)
|