File size: 3,473 Bytes
75cc814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c212840
75cc814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# func.py  ── utilities for Hugging Face Space

# Step1. Image to Text
from typing import Union
from pathlib import Path
from PIL import Image
from transformers import pipeline

# lazy-load caption model once
_captioner = None
def _get_captioner():
    global _captioner
    if _captioner is None:
        _captioner = pipeline(
            "image-to-text",
            model="Salesforce/blip-image-captioning-large"
        )
    return _captioner

def img2text(img: Union[Image.Image, str, Path]) -> str:
    """
    Return a short English caption for an image.

    Args:
        img: PIL.Image, local path, or pathlib.Path.

    Returns:
        Caption string.
    """
    # ensure PIL.Image
    if not isinstance(img, Image.Image):
        img = Image.open(img)
    return _get_captioner()(img)[0]["generated_text"]

# Step2. Text Generation (Based on Caption)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

_MODEL_NAME = "aspis/gpt2-genre-story-generation"
_PROMPT     = (
    "Write a funny and warm children's story (50-100 words) for ages 3-10, "
    "fully and strictly based on this scene: {caption}\nStory:"
)

_tokenizer, _model = None, None
def _load_story_model():
    """Lazy-load tokenizer / model once."""
    global _tokenizer, _model
    if _model is None:
        _tokenizer = AutoTokenizer.from_pretrained(_MODEL_NAME)
        _model     = AutoModelForCausalLM.from_pretrained(_MODEL_NAME)
        if torch.cuda.is_available():
            _model = _model.to("cuda")
    return _tokenizer, _model


def text2story(caption: str) -> str:
    """
    Generate a 50-100-word children’s story from an image caption.

    Args:
        caption: Scene description string.

    Returns:
        Story text (≤100 words).
    """
    tok, mdl = _load_story_model()

    prompt  = _PROMPT.format(caption=caption)
    inputs  = tok(prompt, return_tensors="pt", add_special_tokens=False)
    if mdl.device.type == "cuda":
        inputs = {k: v.to("cuda") for k, v in inputs.items()}

    gen_ids = mdl.generate(
        **inputs,
        max_new_tokens=150,
        do_sample=True,
        top_p=0.9,
        temperature=0.8,
        pad_token_id=tok.eos_token_id,
        repetition_penalty=1.1
    )[0]

    # drop prompt, decode, keep ≤100 words, end at last period
    story_ids = gen_ids[inputs["input_ids"].shape[-1]:]
    story     = tok.decode(story_ids, skip_special_tokens=True).strip()
    story     = story[: story.rfind(".") + 1] if "." in story else story
    return " ".join(story.split()[:100])

# Step3. Text to Audio
import numpy as np
import textwrap
import soundfile as sf
from transformers import pipeline

_TTS_MODEL = "facebook/mms-tts-eng"
_tts = None
def _get_tts():
    """Lazy-load the TTS pipeline once."""
    global _tts
    if _tts is None:
        _tts = pipeline("text-to-speech", model=_TTS_MODEL)
    return _tts


def story2audio(story: str, wav_path: str = "story.wav") -> str:
    """
    Synthesize speech for a story and save as WAV.

    Args:
        story: Text returned by `text2story(...)`.
        wav_path: Output file name.

    Returns:
        Path to the saved WAV file.
    """
    tts = _get_tts()
    chunks = textwrap.wrap(story, width=200)               # long text → stable chunks
    audio  = np.concatenate([tts(c)["audio"].squeeze()
                             for c in chunks])
    sf.write(wav_path, audio, tts.model.config.sampling_rate)
    return wav_path