Szeyu commited on
Commit
03cd04b
·
verified ·
1 Parent(s): 2e2fdf8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -75,7 +75,7 @@ def get_caption(image_bytes):
75
  def get_story(caption):
76
  """
77
  Generates a humorous and engaging children's story based on the caption.
78
- Uses a prompt to instruct the model and limits token generation to 80 tokens.
79
  """
80
  prompt = (
81
  f"Write a funny, warm, and imaginative children's story for ages 3-10, 50-100 words, "
@@ -83,22 +83,28 @@ def get_story(caption):
83
  )
84
  result = st.session_state.storyer(
85
  prompt,
86
- max_new_tokens=80,
87
  do_sample=True,
88
  temperature=0.7,
89
  top_p=0.9,
90
  return_full_text=False
91
  )
92
 
93
- # Log the raw result for debugging (this is viewable in the server logs)
94
  print("Story generation raw result:", result)
95
 
96
  raw_story = result[0].get("generated_text", "").strip()
97
 
98
- # Remove the prompt from the output if it is included.
99
  if raw_story.startswith(prompt):
100
- raw_story = raw_story[len(prompt):].strip()
101
-
 
 
 
 
 
 
102
  words = raw_story.split()
103
  story = " ".join(words[:100])
104
  return story
@@ -107,8 +113,8 @@ def get_story(caption):
107
  def get_audio(story):
108
  """
109
  Converts the generated story text into audio.
110
- Splits the text into 300-character chunks, processes each via the TTS pipeline,
111
- and concatenates the resulting audio arrays. If no audio is generated, 1 second of silence is used.
112
  """
113
  chunks = textwrap.wrap(story, width=300)
114
  audio_chunks = []
@@ -153,7 +159,6 @@ if uploaded_file is not None:
153
  with st.spinner("Generating story..."):
154
  story = get_story(caption)
155
  st.markdown("<h3 style='text-align: center;'>Your Story:</h3>", unsafe_allow_html=True)
156
- # If the story is empty (or consists only of whitespace), display a default message.
157
  if not story.strip():
158
  st.write("No story was generated. Please try again.")
159
  else:
 
75
  def get_story(caption):
76
  """
77
  Generates a humorous and engaging children's story based on the caption.
78
+ Uses a prompt to instruct the model and limits token generation.
79
  """
80
  prompt = (
81
  f"Write a funny, warm, and imaginative children's story for ages 3-10, 50-100 words, "
 
83
  )
84
  result = st.session_state.storyer(
85
  prompt,
86
+ max_new_tokens=120, # Increased from 80 to 120 for more continuation space
87
  do_sample=True,
88
  temperature=0.7,
89
  top_p=0.9,
90
  return_full_text=False
91
  )
92
 
93
+ # Log the raw result for debugging (viewable in server logs)
94
  print("Story generation raw result:", result)
95
 
96
  raw_story = result[0].get("generated_text", "").strip()
97
 
98
+ # If the generated text starts with the prompt, remove it only if there is substantial extra content.
99
  if raw_story.startswith(prompt):
100
+ # Compute the extra part after the prompt.
101
+ extra_text = raw_story[len(prompt):].strip()
102
+ # Only use the extra text if it is longer than a threshold (e.g. 20 characters).
103
+ if len(extra_text) > 20:
104
+ raw_story = extra_text
105
+ else:
106
+ # If not, use the full raw_story instead.
107
+ raw_story = raw_story
108
  words = raw_story.split()
109
  story = " ".join(words[:100])
110
  return story
 
113
  def get_audio(story):
114
  """
115
  Converts the generated story text into audio.
116
+ Splits the text into 300-character chunks to reduce repeated TTS calls.
117
+ Checks each chunk; if no valid audio is produced, creates 1 second of silence.
118
  """
119
  chunks = textwrap.wrap(story, width=300)
120
  audio_chunks = []
 
159
  with st.spinner("Generating story..."):
160
  story = get_story(caption)
161
  st.markdown("<h3 style='text-align: center;'>Your Story:</h3>", unsafe_allow_html=True)
 
162
  if not story.strip():
163
  st.write("No story was generated. Please try again.")
164
  else: