import gradio as gr import spaces import torch import difflib from threading import Thread from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, TextIteratorStreamer model_id = "textcleanlm/textcleanlm-1-4b" model = None tokenizer = None def load_model(): global model, tokenizer if model is None: tokenizer = AutoTokenizer.from_pretrained(model_id) # Add padding token if needed if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Try different model classes for model_class in [AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel]: try: model = model_class.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto" ) break except: continue if model is None: raise ValueError(f"Could not load model {model_id}") return model, tokenizer def create_diff_html(original, cleaned): """Create HTML diff visualization""" original_lines = original.splitlines(keepends=True) cleaned_lines = cleaned.splitlines(keepends=True) differ = difflib.unified_diff(original_lines, cleaned_lines, fromfile='Original', tofile='Cleaned', lineterm='') html_diff = '
' for line in differ: if line.startswith('+++') or line.startswith('---'): html_diff += f'
{line}
' elif line.startswith('@@'): html_diff += f'
{line}
' elif line.startswith('+'): html_diff += f'
{line}
' elif line.startswith('-'): html_diff += f'
{line}
' else: html_diff += f'
{line}
' html_diff += '
' return html_diff @spaces.GPU(duration=60) def clean_text(text): model, tokenizer = load_model() # Apply chat template messages = [ {"role": "user", "content": text} ] # Apply the chat template formatted_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(formatted_text, return_tensors="pt", max_length=4096, truncation=True) inputs = {k: v.cuda() for k, v in inputs.items()} # Enable streaming streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) generation_kwargs = dict( **inputs, max_length=4096, num_beams=1, # Set to 1 for streaming do_sample=True, temperature=1.0, streamer=streamer, ) # Run generation in a separate thread thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Collect generated text, skipping the input generated_text = "" input_length = len(formatted_text) full_output = "" for new_text in streamer: full_output += new_text # Only yield the part after the input if len(full_output) > input_length: generated_text = full_output[input_length:].strip() yield generated_text, "" thread.join() # After generation is complete, create diff diff_html = create_diff_html(text, generated_text) yield generated_text, diff_html # Create the interface with blocks for better control with gr.Blocks(title="TextClean-4B Demo") as demo: gr.Markdown("# TextClean-4B Demo") gr.Markdown("Simple demo for text cleaning using textcleanlm/textclean-4B model") with gr.Row(): with gr.Column(): input_text = gr.Textbox( lines=5, placeholder="Enter text to clean...", label="Input Text" ) submit_btn = gr.Button("Clean Text", variant="primary") with gr.Row(): output_text = gr.Textbox( lines=5, label="Cleaned Text", interactive=False ) with gr.Row(): diff_display = gr.HTML(label="Diff View") submit_btn.click( fn=clean_text, inputs=input_text, outputs=[output_text, diff_display] ) if __name__ == "__main__": demo.launch()