Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
import io | |
import torch | |
from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline | |
from gtts import gTTS | |
import os | |
import base64 | |
import time | |
# Set page configuration | |
st.set_page_config( | |
page_title="Storyteller for Kids", | |
page_icon="π", | |
layout="centered" | |
) | |
# Custom CSS | |
st.markdown(""" | |
<style> | |
.main { | |
background-color: #f5f7ff; | |
} | |
.stTitle { | |
color: #3366cc; | |
font-family: 'Comic Sans MS', cursive; | |
} | |
.stHeader { | |
font-family: 'Comic Sans MS', cursive; | |
} | |
.stImage { | |
border-radius: 15px; | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); | |
} | |
.story-container { | |
background-color: #e6f2ff; | |
padding: 20px; | |
border-radius: 10px; | |
border: 2px dashed #3366cc; | |
font-size: 18px; | |
line-height: 1.6; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Title and description | |
st.title("π§Έ Kid's Storyteller π§Έ") | |
st.markdown("### Upload an image and I'll tell you a magical story about it!") | |
# Function to load image captioning model | |
def load_caption_model(): | |
try: | |
with st.spinner("Loading image captioning model... (This may take a minute)"): | |
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
return processor, model, None | |
except Exception as e: | |
return None, None, str(e) | |
# Function to load story generation model | |
def load_story_model(): | |
try: | |
with st.spinner("Loading story generation model... (This may take a minute)"): | |
story_generator = pipeline("text-generation", model="gpt2") | |
return story_generator, None | |
except Exception as e: | |
return None, str(e) | |
# Function to generate caption from image | |
def generate_caption(image, processor, model): | |
inputs = processor(image, return_tensors="pt") | |
out = model.generate(**inputs, max_length=50) | |
caption = processor.decode(out[0], skip_special_tokens=True) | |
return caption | |
# Function to generate story from caption | |
def generate_story(caption, story_generator): | |
# Make the prompt child-friendly and whimsical | |
prompt = f"Once upon a time in a magical land, {caption}. The children were amazed when " | |
result = story_generator(prompt, max_length=150, num_return_sequences=1, temperature=0.8) | |
story = result[0]['generated_text'] | |
# Make sure the story is between 50-100 words | |
story_words = story.split() | |
if len(story_words) > 100: | |
story = ' '.join(story_words[:100]) | |
# Add a closing sentence | |
story += ". And they all lived happily ever after." | |
elif len(story_words) < 50: | |
# If too short, generate more | |
additional = story_generator(story, max_length=150, num_return_sequences=1) | |
story = additional[0]['generated_text'] | |
story_words = story.split() | |
if len(story_words) > 100: | |
story = ' '.join(story_words[:100]) | |
story += ". And they all lived happily ever after." | |
return story | |
# Function to convert text to speech and create audio player | |
def text_to_speech(text): | |
try: | |
tts = gTTS(text=text, lang='en', slow=False) | |
audio_file = "story_audio.mp3" | |
tts.save(audio_file) | |
# Create audio player | |
with open(audio_file, "rb") as file: | |
audio_bytes = file.read() | |
audio_b64 = base64.b64encode(audio_bytes).decode() | |
audio_player = f""" | |
<audio controls autoplay> | |
<source src="data:audio/mp3;base64,{audio_b64}" type="audio/mp3"> | |
Your browser does not support the audio element. | |
</audio> | |
""" | |
return audio_player, None | |
except Exception as e: | |
return None, str(e) | |
# Main application flow | |
try: | |
# Load models with status checks | |
with st.spinner("Loading AI models... This may take a moment the first time you run the app."): | |
caption_processor, caption_model, caption_error = load_caption_model() | |
story_model, story_error = load_story_model() | |
if caption_error: | |
st.error(f"Error loading caption model: {caption_error}") | |
if story_error: | |
st.error(f"Error loading story model: {story_error}") | |
# If models loaded successfully | |
if caption_processor and caption_model and story_model: | |
# Show example images for kids to understand | |
st.markdown("### π Examples of images you can upload:") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.markdown("π± Pets") | |
with col2: | |
st.markdown("π° Places") | |
with col3: | |
st.markdown("π§© Toys") | |
# File uploader | |
uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
# Display the uploaded image | |
image_bytes = uploaded_file.getvalue() | |
image = Image.open(io.BytesIO(image_bytes)) | |
st.image(image, caption='Uploaded Image', use_column_width=True, output_format="JPEG") | |
with st.spinner('Creating your story... π'): | |
# Generate caption | |
caption = generate_caption(image, caption_processor, caption_model) | |
# Generate story | |
story = generate_story(caption, story_model) | |
# Display the story with some styling | |
st.markdown("## π Your Magical Story") | |
st.markdown(f"<div class='story-container'>{story}</div>", | |
unsafe_allow_html=True) | |
# Convert to speech and play | |
st.markdown("## π Listen to the Story") | |
audio_player, audio_error = text_to_speech(story) | |
if audio_player: | |
st.markdown(audio_player, unsafe_allow_html=True) | |
else: | |
st.error(f"Could not generate audio: {audio_error}") | |
# Download options | |
st.download_button( | |
label="Download Story (Text)", | |
data=story, | |
file_name="my_story.txt", | |
mime="text/plain" | |
) | |
else: | |
st.warning("Some AI models didn't load correctly. Please refresh the page or try again later.") | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
st.markdown("Please try again with a different image.") | |
# Footer | |
st.markdown("---") | |
st.markdown("Created for young storytellers aged 3-10 years old π") | |
st.markdown("Powered by Hugging Face Transformers π€") |