File size: 3,894 Bytes
fea4095
 
25c11ba
fea4095
 
 
25c11ba
fea4095
 
25c11ba
7276d4c
fea4095
 
25c11ba
fea4095
bf2292c
25c11ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fea4095
25c11ba
fea4095
 
 
 
fee88b4
25c11ba
 
 
 
 
 
 
 
 
 
fea4095
 
 
25c11ba
 
 
 
 
 
 
 
 
 
fea4095
 
25c11ba
 
 
fee88b4
fea4095
cddc4c2
fea4095
 
cddc4c2
 
25c11ba
cddc4c2
 
25c11ba
 
 
 
cddc4c2
25c11ba
 
 
 
 
 
cddc4c2
25c11ba
 
 
 
 
cddc4c2
 
 
25c11ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer

# Cache for model and tokenizer
MODEL = None
TOKENIZER = None

def initialize():
    global MODEL, TOKENIZER
    
    if MODEL is None:
        print("Loading model and tokenizer...")
        model_id = "jatingocodeo/SmolLM2"
        
        try:
            # Load tokenizer
            print("\n1. Loading tokenizer...")
            TOKENIZER = AutoTokenizer.from_pretrained(model_id)
            print("✓ Tokenizer loaded successfully")
            
            # Add special tokens if needed
            special_tokens = {
                'pad_token': '[PAD]',
                'eos_token': '</s>',
                'bos_token': '<s>'
            }
            num_added = TOKENIZER.add_special_tokens(special_tokens)
            print(f"✓ Added {num_added} special tokens")
            
            # Load model
            print("\n2. Loading model...")
            MODEL = AutoModelForCausalLM.from_pretrained(
                model_id,
                trust_remote_code=True,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                low_cpu_mem_usage=True
            )
            
            # Move model to appropriate device
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            MODEL = MODEL.to(device)
            print(f"✓ Model loaded successfully and moved to {device}")
            
        except Exception as e:
            print(f"Error initializing model: {str(e)}")
            raise

def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
    # Initialize if not already done
    if MODEL is None:
        initialize()
    
    try:
        # Process prompt
        if not prompt.strip():
            return "Please enter a prompt."
        
        if not prompt.startswith(TOKENIZER.bos_token):
            prompt = TOKENIZER.bos_token + prompt
        
        # Encode prompt
        input_ids = TOKENIZER.encode(prompt, return_tensors="pt", truncation=True, max_length=2048)
        input_ids = input_ids.to(MODEL.device)
        
        # Generate
        with torch.no_grad():
            output_ids = MODEL.generate(
                input_ids,
                max_length=min(max_length + len(input_ids[0]), 2048),
                temperature=temperature,
                top_k=top_k,
                do_sample=True,
                pad_token_id=TOKENIZER.pad_token_id,
                eos_token_id=TOKENIZER.eos_token_id,
                num_return_sequences=1
            )
        
        # Decode and return
        generated_text = TOKENIZER.decode(output_ids[0], skip_special_tokens=True)
        return generated_text.strip()
        
    except Exception as e:
        return f"Error generating text: {str(e)}"

# Initialize on startup
initialize()

# Create Gradio interface
iface = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=2),
        gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top K"),
    ],
    outputs=gr.Textbox(label="Generated Text", lines=5),
    title="SmolLM2 Text Generator",
    description="""Generate text using the fine-tuned SmolLM2 model.
    - Max Length: Controls the length of generated text
    - Temperature: Controls randomness (higher = more creative)
    - Top K: Controls diversity of word choices""",
    examples=[
        ["Once upon a time", 100, 0.7, 50],
        ["The quick brown fox", 150, 0.8, 40],
        ["In a galaxy far far away", 200, 0.9, 30],
    ],
    allow_flagging="never"
)

if __name__ == "__main__":
    iface.launch()