Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import os, re
|
2 |
from typing import List
|
3 |
import torch
|
4 |
import gradio as gr
|
@@ -8,19 +8,16 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
8 |
MODEL_ID = os.getenv("MODEL_ID", "Eemansleepdeprived/Humaneyes")
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
|
11 |
-
# Lazy-loaded
|
12 |
_tokenizer = None
|
13 |
_model = None
|
14 |
|
15 |
def load_model():
|
16 |
-
"""Load the model on first use (speeds up Space startup)."""
|
17 |
global _tokenizer, _model
|
18 |
if _tokenizer is None or _model is None:
|
19 |
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
20 |
_model = AutoModelForSeq2SeqLM.from_pretrained(
|
21 |
-
MODEL_ID,
|
22 |
-
low_cpu_mem_usage=True,
|
23 |
-
torch_dtype=torch.float32, # CPU safe
|
24 |
).to(device).eval()
|
25 |
return _tokenizer, _model
|
26 |
|
@@ -45,29 +42,10 @@ def protect(text: str):
|
|
45 |
|
46 |
def restore(text: str, protected: List[str]):
|
47 |
def unwrap(m):
|
48 |
-
idx = int(m.group(1))
|
49 |
-
return protected[idx]
|
50 |
text = re.sub(rf"{SENTINEL_OPEN}(\d+){SENTINEL_CLOSE}", unwrap, text)
|
51 |
return text.replace(SENTINEL_OPEN, "").replace(SENTINEL_CLOSE, "")
|
52 |
|
53 |
-
SYSTEM_RULES = (
|
54 |
-
"Rewrite the text to sound natural, nuanced, and human while preserving meaning.\n"
|
55 |
-
"Rules:\n"
|
56 |
-
"1) Do not change anything between §§KEEP_OPEN§§<id>§§KEEP_CLOSE§§.\n"
|
57 |
-
"2) Keep citations, links, numbers, and code exactly the same.\n"
|
58 |
-
"3) Keep facts the same. Improve clarity, flow, and rhythm. Vary sentence length.\n"
|
59 |
-
"4) No em dashes; use simple punctuation.\n"
|
60 |
-
"5) Keep the requested tone and region.\n"
|
61 |
-
)
|
62 |
-
|
63 |
-
def build_input(text: str, tone: str, region: str, level: str, intensity: int) -> str:
|
64 |
-
return (
|
65 |
-
f"{SYSTEM_RULES}\n"
|
66 |
-
f"Tone: {tone}. Region: {region} English. Reading level: {level}. "
|
67 |
-
f"Humanization intensity: {intensity} (10 strongest).\n\n"
|
68 |
-
f"INPUT:\n{text}\n\nOUTPUT:\n"
|
69 |
-
)
|
70 |
-
|
71 |
def chunk_text(s: str, max_chars: int = 1100):
|
72 |
parts, buf, cur = [], [], 0
|
73 |
for block in re.split(r"(\n{2,})", s):
|
@@ -79,33 +57,65 @@ def chunk_text(s: str, max_chars: int = 1100):
|
|
79 |
return parts
|
80 |
|
81 |
@torch.inference_mode()
|
82 |
-
def
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
)
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
96 |
protected_text, bag = protect(text)
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
final_text = restore(draft, bag)
|
|
|
|
|
102 |
for i, span in enumerate(bag):
|
103 |
marker = f"{SENTINEL_OPEN}{i}{SENTINEL_CLOSE}"
|
104 |
if marker in protected_text and span not in final_text:
|
105 |
final_text = final_text.replace(marker, span)
|
|
|
106 |
return final_text
|
107 |
|
108 |
-
# ---------- Gradio UI
|
109 |
def ui_humanize(text, tone, region, reading_level, intensity):
|
110 |
return humanize_core(text, tone, region, reading_level, int(intensity))
|
111 |
|
@@ -120,9 +130,8 @@ demo = gr.Interface(
|
|
120 |
],
|
121 |
outputs=gr.Textbox(label="Humanized"),
|
122 |
title="NoteCraft Humanizer (Humaneyes)",
|
123 |
-
description="
|
124 |
).queue()
|
125 |
|
126 |
if __name__ == "__main__":
|
127 |
-
# No heavy work here—model loads on first call
|
128 |
demo.launch()
|
|
|
1 |
+
import os, re, difflib
|
2 |
from typing import List
|
3 |
import torch
|
4 |
import gradio as gr
|
|
|
8 |
MODEL_ID = os.getenv("MODEL_ID", "Eemansleepdeprived/Humaneyes")
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
|
11 |
+
# Lazy-loaded
|
12 |
_tokenizer = None
|
13 |
_model = None
|
14 |
|
15 |
def load_model():
|
|
|
16 |
global _tokenizer, _model
|
17 |
if _tokenizer is None or _model is None:
|
18 |
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
19 |
_model = AutoModelForSeq2SeqLM.from_pretrained(
|
20 |
+
MODEL_ID, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
|
|
|
21 |
).to(device).eval()
|
22 |
return _tokenizer, _model
|
23 |
|
|
|
42 |
|
43 |
def restore(text: str, protected: List[str]):
|
44 |
def unwrap(m):
|
45 |
+
idx = int(m.group(1)); return protected[idx]
|
|
|
46 |
text = re.sub(rf"{SENTINEL_OPEN}(\d+){SENTINEL_CLOSE}", unwrap, text)
|
47 |
return text.replace(SENTINEL_OPEN, "").replace(SENTINEL_CLOSE, "")
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
def chunk_text(s: str, max_chars: int = 1100):
|
50 |
parts, buf, cur = [], [], 0
|
51 |
for block in re.split(r"(\n{2,})", s):
|
|
|
57 |
return parts
|
58 |
|
59 |
@torch.inference_mode()
|
60 |
+
def generate_raw(prompt: str, beams: int = 5, temp: float = None, top_p: float = None, max_new: int = 256) -> str:
|
61 |
+
tok, mdl = load_model()
|
62 |
+
enc = tok(prompt, return_tensors="pt", truncation=True).to(device)
|
63 |
+
gen_kwargs = dict(max_new_tokens=max_new, no_repeat_ngram_size=3)
|
64 |
+
if temp is None:
|
65 |
+
# deterministic beam search (good first try)
|
66 |
+
out = mdl.generate(**enc, num_beams=beams, do_sample=False, **gen_kwargs)
|
67 |
+
else:
|
68 |
+
# stronger change fallback
|
69 |
+
out = mdl.generate(**enc, do_sample=True, temperature=temp, top_p=top_p or 0.9, num_beams=1, **gen_kwargs)
|
70 |
+
return tok.decode(out[0], skip_special_tokens=True)
|
71 |
+
|
72 |
+
def difference_ratio(a: str, b: str) -> float:
|
73 |
+
return difflib.SequenceMatcher(None, a, b).ratio()
|
74 |
+
|
75 |
+
def humanize_core(text: str, tone: str, region: str, level: str, intensity: int):
|
76 |
+
# 1) Protect spans
|
77 |
protected_text, bag = protect(text)
|
78 |
+
|
79 |
+
# 2) Create short, model-friendly prompts
|
80 |
+
# Many Pegasus/T5 paraphrasers respond to one of these:
|
81 |
+
prompts = [
|
82 |
+
f"humanize: {protected_text}",
|
83 |
+
f"paraphrase: {protected_text}",
|
84 |
+
protected_text,
|
85 |
+
]
|
86 |
+
|
87 |
+
# 3) Try deterministic first
|
88 |
+
draft = None
|
89 |
+
for p in prompts:
|
90 |
+
out = generate_raw(p, beams=5, max_new=320)
|
91 |
+
if out.strip():
|
92 |
+
draft = out.strip()
|
93 |
+
if difference_ratio(protected_text, draft) < 0.98: # changed enough
|
94 |
+
break
|
95 |
+
|
96 |
+
# 4) If barely changed, try a stronger pass (sampling)
|
97 |
+
if draft is None or difference_ratio(protected_text, draft) >= 0.98:
|
98 |
+
for p in prompts:
|
99 |
+
out = generate_raw(p, temp=0.8, top_p=0.92, max_new=320)
|
100 |
+
if out.strip():
|
101 |
+
draft = out.strip()
|
102 |
+
break
|
103 |
+
|
104 |
+
if not draft:
|
105 |
+
draft = protected_text # absolute fallback
|
106 |
+
|
107 |
+
draft = draft.replace("—", "-") # enforce simple punctuation
|
108 |
final_text = restore(draft, bag)
|
109 |
+
|
110 |
+
# paranoia: ensure protected spans survived
|
111 |
for i, span in enumerate(bag):
|
112 |
marker = f"{SENTINEL_OPEN}{i}{SENTINEL_CLOSE}"
|
113 |
if marker in protected_text and span not in final_text:
|
114 |
final_text = final_text.replace(marker, span)
|
115 |
+
|
116 |
return final_text
|
117 |
|
118 |
+
# ---------- Gradio UI + REST (/api/predict/) ----------
|
119 |
def ui_humanize(text, tone, region, reading_level, intensity):
|
120 |
return humanize_core(text, tone, region, reading_level, int(intensity))
|
121 |
|
|
|
130 |
],
|
131 |
outputs=gr.Textbox(label="Humanized"),
|
132 |
title="NoteCraft Humanizer (Humaneyes)",
|
133 |
+
description="REST: POST /api/predict/ with { data: [text,tone,region,level,intensity] }",
|
134 |
).queue()
|
135 |
|
136 |
if __name__ == "__main__":
|
|
|
137 |
demo.launch()
|