Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,441 Bytes
c84b8a9 6a080c2 c84b8a9 6a080c2 c84b8a9 6a080c2 c84b8a9 6a080c2 c84b8a9 6a080c2 c84b8a9 6a080c2 c84b8a9 6a080c2 c84b8a9 6a080c2 c84b8a9 6a080c2 c84b8a9 6a080c2 c84b8a9 6a080c2 c84b8a9 6a080c2 c84b8a9 6a080c2 c84b8a9 6a080c2 c84b8a9 6a080c2 c84b8a9 6a080c2 c84b8a9 6a080c2 c84b8a9 6a080c2 c84b8a9 6a080c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
"""
Example script for running inference with the Rose Beeper model.
"""
import torch
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
import os
# Import the inference components (from the previous artifact)
from beeper_inference import (
BeeperRoseGPT,
BeeperIO,
generate,
get_default_config
)
class BeeperInference:
"""Wrapper class for easy inference with the Rose Beeper model."""
def __init__(self,
checkpoint_path: str = None,
tokenizer_path: str = "beeper.tokenizer.json",
device: str = None,
hf_repo: str = "AbstractPhil/beeper-rose-v5"):
"""
Initialize the Beeper model for inference.
Args:
checkpoint_path: Path to local checkpoint file (.pt or .safetensors)
tokenizer_path: Path to tokenizer file
device: Device to run on ('cuda', 'cpu', or None for auto)
hf_repo: HuggingFace repository to download from if no local checkpoint
"""
# Set device
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
print(f"Using device: {self.device}")
# Load configuration
self.config = get_default_config()
# Initialize model
self.model = BeeperRoseGPT(self.config).to(self.device)
# Initialize pentachora banks
cap_cfg = self.config.get("capoera", {})
# Using default sizes since we don't have the exact corpus info at inference
self.model.ensure_pentachora(
coarse_C=20, # Approximate number of datasets
medium_C=int(cap_cfg.get("topic_bins", 512)),
fine_C=int(cap_cfg.get("mood_bins", 7)),
dim=self.config["dim"],
device=self.device
)
# Load weights
self._load_weights(checkpoint_path, hf_repo)
# Load tokenizer
self._load_tokenizer(tokenizer_path, hf_repo)
# Set to eval mode
self.model.eval()
def _load_weights(self, checkpoint_path: str, hf_repo: str):
"""Load model weights from local file or HuggingFace."""
loaded = False
# Try local checkpoint first
if checkpoint_path and os.path.exists(checkpoint_path):
print(f"Loading weights from: {checkpoint_path}")
missing, unexpected = BeeperIO.load_into_model(
self.model, checkpoint_path, map_location=str(self.device), strict=False
)
print(f"Loaded | missing={len(missing)} unexpected={len(unexpected)}")
loaded = True
# Try HuggingFace if no local checkpoint
if not loaded and hf_repo:
try:
print(f"Downloading weights from HuggingFace: {hf_repo}")
path = hf_hub_download(repo_id=hf_repo, filename="beeper_final.safetensors")
missing, unexpected = BeeperIO.load_into_model(
self.model, path, map_location=str(self.device), strict=False
)
print(f"Loaded | missing={len(missing)} unexpected={len(unexpected)}")
loaded = True
except Exception as e:
print(f"Failed to download from HuggingFace: {e}")
if not loaded:
print("WARNING: No weights loaded, using random initialization!")
def _load_tokenizer(self, tokenizer_path: str, hf_repo: str):
"""Load tokenizer from local file or HuggingFace."""
if os.path.exists(tokenizer_path):
print(f"Loading tokenizer from: {tokenizer_path}")
self.tokenizer = Tokenizer.from_file(tokenizer_path)
else:
try:
print(f"Downloading tokenizer from HuggingFace: {hf_repo}")
path = hf_hub_download(repo_id=hf_repo, filename="tokenizer.json")
self.tokenizer = Tokenizer.from_file(path)
except Exception as e:
raise RuntimeError(f"Failed to load tokenizer: {e}")
def generate_text(self,
prompt: str,
max_new_tokens: int = 120,
temperature: float = 0.9,
top_k: int = 40,
top_p: float = 0.9,
repetition_penalty: float = 1.1,
presence_penalty: float = 0.6,
frequency_penalty: float = 0.0) -> str:
"""
Generate text from a prompt.
Args:
prompt: Input text to continue from
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature (0.1-2.0 typical)
top_k: Top-k sampling (0 to disable)
top_p: Nucleus sampling threshold (0.0-1.0)
repetition_penalty: Penalty for repeated tokens
presence_penalty: Penalty for tokens that have appeared
frequency_penalty: Penalty based on token frequency
Returns:
Generated text string
"""
return generate(
model=self.model,
tok=self.tokenizer,
cfg=self.config,
prompt=prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
device=self.device,
detokenize=True
)
def batch_generate(self, prompts: list, **kwargs) -> list:
"""Generate text for multiple prompts."""
results = []
for prompt in prompts:
results.append(self.generate_text(prompt, **kwargs))
return results
def main():
"""Example usage of the Beeper inference class."""
# Initialize the model
print("Initializing Rose Beeper model...")
beeper = BeeperInference(
checkpoint_path=None, # Will download from HF
device=None # Auto-select GPU if available
)
# Example prompts
prompts = [
"The robot went to school and",
"Once upon a time in a distant galaxy,",
"The meaning of life is",
"In the beginning, there was",
"The scientist discovered that",
]
print("\n" + "="*60)
print("GENERATING SAMPLES")
print("="*60 + "\n")
for prompt in prompts:
print(f"Prompt: {prompt}")
print("-" * 40)
# Generate with different settings
# Standard generation
output = beeper.generate_text(
prompt=prompt,
max_new_tokens=100,
temperature=0.9,
top_k=40,
top_p=0.9
)
print(f"Output: {output}")
print()
# More creative generation
creative_output = beeper.generate_text(
prompt=prompt,
max_new_tokens=50,
temperature=1.2,
top_k=50,
top_p=0.95,
repetition_penalty=1.2
)
print(f"Creative: {creative_output}")
print("\n" + "="*60 + "\n")
if __name__ == "__main__":
main() |