Szeyu's picture
Update app.py
df08c46 verified
raw
history blame
4.6 kB
import re
import streamlit as st
from transformers import pipeline
import textwrap
import numpy as np
import soundfile as sf
import tempfile
import os
from PIL import Image
import string
# Initialize pipelines with caching
@st.cache_resource
def load_pipelines():
captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
storyer = pipeline("text-generation", model="aspis/gpt2-genre-story-generation")
tts = pipeline("text-to-speech", model="facebook/mms-tts-eng")
return captioner, storyer, tts
captioner, storyer, tts = load_pipelines()
def clean_generated_story(raw_story: str) -> str:
"""
Cleans the generated story by:
1. Removing digits.
2. Removing words that are likely random letter combinations based on having no vowels.
3. Removing single-letter words unless they are allowed (such as 'a' or 'I').
"""
# Remove all digits using regex
story_without_numbers = re.sub(r'\d+', '', raw_story)
vowels = set('aeiouAEIOU')
def is_valid_word(word: str) -> bool:
# Allow "a" and "I" for single-letter words
if len(word) == 1 and word.lower() not in ['a', 'i']:
return False
# For words longer than one letter, filter out those that do not contain any vowels
if len(word) > 1 and not any(char in vowels for char in word):
return False
return True
# Split the story into words, apply filtering, and recombine into a clean story
words = story_without_numbers.split()
filtered_words = [word for word in words if is_valid_word(word)]
# Optionally, you can trim the clean story to a certain word count
clean_story = " ".join(filtered_words[:100])
return clean_story
def get_caption(image) -> str:
"""
Takes an image and returns a generated caption.
"""
pil_image = Image.open(image)
caption = captioner(pil_image)[0]["generated_text"]
st.write("**๐ŸŒŸ What's in the picture: ๐ŸŒŸ**")
st.write(caption)
return caption
def get_story(caption: str) -> str:
"""
Takes a caption and returns a funny, bright, and playful story targeted toward young children.
"""
prompt = (
f"Write a funny, bright, and playful story for young children precisely centered on this scene {caption}\nStory: "
f"mention the exact place, location or venue within {caption}. "
f"Make the story magical and exciting, with lots of friendly descriptions that little ones can enjoy."
)
raw = storyer(
prompt,
max_new_tokens=150,
temperature=0.7,
top_p=0.9,
no_repeat_ngram_size=2,
return_full_text=False
)[0]["generated_text"].strip()
story = clean_generated_story(raw)
st.write("**๐Ÿ“– Your funny story: ๐Ÿ“–**")
st.write(story)
return story
def generate_audio(story: str) -> str:
"""
Converts a text story into speech audio and returns the file path for the audio.
"""
chunks = textwrap.wrap(story, width=200)
audio = np.concatenate([tts(chunk)["audio"].squeeze() for chunk in chunks])
# Save audio to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
sf.write(temp_file.name, audio, tts.model.config.sampling_rate)
temp_file_path = temp_file.name
return temp_file_path
def generate_content(image):
"""
Pipeline that takes an image, generates a caption, a story based on that caption,
and produces an audio file from the story.
"""
caption = get_caption(image)
story = get_story(caption)
audio_path = generate_audio(story)
return caption, story, audio_path
# Streamlit UI section
st.title("โœจ Magic Story Maker โœจ")
st.markdown("Upload a picture to make a funny story and hear it too! ๐Ÿ“ธ")
uploaded_image = st.file_uploader("Choose your picture", type=["jpg", "jpeg", "png"])
if uploaded_image is None:
st.image("https://example.com/placeholder_image.jpg", caption="Upload your picture here! ๐Ÿ“ท", use_container_width=True)
else:
st.image(uploaded_image, caption="Your Picture ๐ŸŒŸ", use_container_width=True)
if st.button("โœจ Make My Story! โœจ"):
if uploaded_image is not None:
with st.spinner("๐Ÿ”ฎ Creating your magical story..."):
caption, story, audio_path = generate_content(uploaded_image)
st.success("๐ŸŽ‰ Your story is ready! ๐ŸŽ‰")
st.audio(audio_path, format="audio/wav")
os.remove(audio_path)
else:
st.warning("Please upload a picture first! ๐Ÿ“ธ")