Billy_Space / app.py
BillyZ1129's picture
Update app.py
c0da145 verified
raw
history blame
6.93 kB
import streamlit as st
from PIL import Image
import io
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline
from gtts import gTTS
import os
import base64
import time
# Set page configuration
st.set_page_config(
page_title="Storyteller for Kids",
page_icon="πŸ“š",
layout="centered"
)
# Custom CSS
st.markdown("""
<style>
.main {
background-color: #f5f7ff;
}
.stTitle {
color: #3366cc;
font-family: 'Comic Sans MS', cursive;
}
.stHeader {
font-family: 'Comic Sans MS', cursive;
}
.stImage {
border-radius: 15px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
}
.story-container {
background-color: #e6f2ff;
padding: 20px;
border-radius: 10px;
border: 2px dashed #3366cc;
font-size: 18px;
line-height: 1.6;
}
</style>
""", unsafe_allow_html=True)
# Title and description
st.title("🧸 Kid's Storyteller 🧸")
st.markdown("### Upload an image and I'll tell you a magical story about it!")
# Function to load image captioning model
@st.cache_resource
def load_caption_model():
try:
with st.spinner("Loading image captioning model... (This may take a minute)"):
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
return processor, model, None
except Exception as e:
return None, None, str(e)
# Function to load story generation model
@st.cache_resource
def load_story_model():
try:
with st.spinner("Loading story generation model... (This may take a minute)"):
story_generator = pipeline("text-generation", model="gpt2")
return story_generator, None
except Exception as e:
return None, str(e)
# Function to generate caption from image
def generate_caption(image, processor, model):
inputs = processor(image, return_tensors="pt")
out = model.generate(**inputs, max_length=50)
caption = processor.decode(out[0], skip_special_tokens=True)
return caption
# Function to generate story from caption
def generate_story(caption, story_generator):
# Make the prompt child-friendly and whimsical
prompt = f"Once upon a time in a magical land, {caption}. The children were amazed when "
result = story_generator(prompt, max_length=150, num_return_sequences=1, temperature=0.8)
story = result[0]['generated_text']
# Make sure the story is between 50-100 words
story_words = story.split()
if len(story_words) > 100:
story = ' '.join(story_words[:100])
# Add a closing sentence
story += ". And they all lived happily ever after."
elif len(story_words) < 50:
# If too short, generate more
additional = story_generator(story, max_length=150, num_return_sequences=1)
story = additional[0]['generated_text']
story_words = story.split()
if len(story_words) > 100:
story = ' '.join(story_words[:100])
story += ". And they all lived happily ever after."
return story
# Function to convert text to speech and create audio player
def text_to_speech(text):
try:
tts = gTTS(text=text, lang='en', slow=False)
audio_file = "story_audio.mp3"
tts.save(audio_file)
# Create audio player
with open(audio_file, "rb") as file:
audio_bytes = file.read()
audio_b64 = base64.b64encode(audio_bytes).decode()
audio_player = f"""
<audio controls autoplay>
<source src="data:audio/mp3;base64,{audio_b64}" type="audio/mp3">
Your browser does not support the audio element.
</audio>
"""
return audio_player, None
except Exception as e:
return None, str(e)
# Main application flow
try:
# Load models with status checks
with st.spinner("Loading AI models... This may take a moment the first time you run the app."):
caption_processor, caption_model, caption_error = load_caption_model()
story_model, story_error = load_story_model()
if caption_error:
st.error(f"Error loading caption model: {caption_error}")
if story_error:
st.error(f"Error loading story model: {story_error}")
# If models loaded successfully
if caption_processor and caption_model and story_model:
# Show example images for kids to understand
st.markdown("### 🌟 Examples of images you can upload:")
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("🐱 Pets")
with col2:
st.markdown("🏰 Places")
with col3:
st.markdown("🧩 Toys")
# File uploader
uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Display the uploaded image
image_bytes = uploaded_file.getvalue()
image = Image.open(io.BytesIO(image_bytes))
st.image(image, caption='Uploaded Image', use_column_width=True, output_format="JPEG")
with st.spinner('Creating your story... πŸ“'):
# Generate caption
caption = generate_caption(image, caption_processor, caption_model)
# Generate story
story = generate_story(caption, story_model)
# Display the story with some styling
st.markdown("## πŸ“– Your Magical Story")
st.markdown(f"<div class='story-container'>{story}</div>",
unsafe_allow_html=True)
# Convert to speech and play
st.markdown("## πŸ”Š Listen to the Story")
audio_player, audio_error = text_to_speech(story)
if audio_player:
st.markdown(audio_player, unsafe_allow_html=True)
else:
st.error(f"Could not generate audio: {audio_error}")
# Download options
st.download_button(
label="Download Story (Text)",
data=story,
file_name="my_story.txt",
mime="text/plain"
)
else:
st.warning("Some AI models didn't load correctly. Please refresh the page or try again later.")
except Exception as e:
st.error(f"An error occurred: {e}")
st.markdown("Please try again with a different image.")
# Footer
st.markdown("---")
st.markdown("Created for young storytellers aged 3-10 years old 🌈")
st.markdown("Powered by Hugging Face Transformers πŸ€—")