LinkLinkWu commited on
Commit
79a38ee
·
verified ·
1 Parent(s): 8ac7c32

Update func.py

Browse files
Files changed (1) hide show
  1. func.py +25 -14
func.py CHANGED
@@ -32,9 +32,10 @@ def img2text(img: Union[Image.Image, str, Path]) -> str:
32
  img = Image.open(img)
33
  return _get_captioner()(img)[0]["generated_text"]
34
 
35
- # Step 2. Caption ➜ Children’s story (DeepSeek-R1 1.5 B)
36
  # -------------------------------------------------------------------
37
- import torch
 
 
38
  from transformers import pipeline
39
 
40
  _GEN_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
@@ -52,15 +53,27 @@ def _get_generator():
52
  "text-generation",
53
  model=_GEN_MODEL,
54
  device=0 if torch.cuda.is_available() else -1,
55
- # common decoding params – can still be overridden in the call
56
  max_new_tokens=150,
57
  do_sample=True,
58
  top_p=0.9,
59
  temperature=0.8,
 
 
60
  )
61
  return _generator
62
 
63
 
 
 
 
 
 
 
 
 
 
 
 
64
  def text2story(caption: str) -> str:
65
  """
66
  Generate a ≤100-word children’s story from the image caption.
@@ -69,18 +82,16 @@ def text2story(caption: str) -> str:
69
  caption: scene description string.
70
 
71
  Returns:
72
- Story text (plain string, trimmed to ≤100 words).
73
  """
74
- prompt = _PROMPT_TMPL.format(caption=caption)
75
- gen = _get_generator()(
76
- prompt,
77
- return_full_text=False # only the completion, not the prompt
78
- )[0]["generated_text"]
79
-
80
- # ensure last sentence is closed
81
- story = gen.strip()
82
- if "." in story:
83
- story = story[: story.rfind(".") + 1]
84
 
85
  # hard cap at 100 words
86
  return " ".join(story.split()[:100])
 
32
  img = Image.open(img)
33
  return _get_captioner()(img)[0]["generated_text"]
34
 
 
35
  # -------------------------------------------------------------------
36
+ # Step 2. Caption ➜ Children’s story (DeepSeek-R1 1.5 B)
37
+ # -------------------------------------------------------------------
38
+ import torch, re
39
  from transformers import pipeline
40
 
41
  _GEN_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
 
53
  "text-generation",
54
  model=_GEN_MODEL,
55
  device=0 if torch.cuda.is_available() else -1,
 
56
  max_new_tokens=150,
57
  do_sample=True,
58
  top_p=0.9,
59
  temperature=0.8,
60
+ no_repeat_ngram_size=4, # ← block 4-gram repeats
61
+ repetition_penalty=1.15 # ← soften copy-loops
62
  )
63
  return _generator
64
 
65
 
66
+ def _dedup_sentences(text: str) -> str:
67
+ """Remove exact duplicate sentences while preserving order."""
68
+ seen, cleaned = set(), []
69
+ for sent in re.split(r'(?<=[.!?])\s+', text.strip()):
70
+ s = sent.strip()
71
+ if s and s not in seen:
72
+ cleaned.append(s)
73
+ seen.add(s)
74
+ return " ".join(cleaned)
75
+
76
+
77
  def text2story(caption: str) -> str:
78
  """
79
  Generate a ≤100-word children’s story from the image caption.
 
82
  caption: scene description string.
83
 
84
  Returns:
85
+ Story text (plain string, ≤100 words, no exact duplicate sentences).
86
  """
87
+ prompt = _PROMPT_TMPL.format(caption=caption)
88
+ raw = _get_generator()(prompt, return_full_text=False)[0]["generated_text"]
89
+
90
+ story = _dedup_sentences(raw)
91
+
92
+ # ensure ending punctuation
93
+ if story and story[-1] not in ".!?":
94
+ story += "."
 
 
95
 
96
  # hard cap at 100 words
97
  return " ".join(story.split()[:100])