LinkLinkWu's picture
Update func.py
79a38ee verified
raw
history blame
3.74 kB
# func.py ── utilities for Hugging Face Space
# Step1. Image to Text
from typing import Union
from pathlib import Path
from PIL import Image
from transformers import pipeline
# lazy-load caption model once
_captioner = None
def _get_captioner():
global _captioner
if _captioner is None:
_captioner = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-large"
)
return _captioner
def img2text(img: Union[Image.Image, str, Path]) -> str:
"""
Return a short English caption for an image.
Args:
img: PIL.Image, local path, or pathlib.Path.
Returns:
Caption string.
"""
# ensure PIL.Image
if not isinstance(img, Image.Image):
img = Image.open(img)
return _get_captioner()(img)[0]["generated_text"]
# -------------------------------------------------------------------
# Step 2. Caption ➜ Children’s story (DeepSeek-R1 1.5 B)
# -------------------------------------------------------------------
import torch, re
from transformers import pipeline
_GEN_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
_PROMPT_TMPL = (
"Write a funny and warm children's story (50-100 words) for ages 3-10, "
"fully and strictly based on this scene: {caption}\nStory:"
)
_generator = None
def _get_generator():
"""Lazy-load DeepSeek generator once (GPU if available)."""
global _generator
if _generator is None:
_generator = pipeline(
"text-generation",
model=_GEN_MODEL,
device=0 if torch.cuda.is_available() else -1,
max_new_tokens=150,
do_sample=True,
top_p=0.9,
temperature=0.8,
no_repeat_ngram_size=4, # ← block 4-gram repeats
repetition_penalty=1.15 # ← soften copy-loops
)
return _generator
def _dedup_sentences(text: str) -> str:
"""Remove exact duplicate sentences while preserving order."""
seen, cleaned = set(), []
for sent in re.split(r'(?<=[.!?])\s+', text.strip()):
s = sent.strip()
if s and s not in seen:
cleaned.append(s)
seen.add(s)
return " ".join(cleaned)
def text2story(caption: str) -> str:
"""
Generate a ≤100-word children’s story from the image caption.
Args:
caption: scene description string.
Returns:
Story text (plain string, ≤100 words, no exact duplicate sentences).
"""
prompt = _PROMPT_TMPL.format(caption=caption)
raw = _get_generator()(prompt, return_full_text=False)[0]["generated_text"]
story = _dedup_sentences(raw)
# ensure ending punctuation
if story and story[-1] not in ".!?":
story += "."
# hard cap at 100 words
return " ".join(story.split()[:100])
# Step3. Text to Audio
import numpy as np
import textwrap
import soundfile as sf
from transformers import pipeline
_TTS_MODEL = "facebook/mms-tts-eng"
_tts = None
def _get_tts():
"""Lazy-load the TTS pipeline once."""
global _tts
if _tts is None:
_tts = pipeline("text-to-speech", model=_TTS_MODEL)
return _tts
def story2audio(story: str, wav_path: str = "story.wav") -> str:
"""
Synthesize speech for a story and save as WAV.
Args:
story: Text returned by `text2story(...)`.
wav_path: Output file name.
Returns:
Path to the saved WAV file.
"""
tts = _get_tts()
chunks = textwrap.wrap(story, width=200) # long text → stable chunks
audio = np.concatenate([tts(c)["audio"].squeeze()
for c in chunks])
sf.write(wav_path, audio, tts.model.config.sampling_rate)
return wav_path