peterweber's picture
Update app.py
faae576 verified
raw
history blame
4.8 kB
import os, re, difflib
from typing import List
import gradio as gr
from ctransformers import AutoModelForCausalLM
# ---------------- Model (GGUF on CPU) ----------------
# These defaults work on HF free CPU Spaces.
REPO_ID = os.getenv("LLAMA_GGUF_REPO", "bartowski/Llama-3.2-3B-Instruct-GGUF")
FILENAME = os.getenv("LLAMA_GGUF_FILE", "Llama-3.2-3B-Instruct-Q5_0.gguf") # if not found, use Q8_0
MODEL_TYPE = "llama"
# lazy-load for fast startup
_llm = None
def load_model():
global _llm
if _llm is None:
_llm = AutoModelForCausalLM.from_pretrained(
REPO_ID,
model_file=FILENAME,
model_type=MODEL_TYPE,
gpu_layers=0,
context_length=8192,
)
return _llm
# ---------------- Protect / restore ----------------
SENTINEL_OPEN, SENTINEL_CLOSE = "§§KEEP_OPEN§§", "§§KEEP_CLOSE§§"
URL_RE = re.compile(r'(https?://\S+)')
CODE_RE = re.compile(r'`{1,3}[\s\S]*?`{1,3}')
CITE_RE = re.compile(r'\[(?:[^\]]+?)\]|\(\d{4}\)|\[\d+(?:-\d+)?\]')
NUM_RE = re.compile(r'\b\d[\d,.\-/]*\b')
def protect(text: str):
protected = []
def wrap(m):
protected.append(m.group(0))
return f"{SENTINEL_OPEN}{len(protected)-1}{SENTINEL_CLOSE}"
text = CODE_RE.sub(wrap, text)
text = URL_RE.sub(wrap, text)
text = CITE_RE.sub(wrap, text)
text = NUM_RE.sub(wrap, text)
return text, protected
def restore(text: str, protected: List[str]):
def unwrap(m): return protected[int(m.group(1))]
text = re.sub(rf"{SENTINEL_OPEN}(\d+){SENTINEL_CLOSE}", unwrap, text)
return text.replace(SENTINEL_OPEN, "").replace(SENTINEL_CLOSE, "")
# ---------------- Prompting (Llama 3.x chat template) ----------------
SYSTEM = (
"You are an expert editor. Humanize the user's text: improve flow, vary sentence length, "
"split run-ons, replace stiff phrasing with natural alternatives, and preserve meaning. "
"Do NOT alter anything wrapped by §§KEEP_OPEN§§<id>§§KEEP_CLOSE§§ (citations, URLs, numbers, code). "
"Keep the requested tone and region. No em dashes—use simple punctuation."
)
def build_prompt(text: str, tone: str, region: str, level: str, intensity: int) -> str:
user = (
f"Tone: {tone}. Region: {region} English. Reading level: {level}. "
f"Humanization intensity: {intensity} (10 strongest).\n\n"
f"Rewrite this text. Keep markers intact:\n\n{text}"
)
# Llama 3.x chat format
return (
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
f"{SYSTEM}\n"
"<|eot_id|><|start_header_id|>user<|end_header_id|>\n"
f"{user}\n"
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
)
def diff_ratio(a: str, b: str) -> float:
return difflib.SequenceMatcher(None, a, b).ratio()
def generate_once(prompt: str, temperature: float, max_new: int = 768) -> str:
llm = load_model()
out = llm(
prompt,
temperature=temperature,
top_p=0.95,
max_new_tokens=max_new,
stop=["<|eot_id|>"]
)
return out.strip()
# ---------------- Main humanizer ----------------
def humanize_core(text: str, tone: str, region: str, level: str, intensity: int):
protected_text, bag = protect(text)
prompt = build_prompt(protected_text, tone, region, level, intensity)
# pass 1 (conservative), pass 2 (stronger) if too similar
draft = generate_once(prompt, temperature=0.35)
if diff_ratio(protected_text, draft) > 0.97:
draft = generate_once(prompt, temperature=0.9)
draft = draft.replace("—", "-")
final = restore(draft, bag)
# ensure all protected spans survived
for i, span in enumerate(bag):
marker = f"{SENTINEL_OPEN}{i}{SENTINEL_CLOSE}"
if marker in protected_text and span not in final:
final = final.replace(marker, span)
return final
# ---------------- Gradio UI (and REST at /api/predict/) ----------------
def ui_humanize(text, tone, region, level, intensity):
return humanize_core(text, tone, region, level, int(intensity))
demo = gr.Interface(
fn=ui_humanize,
inputs=[
gr.Textbox(lines=12, label="Input text"),
gr.Dropdown(["professional","casual","academic","friendly","persuasive"], value="professional", label="Tone"),
gr.Dropdown(["US","UK","KE"], value="US", label="Region"),
gr.Dropdown(["general","simple","advanced"], value="general", label="Reading level"),
gr.Slider(1, 10, value=6, step=1, label="Humanization intensity"),
],
outputs=gr.Textbox(label="Humanized"),
title="NoteCraft Humanizer (Llama-3.2-3B-Instruct)",
description="REST: POST /api/predict/ with { data: [text,tone,region,level,intensity] }",
).queue()
if __name__ == "__main__":
demo.launch()