LinkLinkWu commited on
Commit
9634e12
·
verified ·
1 Parent(s): 79a38ee

Update func.py

Browse files
Files changed (1) hide show
  1. func.py +43 -27
func.py CHANGED
@@ -33,34 +33,27 @@ def img2text(img: Union[Image.Image, str, Path]) -> str:
33
  return _get_captioner()(img)[0]["generated_text"]
34
 
35
  # -------------------------------------------------------------------
36
- # Step 2. Caption ➜ Children’s story (DeepSeek-R1 1.5 B)
37
  # -------------------------------------------------------------------
38
  import torch, re
39
- from transformers import pipeline
40
 
41
- _GEN_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
42
  _PROMPT_TMPL = (
43
  "Write a funny and warm children's story (50-100 words) for ages 3-10, "
44
  "fully and strictly based on this scene: {caption}\nStory:"
45
  )
46
 
47
- _generator = None
48
- def _get_generator():
49
- """Lazy-load DeepSeek generator once (GPU if available)."""
50
- global _generator
51
- if _generator is None:
52
- _generator = pipeline(
53
- "text-generation",
54
- model=_GEN_MODEL,
55
- device=0 if torch.cuda.is_available() else -1,
56
- max_new_tokens=150,
57
- do_sample=True,
58
- top_p=0.9,
59
- temperature=0.8,
60
- no_repeat_ngram_size=4, # ← block 4-gram repeats
61
- repetition_penalty=1.15 # ← soften copy-loops
62
- )
63
- return _generator
64
 
65
 
66
  def _dedup_sentences(text: str) -> str:
@@ -76,7 +69,7 @@ def _dedup_sentences(text: str) -> str:
76
 
77
  def text2story(caption: str) -> str:
78
  """
79
- Generate a ≤100-word children’s story from the image caption.
80
 
81
  Args:
82
  caption: scene description string.
@@ -85,15 +78,38 @@ def text2story(caption: str) -> str:
85
  Story text (plain string, ≤100 words, no exact duplicate sentences).
86
  """
87
  prompt = _PROMPT_TMPL.format(caption=caption)
88
- raw = _get_generator()(prompt, return_full_text=False)[0]["generated_text"]
89
-
90
- story = _dedup_sentences(raw)
91
-
92
- # ensure ending punctuation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  if story and story[-1] not in ".!?":
94
  story += "."
95
 
96
- # hard cap at 100 words
97
  return " ".join(story.split()[:100])
98
 
99
  # Step3. Text to Audio
 
33
  return _get_captioner()(img)[0]["generated_text"]
34
 
35
  # -------------------------------------------------------------------
36
+ # Step 2. Caption ➜ Children’s story (BLOOM-560M)
37
  # -------------------------------------------------------------------
38
  import torch, re
39
+ from transformers import AutoTokenizer, AutoModelForCausalLM
40
 
 
41
  _PROMPT_TMPL = (
42
  "Write a funny and warm children's story (50-100 words) for ages 3-10, "
43
  "fully and strictly based on this scene: {caption}\nStory:"
44
  )
45
 
46
+ _tokenizer = None
47
+ _model = None
48
+ def _get_model_and_tokenizer():
49
+ """Lazy-load BLOOM-560M model and tokenizer once (GPU if available)."""
50
+ global _tokenizer, _model
51
+ if _tokenizer is None or _model is None:
52
+ _tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
53
+ _model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
54
+ if torch.cuda.is_available():
55
+ _model = _model.to("cuda")
56
+ return _tokenizer, _model
 
 
 
 
 
 
57
 
58
 
59
  def _dedup_sentences(text: str) -> str:
 
69
 
70
  def text2story(caption: str) -> str:
71
  """
72
+ Generate a ≤100-word children’s story from the image caption using BLOOM-560M.
73
 
74
  Args:
75
  caption: scene description string.
 
78
  Story text (plain string, ≤100 words, no exact duplicate sentences).
79
  """
80
  prompt = _PROMPT_TMPL.format(caption=caption)
81
+ tokenizer, model = _get_model_and_tokenizer()
82
+
83
+ # Tokenize input
84
+ inputs = tokenizer(prompt, return_tensors="pt")
85
+ if torch.cuda.is_available():
86
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
87
+
88
+ # Generate text
89
+ outputs = model.generate(
90
+ inputs["input_ids"],
91
+ max_new_tokens=150,
92
+ do_sample=True,
93
+ top_p=0.9,
94
+ temperature=0.8,
95
+ no_repeat_ngram_size=4, # Block 4-gram repeats
96
+ repetition_penalty=1.15, # Soften copy-loops
97
+ pad_token_id=tokenizer.eos_token_id
98
+ )
99
+
100
+ # Decode generated text
101
+ raw = tokenizer.decode(outputs[0], skip_special_tokens=True)
102
+ # Remove prompt from output
103
+ story = raw[len(prompt):].strip()
104
+
105
+ # Deduplicate sentences
106
+ story = _dedup_sentences(story)
107
+
108
+ # Ensure ending punctuation
109
  if story and story[-1] not in ".!?":
110
  story += "."
111
 
112
+ # Hard cap at 100 words
113
  return " ".join(story.split()[:100])
114
 
115
  # Step3. Text to Audio