Szeyu's picture
Update app.py
f0a6b70 verified
raw
history blame
3.56 kB
import streamlit as st
from transformers import pipeline
import textwrap
import numpy as np
import soundfile as sf
import tempfile
import os
from PIL import Image
# Initialize pipelines with caching to avoid reloading
@st.cache_resource
def load_pipelines():
# Load pipeline for generating captions from images
captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
# Load pipeline for generating stories from text prompts
storyer = pipeline("text-generation", model="aspis/gpt2-genre-story-generation")
# Load pipeline for converting text to speech
tts = pipeline("text-to-speech", model="facebook/mms-tts-eng")
return captioner, storyer, tts
# Load the pipelines once and reuse them
captioner, storyer, tts = load_pipelines()
# Function to generate caption, story, and audio from an uploaded image
def generate_content(image):
# Convert the uploaded image to a PIL image format
pil_image = Image.open(image)
# Generate a caption based on the image content
caption = captioner(pil_image)[0]["generated_text"]
st.write("**Caption:**", caption)
# Create a prompt for generating a children's story
prompt = (
f"Write a funny, warm children's story for ages 3-10, 50–100 words, "
f"in third-person narrative, that describes this scene exactly: {caption} "
f"mention the exact place or venue within {caption}"
)
# Generate the story based on the prompt
raw = storyer(
prompt,
max_new_tokens=150,
temperature=0.7,
top_p=0.9,
no_repeat_ngram_size=2,
return_full_text=False
)[0]["generated_text"].strip()
# Trim the generated story to a maximum of 100 words
words = raw.split()
story = " ".join(words[:100])
st.write("**Story:**", story)
# Split the story into chunks of 200 characters for text-to-speech processing
chunks = textwrap.wrap(story, width=200)
# Generate and concatenate audio for each text chunk
audio = np.concatenate([tts(chunk)["audio"].squeeze() for chunk in chunks])
# Save the concatenated audio to a temporary WAV file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
sf.write(temp_file.name, audio, tts.model.config.sampling_rate)
temp_file_path = temp_file.name
return caption, story, temp_file_path
# Streamlit UI for the application
st.title("Image to Children's Story and Audio")
st.markdown("""
Upload an image below to generate a caption, a funny children's story,
and an audio narration based on the image. The story will be tailored
for children aged 3-10.
""")
# File uploader for image input
uploaded_image = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"], help="Supported formats: JPG, JPEG, PNG")
if uploaded_image is not None:
# Display the uploaded image with a caption
st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
# Button to trigger content generation
if st.button("Generate Story and Audio", help="Click to create the story and audio"):
# Show a spinner while content is being generated
with st.spinner("Generating your story and audio narration..."):
caption, story, audio_path = generate_content(uploaded_image)
# Display the audio player with the generated narration
st.audio(audio_path, format="audio/wav")
# Remove the temporary audio file after use
os.remove(audio_path)