|
import streamlit as st |
|
from transformers import AutoProcessor, AutoModelForImageTextToText, pipeline |
|
import torch |
|
from PIL import Image |
|
import io |
|
import numpy as np |
|
from kokoro import KPipeline |
|
|
|
|
|
|
|
|
|
processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large") |
|
caption_model = AutoModelForImageTextToText.from_pretrained("Salesforce/blip-image-captioning-large") |
|
|
|
|
|
story_generator = pipeline("text-generation", model="deepseek-ai/DeepSeek-R1-Distill-Qwen-14B") |
|
|
|
|
|
audio_pipeline = KPipeline(lang_code='a') |
|
|
|
|
|
def generate_caption(image_bytes): |
|
image = Image.open(io.BytesIO(image_bytes)) |
|
inputs = processor(images=image, text="Generate a caption:", return_tensors="pt") |
|
outputs = caption_model.generate(**inputs) |
|
caption = processor.decode(outputs[0], skip_special_tokens=True) |
|
return caption |
|
|
|
|
|
def generate_story(caption): |
|
prompt = f"Based on the description '{caption}', tell a short story for children aged 3 to 10 in no more than 100 words." |
|
story_output = story_generator(prompt, max_length=150, num_return_sequences=1) |
|
story = story_output[0]["generated_text"] |
|
|
|
story_words = story.split() |
|
if len(story_words) > 100: |
|
story = " ".join(story_words[:100]) |
|
return story |
|
|
|
|
|
def generate_audio(story): |
|
audio_generator = audio_pipeline( |
|
story, voice='af_heart', speed=1, split_pattern=r'\n+' |
|
) |
|
audio_segments = [] |
|
|
|
for i, (gs, ps, audio) in enumerate(audio_generator): |
|
audio_segments.append(audio) |
|
if not audio_segments: |
|
return None |
|
|
|
concatenated_audio = np.concatenate(audio_segments) |
|
|
|
audio_buffer = io.BytesIO() |
|
sf.write(audio_buffer, concatenated_audio, 24000, format='WAV') |
|
audio_buffer.seek(0) |
|
return audio_buffer |
|
|
|
|
|
st.title("Image to Story Audio Generator") |
|
st.write("Upload an image to generate a short children's story (≤100 words) as audio.") |
|
|
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_file is not None: |
|
image_bytes = uploaded_file.read() |
|
st.image(image_bytes, caption="Uploaded Image", use_column_width=True) |
|
|
|
|
|
with st.spinner("Generating caption..."): |
|
caption = generate_caption(image_bytes) |
|
st.write("**Generated Caption:**") |
|
st.write(caption) |
|
|
|
|
|
with st.spinner("Generating story..."): |
|
story = generate_story(caption) |
|
st.write("**Generated Story:**") |
|
st.write(story) |
|
|
|
|
|
with st.spinner("Generating audio..."): |
|
audio_buffer = generate_audio(story) |
|
if audio_buffer: |
|
st.audio(audio_buffer, format="audio/wav") |
|
st.download_button( |
|
label="Download Story Audio", |
|
data=audio_buffer, |
|
file_name="story_audio.wav", |
|
mime="audio/wav" |
|
) |
|
else: |
|
st.error("Failed to generate audio.") |