meet-beeper / load_for_inference.py
AbstractPhil's picture
Create load_for_inference.py
c84b8a9 verified
raw
history blame
7.76 kB
"""
Rose Beeper Model - Inference Example
Simple script showing how to load and use the model for text generation
"""
import torch
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
# Import the extracted components (assuming they're in a module called 'beeper_inference')
# from beeper_inference import BeeperRoseGPT, BeeperIO, generate, get_default_config
def load_model_for_inference(
checkpoint_path: str = None,
tokenizer_path: str = "beeper.tokenizer.json",
hf_repo: str = "AbstractPhil/beeper-rose-v5",
device: str = "cuda"
):
"""
Load the Rose Beeper model for inference.
Args:
checkpoint_path: Path to local checkpoint file (.pt or .safetensors)
tokenizer_path: Path to tokenizer file
hf_repo: HuggingFace repository to download from if no local checkpoint
device: Device to load model on ("cuda" or "cpu")
Returns:
Tuple of (model, tokenizer, config)
"""
# Get default configuration
config = get_default_config()
# Set device
device = torch.device(device if torch.cuda.is_available() else "cpu")
# Initialize model
model = BeeperRoseGPT(config).to(device)
# Initialize pentachora banks
# These are the default sizes from the training configuration
cap_cfg = config.get("capoera", {})
coarse_C = 20 # Approximate number of alive datasets
model.ensure_pentachora(
coarse_C=coarse_C,
medium_C=int(cap_cfg.get("topic_bins", 512)),
fine_C=int(cap_cfg.get("mood_bins", 7)),
dim=config["dim"],
device=device
)
# Load checkpoint
loaded = False
# Try loading from local path
if checkpoint_path and os.path.exists(checkpoint_path):
print(f"Loading model from: {checkpoint_path}")
missing, unexpected = BeeperIO.load_into_model(
model, checkpoint_path, map_location="cpu", strict=False
)
print(f"Loaded | missing={len(missing)} unexpected={len(unexpected)}")
loaded = True
# Try downloading from HuggingFace
if not loaded and hf_repo:
try:
print(f"Downloading model from HuggingFace: {hf_repo}")
path = hf_hub_download(repo_id=hf_repo, filename="beeper_final.safetensors")
missing, unexpected = BeeperIO.load_into_model(
model, path, map_location="cpu", strict=False
)
print(f"Loaded | missing={len(missing)} unexpected={len(unexpected)}")
loaded = True
except Exception as e:
print(f"Failed to download from HuggingFace: {e}")
if not loaded:
print("WARNING: No weights loaded, using random initialization!")
# Load tokenizer
if os.path.exists(tokenizer_path):
tok = Tokenizer.from_file(tokenizer_path)
print(f"Loaded tokenizer from: {tokenizer_path}")
else:
# Try downloading tokenizer from HF
try:
tok_path = hf_hub_download(repo_id=hf_repo, filename="tokenizer.json")
tok = Tokenizer.from_file(tok_path)
print(f"Downloaded tokenizer from HuggingFace")
except Exception as e:
raise RuntimeError(f"Could not load tokenizer: {e}")
# Set model to eval mode
model.eval()
return model, tok, config
def interactive_generation(model, tokenizer, config, device="cuda"):
"""
Interactive text generation loop.
Args:
model: The loaded BeeperRoseGPT model
tokenizer: The tokenizer
config: Model configuration
device: Device to run on
"""
device = torch.device(device if torch.cuda.is_available() else "cpu")
model = model.to(device)
print("\n=== Rose Beeper Interactive Generation ===")
print("Enter your prompt (or 'quit' to exit)")
print("Commands: /temp <value>, /top_k <value>, /top_p <value>, /max <tokens>")
print("-" * 50)
# Generation settings (can be modified)
settings = {
"max_new_tokens": 100,
"temperature": config["temperature"],
"top_k": config["top_k"],
"top_p": config["top_p"],
"repetition_penalty": config["repetition_penalty"],
"presence_penalty": config["presence_penalty"],
"frequency_penalty": config["frequency_penalty"],
}
while True:
prompt = input("\nPrompt: ").strip()
if prompt.lower() == 'quit':
break
# Handle commands
if prompt.startswith('/'):
parts = prompt.split()
cmd = parts[0].lower()
if cmd == '/temp' and len(parts) > 1:
settings["temperature"] = float(parts[1])
print(f"Temperature set to {settings['temperature']}")
continue
elif cmd == '/top_k' and len(parts) > 1:
settings["top_k"] = int(parts[1])
print(f"Top-k set to {settings['top_k']}")
continue
elif cmd == '/top_p' and len(parts) > 1:
settings["top_p"] = float(parts[1])
print(f"Top-p set to {settings['top_p']}")
continue
elif cmd == '/max' and len(parts) > 1:
settings["max_new_tokens"] = int(parts[1])
print(f"Max tokens set to {settings['max_new_tokens']}")
continue
else:
print("Unknown command")
continue
if not prompt:
continue
# Generate text
print("\nGenerating...")
output = generate(
model=model,
tok=tokenizer,
cfg=config,
prompt=prompt,
device=device,
**settings
)
print("\nOutput:", output)
print("-" * 50)
def batch_generation_example(model, tokenizer, config, device="cuda"):
"""
Example of batch generation with different settings.
"""
device = torch.device(device if torch.cuda.is_available() else "cpu")
model = model.to(device)
prompts = [
"The robot went to school and",
"Once upon a time in a magical forest",
"The scientist discovered that",
"In the year 2050, humanity",
"The philosophy of mind suggests",
]
print("\n=== Batch Generation Examples ===\n")
for prompt in prompts:
print(f"Prompt: {prompt}")
# Generate with different temperatures
for temp in [0.5, 0.9, 1.2]:
output = generate(
model=model,
tok=tokenizer,
cfg=config,
prompt=prompt,
max_new_tokens=50,
temperature=temp,
device=device
)
print(f" Temp {temp}: {output}")
print("-" * 50)
# Main execution example
if __name__ == "__main__":
import os
# Load model
model, tokenizer, config = load_model_for_inference(
checkpoint_path=None, # Will download from HF
hf_repo="AbstractPhil/beeper-rose-v5",
device="cuda"
)
# Example: Single generation
print("\n=== Single Generation Example ===")
output = generate(
model=model,
tok=tokenizer,
cfg=config,
prompt="The meaning of life is",
max_new_tokens=100,
temperature=0.9,
device="cuda"
)
print(f"Output: {output}")
# Example: Batch generation with different settings
# batch_generation_example(model, tokenizer, config)
# Example: Interactive generation
# interactive_generation(model, tokenizer, config)