File size: 5,235 Bytes
d732c64
 
 
 
 
 
 
92fdc22
 
a791bff
 
b88893e
 
 
92fdc22
b88893e
a791bff
 
b88893e
 
 
 
92fdc22
b88893e
a791bff
 
 
b88893e
 
 
92fdc22
272411e
a791bff
b88893e
 
a791bff
 
 
92fdc22
b88893e
92fdc22
 
 
a791bff
92fdc22
a791bff
92fdc22
 
 
b88893e
 
92fdc22
 
b88893e
92fdc22
b88893e
 
92fdc22
 
1292c85
b88893e
 
1292c85
654e0e4
1292c85
 
 
 
 
 
 
b88893e
 
92fdc22
 
 
 
b88893e
 
654e0e4
f137cdb
b88893e
92fdc22
 
 
 
 
 
 
 
b88893e
92fdc22
 
 
 
 
 
 
 
b88893e
92fdc22
 
 
b88893e
 
 
 
92fdc22
 
f74c9b9
92fdc22
 
b88893e
92fdc22
 
 
b88893e
92fdc22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b88893e
92fdc22
 
 
 
 
 
 
 
 
 
 
b88893e
 
92fdc22
 
 
 
 
 
 
 
 
b88893e
92fdc22
 
 
 
 
29df25d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import io                   
import wave                 
import re                   
import streamlit as st      
from transformers import pipeline  
from PIL import Image       
import numpy as np          

# 1) CACHE & LOAD MODELS (CPU only)
@st.cache_resource(show_spinner=False)
def load_captioner():
    return pipeline(
        "image-to-text",
        model="Salesforce/blip-image-captioning-base",
        device=-1  # force CPU
    )

@st.cache_resource(show_spinner=False)
def load_story_pipe():
    return pipeline(
        "text2text-generation",
        model="google/flan-t5-base",
        device=-1  # force CPU
    )

@st.cache_resource(show_spinner=False)
def load_tts_pipe():
    return pipeline(
        "text-to-speech",
        model="facebook/mms-tts-eng",
        device=-1  # force CPU
    )

# 2) HELPER FUNCTIONS
def sentence_case(text: str) -> str:
    parts = re.split(r'([.!?])', text)
    out = []
    for i in range(0, len(parts) - 1, 2):
        sentence = parts[i].strip()
        delimiter = parts[i + 1]
        if sentence:
            formatted = sentence[0].upper() + sentence[1:]
            out.append(f"{formatted}{delimiter}")
    if len(parts) % 2:
        last = parts[-1].strip()
        if last:
            formatted = last[0].upper() + last[1:]
            out.append(formatted)
    return " ".join(" ".join(out).split())

def caption_image(img: Image.Image, captioner) -> str:
    if img.mode != "RGB":
        img = img.convert("RGB")
    results = captioner(img)
    return (results[0].get("generated_text", "") if results else "")

def story_from_caption(caption: str, pipe) -> str:
    if not caption:
        return "Could not generate a story without a caption."
    prompt = f"Write a creative imaginary 50–100 word story using this scene: {caption}"
    results = pipe(
        prompt,
        max_length=100,
        min_length=80,
        do_sample=True,           
        top_k=100,                
        top_p=0.9,              
        temperature=0.5,         
        repetition_penalty=1.1,  
        no_repeat_ngram_size=4,   
        early_stopping=False
    )
    raw = results[0]["generated_text"].strip()
    raw = re.sub(re.escape(prompt), "", raw, flags=re.IGNORECASE).strip()
    idx = max(raw.rfind("."), raw.rfind("!"), raw.rfind("?"))
    if idx != -1:
        raw = raw[:idx+1]
    return sentence_case(raw)



def tts_bytes(text: str, tts_pipe) -> bytes:
    if not text:
        return b""
    cleaned = re.sub(r'^["\']|["\']$', '', text).strip()
    cleaned = re.sub(r'\.{2,}', '.', cleaned).replace('…', '...')
    if cleaned[-1] not in ".!?":
        cleaned += "."
    cleaned = " ".join(cleaned.split())
    output = tts_pipe(cleaned)
    result = output[0] if isinstance(output, list) else output
    audio_array = result.get("audio")
    rate = result.get("sampling_rate")
    if audio_array is None or rate is None:
        return b""
    if audio_array.ndim == 1:
        data = audio_array[:, np.newaxis]
    else:
        data = audio_array.T
    pcm = (data * 32767).astype(np.int16)
    buf = io.BytesIO()
    wf = wave.open(buf, "wb")
    wf.setnchannels(data.shape[1])
    wf.setsampwidth(2)
    wf.setframerate(rate)
    wf.writeframes(pcm.tobytes())
    wf.close()
    buf.seek(0)
    return buf.read()

# 3) STREAMLIT USER INTERFACE
st.set_page_config(page_title="✨ Imagine & Narrate", page_icon="✨", layout="centered")

# Persist upload across reruns
if "uploaded_file" not in st.session_state:
    st.session_state.uploaded_file = None

new_upload = st.file_uploader(
    "Choose an image file",
    type=["jpg", "jpeg", "png"]
)
if new_upload is not None:
    st.session_state.uploaded_file = new_upload

if st.session_state.uploaded_file is None:
    st.title("✨ Imagine & Narrate")
    st.info("➡️ Upload an image above to start the magic!")
    st.stop()

uploaded = st.session_state.uploaded_file
try:
    img = Image.open(uploaded)
except Exception as e:
    st.error(f"Could not load the image: {e}")
    st.stop()

st.title("✨ Imagine & Narrate")
st.subheader("📸 Your Visual Input")
st.image(img, caption=uploaded.name, use_container_width=True)
st.divider()

# Step 1: Generate Caption
st.subheader("🧠 Generating Caption")
with st.spinner("Analyzing image..."):
    captioner = load_captioner()
    raw_caption = caption_image(img, captioner)
    if not raw_caption:
        st.error("Failed to generate caption.")
        st.stop()
    caption = sentence_case(raw_caption)
st.markdown(f"**Identified Scene:** {caption}")
st.divider()

# Step 2: Generate Story
st.subheader("📖 Crafting a Story")
with st.spinner("Writing story..."):
    story_pipe = load_story_pipe()
    story = story_from_caption(caption, story_pipe)
    if not story or story.strip() in {".", "..", "..."}:
        st.error("Failed to generate story.")
        st.stop()
st.write(story)
st.divider()

# Step 3: Synthesize Audio
st.subheader("👂 Hear the Story")
with st.spinner("Synthesizing audio..."):
    tts_pipe = load_tts_pipe()
    audio_bytes = tts_bytes(story, tts_pipe)
    if not audio_bytes:
        st.warning("Audio generation failed.")
    else:
        st.audio(audio_bytes, format="audio/wav")
st.balloons()