LinkLinkWu's picture
Update func.py
9634e12 verified
# func.py ── utilities for Hugging Face Space
# Step1. Image to Text
from typing import Union
from pathlib import Path
from PIL import Image
from transformers import pipeline
# lazy-load caption model once
_captioner = None
def _get_captioner():
global _captioner
if _captioner is None:
_captioner = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-large"
)
return _captioner
def img2text(img: Union[Image.Image, str, Path]) -> str:
"""
Return a short English caption for an image.
Args:
img: PIL.Image, local path, or pathlib.Path.
Returns:
Caption string.
"""
# ensure PIL.Image
if not isinstance(img, Image.Image):
img = Image.open(img)
return _get_captioner()(img)[0]["generated_text"]
# -------------------------------------------------------------------
# Step 2. Caption ➜ Children’s story (BLOOM-560M)
# -------------------------------------------------------------------
import torch, re
from transformers import AutoTokenizer, AutoModelForCausalLM
_PROMPT_TMPL = (
"Write a funny and warm children's story (50-100 words) for ages 3-10, "
"fully and strictly based on this scene: {caption}\nStory:"
)
_tokenizer = None
_model = None
def _get_model_and_tokenizer():
"""Lazy-load BLOOM-560M model and tokenizer once (GPU if available)."""
global _tokenizer, _model
if _tokenizer is None or _model is None:
_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
if torch.cuda.is_available():
_model = _model.to("cuda")
return _tokenizer, _model
def _dedup_sentences(text: str) -> str:
"""Remove exact duplicate sentences while preserving order."""
seen, cleaned = set(), []
for sent in re.split(r'(?<=[.!?])\s+', text.strip()):
s = sent.strip()
if s and s not in seen:
cleaned.append(s)
seen.add(s)
return " ".join(cleaned)
def text2story(caption: str) -> str:
"""
Generate a ≤100-word children’s story from the image caption using BLOOM-560M.
Args:
caption: scene description string.
Returns:
Story text (plain string, ≤100 words, no exact duplicate sentences).
"""
prompt = _PROMPT_TMPL.format(caption=caption)
tokenizer, model = _get_model_and_tokenizer()
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Generate text
outputs = model.generate(
inputs["input_ids"],
max_new_tokens=150,
do_sample=True,
top_p=0.9,
temperature=0.8,
no_repeat_ngram_size=4, # Block 4-gram repeats
repetition_penalty=1.15, # Soften copy-loops
pad_token_id=tokenizer.eos_token_id
)
# Decode generated text
raw = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove prompt from output
story = raw[len(prompt):].strip()
# Deduplicate sentences
story = _dedup_sentences(story)
# Ensure ending punctuation
if story and story[-1] not in ".!?":
story += "."
# Hard cap at 100 words
return " ".join(story.split()[:100])
# Step3. Text to Audio
import numpy as np
import textwrap
import soundfile as sf
from transformers import pipeline
_TTS_MODEL = "facebook/mms-tts-eng"
_tts = None
def _get_tts():
"""Lazy-load the TTS pipeline once."""
global _tts
if _tts is None:
_tts = pipeline("text-to-speech", model=_TTS_MODEL)
return _tts
def story2audio(story: str, wav_path: str = "story.wav") -> str:
"""
Synthesize speech for a story and save as WAV.
Args:
story: Text returned by `text2story(...)`.
wav_path: Output file name.
Returns:
Path to the saved WAV file.
"""
tts = _get_tts()
chunks = textwrap.wrap(story, width=200) # long text → stable chunks
audio = np.concatenate([tts(c)["audio"].squeeze()
for c in chunks])
sf.write(wav_path, audio, tts.model.config.sampling_rate)
return wav_path