LinkLinkWu's picture
Update func.py
c212840 verified
# func.py ── utilities for Hugging Face Space
# Step1. Image to Text
from typing import Union
from pathlib import Path
from PIL import Image
from transformers import pipeline
# lazy-load caption model once
_captioner = None
def _get_captioner():
global _captioner
if _captioner is None:
_captioner = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-large"
)
return _captioner
def img2text(img: Union[Image.Image, str, Path]) -> str:
"""
Return a short English caption for an image.
Args:
img: PIL.Image, local path, or pathlib.Path.
Returns:
Caption string.
"""
# ensure PIL.Image
if not isinstance(img, Image.Image):
img = Image.open(img)
return _get_captioner()(img)[0]["generated_text"]
# Step2. Text Generation (Based on Caption)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
_MODEL_NAME = "aspis/gpt2-genre-story-generation"
_PROMPT = (
"Write a funny and warm children's story (50-100 words) for ages 3-10, "
"fully and strictly based on this scene: {caption}\nStory:"
)
_tokenizer, _model = None, None
def _load_story_model():
"""Lazy-load tokenizer / model once."""
global _tokenizer, _model
if _model is None:
_tokenizer = AutoTokenizer.from_pretrained(_MODEL_NAME)
_model = AutoModelForCausalLM.from_pretrained(_MODEL_NAME)
if torch.cuda.is_available():
_model = _model.to("cuda")
return _tokenizer, _model
def text2story(caption: str) -> str:
"""
Generate a 50-100-word children’s story from an image caption.
Args:
caption: Scene description string.
Returns:
Story text (≤100 words).
"""
tok, mdl = _load_story_model()
prompt = _PROMPT.format(caption=caption)
inputs = tok(prompt, return_tensors="pt", add_special_tokens=False)
if mdl.device.type == "cuda":
inputs = {k: v.to("cuda") for k, v in inputs.items()}
gen_ids = mdl.generate(
**inputs,
max_new_tokens=150,
do_sample=True,
top_p=0.9,
temperature=0.8,
pad_token_id=tok.eos_token_id,
repetition_penalty=1.1
)[0]
# drop prompt, decode, keep ≤100 words, end at last period
story_ids = gen_ids[inputs["input_ids"].shape[-1]:]
story = tok.decode(story_ids, skip_special_tokens=True).strip()
story = story[: story.rfind(".") + 1] if "." in story else story
return " ".join(story.split()[:100])
# Step3. Text to Audio
import numpy as np
import textwrap
import soundfile as sf
from transformers import pipeline
_TTS_MODEL = "facebook/mms-tts-eng"
_tts = None
def _get_tts():
"""Lazy-load the TTS pipeline once."""
global _tts
if _tts is None:
_tts = pipeline("text-to-speech", model=_TTS_MODEL)
return _tts
def story2audio(story: str, wav_path: str = "story.wav") -> str:
"""
Synthesize speech for a story and save as WAV.
Args:
story: Text returned by `text2story(...)`.
wav_path: Output file name.
Returns:
Path to the saved WAV file.
"""
tts = _get_tts()
chunks = textwrap.wrap(story, width=200) # long text → stable chunks
audio = np.concatenate([tts(c)["audio"].squeeze()
for c in chunks])
sf.write(wav_path, audio, tts.model.config.sampling_rate)
return wav_path