Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Example script for running inference with the Rose Beeper model. | |
""" | |
import torch | |
from tokenizers import Tokenizer | |
from huggingface_hub import hf_hub_download | |
import os | |
# Import the inference components (from the previous artifact) | |
from beeper_inference import ( | |
BeeperRoseGPT, | |
BeeperIO, | |
generate, | |
get_default_config | |
) | |
class BeeperInference: | |
"""Wrapper class for easy inference with the Rose Beeper model.""" | |
def __init__(self, | |
checkpoint_path: str = None, | |
tokenizer_path: str = "beeper.tokenizer.json", | |
device: str = None, | |
hf_repo: str = "AbstractPhil/beeper-rose-v5"): | |
""" | |
Initialize the Beeper model for inference. | |
Args: | |
checkpoint_path: Path to local checkpoint file (.pt or .safetensors) | |
tokenizer_path: Path to tokenizer file | |
device: Device to run on ('cuda', 'cpu', or None for auto) | |
hf_repo: HuggingFace repository to download from if no local checkpoint | |
""" | |
# Set device | |
if device is None: | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
else: | |
self.device = torch.device(device) | |
print(f"Using device: {self.device}") | |
# Load configuration | |
self.config = get_default_config() | |
# Initialize model | |
self.model = BeeperRoseGPT(self.config).to(self.device) | |
# Initialize pentachora banks | |
cap_cfg = self.config.get("capoera", {}) | |
# Using default sizes since we don't have the exact corpus info at inference | |
self.model.ensure_pentachora( | |
coarse_C=20, # Approximate number of datasets | |
medium_C=int(cap_cfg.get("topic_bins", 512)), | |
fine_C=int(cap_cfg.get("mood_bins", 7)), | |
dim=self.config["dim"], | |
device=self.device | |
) | |
# Load weights | |
self._load_weights(checkpoint_path, hf_repo) | |
# Load tokenizer | |
self._load_tokenizer(tokenizer_path, hf_repo) | |
# Set to eval mode | |
self.model.eval() | |
def _load_weights(self, checkpoint_path: str, hf_repo: str): | |
"""Load model weights from local file or HuggingFace.""" | |
loaded = False | |
# Try local checkpoint first | |
if checkpoint_path and os.path.exists(checkpoint_path): | |
print(f"Loading weights from: {checkpoint_path}") | |
missing, unexpected = BeeperIO.load_into_model( | |
self.model, checkpoint_path, map_location=str(self.device), strict=False | |
) | |
print(f"Loaded | missing={len(missing)} unexpected={len(unexpected)}") | |
loaded = True | |
# Try HuggingFace if no local checkpoint | |
if not loaded and hf_repo: | |
try: | |
print(f"Downloading weights from HuggingFace: {hf_repo}") | |
path = hf_hub_download(repo_id=hf_repo, filename="beeper_final.safetensors") | |
missing, unexpected = BeeperIO.load_into_model( | |
self.model, path, map_location=str(self.device), strict=False | |
) | |
print(f"Loaded | missing={len(missing)} unexpected={len(unexpected)}") | |
loaded = True | |
except Exception as e: | |
print(f"Failed to download from HuggingFace: {e}") | |
if not loaded: | |
print("WARNING: No weights loaded, using random initialization!") | |
def _load_tokenizer(self, tokenizer_path: str, hf_repo: str): | |
"""Load tokenizer from local file or HuggingFace.""" | |
if os.path.exists(tokenizer_path): | |
print(f"Loading tokenizer from: {tokenizer_path}") | |
self.tokenizer = Tokenizer.from_file(tokenizer_path) | |
else: | |
try: | |
print(f"Downloading tokenizer from HuggingFace: {hf_repo}") | |
path = hf_hub_download(repo_id=hf_repo, filename="tokenizer.json") | |
self.tokenizer = Tokenizer.from_file(path) | |
except Exception as e: | |
raise RuntimeError(f"Failed to load tokenizer: {e}") | |
def generate_text(self, | |
prompt: str, | |
max_new_tokens: int = 120, | |
temperature: float = 0.9, | |
top_k: int = 40, | |
top_p: float = 0.9, | |
repetition_penalty: float = 1.1, | |
presence_penalty: float = 0.6, | |
frequency_penalty: float = 0.0) -> str: | |
""" | |
Generate text from a prompt. | |
Args: | |
prompt: Input text to continue from | |
max_new_tokens: Maximum tokens to generate | |
temperature: Sampling temperature (0.1-2.0 typical) | |
top_k: Top-k sampling (0 to disable) | |
top_p: Nucleus sampling threshold (0.0-1.0) | |
repetition_penalty: Penalty for repeated tokens | |
presence_penalty: Penalty for tokens that have appeared | |
frequency_penalty: Penalty based on token frequency | |
Returns: | |
Generated text string | |
""" | |
return generate( | |
model=self.model, | |
tok=self.tokenizer, | |
cfg=self.config, | |
prompt=prompt, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
presence_penalty=presence_penalty, | |
frequency_penalty=frequency_penalty, | |
device=self.device, | |
detokenize=True | |
) | |
def batch_generate(self, prompts: list, **kwargs) -> list: | |
"""Generate text for multiple prompts.""" | |
results = [] | |
for prompt in prompts: | |
results.append(self.generate_text(prompt, **kwargs)) | |
return results | |
def main(): | |
"""Example usage of the Beeper inference class.""" | |
# Initialize the model | |
print("Initializing Rose Beeper model...") | |
beeper = BeeperInference( | |
checkpoint_path=None, # Will download from HF | |
device=None # Auto-select GPU if available | |
) | |
# Example prompts | |
prompts = [ | |
"The robot went to school and", | |
"Once upon a time in a distant galaxy,", | |
"The meaning of life is", | |
"In the beginning, there was", | |
"The scientist discovered that", | |
] | |
print("\n" + "="*60) | |
print("GENERATING SAMPLES") | |
print("="*60 + "\n") | |
for prompt in prompts: | |
print(f"Prompt: {prompt}") | |
print("-" * 40) | |
# Generate with different settings | |
# Standard generation | |
output = beeper.generate_text( | |
prompt=prompt, | |
max_new_tokens=100, | |
temperature=0.9, | |
top_k=40, | |
top_p=0.9 | |
) | |
print(f"Output: {output}") | |
print() | |
# More creative generation | |
creative_output = beeper.generate_text( | |
prompt=prompt, | |
max_new_tokens=50, | |
temperature=1.2, | |
top_k=50, | |
top_p=0.95, | |
repetition_penalty=1.2 | |
) | |
print(f"Creative: {creative_output}") | |
print("\n" + "="*60 + "\n") | |
if __name__ == "__main__": | |
main() |