Szeyu commited on
Commit
64fd107
·
verified ·
1 Parent(s): 82a099b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -0
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline
3
+ from PIL import Image
4
+ import io, textwrap, numpy as np, soundfile as sf
5
+
6
+ # ------------------ Streamlit Page Configuration ------------------
7
+ st.set_page_config(
8
+ page_title="Picture to Story Magic", # App title on browser tab
9
+ page_icon="🦄", # Fun unicorn icon
10
+ layout="centered"
11
+ )
12
+
13
+ # ------------------ Custom CSS for a Colorful Background ------------------
14
+ st.markdown(
15
+ """
16
+ <style>
17
+ body {
18
+ background-color: #FDEBD0; /* A soft pastel color */
19
+ }
20
+ </style>
21
+ """,
22
+ unsafe_allow_html=True
23
+ )
24
+
25
+ # ------------------ Playful Header for Young Users ------------------
26
+ st.markdown(
27
+ """
28
+ <h1 style='text-align: center; color: #ff66cc;'>Picture to Story Magic!</h1>
29
+ <p style='text-align: center; font-size: 24px;'>
30
+ Hi little artist! Upload your picture and let us create a fun story just for you! 🎉
31
+ </p>
32
+ """,
33
+ unsafe_allow_html=True
34
+ )
35
+
36
+ # ------------------ Lazy Model Loading ------------------
37
+ def load_models():
38
+ """
39
+ Lazy-load the required pipelines and store them in session state.
40
+ Pipelines:
41
+ 1. Captioner: Generates descriptive text from an image using a lighter model.
42
+ 2. Storyer: Generates a humorous children's story using aspis/gpt2-genre-story-generation.
43
+ 3. TTS: Converts text into audio.
44
+ """
45
+ if "captioner" not in st.session_state:
46
+ # Use the "base" version for faster/cost-effective captioning.
47
+ st.session_state.captioner = pipeline(
48
+ "image-to-text",
49
+ model="Salesforce/blip-image-captioning-base"
50
+ )
51
+ if "storyer" not in st.session_state:
52
+ st.session_state.storyer = pipeline(
53
+ "text-generation",
54
+ model="aspis/gpt2-genre-story-generation"
55
+ )
56
+ if "tts" not in st.session_state:
57
+ st.session_state.tts = pipeline(
58
+ "text-to-speech",
59
+ model="facebook/mms-tts-eng"
60
+ )
61
+
62
+ # ------------------ Caching Functions ------------------
63
+ @st.cache_data(show_spinner=False)
64
+ def get_caption(image_bytes):
65
+ """
66
+ Convert the image bytes into a smaller image to speed up captioning,
67
+ then return the generated caption.
68
+ """
69
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
70
+ # Resize the image (preserving aspect ratio) to only 256x256 for faster processing.
71
+ image.thumbnail((256, 256))
72
+ caption = st.session_state.captioner(image)[0]["generated_text"]
73
+ return caption
74
+
75
+ @st.cache_data(show_spinner=False)
76
+ def get_story(caption):
77
+ """
78
+ Generate a humorous and engaging children's story using the caption.
79
+ The prompt instructs the model to produce a playful story (50-100 words).
80
+ We lower max_new_tokens to 80 so that it generates its text faster.
81
+ """
82
+ prompt = (
83
+ f"Write a funny, warm, and imaginative children's story for ages 3-10, 50-100 words, "
84
+ f"in third-person narrative, as if the author is playfully describing the scene in the image: {caption}. "
85
+ "Explicitly mention the exact venue or location (such as a park, school, or home), describe specific characters "
86
+ "(for example, a little girl named Lily or a boy named Jack), and detail the humorous actions they perform. "
87
+ "Ensure the story is playful, engaging, and ends with a complete sentence."
88
+ )
89
+ raw_story = st.session_state.storyer(
90
+ prompt,
91
+ max_new_tokens=80, # Reduced token generation for faster response
92
+ do_sample=True,
93
+ temperature=0.7,
94
+ top_p=0.9,
95
+ return_full_text=False
96
+ )[0]["generated_text"].strip()
97
+ words = raw_story.split()
98
+ return " ".join(words[:100])
99
+
100
+ @st.cache_data(show_spinner=False)
101
+ def get_audio(story):
102
+ """
103
+ Convert the generated story text into audio.
104
+ The text is split into 300-character chunks to reduce repeated TTS calls,
105
+ the audio chunks are concatenated, and then stored in an in-memory WAV buffer.
106
+ """
107
+ chunks = textwrap.wrap(story, width=300)
108
+ audio_chunks = [st.session_state.tts(chunk)["audio"].squeeze() for chunk in chunks]
109
+ audio = np.concatenate(audio_chunks)
110
+ buffer = io.BytesIO()
111
+ sf.write(buffer, audio, st.session_state.tts.model.config.sampling_rate, format="WAV")
112
+ buffer.seek(0)
113
+ return buffer
114
+
115
+ # ------------------ Main App Logic ------------------
116
+ uploaded_file = st.file_uploader("Choose a Picture...", type=["jpg", "jpeg", "png"])
117
+ if uploaded_file is not None:
118
+ try:
119
+ load_models() # Ensure models are loaded once
120
+ image_bytes = uploaded_file.getvalue()
121
+ # Display the user-uploaded image
122
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
123
+ st.image(image, caption="Your Amazing Picture!", use_column_width=True)
124
+ st.markdown("<h3 style='text-align: center;'>Ready for your story?</h3>", unsafe_allow_html=True)
125
+
126
+ if st.button("Story, Please!"):
127
+ with st.spinner("Generating caption..."):
128
+ caption = get_caption(image_bytes)
129
+ st.markdown("<h3 style='text-align: center;'>Caption:</h3>", unsafe_allow_html=True)
130
+ st.write(caption)
131
+
132
+ with st.spinner("Generating story..."):
133
+ story = get_story(caption)
134
+ st.markdown("<h3 style='text-align: center;'>Your Story:</h3>", unsafe_allow_html=True)
135
+ st.write(story)
136
+
137
+ with st.spinner("Generating audio..."):
138
+ audio_buffer = get_audio(story)
139
+ st.audio(audio_buffer, format="audio/wav", start_time=0)
140
+ st.markdown(
141
+ "<p style='text-align: center; font-weight: bold;'>Enjoy your magical story! 🎶</p>",
142
+ unsafe_allow_html=True
143
+ )
144
+ except Exception as e:
145
+ st.error("Oops! Something went wrong. Please try a different picture or check the file format!")
146
+ st.error(f"Error details: {e}")