AbstractPhil commited on
Commit
6a080c2
·
verified ·
1 Parent(s): d741dd0

Update load_for_inference.py

Browse files
Files changed (1) hide show
  1. load_for_inference.py +181 -207
load_for_inference.py CHANGED
@@ -1,241 +1,215 @@
1
  """
2
- Rose Beeper Model - Inference Example
3
- Simple script showing how to load and use the model for text generation
4
  """
5
 
6
  import torch
7
  from tokenizers import Tokenizer
8
  from huggingface_hub import hf_hub_download
 
9
 
10
- # Import the extracted components (assuming they're in a module called 'beeper_inference')
11
- # from beeper_inference import BeeperRoseGPT, BeeperIO, generate, get_default_config
 
 
 
 
 
12
 
13
- def load_model_for_inference(
14
- checkpoint_path: str = None,
15
- tokenizer_path: str = "beeper.tokenizer.json",
16
- hf_repo: str = "AbstractPhil/beeper-rose-v5",
17
- device: str = "cuda"
18
- ):
19
- """
20
- Load the Rose Beeper model for inference.
21
-
22
- Args:
23
- checkpoint_path: Path to local checkpoint file (.pt or .safetensors)
24
- tokenizer_path: Path to tokenizer file
25
- hf_repo: HuggingFace repository to download from if no local checkpoint
26
- device: Device to load model on ("cuda" or "cpu")
27
-
28
- Returns:
29
- Tuple of (model, tokenizer, config)
30
- """
31
- # Get default configuration
32
- config = get_default_config()
33
-
34
- # Set device
35
- device = torch.device(device if torch.cuda.is_available() else "cpu")
36
-
37
- # Initialize model
38
- model = BeeperRoseGPT(config).to(device)
39
-
40
- # Initialize pentachora banks
41
- # These are the default sizes from the training configuration
42
- cap_cfg = config.get("capoera", {})
43
- coarse_C = 20 # Approximate number of alive datasets
44
- model.ensure_pentachora(
45
- coarse_C=coarse_C,
46
- medium_C=int(cap_cfg.get("topic_bins", 512)),
47
- fine_C=int(cap_cfg.get("mood_bins", 7)),
48
- dim=config["dim"],
49
- device=device
50
- )
51
-
52
- # Load checkpoint
53
- loaded = False
54
-
55
- # Try loading from local path
56
- if checkpoint_path and os.path.exists(checkpoint_path):
57
- print(f"Loading model from: {checkpoint_path}")
58
- missing, unexpected = BeeperIO.load_into_model(
59
- model, checkpoint_path, map_location="cpu", strict=False
60
  )
61
- print(f"Loaded | missing={len(missing)} unexpected={len(unexpected)}")
62
- loaded = True
63
-
64
- # Try downloading from HuggingFace
65
- if not loaded and hf_repo:
66
- try:
67
- print(f"Downloading model from HuggingFace: {hf_repo}")
68
- path = hf_hub_download(repo_id=hf_repo, filename="beeper_final.safetensors")
 
 
 
 
 
 
 
 
 
69
  missing, unexpected = BeeperIO.load_into_model(
70
- model, path, map_location="cpu", strict=False
71
  )
72
  print(f"Loaded | missing={len(missing)} unexpected={len(unexpected)}")
73
  loaded = True
74
- except Exception as e:
75
- print(f"Failed to download from HuggingFace: {e}")
76
-
77
- if not loaded:
78
- print("WARNING: No weights loaded, using random initialization!")
79
-
80
- # Load tokenizer
81
- if os.path.exists(tokenizer_path):
82
- tok = Tokenizer.from_file(tokenizer_path)
83
- print(f"Loaded tokenizer from: {tokenizer_path}")
84
- else:
85
- # Try downloading tokenizer from HF
86
- try:
87
- tok_path = hf_hub_download(repo_id=hf_repo, filename="tokenizer.json")
88
- tok = Tokenizer.from_file(tok_path)
89
- print(f"Downloaded tokenizer from HuggingFace")
90
- except Exception as e:
91
- raise RuntimeError(f"Could not load tokenizer: {e}")
92
-
93
- # Set model to eval mode
94
- model.eval()
95
-
96
- return model, tok, config
97
-
98
-
99
- def interactive_generation(model, tokenizer, config, device="cuda"):
100
- """
101
- Interactive text generation loop.
102
-
103
- Args:
104
- model: The loaded BeeperRoseGPT model
105
- tokenizer: The tokenizer
106
- config: Model configuration
107
- device: Device to run on
108
- """
109
- device = torch.device(device if torch.cuda.is_available() else "cpu")
110
- model = model.to(device)
111
-
112
- print("\n=== Rose Beeper Interactive Generation ===")
113
- print("Enter your prompt (or 'quit' to exit)")
114
- print("Commands: /temp <value>, /top_k <value>, /top_p <value>, /max <tokens>")
115
- print("-" * 50)
116
-
117
- # Generation settings (can be modified)
118
- settings = {
119
- "max_new_tokens": 100,
120
- "temperature": config["temperature"],
121
- "top_k": config["top_k"],
122
- "top_p": config["top_p"],
123
- "repetition_penalty": config["repetition_penalty"],
124
- "presence_penalty": config["presence_penalty"],
125
- "frequency_penalty": config["frequency_penalty"],
126
- }
127
-
128
- while True:
129
- prompt = input("\nPrompt: ").strip()
130
 
131
- if prompt.lower() == 'quit':
132
- break
 
 
 
 
 
 
 
 
 
 
133
 
134
- # Handle commands
135
- if prompt.startswith('/'):
136
- parts = prompt.split()
137
- cmd = parts[0].lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- if cmd == '/temp' and len(parts) > 1:
140
- settings["temperature"] = float(parts[1])
141
- print(f"Temperature set to {settings['temperature']}")
142
- continue
143
- elif cmd == '/top_k' and len(parts) > 1:
144
- settings["top_k"] = int(parts[1])
145
- print(f"Top-k set to {settings['top_k']}")
146
- continue
147
- elif cmd == '/top_p' and len(parts) > 1:
148
- settings["top_p"] = float(parts[1])
149
- print(f"Top-p set to {settings['top_p']}")
150
- continue
151
- elif cmd == '/max' and len(parts) > 1:
152
- settings["max_new_tokens"] = int(parts[1])
153
- print(f"Max tokens set to {settings['max_new_tokens']}")
154
- continue
155
- else:
156
- print("Unknown command")
157
- continue
158
-
159
- if not prompt:
160
- continue
161
-
162
- # Generate text
163
- print("\nGenerating...")
164
- output = generate(
165
- model=model,
166
- tok=tokenizer,
167
- cfg=config,
168
  prompt=prompt,
169
- device=device,
170
- **settings
 
 
 
 
 
 
 
171
  )
172
-
173
- print("\nOutput:", output)
174
- print("-" * 50)
 
 
 
 
175
 
176
 
177
- def batch_generation_example(model, tokenizer, config, device="cuda"):
178
- """
179
- Example of batch generation with different settings.
180
- """
181
- device = torch.device(device if torch.cuda.is_available() else "cpu")
182
- model = model.to(device)
 
 
 
183
 
 
184
  prompts = [
185
  "The robot went to school and",
186
- "Once upon a time in a magical forest",
 
 
187
  "The scientist discovered that",
188
- "In the year 2050, humanity",
189
- "The philosophy of mind suggests",
190
  ]
191
 
192
- print("\n=== Batch Generation Examples ===\n")
 
 
193
 
194
  for prompt in prompts:
195
  print(f"Prompt: {prompt}")
 
196
 
197
- # Generate with different temperatures
198
- for temp in [0.5, 0.9, 1.2]:
199
- output = generate(
200
- model=model,
201
- tok=tokenizer,
202
- cfg=config,
203
- prompt=prompt,
204
- max_new_tokens=50,
205
- temperature=temp,
206
- device=device
207
- )
208
- print(f" Temp {temp}: {output}")
209
 
210
- print("-" * 50)
 
 
 
 
 
 
 
 
 
 
211
 
212
 
213
- # Main execution example
214
  if __name__ == "__main__":
215
- import os
216
-
217
- # Load model
218
- model, tokenizer, config = load_model_for_inference(
219
- checkpoint_path=None, # Will download from HF
220
- hf_repo="AbstractPhil/beeper-rose-v5",
221
- device="cuda"
222
- )
223
-
224
- # Example: Single generation
225
- print("\n=== Single Generation Example ===")
226
- output = generate(
227
- model=model,
228
- tok=tokenizer,
229
- cfg=config,
230
- prompt="The meaning of life is",
231
- max_new_tokens=100,
232
- temperature=0.9,
233
- device="cuda"
234
- )
235
- print(f"Output: {output}")
236
-
237
- # Example: Batch generation with different settings
238
- # batch_generation_example(model, tokenizer, config)
239
-
240
- # Example: Interactive generation
241
- # interactive_generation(model, tokenizer, config)
 
1
  """
2
+ Example script for running inference with the Rose Beeper model.
 
3
  """
4
 
5
  import torch
6
  from tokenizers import Tokenizer
7
  from huggingface_hub import hf_hub_download
8
+ import os
9
 
10
+ # Import the inference components (from the previous artifact)
11
+ from beeper_inference import (
12
+ BeeperRoseGPT,
13
+ BeeperIO,
14
+ generate,
15
+ get_default_config
16
+ )
17
 
18
+
19
+ class BeeperInference:
20
+ """Wrapper class for easy inference with the Rose Beeper model."""
21
+
22
+ def __init__(self,
23
+ checkpoint_path: str = None,
24
+ tokenizer_path: str = "beeper.tokenizer.json",
25
+ device: str = None,
26
+ hf_repo: str = "AbstractPhil/beeper-rose-v5"):
27
+ """
28
+ Initialize the Beeper model for inference.
29
+
30
+ Args:
31
+ checkpoint_path: Path to local checkpoint file (.pt or .safetensors)
32
+ tokenizer_path: Path to tokenizer file
33
+ device: Device to run on ('cuda', 'cpu', or None for auto)
34
+ hf_repo: HuggingFace repository to download from if no local checkpoint
35
+ """
36
+
37
+ # Set device
38
+ if device is None:
39
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ else:
41
+ self.device = torch.device(device)
42
+
43
+ print(f"Using device: {self.device}")
44
+
45
+ # Load configuration
46
+ self.config = get_default_config()
47
+
48
+ # Initialize model
49
+ self.model = BeeperRoseGPT(self.config).to(self.device)
50
+
51
+ # Initialize pentachora banks
52
+ cap_cfg = self.config.get("capoera", {})
53
+ # Using default sizes since we don't have the exact corpus info at inference
54
+ self.model.ensure_pentachora(
55
+ coarse_C=20, # Approximate number of datasets
56
+ medium_C=int(cap_cfg.get("topic_bins", 512)),
57
+ fine_C=int(cap_cfg.get("mood_bins", 7)),
58
+ dim=self.config["dim"],
59
+ device=self.device
 
 
 
 
 
60
  )
61
+
62
+ # Load weights
63
+ self._load_weights(checkpoint_path, hf_repo)
64
+
65
+ # Load tokenizer
66
+ self._load_tokenizer(tokenizer_path, hf_repo)
67
+
68
+ # Set to eval mode
69
+ self.model.eval()
70
+
71
+ def _load_weights(self, checkpoint_path: str, hf_repo: str):
72
+ """Load model weights from local file or HuggingFace."""
73
+ loaded = False
74
+
75
+ # Try local checkpoint first
76
+ if checkpoint_path and os.path.exists(checkpoint_path):
77
+ print(f"Loading weights from: {checkpoint_path}")
78
  missing, unexpected = BeeperIO.load_into_model(
79
+ self.model, checkpoint_path, map_location=str(self.device), strict=False
80
  )
81
  print(f"Loaded | missing={len(missing)} unexpected={len(unexpected)}")
82
  loaded = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ # Try HuggingFace if no local checkpoint
85
+ if not loaded and hf_repo:
86
+ try:
87
+ print(f"Downloading weights from HuggingFace: {hf_repo}")
88
+ path = hf_hub_download(repo_id=hf_repo, filename="beeper_final.safetensors")
89
+ missing, unexpected = BeeperIO.load_into_model(
90
+ self.model, path, map_location=str(self.device), strict=False
91
+ )
92
+ print(f"Loaded | missing={len(missing)} unexpected={len(unexpected)}")
93
+ loaded = True
94
+ except Exception as e:
95
+ print(f"Failed to download from HuggingFace: {e}")
96
 
97
+ if not loaded:
98
+ print("WARNING: No weights loaded, using random initialization!")
99
+
100
+ def _load_tokenizer(self, tokenizer_path: str, hf_repo: str):
101
+ """Load tokenizer from local file or HuggingFace."""
102
+ if os.path.exists(tokenizer_path):
103
+ print(f"Loading tokenizer from: {tokenizer_path}")
104
+ self.tokenizer = Tokenizer.from_file(tokenizer_path)
105
+ else:
106
+ try:
107
+ print(f"Downloading tokenizer from HuggingFace: {hf_repo}")
108
+ path = hf_hub_download(repo_id=hf_repo, filename="tokenizer.json")
109
+ self.tokenizer = Tokenizer.from_file(path)
110
+ except Exception as e:
111
+ raise RuntimeError(f"Failed to load tokenizer: {e}")
112
+
113
+ def generate_text(self,
114
+ prompt: str,
115
+ max_new_tokens: int = 120,
116
+ temperature: float = 0.9,
117
+ top_k: int = 40,
118
+ top_p: float = 0.9,
119
+ repetition_penalty: float = 1.1,
120
+ presence_penalty: float = 0.6,
121
+ frequency_penalty: float = 0.0) -> str:
122
+ """
123
+ Generate text from a prompt.
124
+
125
+ Args:
126
+ prompt: Input text to continue from
127
+ max_new_tokens: Maximum tokens to generate
128
+ temperature: Sampling temperature (0.1-2.0 typical)
129
+ top_k: Top-k sampling (0 to disable)
130
+ top_p: Nucleus sampling threshold (0.0-1.0)
131
+ repetition_penalty: Penalty for repeated tokens
132
+ presence_penalty: Penalty for tokens that have appeared
133
+ frequency_penalty: Penalty based on token frequency
134
 
135
+ Returns:
136
+ Generated text string
137
+ """
138
+ return generate(
139
+ model=self.model,
140
+ tok=self.tokenizer,
141
+ cfg=self.config,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  prompt=prompt,
143
+ max_new_tokens=max_new_tokens,
144
+ temperature=temperature,
145
+ top_k=top_k,
146
+ top_p=top_p,
147
+ repetition_penalty=repetition_penalty,
148
+ presence_penalty=presence_penalty,
149
+ frequency_penalty=frequency_penalty,
150
+ device=self.device,
151
+ detokenize=True
152
  )
153
+
154
+ def batch_generate(self, prompts: list, **kwargs) -> list:
155
+ """Generate text for multiple prompts."""
156
+ results = []
157
+ for prompt in prompts:
158
+ results.append(self.generate_text(prompt, **kwargs))
159
+ return results
160
 
161
 
162
+ def main():
163
+ """Example usage of the Beeper inference class."""
164
+
165
+ # Initialize the model
166
+ print("Initializing Rose Beeper model...")
167
+ beeper = BeeperInference(
168
+ checkpoint_path=None, # Will download from HF
169
+ device=None # Auto-select GPU if available
170
+ )
171
 
172
+ # Example prompts
173
  prompts = [
174
  "The robot went to school and",
175
+ "Once upon a time in a distant galaxy,",
176
+ "The meaning of life is",
177
+ "In the beginning, there was",
178
  "The scientist discovered that",
 
 
179
  ]
180
 
181
+ print("\n" + "="*60)
182
+ print("GENERATING SAMPLES")
183
+ print("="*60 + "\n")
184
 
185
  for prompt in prompts:
186
  print(f"Prompt: {prompt}")
187
+ print("-" * 40)
188
 
189
+ # Generate with different settings
190
+ # Standard generation
191
+ output = beeper.generate_text(
192
+ prompt=prompt,
193
+ max_new_tokens=100,
194
+ temperature=0.9,
195
+ top_k=40,
196
+ top_p=0.9
197
+ )
198
+ print(f"Output: {output}")
199
+ print()
 
200
 
201
+ # More creative generation
202
+ creative_output = beeper.generate_text(
203
+ prompt=prompt,
204
+ max_new_tokens=50,
205
+ temperature=1.2,
206
+ top_k=50,
207
+ top_p=0.95,
208
+ repetition_penalty=1.2
209
+ )
210
+ print(f"Creative: {creative_output}")
211
+ print("\n" + "="*60 + "\n")
212
 
213
 
 
214
  if __name__ == "__main__":
215
+ main()