Szeyu commited on
Commit
f7c507d
·
verified ·
1 Parent(s): 7aefbcf

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -194
app.py DELETED
@@ -1,194 +0,0 @@
1
- import streamlit as st
2
- from transformers import pipeline
3
- from PIL import Image
4
- import io, textwrap, numpy as np, soundfile as sf
5
- import logging
6
-
7
- # Set up logging for debugging
8
- logging.basicConfig(level=logging.INFO)
9
-
10
- # ------------------ Streamlit Page Configuration ------------------
11
- st.set_page_config(
12
- page_title="Picture to Story Magic",
13
- page_icon="🦄",
14
- layout="centered"
15
- )
16
-
17
- # ------------------ Custom CSS for a Colorful Background ------------------
18
- st.markdown(
19
- """
20
- <style>
21
- body {
22
- background-color: #FDEBD0;
23
- }
24
- </style>
25
- """,
26
- unsafe_allow_html=True
27
- )
28
-
29
- # ------------------ Playful Header for Young Users ------------------
30
- st.markdown(
31
- """
32
- <h1 style='text-align: center; color: #ff66cc;'>Picture to Story Magic!</h1>
33
- <p style='text-align: center; font-size: 24px;'>
34
- Hi little artist! Upload your picture and let us create a fun story just for you! 🎉
35
- </p>
36
- """,
37
- unsafe_allow_html=True
38
- )
39
-
40
- # ------------------ Lazy Model Loading ------------------
41
- def load_models():
42
- """
43
- Lazy-load the required pipelines and store them in session state.
44
- """
45
- if "captioner" not in st.session_state:
46
- st.session_state.captioner = pipeline(
47
- "image-to-text",
48
- model="Salesforce/blip-image-captioning-large"
49
- )
50
- if "storyer" not in st.session_state:
51
- try:
52
- st.session_state.storyer = pipeline(
53
- "text-generation",
54
- model="aspis/gpt2-genre-story-generation"
55
- )
56
- except Exception as e:
57
- logging.error(f"Failed to load aspis/gpt2-genre-story-generation: {e}")
58
- raise Exception("Text generation model could not be loaded.")
59
- if "tts" not in st.session_state:
60
- st.session_state.tts = pipeline(
61
- "text-to-speech",
62
- model="facebook/mms-tts-eng"
63
- )
64
-
65
- # ------------------ Caching Functions ------------------
66
- @st.cache_data(show_spinner=False)
67
- def get_caption(image_bytes):
68
- """
69
- Converts image bytes into a lower resolution image (maximum 256x256)
70
- and generates a caption.
71
- """
72
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
73
- image.thumbnail((256, 256))
74
- caption = st.session_state.captioner(image)[0]["generated_text"]
75
- return caption
76
-
77
- @st.cache_data(show_spinner=False)
78
- def get_story(caption):
79
- """
80
- Generates a humorous children's story (50-100 words) based on the caption.
81
- Optimized for faster generation with aspis/gpt2-genre-story-generation.
82
- """
83
- prompt = (
84
- f"Write a funny, warm children's story (50-100 words) for ages 3-10 based on: {caption}. "
85
- f"Third-person narrative, playful tone."
86
- )
87
- try:
88
- result = st.session_state.storyer(
89
- prompt,
90
- max_new_tokens=80, # Reduced for faster generation
91
- do_sample=True,
92
- temperature=0.7, # Balanced creativity
93
- top_k=40, # Faster sampling with smaller k
94
- top_p=0.85, # Tighter sampling for coherence
95
- return_full_text=False
96
- )
97
- logging.info(f"Story generation raw result: {result}")
98
-
99
- # Extract and clean generated text
100
- raw_story = result[0].get("generated_text", "").strip()
101
-
102
- # Ensure story is 50-100 words
103
- words = raw_story.split()
104
- if len(words) < 50 or not raw_story:
105
- logging.warning("Generated story too short or empty. Using fallback.")
106
- raw_story = (
107
- f"In a land of {caption}, a silly bunny named Bouncy found a shiny star! "
108
- f"It sparkled, making Bouncy giggle and hop high. The star said, 'Dance!' "
109
- f"So Bouncy twirled with squirrels and birds. They threw a forest party, "
110
- f"singing silly songs under the twinkling sky, laughing all night."
111
- )
112
- words = raw_story.split()
113
-
114
- # Truncate to 100 words if too long
115
- story = " ".join(words[:100])
116
- # Pad if too short
117
- if len(words) < 50:
118
- story += " And they all lived happily ever after!"
119
-
120
- return story
121
- except Exception as e:
122
- logging.error(f"Story generation failed: {e}")
123
- # Fallback story
124
- return (
125
- f"In a land of {caption}, a silly bunny named Bouncy found a shiny star! "
126
- f"It sparkled, making Bouncy giggle and hop high. The star said, 'Dance!' "
127
- f"So Bouncy twirled with squirrels and birds. They threw a forest party, "
128
- f"singing silly songs under the twinkling sky, laughing all night."
129
- )
130
-
131
- @st.cache_data(show_spinner=False)
132
- def get_audio(story):
133
- """
134
- Converts the generated story text into audio.
135
- Splits the text into 300-character chunks to reduce repeated TTS calls.
136
- """
137
- chunks = textwrap.wrap(story, width=300)
138
- audio_chunks = []
139
- for chunk in chunks:
140
- try:
141
- output = st.session_state.tts(chunk)
142
- if isinstance(output, list):
143
- output = output[0]
144
- if "audio" in output:
145
- audio_array = np.array(output["audio"]).squeeze()
146
- audio_chunks.append(audio_array)
147
- except Exception:
148
- continue
149
-
150
- if not audio_chunks:
151
- sr = st.session_state.tts.model.config.sampling_rate
152
- audio = np.zeros(sr, dtype=np.float32)
153
- else:
154
- audio = np.concatenate(audio_chunks)
155
-
156
- buffer = io.BytesIO()
157
- sf.write(buffer, audio, st.session_state.tts.model.config.sampling_rate, format="WAV")
158
- buffer.seek(0)
159
- return buffer
160
-
161
- # ------------------ Main App Logic ------------------
162
- uploaded_file = st.file_uploader("Choose a Picture...", type=["jpg", "jpeg", "png"])
163
- if uploaded_file is not None:
164
- try:
165
- load_models() # Ensure models are loaded
166
- image_bytes = uploaded_file.getvalue()
167
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
168
- st.image(image, caption="Your Amazing Picture!", use_column_width=True)
169
- st.markdown("<h3 style='text-align: center;'>Ready for your story?</h3>", unsafe_allow_html=True)
170
-
171
- if st.button("Story, Please!"):
172
- with st.spinner("Generating caption..."):
173
- caption = get_caption(image_bytes)
174
- st.markdown("<h3 style='text-align: center;'>Caption:</h3>", unsafe_allow_html=True)
175
- st.write(caption)
176
-
177
- with st.spinner("Generating story..."):
178
- story = get_story(caption)
179
- st.markdown("<h3 style='text-align: center;'>Your Story:</h3>", unsafe_allow_html=True)
180
- if not story.strip():
181
- st.write("No story was generated. Please try again.")
182
- else:
183
- st.write(story)
184
-
185
- with st.spinner("Generating audio..."):
186
- audio_buffer = get_audio(story)
187
- st.audio(audio_buffer, format="audio/wav", start_time=0)
188
- st.markdown(
189
- "<p style='text-align: center; font-weight: bold;'>Enjoy your magical story! 🎶</p>",
190
- unsafe_allow_html=True
191
- )
192
- except Exception as e:
193
- st.error("Oops! Something went wrong. Please try a different picture or check the file format!")
194
- st.error(f"Error details: {e}")