meet-beeper / load_for_inference.py
AbstractPhil's picture
Update load_for_inference.py
6a080c2 verified
"""
Example script for running inference with the Rose Beeper model.
"""
import torch
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
import os
# Import the inference components (from the previous artifact)
from beeper_inference import (
BeeperRoseGPT,
BeeperIO,
generate,
get_default_config
)
class BeeperInference:
"""Wrapper class for easy inference with the Rose Beeper model."""
def __init__(self,
checkpoint_path: str = None,
tokenizer_path: str = "beeper.tokenizer.json",
device: str = None,
hf_repo: str = "AbstractPhil/beeper-rose-v5"):
"""
Initialize the Beeper model for inference.
Args:
checkpoint_path: Path to local checkpoint file (.pt or .safetensors)
tokenizer_path: Path to tokenizer file
device: Device to run on ('cuda', 'cpu', or None for auto)
hf_repo: HuggingFace repository to download from if no local checkpoint
"""
# Set device
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
print(f"Using device: {self.device}")
# Load configuration
self.config = get_default_config()
# Initialize model
self.model = BeeperRoseGPT(self.config).to(self.device)
# Initialize pentachora banks
cap_cfg = self.config.get("capoera", {})
# Using default sizes since we don't have the exact corpus info at inference
self.model.ensure_pentachora(
coarse_C=20, # Approximate number of datasets
medium_C=int(cap_cfg.get("topic_bins", 512)),
fine_C=int(cap_cfg.get("mood_bins", 7)),
dim=self.config["dim"],
device=self.device
)
# Load weights
self._load_weights(checkpoint_path, hf_repo)
# Load tokenizer
self._load_tokenizer(tokenizer_path, hf_repo)
# Set to eval mode
self.model.eval()
def _load_weights(self, checkpoint_path: str, hf_repo: str):
"""Load model weights from local file or HuggingFace."""
loaded = False
# Try local checkpoint first
if checkpoint_path and os.path.exists(checkpoint_path):
print(f"Loading weights from: {checkpoint_path}")
missing, unexpected = BeeperIO.load_into_model(
self.model, checkpoint_path, map_location=str(self.device), strict=False
)
print(f"Loaded | missing={len(missing)} unexpected={len(unexpected)}")
loaded = True
# Try HuggingFace if no local checkpoint
if not loaded and hf_repo:
try:
print(f"Downloading weights from HuggingFace: {hf_repo}")
path = hf_hub_download(repo_id=hf_repo, filename="beeper_final.safetensors")
missing, unexpected = BeeperIO.load_into_model(
self.model, path, map_location=str(self.device), strict=False
)
print(f"Loaded | missing={len(missing)} unexpected={len(unexpected)}")
loaded = True
except Exception as e:
print(f"Failed to download from HuggingFace: {e}")
if not loaded:
print("WARNING: No weights loaded, using random initialization!")
def _load_tokenizer(self, tokenizer_path: str, hf_repo: str):
"""Load tokenizer from local file or HuggingFace."""
if os.path.exists(tokenizer_path):
print(f"Loading tokenizer from: {tokenizer_path}")
self.tokenizer = Tokenizer.from_file(tokenizer_path)
else:
try:
print(f"Downloading tokenizer from HuggingFace: {hf_repo}")
path = hf_hub_download(repo_id=hf_repo, filename="tokenizer.json")
self.tokenizer = Tokenizer.from_file(path)
except Exception as e:
raise RuntimeError(f"Failed to load tokenizer: {e}")
def generate_text(self,
prompt: str,
max_new_tokens: int = 120,
temperature: float = 0.9,
top_k: int = 40,
top_p: float = 0.9,
repetition_penalty: float = 1.1,
presence_penalty: float = 0.6,
frequency_penalty: float = 0.0) -> str:
"""
Generate text from a prompt.
Args:
prompt: Input text to continue from
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature (0.1-2.0 typical)
top_k: Top-k sampling (0 to disable)
top_p: Nucleus sampling threshold (0.0-1.0)
repetition_penalty: Penalty for repeated tokens
presence_penalty: Penalty for tokens that have appeared
frequency_penalty: Penalty based on token frequency
Returns:
Generated text string
"""
return generate(
model=self.model,
tok=self.tokenizer,
cfg=self.config,
prompt=prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
device=self.device,
detokenize=True
)
def batch_generate(self, prompts: list, **kwargs) -> list:
"""Generate text for multiple prompts."""
results = []
for prompt in prompts:
results.append(self.generate_text(prompt, **kwargs))
return results
def main():
"""Example usage of the Beeper inference class."""
# Initialize the model
print("Initializing Rose Beeper model...")
beeper = BeeperInference(
checkpoint_path=None, # Will download from HF
device=None # Auto-select GPU if available
)
# Example prompts
prompts = [
"The robot went to school and",
"Once upon a time in a distant galaxy,",
"The meaning of life is",
"In the beginning, there was",
"The scientist discovered that",
]
print("\n" + "="*60)
print("GENERATING SAMPLES")
print("="*60 + "\n")
for prompt in prompts:
print(f"Prompt: {prompt}")
print("-" * 40)
# Generate with different settings
# Standard generation
output = beeper.generate_text(
prompt=prompt,
max_new_tokens=100,
temperature=0.9,
top_k=40,
top_p=0.9
)
print(f"Output: {output}")
print()
# More creative generation
creative_output = beeper.generate_text(
prompt=prompt,
max_new_tokens=50,
temperature=1.2,
top_k=50,
top_p=0.95,
repetition_penalty=1.2
)
print(f"Creative: {creative_output}")
print("\n" + "="*60 + "\n")
if __name__ == "__main__":
main()