Szeyu commited on
Commit
08b5abc
·
verified ·
1 Parent(s): ed3e053

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -25
app.py CHANGED
@@ -54,11 +54,8 @@ def load_models():
54
  model="aspis/gpt2-genre-story-generation"
55
  )
56
  except Exception as e:
57
- logging.warning(f"Failed to load aspis/gpt2-genre-story-generation: {e}. Falling back to gpt2.")
58
- st.session_state.storyer = pipeline(
59
- "text-generation",
60
- model="gpt2"
61
- )
62
  if "tts" not in st.session_state:
63
  st.session_state.tts = pipeline(
64
  "text-to-speech",
@@ -80,48 +77,55 @@ def get_caption(image_bytes):
80
  @st.cache_data(show_spinner=False)
81
  def get_story(caption):
82
  """
83
- Generates a humorous and engaging children's story based on the caption.
84
- Uses a simplified prompt and robust output parsing.
85
  """
86
  prompt = (
87
- f"Create a funny, warm children's story (50-100 words) for ages 3-10 based on: {caption}. "
88
- f"Use third-person narrative, as if playfully describing the scene."
89
  )
90
  try:
91
  result = st.session_state.storyer(
92
  prompt,
93
- max_new_tokens=150, # Increased to allow more room for story
94
  do_sample=True,
95
- temperature=0.8, # Slightly higher for creativity
96
- top_p=0.9,
 
97
  return_full_text=False
98
  )
99
  logging.info(f"Story generation raw result: {result}")
100
 
101
- # Extract generated text
102
  raw_story = result[0].get("generated_text", "").strip()
103
 
104
- # If no meaningful output, generate a fallback story
105
- if not raw_story or len(raw_story.split()) < 10:
 
106
  logging.warning("Generated story too short or empty. Using fallback.")
107
  raw_story = (
108
- f"Once upon a time, in a land of {caption}, a silly squirrel named Sammy "
109
- f"found a shiny treasure! He danced with joy, but oh no! It was a magic acorn! "
110
- f"It grew into a giant tree, and Sammy climbed to the top, giggling all the way. "
111
- f"The tree sang funny songs, and all the animals joined in for a big party!"
112
  )
 
113
 
114
- # Truncate to 100 words
115
- words = raw_story.split()
116
  story = " ".join(words[:100])
 
 
 
 
117
  return story
118
  except Exception as e:
119
  logging.error(f"Story generation failed: {e}")
120
- # Fallback story in case of errors
121
  return (
122
- f"Once upon a time, in a land of {caption}, a silly squirrel named Sammy "
123
- f"found a shiny treasure! He danced with joy, but oh no! It was a magic acorn! "
124
- f"It grew into a giant tree, and Sammy climbed to the top, giggling all the way."
 
125
  )
126
 
127
  @st.cache_data(show_spinner=False)
 
54
  model="aspis/gpt2-genre-story-generation"
55
  )
56
  except Exception as e:
57
+ logging.error(f"Failed to load aspis/gpt2-genre-story-generation: {e}")
58
+ raise Exception("Text generation model could not be loaded.")
 
 
 
59
  if "tts" not in st.session_state:
60
  st.session_state.tts = pipeline(
61
  "text-to-speech",
 
77
  @st.cache_data(show_spinner=False)
78
  def get_story(caption):
79
  """
80
+ Generates a humorous children's story (50-100 words) based on the caption.
81
+ Optimized for faster generation with aspis/gpt2-genre-story-generation.
82
  """
83
  prompt = (
84
+ f"Write a funny, warm children's story (50-100 words) for ages 3-10 based on: {caption}. "
85
+ f"Third-person narrative, playful tone."
86
  )
87
  try:
88
  result = st.session_state.storyer(
89
  prompt,
90
+ max_new_tokens=80, # Reduced for faster generation
91
  do_sample=True,
92
+ temperature=0.7, # Balanced creativity
93
+ top_k=40, # Faster sampling with smaller k
94
+ top_p=0.85, # Tighter sampling for coherence
95
  return_full_text=False
96
  )
97
  logging.info(f"Story generation raw result: {result}")
98
 
99
+ # Extract and clean generated text
100
  raw_story = result[0].get("generated_text", "").strip()
101
 
102
+ # Ensure story is 50-100 words
103
+ words = raw_story.split()
104
+ if len(words) < 50 or not raw_story:
105
  logging.warning("Generated story too short or empty. Using fallback.")
106
  raw_story = (
107
+ f"In a land of {caption}, a silly bunny named Bouncy found a shiny star! "
108
+ f"It sparkled, making Bouncy giggle and hop high. The star said, 'Dance!' "
109
+ f"So Bouncy twirled with squirrels and birds. They threw a forest party, "
110
+ f"singing silly songs under the twinkling sky, laughing all night."
111
  )
112
+ words = raw_story.split()
113
 
114
+ # Truncate to 100 words if too long
 
115
  story = " ".join(words[:100])
116
+ # Pad if too short
117
+ if len(words) < 50:
118
+ story += " And they all lived happily ever after!"
119
+
120
  return story
121
  except Exception as e:
122
  logging.error(f"Story generation failed: {e}")
123
+ # Fallback story
124
  return (
125
+ f"In a land of {caption}, a silly bunny named Bouncy found a shiny star! "
126
+ f"It sparkled, making Bouncy giggle and hop high. The star said, 'Dance!' "
127
+ f"So Bouncy twirled with squirrels and birds. They threw a forest party, "
128
+ f"singing silly songs under the twinkling sky, laughing all night."
129
  )
130
 
131
  @st.cache_data(show_spinner=False)