BillyZ1129 commited on
Commit
c0da145
·
verified ·
1 Parent(s): 7d6fb10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -162
app.py CHANGED
@@ -4,200 +4,193 @@ import io
4
  import torch
5
  from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline
6
  from gtts import gTTS
7
- import tempfile
8
  import os
9
  import base64
10
- import numpy as np
11
 
12
- # Set page config
13
  st.set_page_config(
14
- page_title="StoryTime: Kids' Storyteller",
15
  page_icon="📚",
16
  layout="centered"
17
  )
18
 
19
- # Load and apply CSS
20
- def load_css(file_name):
21
- with open(file_name) as f:
22
- st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- try:
25
- load_css("style.css")
26
- except:
27
- st.warning("Style file not found. Using default styling.")
28
 
29
  # Function to load image captioning model
30
  @st.cache_resource
31
- def load_captioning_model():
32
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
33
- model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
34
- return processor, model
 
 
 
 
35
 
36
  # Function to load story generation model
37
  @st.cache_resource
38
- def load_story_generator():
39
- return pipeline('text-generation', model='gpt2')
 
 
 
 
 
40
 
41
  # Function to generate caption from image
42
  def generate_caption(image, processor, model):
43
- # 确保图像是RGB格式
44
- if image.mode != 'RGB':
45
- image = image.convert('RGB')
46
-
47
- # 标准预处理:调整大小到BLIP模型期望的输入尺寸
48
- image = image.resize((384, 384))
49
-
50
- try:
51
- # 使用处理器准备图像
52
- inputs = processor(image, return_tensors="pt", padding=True)
53
-
54
- # 生成caption
55
- out = model.generate(**inputs, max_length=30)
56
- caption = processor.decode(out[0], skip_special_tokens=True)
57
- return caption
58
- except Exception as e:
59
- # 如果有错误,使用一个备用方法
60
- st.warning(f"Caption generation error: {str(e)}. Using fallback method.")
61
-
62
- # 转换图像为numpy数组
63
- img_array = np.array(image)
64
-
65
- # 手动准备图像为模型输入
66
- pixel_values = processor.image_processor(images=img_array, return_tensors="pt").pixel_values
67
-
68
- # 生成caption
69
- generated_ids = model.generate(pixel_values=pixel_values, max_length=30)
70
- caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
71
- return caption
72
 
73
  # Function to generate story from caption
74
- def generate_story(caption, generator):
75
- prompt = f"{caption} "
76
- story = generator(
77
- prompt,
78
- max_length=150,
79
- num_return_sequences=1,
80
- temperature=0.8,
81
- top_k=50
82
- )[0]['generated_text']
83
-
84
- # Clean up the story
85
- story = story.replace('\n', ' ')
86
- sentences = story.split('.')
87
- if len(sentences) > 5:
88
- story = '.'.join(sentences[:5]) + '.'
89
-
90
- # Strictly control word count between 50-100 words
91
- words = story.split()
92
- word_count = len(words)
93
 
94
- if word_count < 50:
95
- # If story is too short, generate more content
96
- additional_content = generator(
97
- story + " Then, ",
98
- max_length=100,
99
- num_return_sequences=1,
100
- temperature=0.8,
101
- top_k=50
102
- )[0]['generated_text']
103
-
104
- # Add only what's needed to reach 50 words
105
- additional_words = additional_content.split()[word_count:]
106
- words_needed = 50 - word_count
107
- story = ' '.join(words + additional_words[:words_needed])
108
-
109
- if word_count > 100:
110
- # If story is too long, truncate to exactly 100 words
111
- story = ' '.join(words[:100])
112
 
113
- # Ensure the story ends with a period
114
- if not story.endswith('.'):
115
- story += '.'
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  return story
118
 
119
- # Function to convert text to speech
120
  def text_to_speech(text):
121
- tts = gTTS(text=text, lang='en', slow=False)
122
- fp = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
123
- tts.save(fp.name)
124
- return fp.name
125
-
126
- # Add background decorations
127
- def add_background_decorations():
128
- st.markdown(
 
 
 
 
 
 
 
129
  """
130
- <div style="position: fixed; top: 0; right: 0; z-index: -1; opacity: 0.3;">
131
- <img src="data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHdpZHRoPSIyMDAiIGhlaWdodD0iMjAwIiB2aWV3Qm94PSIwIDAgMjAwIDIwMCI+PHBhdGggZD0iTTEwMCAxOTBjNTEuNCAwIDkwLTQwLjYgOTAtOTBTMTUxLjQgMTAgMTAwIDEwIDEwIDUwLjYgMTAgMTAwczM4LjYgOTAgOTAgOTB6IiBmaWxsPSIjNzZiNWM1IiBvcGFjaXR5PSIwLjIiLz48cGF0aCBkPSJNMTgwIDEwMGMwIDQ0LjEtMzUuOSA4MC04MCA4MHMtODAtMzUuOS04MC04MCAzNS45LTgwIDgwLTgwIDgwIDM1LjkgODAgODB6IiBmaWxsPSIjM2Q4NWM2IiBvcGFjaXR5PSIwLjIiLz48L3N2Zz4=" width="200"/>
132
- </div>
133
- <div style="position: fixed; bottom: 0; left: 0; z-index: -1; opacity: 0.3;">
134
- <img src="data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHdpZHRoPSIyMDAiIGhlaWdodD0iMjAwIiB2aWV3Qm94PSIwIDAgMjAwIDIwMCI+PHBhdGggZD0iTTEwMCAxOTBjNTEuNCAwIDkwLTQwLjYgOTAtOTBTMTUxLjQgMTAgMTAwIDEwIDEwIDUwLjYgMTAgMTAwczM4LjYgOTAgOTAgOTB6IiBmaWxsPSIjNzZiNWM1IiBvcGFjaXR5PSIwLjIiLz48cGF0aCBkPSJNMTgwIDEwMGMwIDQ0LjEtMzUuOSA4MC04MCA4MHMtODAtMzUuOS04MC04MCAzNS45LTgwIDgwLTgwIDgwIDM1LjkgODAgODB6IiBmaWxsPSIjM2Q4NWM2IiBvcGFjaXR5PSIwLjIiLz48L3N2Zz4=" width="200"/>
135
- </div>
136
- """,
137
- unsafe_allow_html=True
138
- )
139
 
140
- # Main UI
141
- def main():
142
- add_background_decorations()
143
-
144
- st.title("📚 StoryTime: Kids' Storyteller")
145
- st.markdown("### Upload a picture and listen to a magical story!")
146
-
147
- # Load models
148
- with st.spinner("Loading models... This might take a moment!"):
149
- processor, caption_model = load_captioning_model()
150
- story_generator = load_story_generator()
151
-
152
- # Create columns for better layout
153
- col1, col2 = st.columns([1, 1])
154
-
155
- # Image upload
156
- with col1:
157
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
158
-
159
- if uploaded_file is not None:
160
- # Display the uploaded image
161
- image_bytes = uploaded_file.getvalue()
162
- original_image = Image.open(io.BytesIO(image_bytes))
163
 
164
- # Display the image
 
 
 
 
 
 
 
 
 
165
  with col1:
166
- st.image(original_image, caption='Your Magical Picture', use_column_width=True)
 
 
 
 
167
 
168
- # Generate caption and story
169
- with st.spinner("Looking at your picture and thinking of a story..."):
170
- caption = generate_caption(original_image, processor, caption_model)
171
- # 打印图片描述
172
- st.info(f"Image caption: {caption}")
173
- story = generate_story(caption, story_generator)
174
 
175
- # Display the story and audio
176
- with col2:
177
- st.markdown("### Here's your story:")
178
- st.write(story)
179
-
180
- # Word count display
181
- word_count = len(story.split())
182
- st.caption(f"Story length: {word_count} words")
183
 
184
- # Convert to speech and play
185
- with st.spinner("Creating the storytelling voice..."):
186
- audio_file = text_to_speech(story)
187
 
188
- st.audio(audio_file, format='audio/mp3')
189
- st.success("Story created! Click the play button to listen!")
190
-
191
- # Add download button for audio
192
- with open(audio_file, "rb") as f:
193
- audio_bytes = f.read()
194
-
195
- audio_b64 = base64.b64encode(audio_bytes).decode()
196
- href = f'<a href="data:audio/mp3;base64,{audio_b64}" download="story.mp3">Download the story audio</a>'
197
- st.markdown(href, unsafe_allow_html=True)
198
-
199
- # Clean up the temp file
200
- os.unlink(audio_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
- if __name__ == "__main__":
203
- main()
 
 
 
4
  import torch
5
  from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline
6
  from gtts import gTTS
 
7
  import os
8
  import base64
9
+ import time
10
 
11
+ # Set page configuration
12
  st.set_page_config(
13
+ page_title="Storyteller for Kids",
14
  page_icon="📚",
15
  layout="centered"
16
  )
17
 
18
+ # Custom CSS
19
+ st.markdown("""
20
+ <style>
21
+ .main {
22
+ background-color: #f5f7ff;
23
+ }
24
+ .stTitle {
25
+ color: #3366cc;
26
+ font-family: 'Comic Sans MS', cursive;
27
+ }
28
+ .stHeader {
29
+ font-family: 'Comic Sans MS', cursive;
30
+ }
31
+ .stImage {
32
+ border-radius: 15px;
33
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
34
+ }
35
+ .story-container {
36
+ background-color: #e6f2ff;
37
+ padding: 20px;
38
+ border-radius: 10px;
39
+ border: 2px dashed #3366cc;
40
+ font-size: 18px;
41
+ line-height: 1.6;
42
+ }
43
+ </style>
44
+ """, unsafe_allow_html=True)
45
 
46
+ # Title and description
47
+ st.title("🧸 Kid's Storyteller 🧸")
48
+ st.markdown("### Upload an image and I'll tell you a magical story about it!")
 
49
 
50
  # Function to load image captioning model
51
  @st.cache_resource
52
+ def load_caption_model():
53
+ try:
54
+ with st.spinner("Loading image captioning model... (This may take a minute)"):
55
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
56
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
57
+ return processor, model, None
58
+ except Exception as e:
59
+ return None, None, str(e)
60
 
61
  # Function to load story generation model
62
  @st.cache_resource
63
+ def load_story_model():
64
+ try:
65
+ with st.spinner("Loading story generation model... (This may take a minute)"):
66
+ story_generator = pipeline("text-generation", model="gpt2")
67
+ return story_generator, None
68
+ except Exception as e:
69
+ return None, str(e)
70
 
71
  # Function to generate caption from image
72
  def generate_caption(image, processor, model):
73
+ inputs = processor(image, return_tensors="pt")
74
+ out = model.generate(**inputs, max_length=50)
75
+ caption = processor.decode(out[0], skip_special_tokens=True)
76
+ return caption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  # Function to generate story from caption
79
+ def generate_story(caption, story_generator):
80
+ # Make the prompt child-friendly and whimsical
81
+ prompt = f"Once upon a time in a magical land, {caption}. The children were amazed when "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ result = story_generator(prompt, max_length=150, num_return_sequences=1, temperature=0.8)
84
+ story = result[0]['generated_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ # Make sure the story is between 50-100 words
87
+ story_words = story.split()
88
+ if len(story_words) > 100:
89
+ story = ' '.join(story_words[:100])
90
+ # Add a closing sentence
91
+ story += ". And they all lived happily ever after."
92
+ elif len(story_words) < 50:
93
+ # If too short, generate more
94
+ additional = story_generator(story, max_length=150, num_return_sequences=1)
95
+ story = additional[0]['generated_text']
96
+ story_words = story.split()
97
+ if len(story_words) > 100:
98
+ story = ' '.join(story_words[:100])
99
+ story += ". And they all lived happily ever after."
100
 
101
  return story
102
 
103
+ # Function to convert text to speech and create audio player
104
  def text_to_speech(text):
105
+ try:
106
+ tts = gTTS(text=text, lang='en', slow=False)
107
+ audio_file = "story_audio.mp3"
108
+ tts.save(audio_file)
109
+
110
+ # Create audio player
111
+ with open(audio_file, "rb") as file:
112
+ audio_bytes = file.read()
113
+
114
+ audio_b64 = base64.b64encode(audio_bytes).decode()
115
+ audio_player = f"""
116
+ <audio controls autoplay>
117
+ <source src="data:audio/mp3;base64,{audio_b64}" type="audio/mp3">
118
+ Your browser does not support the audio element.
119
+ </audio>
120
  """
121
+ return audio_player, None
122
+ except Exception as e:
123
+ return None, str(e)
 
 
 
 
 
 
124
 
125
+ # Main application flow
126
+ try:
127
+ # Load models with status checks
128
+ with st.spinner("Loading AI models... This may take a moment the first time you run the app."):
129
+ caption_processor, caption_model, caption_error = load_caption_model()
130
+ story_model, story_error = load_story_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ if caption_error:
133
+ st.error(f"Error loading caption model: {caption_error}")
134
+ if story_error:
135
+ st.error(f"Error loading story model: {story_error}")
136
+
137
+ # If models loaded successfully
138
+ if caption_processor and caption_model and story_model:
139
+ # Show example images for kids to understand
140
+ st.markdown("### 🌟 Examples of images you can upload:")
141
+ col1, col2, col3 = st.columns(3)
142
  with col1:
143
+ st.markdown("🐱 Pets")
144
+ with col2:
145
+ st.markdown("🏰 Places")
146
+ with col3:
147
+ st.markdown("🧩 Toys")
148
 
149
+ # File uploader
150
+ uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])
 
 
 
 
151
 
152
+ if uploaded_file is not None:
153
+ # Display the uploaded image
154
+ image_bytes = uploaded_file.getvalue()
155
+ image = Image.open(io.BytesIO(image_bytes))
156
+ st.image(image, caption='Uploaded Image', use_column_width=True, output_format="JPEG")
 
 
 
157
 
158
+ with st.spinner('Creating your story... 📝'):
159
+ # Generate caption
160
+ caption = generate_caption(image, caption_processor, caption_model)
161
 
162
+ # Generate story
163
+ story = generate_story(caption, story_model)
164
+
165
+ # Display the story with some styling
166
+ st.markdown("## 📖 Your Magical Story")
167
+ st.markdown(f"<div class='story-container'>{story}</div>",
168
+ unsafe_allow_html=True)
169
+
170
+ # Convert to speech and play
171
+ st.markdown("## 🔊 Listen to the Story")
172
+ audio_player, audio_error = text_to_speech(story)
173
+
174
+ if audio_player:
175
+ st.markdown(audio_player, unsafe_allow_html=True)
176
+ else:
177
+ st.error(f"Could not generate audio: {audio_error}")
178
+
179
+ # Download options
180
+ st.download_button(
181
+ label="Download Story (Text)",
182
+ data=story,
183
+ file_name="my_story.txt",
184
+ mime="text/plain"
185
+ )
186
+ else:
187
+ st.warning("Some AI models didn't load correctly. Please refresh the page or try again later.")
188
+
189
+ except Exception as e:
190
+ st.error(f"An error occurred: {e}")
191
+ st.markdown("Please try again with a different image.")
192
 
193
+ # Footer
194
+ st.markdown("---")
195
+ st.markdown("Created for young storytellers aged 3-10 years old 🌈")
196
+ st.markdown("Powered by Hugging Face Transformers 🤗")