Szeyu commited on
Commit
ed3e053
·
verified ·
1 Parent(s): 03cd04b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -44
app.py CHANGED
@@ -2,11 +2,15 @@ import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image
4
  import io, textwrap, numpy as np, soundfile as sf
 
 
 
 
5
 
6
  # ------------------ Streamlit Page Configuration ------------------
7
  st.set_page_config(
8
- page_title="Picture to Story Magic", # App title on browser tab
9
- page_icon="🦄", # Fun unicorn icon
10
  layout="centered"
11
  )
12
 
@@ -15,7 +19,7 @@ st.markdown(
15
  """
16
  <style>
17
  body {
18
- background-color: #FDEBD0; /* A soft pastel color */
19
  }
20
  </style>
21
  """,
@@ -37,11 +41,6 @@ st.markdown(
37
  def load_models():
38
  """
39
  Lazy-load the required pipelines and store them in session state.
40
-
41
- Pipelines:
42
- 1. Captioner: Generates descriptive text from an image using a lighter model.
43
- 2. Storyer: Generates a humorous children's story using aspis/gpt2-genre-story-generation.
44
- 3. TTS: Converts text into audio.
45
  """
46
  if "captioner" not in st.session_state:
47
  st.session_state.captioner = pipeline(
@@ -49,10 +48,17 @@ def load_models():
49
  model="Salesforce/blip-image-captioning-large"
50
  )
51
  if "storyer" not in st.session_state:
52
- st.session_state.storyer = pipeline(
53
- "text-generation",
54
- model="aspis/gpt2-genre-story-generation"
55
- )
 
 
 
 
 
 
 
56
  if "tts" not in st.session_state:
57
  st.session_state.tts = pipeline(
58
  "text-to-speech",
@@ -75,46 +81,54 @@ def get_caption(image_bytes):
75
  def get_story(caption):
76
  """
77
  Generates a humorous and engaging children's story based on the caption.
78
- Uses a prompt to instruct the model and limits token generation.
79
  """
80
  prompt = (
81
- f"Write a funny, warm, and imaginative children's story for ages 3-10, 50-100 words, "
82
- f"{caption}\nStory: in third-person narrative, as if the author is playfully describing the scene in the image."
83
- )
84
- result = st.session_state.storyer(
85
- prompt,
86
- max_new_tokens=120, # Increased from 80 to 120 for more continuation space
87
- do_sample=True,
88
- temperature=0.7,
89
- top_p=0.9,
90
- return_full_text=False
91
  )
92
-
93
- # Log the raw result for debugging (viewable in server logs)
94
- print("Story generation raw result:", result)
95
-
96
- raw_story = result[0].get("generated_text", "").strip()
97
-
98
- # If the generated text starts with the prompt, remove it only if there is substantial extra content.
99
- if raw_story.startswith(prompt):
100
- # Compute the extra part after the prompt.
101
- extra_text = raw_story[len(prompt):].strip()
102
- # Only use the extra text if it is longer than a threshold (e.g. 20 characters).
103
- if len(extra_text) > 20:
104
- raw_story = extra_text
105
- else:
106
- # If not, use the full raw_story instead.
107
- raw_story = raw_story
108
- words = raw_story.split()
109
- story = " ".join(words[:100])
110
- return story
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  @st.cache_data(show_spinner=False)
113
  def get_audio(story):
114
  """
115
  Converts the generated story text into audio.
116
  Splits the text into 300-character chunks to reduce repeated TTS calls.
117
- Checks each chunk; if no valid audio is produced, creates 1 second of silence.
118
  """
119
  chunks = textwrap.wrap(story, width=300)
120
  audio_chunks = []
@@ -173,4 +187,4 @@ if uploaded_file is not None:
173
  )
174
  except Exception as e:
175
  st.error("Oops! Something went wrong. Please try a different picture or check the file format!")
176
- st.error(f"Error details: {e}")
 
2
  from transformers import pipeline
3
  from PIL import Image
4
  import io, textwrap, numpy as np, soundfile as sf
5
+ import logging
6
+
7
+ # Set up logging for debugging
8
+ logging.basicConfig(level=logging.INFO)
9
 
10
  # ------------------ Streamlit Page Configuration ------------------
11
  st.set_page_config(
12
+ page_title="Picture to Story Magic",
13
+ page_icon="🦄",
14
  layout="centered"
15
  )
16
 
 
19
  """
20
  <style>
21
  body {
22
+ background-color: #FDEBD0;
23
  }
24
  </style>
25
  """,
 
41
  def load_models():
42
  """
43
  Lazy-load the required pipelines and store them in session state.
 
 
 
 
 
44
  """
45
  if "captioner" not in st.session_state:
46
  st.session_state.captioner = pipeline(
 
48
  model="Salesforce/blip-image-captioning-large"
49
  )
50
  if "storyer" not in st.session_state:
51
+ try:
52
+ st.session_state.storyer = pipeline(
53
+ "text-generation",
54
+ model="aspis/gpt2-genre-story-generation"
55
+ )
56
+ except Exception as e:
57
+ logging.warning(f"Failed to load aspis/gpt2-genre-story-generation: {e}. Falling back to gpt2.")
58
+ st.session_state.storyer = pipeline(
59
+ "text-generation",
60
+ model="gpt2"
61
+ )
62
  if "tts" not in st.session_state:
63
  st.session_state.tts = pipeline(
64
  "text-to-speech",
 
81
  def get_story(caption):
82
  """
83
  Generates a humorous and engaging children's story based on the caption.
84
+ Uses a simplified prompt and robust output parsing.
85
  """
86
  prompt = (
87
+ f"Create a funny, warm children's story (50-100 words) for ages 3-10 based on: {caption}. "
88
+ f"Use third-person narrative, as if playfully describing the scene."
 
 
 
 
 
 
 
 
89
  )
90
+ try:
91
+ result = st.session_state.storyer(
92
+ prompt,
93
+ max_new_tokens=150, # Increased to allow more room for story
94
+ do_sample=True,
95
+ temperature=0.8, # Slightly higher for creativity
96
+ top_p=0.9,
97
+ return_full_text=False
98
+ )
99
+ logging.info(f"Story generation raw result: {result}")
100
+
101
+ # Extract generated text
102
+ raw_story = result[0].get("generated_text", "").strip()
103
+
104
+ # If no meaningful output, generate a fallback story
105
+ if not raw_story or len(raw_story.split()) < 10:
106
+ logging.warning("Generated story too short or empty. Using fallback.")
107
+ raw_story = (
108
+ f"Once upon a time, in a land of {caption}, a silly squirrel named Sammy "
109
+ f"found a shiny treasure! He danced with joy, but oh no! It was a magic acorn! "
110
+ f"It grew into a giant tree, and Sammy climbed to the top, giggling all the way. "
111
+ f"The tree sang funny songs, and all the animals joined in for a big party!"
112
+ )
113
+
114
+ # Truncate to 100 words
115
+ words = raw_story.split()
116
+ story = " ".join(words[:100])
117
+ return story
118
+ except Exception as e:
119
+ logging.error(f"Story generation failed: {e}")
120
+ # Fallback story in case of errors
121
+ return (
122
+ f"Once upon a time, in a land of {caption}, a silly squirrel named Sammy "
123
+ f"found a shiny treasure! He danced with joy, but oh no! It was a magic acorn! "
124
+ f"It grew into a giant tree, and Sammy climbed to the top, giggling all the way."
125
+ )
126
 
127
  @st.cache_data(show_spinner=False)
128
  def get_audio(story):
129
  """
130
  Converts the generated story text into audio.
131
  Splits the text into 300-character chunks to reduce repeated TTS calls.
 
132
  """
133
  chunks = textwrap.wrap(story, width=300)
134
  audio_chunks = []
 
187
  )
188
  except Exception as e:
189
  st.error("Oops! Something went wrong. Please try a different picture or check the file format!")
190
+ st.error(f"Error details: {e}")