File size: 4,569 Bytes
0804cf7
 
 
67a2d65
0440349
 
0804cf7
b3935eb
0804cf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67a2d65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0804cf7
 
 
 
1d375c6
 
 
 
 
 
 
 
 
0804cf7
 
0440349
 
0804cf7
0440349
 
 
 
 
 
 
 
 
 
 
 
 
67a2d65
0440349
67a2d65
 
 
0440349
67a2d65
 
 
 
 
0440349
 
67a2d65
 
 
 
0804cf7
67a2d65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0804cf7
 
67a2d65
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
import spaces
import torch
import difflib
from threading import Thread
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, TextIteratorStreamer

model_id = "textcleanlm/textcleanlm-1-8b"
model = None
tokenizer = None

def load_model():
    global model, tokenizer
    if model is None:
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        
        # Add padding token if needed
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        # Try different model classes
        for model_class in [AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel]:
            try:
                model = model_class.from_pretrained(
                    model_id,
                    torch_dtype=torch.bfloat16,
                    device_map="auto"
                )
                break
            except:
                continue
                
        if model is None:
            raise ValueError(f"Could not load model {model_id}")
            
    return model, tokenizer

def create_diff_html(original, cleaned):
    """Create HTML diff visualization"""
    original_lines = original.splitlines(keepends=True)
    cleaned_lines = cleaned.splitlines(keepends=True)
    
    differ = difflib.unified_diff(original_lines, cleaned_lines, fromfile='Original', tofile='Cleaned', lineterm='')
    
    html_diff = '<div style="font-family: monospace; font-size: 12px; white-space: pre-wrap;">'
    
    for line in differ:
        if line.startswith('+++') or line.startswith('---'):
            html_diff += f'<div style="color: #666;">{line}</div>'
        elif line.startswith('@@'):
            html_diff += f'<div style="color: #0066cc; font-weight: bold;">{line}</div>'
        elif line.startswith('+'):
            html_diff += f'<div style="background-color: #e6ffed; color: #24292e;">{line}</div>'
        elif line.startswith('-'):
            html_diff += f'<div style="background-color: #ffeef0; color: #24292e;">{line}</div>'
        else:
            html_diff += f'<div>{line}</div>'
    
    html_diff += '</div>'
    return html_diff

@spaces.GPU(duration=60)
def clean_text(text):
    model, tokenizer = load_model()
    
    # Apply chat template
    messages = [
        {"role": "user", "content": text}
    ]
    
    # Apply the chat template
    formatted_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    
    inputs = tokenizer(formatted_text, return_tensors="pt", max_length=4096, truncation=True)
    inputs = {k: v.cuda() for k, v in inputs.items()}
    
    # Enable streaming
    streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
    
    generation_kwargs = dict(
        **inputs,
        max_length=4096,
        num_beams=1,  # Set to 1 for streaming
        do_sample=True,
        temperature=1.0,
        streamer=streamer,
    )
    
    # Run generation in a separate thread
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    # Collect generated text, skipping the input
    generated_text = ""
    input_length = len(formatted_text)
    full_output = ""
    
    for new_text in streamer:
        full_output += new_text
        # Only yield the part after the input
        if len(full_output) > input_length:
            generated_text = full_output[input_length:].strip()
            yield generated_text, ""
    
    thread.join()
    
    # After generation is complete, create diff
    diff_html = create_diff_html(text, generated_text)
    yield generated_text, diff_html

# Create the interface with blocks for better control
with gr.Blocks(title="TextClean-4B Demo") as demo:
    gr.Markdown("# TextClean-4B Demo")
    gr.Markdown("Simple demo for text cleaning using textcleanlm/textclean-4B model")
    
    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(
                lines=5,
                placeholder="Enter text to clean...",
                label="Input Text"
            )
            submit_btn = gr.Button("Clean Text", variant="primary")
    
    with gr.Row():
        output_text = gr.Textbox(
            lines=5,
            label="Cleaned Text",
            interactive=False
        )
    
    with gr.Row():
        diff_display = gr.HTML(label="Diff View")
    
    submit_btn.click(
        fn=clean_text,
        inputs=input_text,
        outputs=[output_text, diff_display]
    )

if __name__ == "__main__":
    demo.launch()