# generate.py import torch from transformers import BertTokenizer from evo_model import EvoDecoderModel device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load tokenizer tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") # Initialize model architecture vocab_size = tokenizer.vocab_size model = EvoDecoderModel(vocab_size=vocab_size) model.load_state_dict(torch.load("evo_decoder_model.pt", map_location=device)) model.to(device) model.eval() def generate_response(prompt, max_length=128, use_web=False): with torch.no_grad(): input_ids = tokenizer(prompt, return_tensors="pt", padding=False, truncation=True).input_ids.to(device) input_ids = input_ids[:, :128] # ✅ clip to trained length logits = model(input_ids) next_token_logits = logits[:, -1, :] # take last token's logits predicted_id = torch.argmax(next_token_logits, dim=-1) output_ids = torch.cat([input_ids, predicted_id.unsqueeze(0)], dim=1) decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True) return decoded