File size: 2,338 Bytes
0804cf7
 
 
0440349
 
0804cf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0440349
0804cf7
 
0440349
 
0804cf7
0440349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0804cf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
import spaces
import torch
from threading import Thread
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, TextIteratorStreamer

model_id = "textcleanlm/textclean-4B"
model = None
tokenizer = None

def load_model():
    global model, tokenizer
    if model is None:
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        
        # Add padding token if needed
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        # Try different model classes
        for model_class in [AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel]:
            try:
                model = model_class.from_pretrained(
                    model_id,
                    torch_dtype=torch.bfloat16,
                    device_map="auto"
                )
                break
            except:
                continue
                
        if model is None:
            raise ValueError(f"Could not load model {model_id}")
            
    return model, tokenizer

@spaces.GPU(duration=60)
def clean_text(text):
    model, tokenizer = load_model()
    
    inputs = tokenizer(text, return_tensors="pt", max_length=4096, truncation=True)
    inputs = {k: v.cuda() for k, v in inputs.items()}
    
    # Enable streaming
    streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
    
    generation_kwargs = dict(
        **inputs,
        max_length=4096,
        num_beams=1,  # Set to 1 for streaming
        do_sample=True,
        temperature=1.0,
        streamer=streamer,
    )
    
    # Run generation in a separate thread
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    # Yield text as it's generated
    generated_text = ""
    for new_text in streamer:
        generated_text += new_text
        yield generated_text
    
    thread.join()

iface = gr.Interface(
    fn=clean_text,
    inputs=gr.Textbox(
        lines=5,
        placeholder="Enter text to clean...",
        label="Input Text"
    ),
    outputs=gr.Textbox(
        lines=5,
        label="Cleaned Text"
    ),
    title="TextClean-4B Demo",
    description="Simple demo for text cleaning using textcleanlm/textclean-4B model"
)

if __name__ == "__main__":
    iface.launch()