Spaces:
Sleeping
Sleeping
File size: 7,128 Bytes
e48bf39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import streamlit as st
from PIL import Image
import io
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline
from gtts import gTTS
import os
import base64
import time
# Set page configuration
st.set_page_config(
page_title="Storyteller for Kids",
page_icon="π",
layout="centered"
)
# Custom CSS
st.markdown("""
<style>
.main {
background-color: #f5f7ff;
}
.stTitle {
color: #3366cc;
font-family: 'Comic Sans MS', cursive;
}
.stHeader {
font-family: 'Comic Sans MS', cursive;
}
.stImage {
border-radius: 15px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
}
.story-container {
background-color: #e6f2ff;
padding: 20px;
border-radius: 10px;
border: 2px dashed #3366cc;
font-size: 18px;
line-height: 1.6;
}
</style>
""", unsafe_allow_html=True)
# Title and description
st.title("π§Έ Kid's Storyteller π§Έ")
st.markdown("### Upload an image and I'll tell you a magical story about it!")
# Function to load image captioning model
@st.cache_resource
def load_caption_model():
try:
with st.spinner("Loading image captioning model... (This may take a minute)"):
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
return processor, model, None
except Exception as e:
return None, None, str(e)
# Function to load story generation model
@st.cache_resource
def load_story_model():
try:
with st.spinner("Loading story generation model... (This may take a minute)"):
story_generator = pipeline("text-generation", model="gpt2")
return story_generator, None
except Exception as e:
return None, str(e)
# Function to generate caption from image
def generate_caption(image, processor, model):
inputs = processor(image, return_tensors="pt")
out = model.generate(**inputs, max_length=50)
caption = processor.decode(out[0], skip_special_tokens=True)
return caption
# Function to generate story from caption
def generate_story(caption, story_generator):
# Make the prompt child-friendly and whimsical
prompt = f"Once upon a time in a magical land, {caption}. The children were amazed when "
result = story_generator(prompt, max_length=150, num_return_sequences=1, temperature=0.8)
story = result[0]['generated_text']
# Make sure the story is between 50-100 words
story_words = story.split()
if len(story_words) > 100:
story = ' '.join(story_words[:100])
# Add a closing sentence
story += ". And they all lived happily ever after."
elif len(story_words) < 50:
# If too short, generate more
additional = story_generator(story, max_length=150, num_return_sequences=1)
story = additional[0]['generated_text']
story_words = story.split()
if len(story_words) > 100:
story = ' '.join(story_words[:100])
story += ". And they all lived happily ever after."
return story
# Function to convert text to speech and create audio player
def text_to_speech(text):
try:
tts = gTTS(text=text, lang='en', slow=False)
audio_file = "story_audio.mp3"
tts.save(audio_file)
# Create audio player
with open(audio_file, "rb") as file:
audio_bytes = file.read()
audio_b64 = base64.b64encode(audio_bytes).decode()
audio_player = f"""
<audio controls autoplay>
<source src="data:audio/mp3;base64,{audio_b64}" type="audio/mp3">
Your browser does not support the audio element.
</audio>
"""
return audio_player, None
except Exception as e:
return None, str(e)
# Main application flow
try:
# Load models with status checks
with st.spinner("Loading AI models... This may take a moment the first time you run the app."):
caption_processor, caption_model, caption_error = load_caption_model()
story_model, story_error = load_story_model()
if caption_error:
st.error(f"Error loading caption model: {caption_error}")
if story_error:
st.error(f"Error loading story model: {story_error}")
# If models loaded successfully
if caption_processor and caption_model and story_model:
# Show example images for kids to understand
st.markdown("### π Examples of images you can upload:")
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("π± Pets")
with col2:
st.markdown("π° Places")
with col3:
st.markdown("π§© Toys")
# File uploader
uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Display the uploaded image
image_bytes = uploaded_file.getvalue()
image = Image.open(io.BytesIO(image_bytes))
st.image(image, caption='Uploaded Image', use_column_width=True, output_format="JPEG")
with st.spinner('Creating your story... π'):
# Generate caption
caption = generate_caption(image, caption_processor, caption_model)
# Generate story
story = generate_story(caption, story_model)
# Display the story with some styling
st.markdown("## π Your Magical Story")
st.markdown(f"<div class='story-container'>{story}</div>",
unsafe_allow_html=True)
# Convert to speech and play
st.markdown("## π Listen to the Story")
audio_player, audio_error = text_to_speech(story)
if audio_player:
st.markdown(audio_player, unsafe_allow_html=True)
else:
st.error(f"Could not generate audio: {audio_error}")
# Download options
st.download_button(
label="Download Story (Text)",
data=story,
file_name="my_story.txt",
mime="text/plain"
)
else:
st.warning("Some AI models didn't load correctly. Please refresh the page or try again later.")
except Exception as e:
st.error(f"An error occurred: {e}")
st.markdown("Please try again with a different image.")
# Footer
st.markdown("---")
st.markdown("Created for young storytellers aged 3-10 years old π")
st.markdown("Powered by Hugging Face Transformers π€") |