""" Example script for running inference with the Rose Beeper model. """ import torch from tokenizers import Tokenizer from huggingface_hub import hf_hub_download import os # Import the inference components (from the previous artifact) from beeper_inference import ( BeeperRoseGPT, BeeperIO, generate, get_default_config ) class BeeperInference: """Wrapper class for easy inference with the Rose Beeper model.""" def __init__(self, checkpoint_path: str = None, tokenizer_path: str = "beeper.tokenizer.json", device: str = None, hf_repo: str = "AbstractPhil/beeper-rose-v5"): """ Initialize the Beeper model for inference. Args: checkpoint_path: Path to local checkpoint file (.pt or .safetensors) tokenizer_path: Path to tokenizer file device: Device to run on ('cuda', 'cpu', or None for auto) hf_repo: HuggingFace repository to download from if no local checkpoint """ # Set device if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = torch.device(device) print(f"Using device: {self.device}") # Load configuration self.config = get_default_config() # Initialize model self.model = BeeperRoseGPT(self.config).to(self.device) # Initialize pentachora banks cap_cfg = self.config.get("capoera", {}) # Using default sizes since we don't have the exact corpus info at inference self.model.ensure_pentachora( coarse_C=20, # Approximate number of datasets medium_C=int(cap_cfg.get("topic_bins", 512)), fine_C=int(cap_cfg.get("mood_bins", 7)), dim=self.config["dim"], device=self.device ) # Load weights self._load_weights(checkpoint_path, hf_repo) # Load tokenizer self._load_tokenizer(tokenizer_path, hf_repo) # Set to eval mode self.model.eval() def _load_weights(self, checkpoint_path: str, hf_repo: str): """Load model weights from local file or HuggingFace.""" loaded = False # Try local checkpoint first if checkpoint_path and os.path.exists(checkpoint_path): print(f"Loading weights from: {checkpoint_path}") missing, unexpected = BeeperIO.load_into_model( self.model, checkpoint_path, map_location=str(self.device), strict=False ) print(f"Loaded | missing={len(missing)} unexpected={len(unexpected)}") loaded = True # Try HuggingFace if no local checkpoint if not loaded and hf_repo: try: print(f"Downloading weights from HuggingFace: {hf_repo}") path = hf_hub_download(repo_id=hf_repo, filename="beeper_final.safetensors") missing, unexpected = BeeperIO.load_into_model( self.model, path, map_location=str(self.device), strict=False ) print(f"Loaded | missing={len(missing)} unexpected={len(unexpected)}") loaded = True except Exception as e: print(f"Failed to download from HuggingFace: {e}") if not loaded: print("WARNING: No weights loaded, using random initialization!") def _load_tokenizer(self, tokenizer_path: str, hf_repo: str): """Load tokenizer from local file or HuggingFace.""" if os.path.exists(tokenizer_path): print(f"Loading tokenizer from: {tokenizer_path}") self.tokenizer = Tokenizer.from_file(tokenizer_path) else: try: print(f"Downloading tokenizer from HuggingFace: {hf_repo}") path = hf_hub_download(repo_id=hf_repo, filename="tokenizer.json") self.tokenizer = Tokenizer.from_file(path) except Exception as e: raise RuntimeError(f"Failed to load tokenizer: {e}") def generate_text(self, prompt: str, max_new_tokens: int = 120, temperature: float = 0.9, top_k: int = 40, top_p: float = 0.9, repetition_penalty: float = 1.1, presence_penalty: float = 0.6, frequency_penalty: float = 0.0) -> str: """ Generate text from a prompt. Args: prompt: Input text to continue from max_new_tokens: Maximum tokens to generate temperature: Sampling temperature (0.1-2.0 typical) top_k: Top-k sampling (0 to disable) top_p: Nucleus sampling threshold (0.0-1.0) repetition_penalty: Penalty for repeated tokens presence_penalty: Penalty for tokens that have appeared frequency_penalty: Penalty based on token frequency Returns: Generated text string """ return generate( model=self.model, tok=self.tokenizer, cfg=self.config, prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, device=self.device, detokenize=True ) def batch_generate(self, prompts: list, **kwargs) -> list: """Generate text for multiple prompts.""" results = [] for prompt in prompts: results.append(self.generate_text(prompt, **kwargs)) return results def main(): """Example usage of the Beeper inference class.""" # Initialize the model print("Initializing Rose Beeper model...") beeper = BeeperInference( checkpoint_path=None, # Will download from HF device=None # Auto-select GPU if available ) # Example prompts prompts = [ "The robot went to school and", "Once upon a time in a distant galaxy,", "The meaning of life is", "In the beginning, there was", "The scientist discovered that", ] print("\n" + "="*60) print("GENERATING SAMPLES") print("="*60 + "\n") for prompt in prompts: print(f"Prompt: {prompt}") print("-" * 40) # Generate with different settings # Standard generation output = beeper.generate_text( prompt=prompt, max_new_tokens=100, temperature=0.9, top_k=40, top_p=0.9 ) print(f"Output: {output}") print() # More creative generation creative_output = beeper.generate_text( prompt=prompt, max_new_tokens=50, temperature=1.2, top_k=50, top_p=0.95, repetition_penalty=1.2 ) print(f"Creative: {creative_output}") print("\n" + "="*60 + "\n") if __name__ == "__main__": main()