AbstractPhil commited on
Commit
c84b8a9
·
verified ·
1 Parent(s): 99d979b

Create load_for_inference.py

Browse files
Files changed (1) hide show
  1. load_for_inference.py +241 -0
load_for_inference.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Rose Beeper Model - Inference Example
3
+ Simple script showing how to load and use the model for text generation
4
+ """
5
+
6
+ import torch
7
+ from tokenizers import Tokenizer
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ # Import the extracted components (assuming they're in a module called 'beeper_inference')
11
+ # from beeper_inference import BeeperRoseGPT, BeeperIO, generate, get_default_config
12
+
13
+ def load_model_for_inference(
14
+ checkpoint_path: str = None,
15
+ tokenizer_path: str = "beeper.tokenizer.json",
16
+ hf_repo: str = "AbstractPhil/beeper-rose-v5",
17
+ device: str = "cuda"
18
+ ):
19
+ """
20
+ Load the Rose Beeper model for inference.
21
+
22
+ Args:
23
+ checkpoint_path: Path to local checkpoint file (.pt or .safetensors)
24
+ tokenizer_path: Path to tokenizer file
25
+ hf_repo: HuggingFace repository to download from if no local checkpoint
26
+ device: Device to load model on ("cuda" or "cpu")
27
+
28
+ Returns:
29
+ Tuple of (model, tokenizer, config)
30
+ """
31
+ # Get default configuration
32
+ config = get_default_config()
33
+
34
+ # Set device
35
+ device = torch.device(device if torch.cuda.is_available() else "cpu")
36
+
37
+ # Initialize model
38
+ model = BeeperRoseGPT(config).to(device)
39
+
40
+ # Initialize pentachora banks
41
+ # These are the default sizes from the training configuration
42
+ cap_cfg = config.get("capoera", {})
43
+ coarse_C = 20 # Approximate number of alive datasets
44
+ model.ensure_pentachora(
45
+ coarse_C=coarse_C,
46
+ medium_C=int(cap_cfg.get("topic_bins", 512)),
47
+ fine_C=int(cap_cfg.get("mood_bins", 7)),
48
+ dim=config["dim"],
49
+ device=device
50
+ )
51
+
52
+ # Load checkpoint
53
+ loaded = False
54
+
55
+ # Try loading from local path
56
+ if checkpoint_path and os.path.exists(checkpoint_path):
57
+ print(f"Loading model from: {checkpoint_path}")
58
+ missing, unexpected = BeeperIO.load_into_model(
59
+ model, checkpoint_path, map_location="cpu", strict=False
60
+ )
61
+ print(f"Loaded | missing={len(missing)} unexpected={len(unexpected)}")
62
+ loaded = True
63
+
64
+ # Try downloading from HuggingFace
65
+ if not loaded and hf_repo:
66
+ try:
67
+ print(f"Downloading model from HuggingFace: {hf_repo}")
68
+ path = hf_hub_download(repo_id=hf_repo, filename="beeper_final.safetensors")
69
+ missing, unexpected = BeeperIO.load_into_model(
70
+ model, path, map_location="cpu", strict=False
71
+ )
72
+ print(f"Loaded | missing={len(missing)} unexpected={len(unexpected)}")
73
+ loaded = True
74
+ except Exception as e:
75
+ print(f"Failed to download from HuggingFace: {e}")
76
+
77
+ if not loaded:
78
+ print("WARNING: No weights loaded, using random initialization!")
79
+
80
+ # Load tokenizer
81
+ if os.path.exists(tokenizer_path):
82
+ tok = Tokenizer.from_file(tokenizer_path)
83
+ print(f"Loaded tokenizer from: {tokenizer_path}")
84
+ else:
85
+ # Try downloading tokenizer from HF
86
+ try:
87
+ tok_path = hf_hub_download(repo_id=hf_repo, filename="tokenizer.json")
88
+ tok = Tokenizer.from_file(tok_path)
89
+ print(f"Downloaded tokenizer from HuggingFace")
90
+ except Exception as e:
91
+ raise RuntimeError(f"Could not load tokenizer: {e}")
92
+
93
+ # Set model to eval mode
94
+ model.eval()
95
+
96
+ return model, tok, config
97
+
98
+
99
+ def interactive_generation(model, tokenizer, config, device="cuda"):
100
+ """
101
+ Interactive text generation loop.
102
+
103
+ Args:
104
+ model: The loaded BeeperRoseGPT model
105
+ tokenizer: The tokenizer
106
+ config: Model configuration
107
+ device: Device to run on
108
+ """
109
+ device = torch.device(device if torch.cuda.is_available() else "cpu")
110
+ model = model.to(device)
111
+
112
+ print("\n=== Rose Beeper Interactive Generation ===")
113
+ print("Enter your prompt (or 'quit' to exit)")
114
+ print("Commands: /temp <value>, /top_k <value>, /top_p <value>, /max <tokens>")
115
+ print("-" * 50)
116
+
117
+ # Generation settings (can be modified)
118
+ settings = {
119
+ "max_new_tokens": 100,
120
+ "temperature": config["temperature"],
121
+ "top_k": config["top_k"],
122
+ "top_p": config["top_p"],
123
+ "repetition_penalty": config["repetition_penalty"],
124
+ "presence_penalty": config["presence_penalty"],
125
+ "frequency_penalty": config["frequency_penalty"],
126
+ }
127
+
128
+ while True:
129
+ prompt = input("\nPrompt: ").strip()
130
+
131
+ if prompt.lower() == 'quit':
132
+ break
133
+
134
+ # Handle commands
135
+ if prompt.startswith('/'):
136
+ parts = prompt.split()
137
+ cmd = parts[0].lower()
138
+
139
+ if cmd == '/temp' and len(parts) > 1:
140
+ settings["temperature"] = float(parts[1])
141
+ print(f"Temperature set to {settings['temperature']}")
142
+ continue
143
+ elif cmd == '/top_k' and len(parts) > 1:
144
+ settings["top_k"] = int(parts[1])
145
+ print(f"Top-k set to {settings['top_k']}")
146
+ continue
147
+ elif cmd == '/top_p' and len(parts) > 1:
148
+ settings["top_p"] = float(parts[1])
149
+ print(f"Top-p set to {settings['top_p']}")
150
+ continue
151
+ elif cmd == '/max' and len(parts) > 1:
152
+ settings["max_new_tokens"] = int(parts[1])
153
+ print(f"Max tokens set to {settings['max_new_tokens']}")
154
+ continue
155
+ else:
156
+ print("Unknown command")
157
+ continue
158
+
159
+ if not prompt:
160
+ continue
161
+
162
+ # Generate text
163
+ print("\nGenerating...")
164
+ output = generate(
165
+ model=model,
166
+ tok=tokenizer,
167
+ cfg=config,
168
+ prompt=prompt,
169
+ device=device,
170
+ **settings
171
+ )
172
+
173
+ print("\nOutput:", output)
174
+ print("-" * 50)
175
+
176
+
177
+ def batch_generation_example(model, tokenizer, config, device="cuda"):
178
+ """
179
+ Example of batch generation with different settings.
180
+ """
181
+ device = torch.device(device if torch.cuda.is_available() else "cpu")
182
+ model = model.to(device)
183
+
184
+ prompts = [
185
+ "The robot went to school and",
186
+ "Once upon a time in a magical forest",
187
+ "The scientist discovered that",
188
+ "In the year 2050, humanity",
189
+ "The philosophy of mind suggests",
190
+ ]
191
+
192
+ print("\n=== Batch Generation Examples ===\n")
193
+
194
+ for prompt in prompts:
195
+ print(f"Prompt: {prompt}")
196
+
197
+ # Generate with different temperatures
198
+ for temp in [0.5, 0.9, 1.2]:
199
+ output = generate(
200
+ model=model,
201
+ tok=tokenizer,
202
+ cfg=config,
203
+ prompt=prompt,
204
+ max_new_tokens=50,
205
+ temperature=temp,
206
+ device=device
207
+ )
208
+ print(f" Temp {temp}: {output}")
209
+
210
+ print("-" * 50)
211
+
212
+
213
+ # Main execution example
214
+ if __name__ == "__main__":
215
+ import os
216
+
217
+ # Load model
218
+ model, tokenizer, config = load_model_for_inference(
219
+ checkpoint_path=None, # Will download from HF
220
+ hf_repo="AbstractPhil/beeper-rose-v5",
221
+ device="cuda"
222
+ )
223
+
224
+ # Example: Single generation
225
+ print("\n=== Single Generation Example ===")
226
+ output = generate(
227
+ model=model,
228
+ tok=tokenizer,
229
+ cfg=config,
230
+ prompt="The meaning of life is",
231
+ max_new_tokens=100,
232
+ temperature=0.9,
233
+ device="cuda"
234
+ )
235
+ print(f"Output: {output}")
236
+
237
+ # Example: Batch generation with different settings
238
+ # batch_generation_example(model, tokenizer, config)
239
+
240
+ # Example: Interactive generation
241
+ # interactive_generation(model, tokenizer, config)