4b-demo / app.py
sumuks's picture
sumuks HF Staff
Update app.py
c6de1b4 verified
import gradio as gr
import spaces
import torch
import difflib
from threading import Thread
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, TextIteratorStreamer
model_id = "textcleanlm/textcleanlm-1-4b"
model = None
tokenizer = None
def load_model():
global model, tokenizer
if model is None:
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Add padding token if needed
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Try different model classes
for model_class in [AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel]:
try:
model = model_class.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto"
)
break
except:
continue
if model is None:
raise ValueError(f"Could not load model {model_id}")
return model, tokenizer
def create_diff_html(original, cleaned):
"""Create HTML diff visualization"""
original_lines = original.splitlines(keepends=True)
cleaned_lines = cleaned.splitlines(keepends=True)
differ = difflib.unified_diff(original_lines, cleaned_lines, fromfile='Original', tofile='Cleaned', lineterm='')
html_diff = '<div style="font-family: monospace; font-size: 12px; white-space: pre-wrap;">'
for line in differ:
if line.startswith('+++') or line.startswith('---'):
html_diff += f'<div style="color: #666;">{line}</div>'
elif line.startswith('@@'):
html_diff += f'<div style="color: #0066cc; font-weight: bold;">{line}</div>'
elif line.startswith('+'):
html_diff += f'<div style="background-color: #e6ffed; color: #24292e;">{line}</div>'
elif line.startswith('-'):
html_diff += f'<div style="background-color: #ffeef0; color: #24292e;">{line}</div>'
else:
html_diff += f'<div>{line}</div>'
html_diff += '</div>'
return html_diff
@spaces.GPU(duration=60)
def clean_text(text):
model, tokenizer = load_model()
# Apply chat template
messages = [
{"role": "user", "content": text}
]
# Apply the chat template
formatted_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(formatted_text, return_tensors="pt", max_length=4096, truncation=True)
inputs = {k: v.cuda() for k, v in inputs.items()}
# Enable streaming
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
max_length=4096,
num_beams=1, # Set to 1 for streaming
do_sample=True,
temperature=1.0,
streamer=streamer,
)
# Run generation in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Collect generated text, skipping the input
generated_text = ""
input_length = len(formatted_text)
full_output = ""
for new_text in streamer:
full_output += new_text
# Only yield the part after the input
if len(full_output) > input_length:
generated_text = full_output[input_length:].strip()
yield generated_text, ""
thread.join()
# After generation is complete, create diff
diff_html = create_diff_html(text, generated_text)
yield generated_text, diff_html
# Create the interface with blocks for better control
with gr.Blocks(title="TextClean-4B Demo") as demo:
gr.Markdown("# TextClean-4B Demo")
gr.Markdown("Simple demo for text cleaning using textcleanlm/textclean-4B model")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
lines=5,
placeholder="Enter text to clean...",
label="Input Text"
)
submit_btn = gr.Button("Clean Text", variant="primary")
with gr.Row():
output_text = gr.Textbox(
lines=5,
label="Cleaned Text",
interactive=False
)
with gr.Row():
diff_display = gr.HTML(label="Diff View")
submit_btn.click(
fn=clean_text,
inputs=input_text,
outputs=[output_text, diff_display]
)
if __name__ == "__main__":
demo.launch()