File size: 4,280 Bytes
75cc814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ac7c32
9634e12
79a38ee
 
9634e12
75cc814
8ac7c32
75cc814
c212840
75cc814
 
9634e12
 
 
 
 
 
 
 
 
 
 
75cc814
 
79a38ee
 
 
 
 
 
 
 
 
 
 
75cc814
 
9634e12
75cc814
 
8ac7c32
75cc814
 
79a38ee
75cc814
79a38ee
9634e12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79a38ee
 
8ac7c32
9634e12
75cc814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# func.py  ── utilities for Hugging Face Space

# Step1. Image to Text
from typing import Union
from pathlib import Path
from PIL import Image
from transformers import pipeline

# lazy-load caption model once
_captioner = None
def _get_captioner():
    global _captioner
    if _captioner is None:
        _captioner = pipeline(
            "image-to-text",
            model="Salesforce/blip-image-captioning-large"
        )
    return _captioner

def img2text(img: Union[Image.Image, str, Path]) -> str:
    """
    Return a short English caption for an image.

    Args:
        img: PIL.Image, local path, or pathlib.Path.

    Returns:
        Caption string.
    """
    # ensure PIL.Image
    if not isinstance(img, Image.Image):
        img = Image.open(img)
    return _get_captioner()(img)[0]["generated_text"]

# -------------------------------------------------------------------
# Step 2. Caption ➜ Children’s story (BLOOM-560M)
# -------------------------------------------------------------------
import torch, re
from transformers import AutoTokenizer, AutoModelForCausalLM

_PROMPT_TMPL = (
    "Write a funny and warm children's story (50-100 words) for ages 3-10, "
    "fully and strictly based on this scene: {caption}\nStory:"
)

_tokenizer = None
_model = None
def _get_model_and_tokenizer():
    """Lazy-load BLOOM-560M model and tokenizer once (GPU if available)."""
    global _tokenizer, _model
    if _tokenizer is None or _model is None:
        _tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
        _model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
        if torch.cuda.is_available():
            _model = _model.to("cuda")
    return _tokenizer, _model


def _dedup_sentences(text: str) -> str:
    """Remove exact duplicate sentences while preserving order."""
    seen, cleaned = set(), []
    for sent in re.split(r'(?<=[.!?])\s+', text.strip()):
        s = sent.strip()
        if s and s not in seen:
            cleaned.append(s)
            seen.add(s)
    return " ".join(cleaned)


def text2story(caption: str) -> str:
    """
    Generate a ≤100-word children’s story from the image caption using BLOOM-560M.

    Args:
        caption: scene description string.

    Returns:
        Story text (plain string, ≤100 words, no exact duplicate sentences).
    """
    prompt = _PROMPT_TMPL.format(caption=caption)
    tokenizer, model = _get_model_and_tokenizer()

    # Tokenize input
    inputs = tokenizer(prompt, return_tensors="pt")
    if torch.cuda.is_available():
        inputs = {k: v.to("cuda") for k, v in inputs.items()}

    # Generate text
    outputs = model.generate(
        inputs["input_ids"],
        max_new_tokens=150,
        do_sample=True,
        top_p=0.9,
        temperature=0.8,
        no_repeat_ngram_size=4,    # Block 4-gram repeats
        repetition_penalty=1.15,   # Soften copy-loops
        pad_token_id=tokenizer.eos_token_id
    )

    # Decode generated text
    raw = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Remove prompt from output
    story = raw[len(prompt):].strip()

    # Deduplicate sentences
    story = _dedup_sentences(story)

    # Ensure ending punctuation
    if story and story[-1] not in ".!?":
        story += "."

    # Hard cap at 100 words
    return " ".join(story.split()[:100])

# Step3. Text to Audio
import numpy as np
import textwrap
import soundfile as sf
from transformers import pipeline

_TTS_MODEL = "facebook/mms-tts-eng"
_tts = None
def _get_tts():
    """Lazy-load the TTS pipeline once."""
    global _tts
    if _tts is None:
        _tts = pipeline("text-to-speech", model=_TTS_MODEL)
    return _tts


def story2audio(story: str, wav_path: str = "story.wav") -> str:
    """
    Synthesize speech for a story and save as WAV.

    Args:
        story: Text returned by `text2story(...)`.
        wav_path: Output file name.

    Returns:
        Path to the saved WAV file.
    """
    tts = _get_tts()
    chunks = textwrap.wrap(story, width=200)               # long text → stable chunks
    audio  = np.concatenate([tts(c)["audio"].squeeze()
                             for c in chunks])
    sf.write(wav_path, audio, tts.model.config.sampling_rate)
    return wav_path