Szeyu commited on
Commit
95bff35
·
verified ·
1 Parent(s): d31e539

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -13
app.py CHANGED
@@ -43,7 +43,6 @@ def load_models():
43
  3. TTS: Converts text into audio.
44
  """
45
  if "captioner" not in st.session_state:
46
- # Use the "base" version for faster/cost-effective captioning.
47
  st.session_state.captioner = pipeline(
48
  "image-to-text",
49
  model="Salesforce/blip-image-captioning-base"
@@ -63,11 +62,11 @@ def load_models():
63
  @st.cache_data(show_spinner=False)
64
  def get_caption(image_bytes):
65
  """
66
- Convert the image bytes into a smaller image to speed up captioning,
67
- then return the generated caption.
68
  """
69
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
70
- # Resize the image (preserving aspect ratio) to only 256x256 for faster processing.
71
  image.thumbnail((256, 256))
72
  caption = st.session_state.captioner(image)[0]["generated_text"]
73
  return caption
@@ -75,9 +74,8 @@ def get_caption(image_bytes):
75
  @st.cache_data(show_spinner=False)
76
  def get_story(caption):
77
  """
78
- Generate a humorous and engaging children's story using the caption.
79
- The prompt instructs the model to produce a playful story (50-100 words).
80
- We lower max_new_tokens to 80 so that it generates its text faster.
81
  """
82
  prompt = (
83
  f"Write a funny, warm, and imaginative children's story for ages 3-10, 50-100 words, "
@@ -88,7 +86,7 @@ def get_story(caption):
88
  )
89
  raw_story = st.session_state.storyer(
90
  prompt,
91
- max_new_tokens=80, # Reduced token generation for faster response
92
  do_sample=True,
93
  temperature=0.7,
94
  top_p=0.9,
@@ -100,9 +98,9 @@ def get_story(caption):
100
  @st.cache_data(show_spinner=False)
101
  def get_audio(story):
102
  """
103
- Convert the generated story text into audio.
104
- The text is split into 300-character chunks to reduce repeated TTS calls,
105
- the audio chunks are concatenated, and then stored in an in-memory WAV buffer.
106
  """
107
  chunks = textwrap.wrap(story, width=300)
108
  audio_chunks = [st.session_state.tts(chunk)["audio"].squeeze() for chunk in chunks]
@@ -116,9 +114,9 @@ def get_audio(story):
116
  uploaded_file = st.file_uploader("Choose a Picture...", type=["jpg", "jpeg", "png"])
117
  if uploaded_file is not None:
118
  try:
119
- load_models() # Ensure models are loaded once
120
  image_bytes = uploaded_file.getvalue()
121
- # Display the user-uploaded image
122
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
123
  st.image(image, caption="Your Amazing Picture!", use_column_width=True)
124
  st.markdown("<h3 style='text-align: center;'>Ready for your story?</h3>", unsafe_allow_html=True)
 
43
  3. TTS: Converts text into audio.
44
  """
45
  if "captioner" not in st.session_state:
 
46
  st.session_state.captioner = pipeline(
47
  "image-to-text",
48
  model="Salesforce/blip-image-captioning-base"
 
62
  @st.cache_data(show_spinner=False)
63
  def get_caption(image_bytes):
64
  """
65
+ Converts image bytes into a lower resolution image (256x256 maximum)
66
+ and generates a caption.
67
  """
68
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
69
+ # Resize to speed up processing
70
  image.thumbnail((256, 256))
71
  caption = st.session_state.captioner(image)[0]["generated_text"]
72
  return caption
 
74
  @st.cache_data(show_spinner=False)
75
  def get_story(caption):
76
  """
77
+ Generates a humorous and engaging children's story based on the caption.
78
+ Uses a prompt to instruct the model and limits token generation to 80 tokens.
 
79
  """
80
  prompt = (
81
  f"Write a funny, warm, and imaginative children's story for ages 3-10, 50-100 words, "
 
86
  )
87
  raw_story = st.session_state.storyer(
88
  prompt,
89
+ max_new_tokens=80,
90
  do_sample=True,
91
  temperature=0.7,
92
  top_p=0.9,
 
98
  @st.cache_data(show_spinner=False)
99
  def get_audio(story):
100
  """
101
+ Converts the generated story text into audio.
102
+ Splits the text into 300-character chunks to reduce repeated TTS calls,
103
+ concatenates the resulting audio chunks, and returns an in-memory WAV buffer.
104
  """
105
  chunks = textwrap.wrap(story, width=300)
106
  audio_chunks = [st.session_state.tts(chunk)["audio"].squeeze() for chunk in chunks]
 
114
  uploaded_file = st.file_uploader("Choose a Picture...", type=["jpg", "jpeg", "png"])
115
  if uploaded_file is not None:
116
  try:
117
+ load_models() # Make sure models are loaded
118
  image_bytes = uploaded_file.getvalue()
119
+ # Display the uploaded image
120
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
121
  st.image(image, caption="Your Amazing Picture!", use_column_width=True)
122
  st.markdown("<h3 style='text-align: center;'>Ready for your story?</h3>", unsafe_allow_html=True)