import gradio as gr
import spaces
import torch
import difflib
from threading import Thread
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, TextIteratorStreamer
model_id = "textcleanlm/textcleanlm-1-8b"
model = None
tokenizer = None
def load_model():
global model, tokenizer
if model is None:
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Add padding token if needed
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Try different model classes
for model_class in [AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel]:
try:
model = model_class.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto"
)
break
except:
continue
if model is None:
raise ValueError(f"Could not load model {model_id}")
return model, tokenizer
def create_diff_html(original, cleaned):
"""Create HTML diff visualization"""
original_lines = original.splitlines(keepends=True)
cleaned_lines = cleaned.splitlines(keepends=True)
differ = difflib.unified_diff(original_lines, cleaned_lines, fromfile='Original', tofile='Cleaned', lineterm='')
html_diff = '
'
for line in differ:
if line.startswith('+++') or line.startswith('---'):
html_diff += f'
{line}
'
elif line.startswith('@@'):
html_diff += f'
{line}
'
elif line.startswith('+'):
html_diff += f'
{line}
'
elif line.startswith('-'):
html_diff += f'
{line}
'
else:
html_diff += f'
{line}
'
html_diff += '
'
return html_diff
@spaces.GPU(duration=60)
def clean_text(text):
model, tokenizer = load_model()
# Apply chat template
messages = [
{"role": "user", "content": text}
]
# Apply the chat template
formatted_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(formatted_text, return_tensors="pt", max_length=4096, truncation=True)
inputs = {k: v.cuda() for k, v in inputs.items()}
# Enable streaming
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
generation_kwargs = dict(
**inputs,
max_length=4096,
num_beams=1, # Set to 1 for streaming
do_sample=True,
temperature=1.0,
streamer=streamer,
)
# Run generation in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Collect generated text, skipping the input
generated_text = ""
input_length = len(formatted_text)
full_output = ""
for new_text in streamer:
full_output += new_text
# Only yield the part after the input
if len(full_output) > input_length:
generated_text = full_output[input_length:].strip()
yield generated_text, ""
thread.join()
# After generation is complete, create diff
diff_html = create_diff_html(text, generated_text)
yield generated_text, diff_html
# Create the interface with blocks for better control
with gr.Blocks(title="TextClean-4B Demo") as demo:
gr.Markdown("# TextClean-4B Demo")
gr.Markdown("Simple demo for text cleaning using textcleanlm/textclean-4B model")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
lines=5,
placeholder="Enter text to clean...",
label="Input Text"
)
submit_btn = gr.Button("Clean Text", variant="primary")
with gr.Row():
output_text = gr.Textbox(
lines=5,
label="Cleaned Text",
interactive=False
)
with gr.Row():
diff_display = gr.HTML(label="Diff View")
submit_btn.click(
fn=clean_text,
inputs=input_text,
outputs=[output_text, diff_display]
)
if __name__ == "__main__":
demo.launch()