import streamlit as st from transformers import pipeline from PIL import Image import io, textwrap, numpy as np, soundfile as sf # ------------------ Streamlit Page Configuration ------------------ st.set_page_config( page_title="Picture to Story Magic", # App title on browser tab page_icon="🦄", # Fun unicorn icon layout="centered" ) # ------------------ Custom CSS for a Colorful Background ------------------ st.markdown( """ """, unsafe_allow_html=True ) # ------------------ Playful Header for Young Users ------------------ st.markdown( """
Hi little artist! Upload your picture and let us create a fun story just for you! 🎉
""", unsafe_allow_html=True ) # ------------------ Lazy Model Loading ------------------ def load_models(): """ Lazy-load the required pipelines and store them in session state. Pipelines: 1. Captioner: Generates descriptive text from an image using a lighter model. 2. Storyer: Generates a humorous children's story using aspis/gpt2-genre-story-generation. 3. TTS: Converts text into audio. """ if "captioner" not in st.session_state: st.session_state.captioner = pipeline( "image-to-text", model="Salesforce/blip-image-captioning-base" ) if "storyer" not in st.session_state: st.session_state.storyer = pipeline( "text-generation", model="aspis/gpt2-genre-story-generation" ) if "tts" not in st.session_state: st.session_state.tts = pipeline( "text-to-speech", model="facebook/mms-tts-eng" ) # ------------------ Caching Functions ------------------ @st.cache_data(show_spinner=False) def get_caption(image_bytes): """ Converts image bytes into a lower resolution image (256x256 maximum) and generates a caption. """ image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Resize to speed up processing image.thumbnail((256, 256)) caption = st.session_state.captioner(image)[0]["generated_text"] return caption @st.cache_data(show_spinner=False) def get_story(caption): """ Generates a humorous and engaging children's story based on the caption. Uses a prompt to instruct the model and limits token generation to 80 tokens. """ prompt = ( f"Write a funny, warm, and imaginative children's story for ages 3-10, 50-100 words, " f"{caption}\nStory: in third-person narrative, as if the author is playfully describing the scene in the image." ) raw_story = st.session_state.storyer( prompt, max_new_tokens=80, do_sample=True, temperature=0.7, top_p=0.9, return_full_text=False )[0]["generated_text"].strip() words = raw_story.split() return " ".join(words[:100]) @st.cache_data(show_spinner=False) def get_audio(story): """ Converts the generated story text into audio. Splits the text into 300-character chunks to reduce repeated TTS calls, concatenates the resulting audio chunks, and returns an in-memory WAV buffer. """ chunks = textwrap.wrap(story, width=300) audio_chunks = [st.session_state.tts(chunk)["audio"].squeeze() for chunk in chunks] audio = np.concatenate(audio_chunks) buffer = io.BytesIO() sf.write(buffer, audio, st.session_state.tts.model.config.sampling_rate, format="WAV") buffer.seek(0) return buffer # ------------------ Main App Logic ------------------ uploaded_file = st.file_uploader("Choose a Picture...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: try: load_models() # Make sure models are loaded image_bytes = uploaded_file.getvalue() # Display the uploaded image image = Image.open(io.BytesIO(image_bytes)).convert("RGB") st.image(image, caption="Your Amazing Picture!", use_column_width=True) st.markdown("Enjoy your magical story! 🎶
", unsafe_allow_html=True ) except Exception as e: st.error("Oops! Something went wrong. Please try a different picture or check the file format!") st.error(f"Error details: {e}")