File size: 8,317 Bytes
77c4802
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# app.py

import io                   # for creating in-memory binary streams
import wave                 # for writing WAV audio files
import re                   # for regular expression utilities
import streamlit as st      # Streamlit UI library
from transformers import pipeline  # Hugging Face inference pipelines
from PIL import Image       # Python Imaging Library for image loading
import numpy as np          # numerical operations, especially array handling

# 1) CACHE & LOAD MODELS
@st.cache_resource(show_spinner=False)
def load_captioner():
    # Loads BLIP image-to-text model; cached so it loads only once.
    # Returns: a function captioner(image: PIL.Image) -> List[Dict],
    return pipeline(
        "image-to-text",
        model="Salesforce/blip-image-captioning-base",
        device="cpu" # Can change to "cuda" if GPU is available
    )

@st.cache_resource(show_spinner=False)
def load_story_pipe():
    # Loads FLAN-T5 text-to-text model for story generation; cached once.
    # Returns: a function story_pipe(prompt: str, **kwargs) -> List[Dict].
    return pipeline(
        "text2text-generation",
        model="google/flan-t5-base",
        device="cpu" # Can change to "cuda" if GPU is available
    )

@st.cache_resource(show_spinner=False)
def load_tts_pipe():
    # Loads Meta MMS-TTS text-to-speech model; cached once.
    # Returns: a function tts_pipe(text: str) -> List[Dict] with "audio" and "sampling_rate".
    return pipeline(
        "text-to-speech",
        model="facebook/mms-tts-eng",
        device="cpu" # Can change to "cuda" if GPU is available
    )

# 2) HELPER FUNCTIONS
def sentence_case(text: str) -> str:
    # Splits text into sentences on .!? delimiters,
    # capitalizes the first character of each sentence,
    # then rejoins into a single string.
    parts = re.split(r'([.!?])', text)  # ["hello", ".", " world", "!"]
    out = []
    for i in range(0, len(parts) - 1, 2):
        sentence = parts[i].strip().capitalize()  # capitalize first letter
        delimiter = parts[i + 1]                  # punctuation
        # Ensure a space before the sentence if it wasn't the very first part
        if out and not sentence.startswith(' ') and out[-1][-1] not in '.!?':
             out.append(f" {sentence}{delimiter}")
        else:
            out.append(f"{sentence}{delimiter}")

    # If trailing text without punctuation exists, capitalize and append it.
    if len(parts) % 2:
        last = parts[-1].strip().capitalize()
        if last:
             # Ensure a space before if needed
             if out and not last.startswith(' ') and out[-1][-1] not in '.!?':
                 out.append(f" {last}")
             else:
                 out.append(last)

    # Clean up potential multiple spaces resulting from split/join
    return " ".join(" ".join(out).split())


def caption_image(img: Image.Image, captioner) -> str:
    # Given a PIL image and a captioner pipeline, returns a single-line caption.
    results = captioner(img)  # run model
    if not results:
        return ""
    # extract "generated_text" field from first result
    return results[0].get("generated_text", "")

def story_from_caption(caption: str, pipe) -> str:
    # Given a caption string and a text2text pipeline, returns a ~100-word story.
    prompt = f"Write a vivid, imaginative ~100-word story about this scene: {caption}"
    results = pipe(
        prompt,
        max_length=120,            # increased max length slightly
        min_length=80,             # minimum generated tokens
        do_sample=True,            # enable sampling
        top_k=100,                 # sample from top_k tokens
        top_p=0.9,                 # nucleus sampling threshold
        temperature=0.7,           # sampling temperature
        repetition_penalty=1.1,    # discourage repetition
        no_repeat_ngram_size=4,    # block repeated n-grams
        early_stopping=False
    )
    raw = results[0]["generated_text"].strip()  # full generated text
    # strip out the prompt if it echoes back - make comparison case-insensitive
    if raw.lower().startswith(prompt.lower()):
        raw = raw[len(prompt):].strip()

    # trim to last complete sentence ending in . ! or ?
    match = re.search(r'[.!?]', raw[::-1]) # Search for the first punctuation from the end
    if match:
        raw = raw[:len(raw) - match.start()] # Trim at that position
    elif len(raw) > 80: # If no punctuation found but story is long, trim to a reasonable length
         raw = raw[:80] + "..."

    return sentence_case(raw)

def tts_bytes(text: str, tts_pipe) -> bytes:
    # Given a text string and a tts pipeline, returns WAV-format bytes.
    # Clean up text for TTS - remove leading/trailing quotes, etc.
    cleaned_text = re.sub(r'^["\']|["\']$', '', text).strip()
    # Basic punctuation cleaning (optional, depending on TTS model)
    cleaned_text = re.sub(r'\.{2,}', '.', cleaned_text) # Replace multiple periods with one
    cleaned_text = cleaned_text.replace('…', '...') # Replace ellipsis char with dots
    # Add a period if the text doesn't end with punctuation (helps TTS model finalize)
    if cleaned_text and cleaned_text[-1] not in '.!?':
         cleaned_text += '.'

    output = tts_pipe(cleaned_text)
    # pipeline may return list or single dict
    result = output[0] if isinstance(output, list) else output
    audio_array = result["audio"]            # numpy array: (channels, samples) or (samples,)
    rate = result["sampling_rate"]           # sampling rate integer

    # ensure audio_array is 2D (samples, channels) for consistent handling
    if audio_array.ndim == 1:
        data = audio_array[:, np.newaxis] # add channel dimension
    else:
        data = audio_array.T # transpose from (channels, samples) to (samples, channels)


    # convert float32 [-1..1] to int16 PCM [-32768..32767]
    pcm = (data * 32767).astype(np.int16)

    buffer = io.BytesIO()
    wf = wave.open(buffer, "wb")
    wf.setnchannels(data.shape[1])         # number of channels
    wf.setsampwidth(2)                     # 16 bits = 2 bytes
    wf.setframerate(rate)                  # samples per second
    wf.writeframes(pcm.tobytes())          # write PCM data
    wf.close()
    buffer.seek(0)
    return buffer.read()                   # return raw WAV bytes

# 3) STREAMLIT USER INTERFACE
st.set_page_config(page_title="Imagine & Narrate", page_icon="✨", layout="centered")
st.title("✨ Imagine & Narrate")
st.write("Upload any image below to see AI imagine and narrate a story about it!")

# -- Upload image widget --
uploaded = st.file_uploader(
    "Choose an image file",
    type=["jpg", "jpeg", "png"]
)
if not uploaded:
    st.info("➡️ Upload an image above to start the magic!")
    st.stop()

# Load the uploaded file into a PIL Image
try:
    img = Image.open(uploaded)
except Exception as e:
    st.error(f"Error loading image: {e}")
    st.stop()


# -- Step 1: Display the image --
st.subheader("📸 Your Visual Input")
st.image(img, use_container_width=True)
st.divider()

# -- Step 2: Generate and display caption --
st.subheader("🧠 Generating Insights")
with st.spinner("Scanning image for key elements…"):
    captioner = load_captioner()
    raw_caption = caption_image(img, captioner)
    if not raw_caption:
         st.warning("Could not generate a caption for the image.")
         st.stop()
    caption = sentence_case(raw_caption)
st.markdown(f"**Identified Scene:** {caption}")
st.divider()

# -- Step 3: Generate and display story --
st.subheader("📖 Crafting a Narrative")
with st.spinner("Writing a compelling story…"):
    story_pipe = load_story_pipe()
    story = story_from_caption(caption, story_pipe)
    if not story or story.strip() == '...': # Check for empty or minimal story
         st.warning("Could not generate a meaningful story from the caption.")
         st.stop()
st.write(story)
st.divider()

# -- Step 4: Synthesize and play audio --
st.subheader("👂 Hear the Story")
with st.spinner("Synthesizing audio narration…"):
    tts_pipe = load_tts_pipe()
    try:
        audio_bytes = tts_bytes(story, tts_pipe)
        st.audio(audio_bytes, format="audio/wav")
    except Exception as e:
        st.error(f"Error generating audio: {e}")


# Celebration animation
st.balloons()