ShaNet / chat.py
umm-dev's picture
Upload 8 files (#1)
336661a verified
#!/usr/bin/env python3
import os
import torch
import pickle
from model import GPTConfig, GPT
import tiktoken
from rich.traceback import install
install()
# ----- CONFIG -----
ckpt_path = 'out/ckpt.pt'
meta_path = 'data/mydata/meta.pkl'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer_name = 'cl100k_base'
max_new_tokens = 1024
temperature = 0.8
top_k = 100
special_tokens = {"<|endoftext|>", "<|im_start|>", "<|im_stop|>"}
# ----- LOAD TOKENIZER -----
enc = tiktoken.get_encoding(tokenizer_name)
encode = enc.encode
decode = enc.decode
# ----- LOAD METADATA -----
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
vocab_size = meta['vocab_size']
# ----- LOAD CHECKPOINT -----
checkpoint = torch.load(ckpt_path, map_location=device)
model_args = checkpoint['model_args']
model_args['vocab_size'] = vocab_size
block_size = model_args.get('block_size', 1024)
# ----- INITIALIZE MODEL -----
model = GPT(GPTConfig(**model_args))
model.load_state_dict(checkpoint['model'])
model.to(device)
model.eval()
@torch.no_grad()
def generate_stream(model, input_ids, max_new_tokens, temperature=1.0, top_k=None):
model.eval()
special_token_id = encode("<|endoftext|>", allowed_special=special_tokens)[0]
for _ in range(max_new_tokens):
if input_ids.size(1) > block_size:
input_ids = input_ids[:, -block_size:]
logits, _ = model(input_ids)
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('Inf')
probs = torch.nn.functional.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token_id = next_token.item()
input_ids = torch.cat((input_ids, next_token), dim=1)
decoded_token = decode([next_token_id])
print(decoded_token, end='', flush=True) if decoded_token not in special_tokens else None
if next_token_id == special_token_id:
break
print() # Ensure newline after generation
return input_ids
def main():
print("πŸ€– AI Assistant is ready. Type 'exit' or press Ctrl+C to quit.\n")
try:
while True:
user_input = input("You: ")
if user_input.lower() in {"exit", "quit"}:
print("πŸ‘‹ Exiting assistant.")
break
prompt = f"""
<|im_start|>user
{user_input}<|endoftext|>
<|im_stop|>
<|im_start|>assistant
"""
input_ids = torch.tensor(encode(prompt, allowed_special=special_tokens), dtype=torch.long, device=device)[None, ...]
print("πŸ€– Assistant:", end=' ', flush=True)
generate_stream(model, input_ids, max_new_tokens, temperature, top_k)
print("-" * 50)
except KeyboardInterrupt:
print("\nπŸ‘‹ Exiting assistant.")
if __name__ == "__main__":
main()