#!/usr/bin/env python3 import os import torch import pickle from model import GPTConfig, GPT import tiktoken from rich.traceback import install install() # ----- CONFIG ----- ckpt_path = 'out/ckpt.pt' meta_path = 'data/mydata/meta.pkl' device = 'cuda' if torch.cuda.is_available() else 'cpu' tokenizer_name = 'cl100k_base' max_new_tokens = 1024 temperature = 0.8 top_k = 100 special_tokens = {"<|endoftext|>", "<|im_start|>", "<|im_stop|>"} # ----- LOAD TOKENIZER ----- enc = tiktoken.get_encoding(tokenizer_name) encode = enc.encode decode = enc.decode # ----- LOAD METADATA ----- with open(meta_path, 'rb') as f: meta = pickle.load(f) vocab_size = meta['vocab_size'] # ----- LOAD CHECKPOINT ----- checkpoint = torch.load(ckpt_path, map_location=device) model_args = checkpoint['model_args'] model_args['vocab_size'] = vocab_size block_size = model_args.get('block_size', 1024) # ----- INITIALIZE MODEL ----- model = GPT(GPTConfig(**model_args)) model.load_state_dict(checkpoint['model']) model.to(device) model.eval() @torch.no_grad() def generate_stream(model, input_ids, max_new_tokens, temperature=1.0, top_k=None): model.eval() special_token_id = encode("<|endoftext|>", allowed_special=special_tokens)[0] for _ in range(max_new_tokens): if input_ids.size(1) > block_size: input_ids = input_ids[:, -block_size:] logits, _ = model(input_ids) logits = logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(logits, top_k) logits[logits < v[:, [-1]]] = -float('Inf') probs = torch.nn.functional.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) next_token_id = next_token.item() input_ids = torch.cat((input_ids, next_token), dim=1) decoded_token = decode([next_token_id]) print(decoded_token, end='', flush=True) if decoded_token not in special_tokens else None if next_token_id == special_token_id: break print() # Ensure newline after generation return input_ids def main(): print("šŸ¤– AI Assistant is ready. Type 'exit' or press Ctrl+C to quit.\n") try: while True: user_input = input("You: ") if user_input.lower() in {"exit", "quit"}: print("šŸ‘‹ Exiting assistant.") break prompt = f""" <|im_start|>user {user_input}<|endoftext|> <|im_stop|> <|im_start|>assistant """ input_ids = torch.tensor(encode(prompt, allowed_special=special_tokens), dtype=torch.long, device=device)[None, ...] print("šŸ¤– Assistant:", end=' ', flush=True) generate_stream(model, input_ids, max_new_tokens, temperature, top_k) print("-" * 50) except KeyboardInterrupt: print("\nšŸ‘‹ Exiting assistant.") if __name__ == "__main__": main()