BillyZ1129 commited on
Commit
8bb3817
·
verified ·
1 Parent(s): 9f48a22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -22
app.py CHANGED
@@ -1,30 +1,175 @@
1
  import streamlit as st
2
- from transformers import pipeline
 
 
 
 
 
 
 
3
 
4
- # Load the text classification model pipeline
5
- classifier = pipeline("text-classification",model='isom5240ust/bert-base-uncased-emotion', return_all_scores=True)
 
 
 
 
6
 
7
- # Streamlit application title
8
- st.title("Text Classification for you")
9
- st.write("Classification for 6 emotions: sadness, joy, love, anger, fear, surprise")
 
10
 
11
- # Text input for user to enter the text to classify
12
- text = st.text_area("Enter the text to classify", "")
 
 
13
 
14
- # Perform text classification when the user clicks the "Classify" button
15
- if st.button("Classify"):
16
- # Perform text classification on the input text
17
- results = classifier(text)[0]
 
 
18
 
19
- # Display the classification result
20
- max_score = float('-inf')
21
- max_label = ''
 
22
 
23
- for result in results:
24
- if result['score'] > max_score:
25
- max_score = result['score']
26
- max_label = result['label']
 
 
27
 
28
- st.write("Text:", text)
29
- st.write("Label:", max_label)
30
- st.write("Score:", max_score)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from PIL import Image
3
+ import io
4
+ import torch
5
+ from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline
6
+ from gtts import gTTS
7
+ import tempfile
8
+ import os
9
+ import base64
10
 
11
+ # Set page config
12
+ st.set_page_config(
13
+ page_title="StoryTime: Kids' Storyteller",
14
+ page_icon="📚",
15
+ layout="centered"
16
+ )
17
 
18
+ # Load and apply CSS
19
+ def load_css(file_name):
20
+ with open(file_name) as f:
21
+ st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
22
 
23
+ try:
24
+ load_css("style.css")
25
+ except:
26
+ st.warning("Style file not found. Using default styling.")
27
 
28
+ # Function to load image captioning model
29
+ @st.cache_resource
30
+ def load_captioning_model():
31
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
32
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
33
+ return processor, model
34
 
35
+ # Function to load story generation model
36
+ @st.cache_resource
37
+ def load_story_generator():
38
+ return pipeline('text-generation', model='gpt2')
39
 
40
+ # Function to generate caption from image
41
+ def generate_caption(image, processor, model):
42
+ inputs = processor(image, return_tensors="pt")
43
+ out = model.generate(**inputs, max_length=30)
44
+ caption = processor.decode(out[0], skip_special_tokens=True)
45
+ return caption
46
 
47
+ # Function to generate story from caption
48
+ def generate_story(caption, generator):
49
+ prompt = f"{caption} "
50
+ story = generator(
51
+ prompt,
52
+ max_length=150,
53
+ num_return_sequences=1,
54
+ temperature=0.8,
55
+ top_k=50
56
+ )[0]['generated_text']
57
+
58
+ # Clean up the story
59
+ story = story.replace('\n', ' ')
60
+ sentences = story.split('.')
61
+ if len(sentences) > 5:
62
+ story = '.'.join(sentences[:5]) + '.'
63
+
64
+ # Strictly control word count between 50-100 words
65
+ words = story.split()
66
+ word_count = len(words)
67
+
68
+ if word_count < 50:
69
+ # If story is too short, generate more content
70
+ additional_content = generator(
71
+ story + " Then, ",
72
+ max_length=100,
73
+ num_return_sequences=1,
74
+ temperature=0.8,
75
+ top_k=50
76
+ )[0]['generated_text']
77
+
78
+ # Add only what's needed to reach 50 words
79
+ additional_words = additional_content.split()[word_count:]
80
+ words_needed = 50 - word_count
81
+ story = ' '.join(words + additional_words[:words_needed])
82
+
83
+ if word_count > 100:
84
+ # If story is too long, truncate to exactly 100 words
85
+ story = ' '.join(words[:100])
86
+
87
+ # Ensure the story ends with a period
88
+ if not story.endswith('.'):
89
+ story += '.'
90
+
91
+ return story
92
+
93
+ # Function to convert text to speech
94
+ def text_to_speech(text):
95
+ tts = gTTS(text=text, lang='en', slow=False)
96
+ fp = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
97
+ tts.save(fp.name)
98
+ return fp.name
99
+
100
+ # Add background decorations
101
+ def add_background_decorations():
102
+ st.markdown(
103
+ """
104
+ <div style="position: fixed; top: 0; right: 0; z-index: -1; opacity: 0.3;">
105
+ <img src="" width="200"/>
106
+ </div>
107
+ <div style="position: fixed; bottom: 0; left: 0; z-index: -1; opacity: 0.3;">
108
+ <img src="" width="200"/>
109
+ </div>
110
+ """,
111
+ unsafe_allow_html=True
112
+ )
113
+
114
+ # Main UI
115
+ def main():
116
+ add_background_decorations()
117
+
118
+ st.title("📚 StoryTime: Kids' Storyteller")
119
+ st.markdown("### Upload a picture and listen to a magical story!")
120
+
121
+ # Load models
122
+ with st.spinner("Loading models... This might take a moment!"):
123
+ processor, caption_model = load_captioning_model()
124
+ story_generator = load_story_generator()
125
+
126
+ # Create columns for better layout
127
+ col1, col2 = st.columns([1, 1])
128
+
129
+ # Image upload
130
+ with col1:
131
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
132
+
133
+ if uploaded_file is not None:
134
+ # Display the uploaded image
135
+ image_bytes = uploaded_file.getvalue()
136
+ original_image = Image.open(io.BytesIO(image_bytes))
137
+
138
+ # Display the image
139
+ with col1:
140
+ st.image(original_image, caption='Your Magical Picture', use_column_width=True)
141
+
142
+ # Generate caption and story
143
+ with st.spinner("Looking at your picture and thinking of a story..."):
144
+ caption = generate_caption(original_image, processor, caption_model)
145
+ story = generate_story(caption, story_generator)
146
+
147
+ # Display the story and audio
148
+ with col2:
149
+ st.markdown("### Here's your story:")
150
+ st.write(story)
151
+
152
+ # Word count display
153
+ word_count = len(story.split())
154
+ st.caption(f"Story length: {word_count} words")
155
+
156
+ # Convert to speech and play
157
+ with st.spinner("Creating the storytelling voice..."):
158
+ audio_file = text_to_speech(story)
159
+
160
+ st.audio(audio_file, format='audio/mp3')
161
+ st.success("Story created! Click the play button to listen!")
162
+
163
+ # Add download button for audio
164
+ with open(audio_file, "rb") as f:
165
+ audio_bytes = f.read()
166
+
167
+ audio_b64 = base64.b64encode(audio_bytes).decode()
168
+ href = f'<a href="data:audio/mp3;base64,{audio_b64}" download="story.mp3">Download the story audio</a>'
169
+ st.markdown(href, unsafe_allow_html=True)
170
+
171
+ # Clean up the temp file
172
+ os.unlink(audio_file)
173
+
174
+ if __name__ == "__main__":
175
+ main()