Szeyu's picture
Update app.py
822643b verified
raw
history blame
3.6 kB
import streamlit as st
from transformers import pipeline
import textwrap
import numpy as np
import soundfile as sf
import tempfile
import os
from PIL import Image
# Initialize pipelines with caching to avoid reloading
@st.cache_resource
def load_pipelines():
# Load pipeline for generating captions from images
captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
# Load pipeline for generating stories from text prompts
storyer = pipeline("text-generation", model="aspis/gpt2-genre-story-generation")
# Load pipeline for converting text to speech
tts = pipeline("text-to-speech", model="facebook/mms-tts-eng")
return captioner, storyer, tts
# Load the pipelines once and reuse them
captioner, storyer, tts = load_pipelines()
# Function to generate caption, story, and audio from an uploaded image
def generate_content(image):
# Convert the uploaded image to a PIL image format
pil_image = Image.open(image)
# Generate a caption based on the image content
caption = captioner(pil_image)[0]["generated_text"]
st.write("**๐ŸŒŸ What's in the picture: ๐ŸŒŸ**")
st.write(caption)
# Create a prompt for generating a children's story
prompt = (
f"Write a funny, warm children's story for ages 3-10, 50โ€“100 words, "
f"in third-person narrative, that describes this scene exactly: {caption} "
f"mention the exact place or venue within {caption}"
)
# Generate the story based on the prompt
raw = storyer(
prompt,
max_new_tokens=150,
temperature=0.7,
top_p=0.9,
no_repeat_ngram_size=2,
return_full_text=False
)[0]["generated_text"].strip()
# Trim the generated story to a maximum of 100 words
words = raw.split()
story = " ".join(words[:100])
st.write("**๐Ÿ“– Your funny story: ๐Ÿ“–**")
st.write(story)
# Split the story into chunks of 200 characters for text-to-speech processing
chunks = textwrap.wrap(story, width=200)
# Generate and concatenate audio for each text chunk
audio = np.concatenate([tts(chunk)["audio"].squeeze() for chunk in chunks])
# Save the concatenated audio to a temporary WAV file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
sf.write(temp_file.name, audio, tts.model.config.sampling_rate)
temp_file_path = temp_file.name
return caption, story, temp_file_path
# Streamlit UI for the application
st.title("โœจ Magic Story Maker โœจ")
st.markdown("Upload a picture to make a funny story and hear it too! ๐Ÿ“ธ")
# File uploader for image input
uploaded_image = st.file_uploader("Choose your picture", type=["jpg", "jpeg", "png"], help="Pick a photo to start the magic!")
# Placeholder image URL (replace with an actual URL of a child-friendly image)
placeholder_url = "https://example.com/placeholder_image.jpg"
if uploaded_image is None:
st.image(placeholder_url, caption="Upload your picture here! ๐Ÿ“ท", use_column_width=True)
else:
st.image(uploaded_image, caption="Your Picture ๐ŸŒŸ", use_column_width=True)
if st.button("โœจ Make My Story! โœจ", help="Click to create your magic story"):
if uploaded_image is not None:
with st.spinner("๐Ÿ”ฎ Creating your magical story..."):
caption, story, audio_path = generate_content(uploaded_image)
st.success("๐ŸŽ‰ Your story is ready! ๐ŸŽ‰")
st.audio(audio_path, format="audio/wav")
os.remove(audio_path)
else:
st.warning("Please upload a picture first! ๐Ÿ“ธ")