|
|
|
import os
|
|
import torch
|
|
import pickle
|
|
from model import GPTConfig, GPT
|
|
import tiktoken
|
|
from rich.traceback import install
|
|
|
|
install()
|
|
|
|
|
|
ckpt_path = 'out/ckpt.pt'
|
|
meta_path = 'data/mydata/meta.pkl'
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
tokenizer_name = 'cl100k_base'
|
|
max_new_tokens = 1024
|
|
temperature = 0.8
|
|
top_k = 100
|
|
special_tokens = {"<|endoftext|>", "<|im_start|>", "<|im_stop|>"}
|
|
|
|
|
|
enc = tiktoken.get_encoding(tokenizer_name)
|
|
encode = enc.encode
|
|
decode = enc.decode
|
|
|
|
|
|
with open(meta_path, 'rb') as f:
|
|
meta = pickle.load(f)
|
|
vocab_size = meta['vocab_size']
|
|
|
|
|
|
checkpoint = torch.load(ckpt_path, map_location=device)
|
|
model_args = checkpoint['model_args']
|
|
model_args['vocab_size'] = vocab_size
|
|
block_size = model_args.get('block_size', 1024)
|
|
|
|
|
|
model = GPT(GPTConfig(**model_args))
|
|
model.load_state_dict(checkpoint['model'])
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
@torch.no_grad()
|
|
def generate_stream(model, input_ids, max_new_tokens, temperature=1.0, top_k=None):
|
|
model.eval()
|
|
special_token_id = encode("<|endoftext|>", allowed_special=special_tokens)[0]
|
|
|
|
for _ in range(max_new_tokens):
|
|
if input_ids.size(1) > block_size:
|
|
input_ids = input_ids[:, -block_size:]
|
|
|
|
logits, _ = model(input_ids)
|
|
logits = logits[:, -1, :] / temperature
|
|
|
|
if top_k is not None:
|
|
v, _ = torch.topk(logits, top_k)
|
|
logits[logits < v[:, [-1]]] = -float('Inf')
|
|
|
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
next_token_id = next_token.item()
|
|
|
|
input_ids = torch.cat((input_ids, next_token), dim=1)
|
|
|
|
decoded_token = decode([next_token_id])
|
|
print(decoded_token, end='', flush=True) if decoded_token not in special_tokens else None
|
|
|
|
if next_token_id == special_token_id:
|
|
break
|
|
|
|
print()
|
|
return input_ids
|
|
|
|
def main():
|
|
print("π€ AI Assistant is ready. Type 'exit' or press Ctrl+C to quit.\n")
|
|
try:
|
|
while True:
|
|
user_input = input("You: ")
|
|
if user_input.lower() in {"exit", "quit"}:
|
|
print("π Exiting assistant.")
|
|
break
|
|
|
|
prompt = f"""
|
|
<|im_start|>user
|
|
{user_input}<|endoftext|>
|
|
<|im_stop|>
|
|
|
|
<|im_start|>assistant
|
|
|
|
"""
|
|
input_ids = torch.tensor(encode(prompt, allowed_special=special_tokens), dtype=torch.long, device=device)[None, ...]
|
|
|
|
print("π€ Assistant:", end=' ', flush=True)
|
|
generate_stream(model, input_ids, max_new_tokens, temperature, top_k)
|
|
print("-" * 50)
|
|
|
|
except KeyboardInterrupt:
|
|
print("\nπ Exiting assistant.")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|