File size: 7,151 Bytes
64fd107 ed3e053 64fd107 ed3e053 64fd107 ed3e053 64fd107 b540ff3 64fd107 ed3e053 64fd107 e1351c4 95bff35 64fd107 95bff35 ed3e053 64fd107 ed3e053 e1351c4 ed3e053 64fd107 95bff35 03cd04b 64fd107 2e8ed85 e1351c4 2e8ed85 64fd107 2e8ed85 64fd107 e1351c4 64fd107 ed3e053 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import streamlit as st
from transformers import pipeline
from PIL import Image
import io, textwrap, numpy as np, soundfile as sf
import logging
# Set up logging for debugging
logging.basicConfig(level=logging.INFO)
# ------------------ Streamlit Page Configuration ------------------
st.set_page_config(
page_title="Picture to Story Magic",
page_icon="🦄",
layout="centered"
)
# ------------------ Custom CSS for a Colorful Background ------------------
st.markdown(
"""
<style>
body {
background-color: #FDEBD0;
}
</style>
""",
unsafe_allow_html=True
)
# ------------------ Playful Header for Young Users ------------------
st.markdown(
"""
<h1 style='text-align: center; color: #ff66cc;'>Picture to Story Magic!</h1>
<p style='text-align: center; font-size: 24px;'>
Hi little artist! Upload your picture and let us create a fun story just for you! 🎉
</p>
""",
unsafe_allow_html=True
)
# ------------------ Lazy Model Loading ------------------
def load_models():
"""
Lazy-load the required pipelines and store them in session state.
"""
if "captioner" not in st.session_state:
st.session_state.captioner = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-large"
)
if "storyer" not in st.session_state:
try:
st.session_state.storyer = pipeline(
"text-generation",
model="aspis/gpt2-genre-story-generation"
)
except Exception as e:
logging.warning(f"Failed to load aspis/gpt2-genre-story-generation: {e}. Falling back to gpt2.")
st.session_state.storyer = pipeline(
"text-generation",
model="gpt2"
)
if "tts" not in st.session_state:
st.session_state.tts = pipeline(
"text-to-speech",
model="facebook/mms-tts-eng"
)
# ------------------ Caching Functions ------------------
@st.cache_data(show_spinner=False)
def get_caption(image_bytes):
"""
Converts image bytes into a lower resolution image (maximum 256x256)
and generates a caption.
"""
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image.thumbnail((256, 256))
caption = st.session_state.captioner(image)[0]["generated_text"]
return caption
@st.cache_data(show_spinner=False)
def get_story(caption):
"""
Generates a humorous and engaging children's story based on the caption.
Uses a simplified prompt and robust output parsing.
"""
prompt = (
f"Create a funny, warm children's story (50-100 words) for ages 3-10 based on: {caption}. "
f"Use third-person narrative, as if playfully describing the scene."
)
try:
result = st.session_state.storyer(
prompt,
max_new_tokens=150, # Increased to allow more room for story
do_sample=True,
temperature=0.8, # Slightly higher for creativity
top_p=0.9,
return_full_text=False
)
logging.info(f"Story generation raw result: {result}")
# Extract generated text
raw_story = result[0].get("generated_text", "").strip()
# If no meaningful output, generate a fallback story
if not raw_story or len(raw_story.split()) < 10:
logging.warning("Generated story too short or empty. Using fallback.")
raw_story = (
f"Once upon a time, in a land of {caption}, a silly squirrel named Sammy "
f"found a shiny treasure! He danced with joy, but oh no! It was a magic acorn! "
f"It grew into a giant tree, and Sammy climbed to the top, giggling all the way. "
f"The tree sang funny songs, and all the animals joined in for a big party!"
)
# Truncate to 100 words
words = raw_story.split()
story = " ".join(words[:100])
return story
except Exception as e:
logging.error(f"Story generation failed: {e}")
# Fallback story in case of errors
return (
f"Once upon a time, in a land of {caption}, a silly squirrel named Sammy "
f"found a shiny treasure! He danced with joy, but oh no! It was a magic acorn! "
f"It grew into a giant tree, and Sammy climbed to the top, giggling all the way."
)
@st.cache_data(show_spinner=False)
def get_audio(story):
"""
Converts the generated story text into audio.
Splits the text into 300-character chunks to reduce repeated TTS calls.
"""
chunks = textwrap.wrap(story, width=300)
audio_chunks = []
for chunk in chunks:
try:
output = st.session_state.tts(chunk)
if isinstance(output, list):
output = output[0]
if "audio" in output:
audio_array = np.array(output["audio"]).squeeze()
audio_chunks.append(audio_array)
except Exception:
continue
if not audio_chunks:
sr = st.session_state.tts.model.config.sampling_rate
audio = np.zeros(sr, dtype=np.float32)
else:
audio = np.concatenate(audio_chunks)
buffer = io.BytesIO()
sf.write(buffer, audio, st.session_state.tts.model.config.sampling_rate, format="WAV")
buffer.seek(0)
return buffer
# ------------------ Main App Logic ------------------
uploaded_file = st.file_uploader("Choose a Picture...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
try:
load_models() # Ensure models are loaded
image_bytes = uploaded_file.getvalue()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
st.image(image, caption="Your Amazing Picture!", use_column_width=True)
st.markdown("<h3 style='text-align: center;'>Ready for your story?</h3>", unsafe_allow_html=True)
if st.button("Story, Please!"):
with st.spinner("Generating caption..."):
caption = get_caption(image_bytes)
st.markdown("<h3 style='text-align: center;'>Caption:</h3>", unsafe_allow_html=True)
st.write(caption)
with st.spinner("Generating story..."):
story = get_story(caption)
st.markdown("<h3 style='text-align: center;'>Your Story:</h3>", unsafe_allow_html=True)
if not story.strip():
st.write("No story was generated. Please try again.")
else:
st.write(story)
with st.spinner("Generating audio..."):
audio_buffer = get_audio(story)
st.audio(audio_buffer, format="audio/wav", start_time=0)
st.markdown(
"<p style='text-align: center; font-weight: bold;'>Enjoy your magical story! 🎶</p>",
unsafe_allow_html=True
)
except Exception as e:
st.error("Oops! Something went wrong. Please try a different picture or check the file format!")
st.error(f"Error details: {e}") |