File size: 3,028 Bytes
336661a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#!/usr/bin/env python3
import os
import torch
import pickle
from model import GPTConfig, GPT
import tiktoken
from rich.traceback import install

install()

# ----- CONFIG -----
ckpt_path      = 'out/ckpt.pt'
meta_path      = 'data/mydata/meta.pkl'
device         = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer_name = 'cl100k_base'
max_new_tokens = 1024
temperature    = 0.8
top_k          = 100
special_tokens = {"<|endoftext|>", "<|im_start|>", "<|im_stop|>"}

# ----- LOAD TOKENIZER -----
enc = tiktoken.get_encoding(tokenizer_name)
encode = enc.encode
decode = enc.decode

# ----- LOAD METADATA -----
with open(meta_path, 'rb') as f:
    meta = pickle.load(f)
vocab_size = meta['vocab_size']

# ----- LOAD CHECKPOINT -----
checkpoint = torch.load(ckpt_path, map_location=device)
model_args = checkpoint['model_args']
model_args['vocab_size'] = vocab_size
block_size = model_args.get('block_size', 1024)

# ----- INITIALIZE MODEL -----
model = GPT(GPTConfig(**model_args))
model.load_state_dict(checkpoint['model'])
model.to(device)
model.eval()

@torch.no_grad()
def generate_stream(model, input_ids, max_new_tokens, temperature=1.0, top_k=None):
    model.eval()
    special_token_id = encode("<|endoftext|>", allowed_special=special_tokens)[0]

    for _ in range(max_new_tokens):
        if input_ids.size(1) > block_size:
            input_ids = input_ids[:, -block_size:]

        logits, _ = model(input_ids)
        logits = logits[:, -1, :] / temperature

        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float('Inf')

        probs = torch.nn.functional.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        next_token_id = next_token.item()

        input_ids = torch.cat((input_ids, next_token), dim=1)

        decoded_token = decode([next_token_id])
        print(decoded_token, end='', flush=True) if decoded_token not in special_tokens else None

        if next_token_id == special_token_id:
            break

    print()  # Ensure newline after generation
    return input_ids

def main():
    print("๐Ÿค– AI Assistant is ready. Type 'exit' or press Ctrl+C to quit.\n")
    try:
        while True:
            user_input = input("You: ")
            if user_input.lower() in {"exit", "quit"}:
                print("๐Ÿ‘‹ Exiting assistant.")
                break

            prompt = f"""

<|im_start|>user

{user_input}<|endoftext|>

<|im_stop|>



<|im_start|>assistant



"""
            input_ids = torch.tensor(encode(prompt, allowed_special=special_tokens), dtype=torch.long, device=device)[None, ...]

            print("๐Ÿค– Assistant:", end=' ', flush=True)
            generate_stream(model, input_ids, max_new_tokens, temperature, top_k)
            print("-" * 50)

    except KeyboardInterrupt:
        print("\n๐Ÿ‘‹ Exiting assistant.")

if __name__ == "__main__":
    main()