|
import streamlit as st |
|
from transformers import pipeline |
|
from PIL import Image |
|
import io, textwrap, numpy as np, soundfile as sf |
|
|
|
|
|
st.set_page_config( |
|
page_title="Picture to Story Magic", |
|
page_icon="🦄", |
|
layout="centered" |
|
) |
|
|
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
body { |
|
background-color: #FDEBD0; /* A soft pastel color */ |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
st.markdown( |
|
""" |
|
<h1 style='text-align: center; color: #ff66cc;'>Picture to Story Magic!</h1> |
|
<p style='text-align: center; font-size: 24px;'> |
|
Hi little artist! Upload your picture and let us create a fun story just for you! 🎉 |
|
</p> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
def load_models(): |
|
""" |
|
Lazy-load the required pipelines and store them in session state. |
|
Pipelines: |
|
1. Captioner: Generates descriptive text from an image using a lighter model. |
|
2. Storyer: Generates a humorous children's story using aspis/gpt2-genre-story-generation. |
|
3. TTS: Converts text into audio. |
|
""" |
|
if "captioner" not in st.session_state: |
|
st.session_state.captioner = pipeline( |
|
"image-to-text", |
|
model="Salesforce/blip-image-captioning-base" |
|
) |
|
if "storyer" not in st.session_state: |
|
st.session_state.storyer = pipeline( |
|
"text-generation", |
|
model="aspis/gpt2-genre-story-generation" |
|
) |
|
if "tts" not in st.session_state: |
|
st.session_state.tts = pipeline( |
|
"text-to-speech", |
|
model="facebook/mms-tts-eng" |
|
) |
|
|
|
|
|
@st.cache_data(show_spinner=False) |
|
def get_caption(image_bytes): |
|
""" |
|
Converts image bytes into a lower resolution image (256x256 maximum) |
|
and generates a caption. |
|
""" |
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
|
image.thumbnail((256, 256)) |
|
caption = st.session_state.captioner(image)[0]["generated_text"] |
|
return caption |
|
|
|
@st.cache_data(show_spinner=False) |
|
def get_story(caption): |
|
""" |
|
Generates a humorous and engaging children's story based on the caption. |
|
Uses a prompt to instruct the model and limits token generation to 80 tokens. |
|
""" |
|
prompt = ( |
|
f"Write a funny, warm, and imaginative children's story for ages 3-10, 50-100 words, " |
|
f"{caption}\nStory: in third-person narrative, as if the author is playfully describing the scene in the image." |
|
) |
|
raw_story = st.session_state.storyer( |
|
prompt, |
|
max_new_tokens=80, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.9, |
|
return_full_text=False |
|
)[0]["generated_text"].strip() |
|
words = raw_story.split() |
|
return " ".join(words[:100]) |
|
|
|
@st.cache_data(show_spinner=False) |
|
def get_audio(story): |
|
""" |
|
Converts the generated story text into audio. |
|
Splits the text into 300-character chunks to reduce repeated TTS calls. |
|
Checks each chunk, and if no valid audio is produced, creates a brief default silent audio. |
|
""" |
|
chunks = textwrap.wrap(story, width=300) |
|
audio_chunks = [] |
|
for chunk in chunks: |
|
try: |
|
output = st.session_state.tts(chunk) |
|
|
|
if isinstance(output, list): |
|
output = output[0] |
|
if "audio" in output: |
|
|
|
audio_array = np.array(output["audio"]).squeeze() |
|
audio_chunks.append(audio_array) |
|
except Exception as e: |
|
|
|
continue |
|
|
|
|
|
if not audio_chunks: |
|
sr = st.session_state.tts.model.config.sampling_rate |
|
audio = np.zeros(sr, dtype=np.float32) |
|
else: |
|
audio = np.concatenate(audio_chunks) |
|
|
|
buffer = io.BytesIO() |
|
sf.write(buffer, audio, st.session_state.tts.model.config.sampling_rate, format="WAV") |
|
buffer.seek(0) |
|
return buffer |
|
|
|
|
|
uploaded_file = st.file_uploader("Choose a Picture...", type=["jpg", "jpeg", "png"]) |
|
if uploaded_file is not None: |
|
try: |
|
load_models() |
|
image_bytes = uploaded_file.getvalue() |
|
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
st.image(image, caption="Your Amazing Picture!", use_column_width=True) |
|
st.markdown("<h3 style='text-align: center;'>Ready for your story?</h3>", unsafe_allow_html=True) |
|
|
|
if st.button("Story, Please!"): |
|
with st.spinner("Generating caption..."): |
|
caption = get_caption(image_bytes) |
|
st.markdown("<h3 style='text-align: center;'>Caption:</h3>", unsafe_allow_html=True) |
|
st.write(caption) |
|
|
|
with st.spinner("Generating story..."): |
|
story = get_story(caption) |
|
st.markdown("<h3 style='text-align: center;'>Your Story:</h3>", unsafe_allow_html=True) |
|
st.write(story) |
|
|
|
with st.spinner("Generating audio..."): |
|
audio_buffer = get_audio(story) |
|
st.audio(audio_buffer, format="audio/wav", start_time=0) |
|
st.markdown( |
|
"<p style='text-align: center; font-weight: bold;'>Enjoy your magical story! 🎶</p>", |
|
unsafe_allow_html=True |
|
) |
|
except Exception as e: |
|
st.error("Oops! Something went wrong. Please try a different picture or check the file format!") |
|
st.error(f"Error details: {e}") |
|
|