Szeyu commited on
Commit
e1351c4
·
verified ·
1 Parent(s): b540ff3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -14
app.py CHANGED
@@ -37,6 +37,7 @@ st.markdown(
37
  def load_models():
38
  """
39
  Lazy-load the required pipelines and store them in session state.
 
40
  Pipelines:
41
  1. Captioner: Generates descriptive text from an image using a lighter model.
42
  2. Storyer: Generates a humorous children's story using aspis/gpt2-genre-story-generation.
@@ -62,11 +63,10 @@ def load_models():
62
  @st.cache_data(show_spinner=False)
63
  def get_caption(image_bytes):
64
  """
65
- Converts image bytes into a lower resolution image (256x256 maximum)
66
  and generates a caption.
67
  """
68
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
69
- # Resize image to 256x256 maximum for faster processing
70
  image.thumbnail((256, 256))
71
  caption = st.session_state.captioner(image)[0]["generated_text"]
72
  return caption
@@ -76,46 +76,51 @@ def get_story(caption):
76
  """
77
  Generates a humorous and engaging children's story based on the caption.
78
  Uses a prompt to instruct the model and limits token generation to 80 tokens.
 
 
79
  """
80
  prompt = (
81
  f"Write a funny, warm, and imaginative children's story for ages 3-10, 50-100 words, "
82
  f"{caption}\nStory: in third-person narrative, as if the author is playfully describing the scene in the image."
83
  )
84
- raw_story = st.session_state.storyer(
85
  prompt,
86
  max_new_tokens=80,
87
  do_sample=True,
88
  temperature=0.7,
89
  top_p=0.9,
90
  return_full_text=False
91
- )[0]["generated_text"].strip()
 
 
 
 
 
 
92
  words = raw_story.split()
93
- return " ".join(words[:100])
 
94
 
95
  @st.cache_data(show_spinner=False)
96
  def get_audio(story):
97
  """
98
  Converts the generated story text into audio.
99
- Splits the text into 300-character chunks to reduce repeated TTS calls.
100
- Checks each chunk, and if no valid audio is produced, creates a brief default silent audio.
101
  """
102
  chunks = textwrap.wrap(story, width=300)
103
  audio_chunks = []
104
  for chunk in chunks:
105
  try:
106
  output = st.session_state.tts(chunk)
107
- # Some pipelines return a list; if so, use the first element.
108
  if isinstance(output, list):
109
  output = output[0]
110
  if "audio" in output:
111
- # Ensure the audio is a numpy array and squeeze any extra dimensions.
112
  audio_array = np.array(output["audio"]).squeeze()
113
  audio_chunks.append(audio_array)
114
- except Exception as e:
115
- # Skip any chunk that raises an error.
116
  continue
117
 
118
- # If no audio was generated, produce 1 second of silence as a fallback.
119
  if not audio_chunks:
120
  sr = st.session_state.tts.model.config.sampling_rate
121
  audio = np.zeros(sr, dtype=np.float32)
@@ -133,7 +138,6 @@ if uploaded_file is not None:
133
  try:
134
  load_models() # Ensure models are loaded
135
  image_bytes = uploaded_file.getvalue()
136
- # Display the uploaded image
137
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
138
  st.image(image, caption="Your Amazing Picture!", use_column_width=True)
139
  st.markdown("<h3 style='text-align: center;'>Ready for your story?</h3>", unsafe_allow_html=True)
@@ -147,7 +151,11 @@ if uploaded_file is not None:
147
  with st.spinner("Generating story..."):
148
  story = get_story(caption)
149
  st.markdown("<h3 style='text-align: center;'>Your Story:</h3>", unsafe_allow_html=True)
150
- st.write(story)
 
 
 
 
151
 
152
  with st.spinner("Generating audio..."):
153
  audio_buffer = get_audio(story)
 
37
  def load_models():
38
  """
39
  Lazy-load the required pipelines and store them in session state.
40
+
41
  Pipelines:
42
  1. Captioner: Generates descriptive text from an image using a lighter model.
43
  2. Storyer: Generates a humorous children's story using aspis/gpt2-genre-story-generation.
 
63
  @st.cache_data(show_spinner=False)
64
  def get_caption(image_bytes):
65
  """
66
+ Converts image bytes into a lower resolution image (maximum 256x256)
67
  and generates a caption.
68
  """
69
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
70
  image.thumbnail((256, 256))
71
  caption = st.session_state.captioner(image)[0]["generated_text"]
72
  return caption
 
76
  """
77
  Generates a humorous and engaging children's story based on the caption.
78
  Uses a prompt to instruct the model and limits token generation to 80 tokens.
79
+
80
+ If no text is generated, a fallback story is returned.
81
  """
82
  prompt = (
83
  f"Write a funny, warm, and imaginative children's story for ages 3-10, 50-100 words, "
84
  f"{caption}\nStory: in third-person narrative, as if the author is playfully describing the scene in the image."
85
  )
86
+ result = st.session_state.storyer(
87
  prompt,
88
  max_new_tokens=80,
89
  do_sample=True,
90
  temperature=0.7,
91
  top_p=0.9,
92
  return_full_text=False
93
+ )
94
+ # Log the raw result for debugging (viewable in server logs)
95
+ print("Story generation raw result:", result)
96
+
97
+ raw_story = result[0].get("generated_text", "").strip()
98
+ if not raw_story:
99
+ raw_story = "Once upon a time, the park was filled with laughter as children played happily under the bright sun."
100
  words = raw_story.split()
101
+ story = " ".join(words[:100])
102
+ return story
103
 
104
  @st.cache_data(show_spinner=False)
105
  def get_audio(story):
106
  """
107
  Converts the generated story text into audio.
108
+ Splits the text into 300-character chunks, processes each via the TTS pipeline,
109
+ and concatenates the resulting audio arrays. If no audio is generated, 1 second of silence is used.
110
  """
111
  chunks = textwrap.wrap(story, width=300)
112
  audio_chunks = []
113
  for chunk in chunks:
114
  try:
115
  output = st.session_state.tts(chunk)
 
116
  if isinstance(output, list):
117
  output = output[0]
118
  if "audio" in output:
 
119
  audio_array = np.array(output["audio"]).squeeze()
120
  audio_chunks.append(audio_array)
121
+ except Exception:
 
122
  continue
123
 
 
124
  if not audio_chunks:
125
  sr = st.session_state.tts.model.config.sampling_rate
126
  audio = np.zeros(sr, dtype=np.float32)
 
138
  try:
139
  load_models() # Ensure models are loaded
140
  image_bytes = uploaded_file.getvalue()
 
141
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
142
  st.image(image, caption="Your Amazing Picture!", use_column_width=True)
143
  st.markdown("<h3 style='text-align: center;'>Ready for your story?</h3>", unsafe_allow_html=True)
 
151
  with st.spinner("Generating story..."):
152
  story = get_story(caption)
153
  st.markdown("<h3 style='text-align: center;'>Your Story:</h3>", unsafe_allow_html=True)
154
+ # If the story is empty (or consists only of whitespace), display a default message.
155
+ if not story.strip():
156
+ st.write("No story was generated. Please try again.")
157
+ else:
158
+ st.write(story)
159
 
160
  with st.spinner("Generating audio..."):
161
  audio_buffer = get_audio(story)