Szeyu's picture
Update app.py
e35a81f verified
raw
history blame
2.47 kB
import streamlit as st
from transformers import pipeline
import textwrap
import numpy as np
import soundfile as sf
import tempfile
import os
from PIL import Image
# Initialize pipelines
@st.cache_resource
def load_pipelines():
captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
storyer = pipeline("text-generation", model="aspis/gpt2-genre-story-generation")
tts = pipeline("text-to-speech", model="facebook/mms-tts-eng")
return captioner, storyer, tts
captioner, storyer, tts = load_pipelines()
# Main logic
def generate_content(image):
# Convert Streamlit uploaded image to PIL image
pil_image = Image.open(image)
# Generate caption
caption = captioner(pil_image)[0]["generated_text"]
st.write("**Caption:**", caption)
# Generate story
prompt = (
f"Write a funny, warm children's story for ages 3-10, 50–100 words, "
f"in third-person narrative, that describes this scene exactly: {caption} "
f"mention the exact place or venue within {caption}"
)
raw = storyer(
prompt,
max_new_tokens=150,
temperature=0.7,
top_p=0.9,
no_repeat_ngram_size=2,
return_full_text=False
)[0]["generated_text"].strip()
# Trim to max 100 words
words = raw.split()
story = " ".join(words[:100])
st.write("**Story:**", story)
# Convert story to speech
chunks = textwrap.wrap(story, width=200)
audio = np.concatenate([tts(chunk)["audio"].squeeze() for chunk in chunks])
# Save audio to temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
sf.write(temp_file.name, audio, tts.model.config.sampling_rate)
temp_file_path = temp_file.name
return caption, story, temp_file_path
# Streamlit UI
st.title("Image to Children's Story and Audio")
st.write("Upload an image to generate a caption, a children's story, and an audio narration.")
uploaded_image = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])
if uploaded_image is not None:
st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
if st.button("Generate Story and Audio"):
with st.spinner("Generating content..."):
caption, story, audio_path = generate_content(uploaded_image)
st.audio(audio_path, format="audio/wav")
# Clean up temporary file
os.remove(audio_path)