KWRegan commited on
Commit
ed96d7f
·
1 Parent(s): 48f6843

Add application file

Browse files
Files changed (1) hide show
  1. lifgen-hook.py +578 -0
lifgen-hook.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% File "lifgen-hook.py" by Gishnu Madhu with editing and feature input by KWR
2
+ # Usage:
3
+ # python lifgen-hook.py -i <source-file> -o <output.lif> -pid <identifier> -pt <initial prompt> -mpv <# of scored items>
4
+ # optional: -nt <# of tokens to search for text word> -bw <width of beam search, 1 for greedy>, -nw <# of words>
5
+ # -st <word/token to start from, 1-based> -a <mode in 0...4 of treating tokens as in-bounds or matching>
6
+ # -model <model to use> #not yet implemented---change model manually for now.
7
+ #
8
+ # Qwen requires a HuggingFace token, needs, "pip install transformers --upgrade" and does a new 4.2GB download
9
+ # DeepSeek does not need an access token---just hit enter to give the empty string
10
+ #
11
+ # Example:
12
+ # python lifgen-hook.py -i YaoJokic.txt -o YaoJokicm50.lif -pid YaoTestByDeepSeek -pt "Compare Yao Ming and Nikola Jokic in NBA basketball" -mpv 50
13
+ import math
14
+
15
+ import torch
16
+ import gc
17
+ import time
18
+ from transformers import AutoTokenizer, AutoModelForCausalLM
19
+ import os
20
+ import argparse
21
+ import sys
22
+ import re
23
+ from huggingface_hub import login
24
+ import os
25
+ from unidecode import unidecode
26
+
27
+ def capture_logits_hook(module, input_args, output):
28
+ """
29
+ Hook function to capture the output of the lm_head layer.
30
+ The output might be a tensor or a tuple containing the tensor.
31
+ We are interested in the tensor containing logits.
32
+ """
33
+ if isinstance(output, torch.Tensor):
34
+ logits = output
35
+ elif isinstance(output, tuple) and len(output) > 0 and isinstance(output[0], torch.Tensor):
36
+ # Common case for models returning more than just logits (e.g., past_key_values)
37
+ # We assume the first element is the logits tensor. Check model docs if unsure.
38
+ logits = output[0]
39
+ else:
40
+ # Cannot determine logits tensor, skip capture for this call
41
+ print(f"Warning: Hook captured unexpected output type: {type(output)}")
42
+ return
43
+
44
+ parser = argparse.ArgumentParser(
45
+ description="LifGenerator for CPU with Hugging face models with greedy decoding",
46
+ epilog="Help Documentation"
47
+ )
48
+ parser.add_argument(
49
+ "-input_file", "-i",
50
+ type=str,
51
+ help="The path to the input file."
52
+ )
53
+
54
+ parser.add_argument(
55
+ "-output_file", "-o",
56
+ type=str,
57
+ help="Name and path of output file"
58
+ )
59
+
60
+ parser.add_argument(
61
+ "-prompt_id", "-pid",
62
+ type=str,
63
+ help="Overall name of item"
64
+ )
65
+
66
+ parser.add_argument(
67
+ "-prompt_topic", "-pt",
68
+ type=str,
69
+ help="Topic given to LLM before stem words"
70
+ )
71
+
72
+ parser.add_argument(
73
+ "-multi_pv", "-mpv",
74
+ type=int,
75
+ help="Number of options to consider at each turn"
76
+ )
77
+
78
+ parser.add_argument(
79
+ "-num_words", "-nw",
80
+ type=int,
81
+ help="Cap on # of text words to iterate"
82
+ )
83
+
84
+ parser.add_argument(
85
+ "-num_tokens", "-nt",
86
+ type=int,
87
+ help="# of tokens to search for text word match"
88
+ )
89
+
90
+ parser.add_argument(
91
+ "-beam_width", "-bw",
92
+ type=int,
93
+ help="Width of beam search, 0 or 1 for greedy"
94
+ )
95
+
96
+ parser.add_argument(
97
+ "-alpha_mode", "-a",
98
+ type=int,
99
+ help="0 = all tokens, up thru 4 = alpha chars plus ' only"
100
+ )
101
+
102
+ parser.add_argument(
103
+ "-start_turn", "-st",
104
+ type=int,
105
+ help="1 by default, adds st-1 words to prompt"
106
+ )
107
+
108
+ parser.add_argument(
109
+ "-model", "-model",
110
+ type=str,
111
+ help="DS for DeepSeek, QWEN for Qwen"
112
+ )
113
+
114
+ args = parser.parse_args()
115
+ print("Welcome to the LifGenerator CPU script!")
116
+ print("This script generates lif files using a Hugging Face model and greedy decoding.")
117
+ print(f"Input file path: {args.input_file}")
118
+ print(f"Output file path: {args.output_file}")
119
+ INPUT_FILE = args.input_file
120
+ INPUT_FILE_STEM = INPUT_FILE.split('.')[0]
121
+ OUTPUT_FILE = args.output_file if args.output_file else (INPUT_FILE_STEM + ".lif")
122
+ PROMPT_ID = args.prompt_id if args.prompt_id else INPUT_FILE
123
+ PROMPT_TOPIC = args.prompt_topic if args.prompt_topic else INPUT_FILE
124
+ MULTI_PV = args.multi_pv if args.multi_pv else 100
125
+ NUM_WORDS = args.num_words if args.num_words else 10000
126
+ NUM_TOKENS = args.num_tokens if args.num_tokens else 10000
127
+ BEAM_WIDTH = args.beam_width if args.beam_width else 1
128
+ ALPHA_MODE = args.alpha_mode if args.alpha_mode else 0
129
+ START_TURN = args.start_turn if args.start_turn else 1
130
+ MODEL_TAG = args.model if args.model else "Qwen"
131
+ MINUS_INF = -1000.0
132
+ # main(INPUT_FILE, OUTPUT_FILE, PROMPT_ID, PROMPT_TOPIC, MULTI_PV, NUM_WORDS, NUM_TOKENS, BEAM_WIDTH, ALPHA_MODE, MODEL_TAG)
133
+
134
+ """
135
+ Match if arg occurs in st surrounded by ends or non-alpha chars.
136
+
137
+ Intent is e.g. for "Karp" to match "Karp, R" but not "Karpov".
138
+ Whether "Karp" matches "Karp-Lipton" depends on whether hyphen is part of name.
139
+ Works even if arg itself has non-alpha characters.
140
+ Used for player and event names AND to identify tokens in command streams.
141
+ Uses C++ "isalpha" for local definition of names.
142
+ Prefer to override it to count underscore as a non-delimiting char.
143
+ Hyphen is always part of tokens but can be used to delimit place and person names,
144
+ so "Khanty" and "Khanty-Mansiysk" can both match "Khanty-Mansiysk" and
145
+ "Vachier" can match "Vachier-Lagrave".
146
+
147
+ With LLM tokens, this allows arg="abc" to match st=" abc" but not vice-versa.
148
+ However, if called with arg.strip() then vice-versa is fine.
149
+ If the token is @-@ then it will match "--" but NOT match a hyphenated word.
150
+ """
151
+
152
+
153
+ def borderedMatch(arg, st, hyphenDelimits=False, underscoreDelimits=False):
154
+ fromPos = st.find(arg)
155
+ while fromPos != -1:
156
+ leftOK = (fromPos == 0)
157
+ if (fromPos > 0):
158
+ c = st[fromPos - 1]
159
+ if c == '-':
160
+ leftOK = hyphenDelimits
161
+ elif c == '_':
162
+ leftOK = underscoreDelimits
163
+ else:
164
+ leftOK = (not c.isalnum())
165
+
166
+ rightEdge = fromPos + len(arg)
167
+ rightOK = (rightEdge == len(st))
168
+ if (not rightOK):
169
+ d = st[rightEdge]
170
+ if d == '-':
171
+ rightOK = hyphenDelimits
172
+ elif d == '_':
173
+ rightOK = underscoreDelimits
174
+ else:
175
+ rightOK = (not d.isalnum())
176
+
177
+ if rightOK and leftOK:
178
+ return True
179
+ else: # try to find another match
180
+ fromPos = st.find(arg, fromPos + 1)
181
+
182
+ return False
183
+
184
+
185
+ def reprat(tok):
186
+ rep = unidecode(repr(tok))
187
+ return f"@{rep.replace('@','(at)')[1:-1]}@"
188
+
189
+
190
+
191
+ hf_token = input("Enter your Huggingface token")
192
+ # Or better:
193
+ # hf_token = os.environ.get("HUGGING_FACE_HUB_TOKEN")
194
+
195
+ if hf_token:
196
+ print("Logging in to Hugging Face Hub...")
197
+ login(token=hf_token)
198
+ else:
199
+ print("HF Token not found. Gated model download might fail.")
200
+
201
+
202
+ def main(INPUT_FILE, OUTPUT_FILE, PROMPT_ID, PROMPT_TOPIC, MULTI_PV, NUM_WORDS, NUM_TOKENS, BEAM_WIDTH, ALPHA_MODE,
203
+ MODEL_TAG):
204
+ # %% Constants and Configuration
205
+ # MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
206
+ # MODEL_NAME = "google/gemma-3-4b-it"
207
+ # MODEL_NAME = "Qwen/Qwen3-1.7B"
208
+ MODEL_NAME = "Qwen/Qwen3-0.6B"
209
+ # MODEL_NAME = "microsoft/Phi-4-mini-instruct"
210
+ #MODEL_NAME = "meta-llama/Llama-2-7b-hf"
211
+ #MODEL_NAME = input(f"Enter hugging face model name or press enter to default to [{MODEL_NAME}]: ") or MODEL_NAME
212
+ DEVICE = "cpu"
213
+ TORCH_DTYPE = torch.float32
214
+ DEPTH_RANGE = 1
215
+ # Ensure INPUT_FILE path is correct for your environment
216
+ # INPUT_FILE = 'feed.txt' # Assuming it's in the same directory or provide full path
217
+ # Create the input file if it doesn't exist for testing
218
+ if not os.path.exists(INPUT_FILE):
219
+ print(f"Warning: Input file '{INPUT_FILE}' not found. Creating a dummy file.")
220
+ with open(INPUT_FILE, 'w', encoding='utf-8') as f:
221
+ f.write("The quick brown fox jumps over the lazy dog")
222
+
223
+ # OUTPUT_FILE = "output.lif" # Changed output filename
224
+ MODEL_CONTEXT_WINDOW = 128_000 # Example context window, adjust if needed for the actual model
225
+ SAFETY_THRESHOLD = 2_000
226
+ MAX_INPUT_TOKENS = MODEL_CONTEXT_WINDOW - SAFETY_THRESHOLD # Max tokens per model *input slice*
227
+
228
+ # %% Load and Quantize Model & Tokenizer
229
+ print("Step 1: Loading model...")
230
+ # Add trust_remote_code=True if necessary for the specific model architecture
231
+ model = AutoModelForCausalLM.from_pretrained(
232
+ MODEL_NAME,
233
+ torch_dtype=TORCH_DTYPE,
234
+ trust_remote_code=True, # Often needed for Qwen-based models
235
+ token=hf_token
236
+ ).to(DEVICE)
237
+ print(f" Model loaded to {DEVICE}.")
238
+
239
+ # print("Step 2: Applying dynamic quantization for faster CPU inference...")
240
+ # Note: Quantization might slightly affect raw logit values compared to fp32/fp16
241
+ # model = torch.quantization.quantize_dynamic(
242
+ # model,
243
+ # {torch.nn.Linear},
244
+ # dtype=torch.qint8
245
+ # )
246
+ hook_handle = model.lm_head.register_forward_hook(capture_logits_hook)
247
+
248
+ ##KWR: NEW
249
+ #model.generation_config.temperature=0
250
+ #model.generation_config.top_p=1.0
251
+
252
+ model.eval()
253
+ print(" Quantization complete. Model is ready for inference.\n")
254
+
255
+ print("Step 3: Loading tokenizer...")
256
+ # Add trust_remote_code=True if necessary for the specific model architecture
257
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True, token=hf_token)
258
+ if tokenizer.pad_token is None:
259
+ print(" Tokenizer missing pad token; setting pad_token = eos_token")
260
+ tokenizer.pad_token = tokenizer.eos_token
261
+ # Important: Ensure model config also reflects this if needed by generation args
262
+ if hasattr(model, 'config'):
263
+ model.config.pad_token_id = tokenizer.eos_token_id
264
+ print(" Tokenizer loaded and configured.\n")
265
+
266
+ # %% User Inputs
267
+ print("Step 4: Prompting user for inputs...")
268
+ # Use default values for easier testing
269
+ promptID = PROMPT_ID # input(" Enter Prompt ID [Default: VanityTestGreedy]: ") or "VanityTestGreedy"
270
+ # MultiPV_str = input(" Enter MultiPV (top logits to show) [Default: 5]: ") or "5"
271
+ MultiPV = MULTI_PV # int(MultiPV_str) # Now only controls how many top logits to display
272
+ # LegalNumberOfMove_str = input(" Enter Max Number of moves [Default: 10]: ") or "10"
273
+ LegalNumberOfMove = NUM_WORDS # int(LegalNumberOfMove_str)
274
+ # EngineID = f"DeepSeek R1 1.5B Qwen-Distil Greedy ({DEVICE.upper()})" # Updated EngineID
275
+ # EngineID = f"Qwen/Qwen3-1.7B"
276
+ EngineID = f"Qwen/Qwen3-0.6B"
277
+ # EngineID = f"Gemma-3-4b-it ({DEVICE.upper()})" # Indicate CPU in EngineID
278
+ Depth = 1
279
+ print(" User inputs captured.\n")
280
+
281
+ # %% Pre-tokenize entire relevant input sequence
282
+ print("Step 5: Pre-tokenizing input sequence...")
283
+ initial_prompt = "Complete successive parts of a sentence given one word at a time:"
284
+ initial_prompt_ids = tokenizer.encode(initial_prompt, add_special_tokens=False)
285
+
286
+ print(f" Reading words from {INPUT_FILE}...")
287
+ lines = []
288
+ try:
289
+ with open(INPUT_FILE, 'r', encoding='utf-8') as f:
290
+ # words_from_file = f.read().split()
291
+ lines = f.readlines()
292
+ words_from_file = "".join(line.replace('\n', '') for line in lines)
293
+ wordList = re.split(r'([a-zA-Z]+|\d+)', words_from_file)
294
+ wordList = [x for x in wordList if x != ' ' and x != '']
295
+ # print("The words are:\n", words_from_file)
296
+
297
+ numChars = 0
298
+ numTextTokens = len(wordList)
299
+ for word in wordList:
300
+ numChars += len(word)
301
+ avgTokenLength = round(numChars/numTextTokens, 4)
302
+ print(f"\nFound {numTextTokens} text word/tokens with average length {avgTokenLength}.\n")
303
+
304
+ except FileNotFoundError:
305
+ print(f"Error: Input file '{INPUT_FILE}' not found. Exiting.")
306
+ exit()
307
+
308
+ all_tokens = list(initial_prompt_ids)
309
+ word_end_indices = [len(initial_prompt_ids)] # Index *after* the last token of each word (or initial prompt)
310
+ processed_words = [] # Store the actual words processed
311
+
312
+ print(" Tokenizing words and building full sequence...")
313
+ for word in wordList:
314
+ word_tokens = tokenizer.encode(" " + word, add_special_tokens=False)
315
+ all_tokens.extend(word_tokens)
316
+ word_end_indices.append(len(all_tokens))
317
+ processed_words.append(word)
318
+
319
+ full_token_tensor = torch.tensor(all_tokens, dtype=torch.long).unsqueeze(0)
320
+ print(f" Pre-tokenized {len(processed_words)} words into a sequence of {len(all_tokens)} tokens.\n")
321
+
322
+ num_words_to_process = min(len(processed_words), LegalNumberOfMove) - (START_TURN - 1)
323
+ if num_words_to_process < len(processed_words) - (START_TURN - 1):
324
+ print(f" Will process the first {num_words_to_process} words due to NUM_WORDS limit.\n")
325
+ elif num_words_to_process == 0:
326
+ print(" Warning: No words to process based on input file or limits.\n")
327
+
328
+ # %% Build file header
329
+ print("Step 8: Preparing output file header...")
330
+ header_lines = [
331
+ f'[PromptID "{promptID}"]\n',
332
+ f'[EngineID "{EngineID}"]\n',
333
+ f'[MultiPV "{MultiPV}"]\n',
334
+ f'[DepthRange "1:1"]\n\n',
335
+ ] + lines + [f'\n\n']
336
+ print(" Header prepared.\n")
337
+
338
+ # %% Main Generation Loop (Using Slicing & Greedy Decoding)
339
+ print("Step 9: Entering main generation loop (using pre-tokenized slicing and greedy decoding)...\n")
340
+ PrevEval = "n.a."
341
+ start_time = time.time()
342
+ current_time = start_time
343
+ numMatchedWords = 0
344
+ numMatchedChars = 0
345
+
346
+ if num_words_to_process > 0:
347
+ if (START_TURN > 1):
348
+ OUTPUT_FILE = OUTPUT_FILE.split('.')[0]+"from"+str(START_TURN)+".lif"
349
+ with open(OUTPUT_FILE, 'w', encoding='utf-8') as writer:
350
+ print(" Writing header to output file...")
351
+ writer.write(''.join(header_lines))
352
+ print(" Header written. Starting word-by-word prediction.\n")
353
+
354
+ for turnCount in range(START_TURN, START_TURN + num_words_to_process):
355
+ current_word = processed_words[turnCount - 1].strip()
356
+ # print(f"Turn {turnCount}: Predicting after word '{current_word}'")
357
+
358
+ slice_end_index = word_end_indices[turnCount - 1]
359
+ slice_start_index = max(0, slice_end_index - MAX_INPUT_TOKENS)
360
+ # print(f" 9.1/9.2: Context slice indices: [{slice_start_index}:{slice_end_index}]")
361
+
362
+ input_tensor = full_token_tensor[:, slice_start_index:slice_end_index]
363
+ current_input_len = input_tensor.shape[1]
364
+ # print(f" 9.3: Sliced input tensor shape: {input_tensor.shape}")
365
+
366
+ input_tensor_dev = input_tensor.to(DEVICE)
367
+
368
+ start_time_gen = time.time()
369
+ # 9.4 Generate next token using GREEDY DECODING
370
+ # print(f" 9.4: Running model.generate() with {current_input_len} input tokens (Greedy Decoding)...")
371
+ with torch.no_grad():
372
+ outputs = model.generate(
373
+ input_tensor_dev,
374
+ max_new_tokens=2,
375
+ min_new_tokens=2, # Explicitly require 1 new token
376
+ output_scores=True, # Get logits
377
+ return_dict_in_generate=True, # Get dict output
378
+ do_sample=False, # Disable sampling -> Use Greedy Decoding
379
+ pad_token_id=tokenizer.pad_token_id,
380
+ num_beams=BEAM_WIDTH,
381
+ num_return_sequences=BEAM_WIDTH,
382
+ # Removed num_beams and num_return_sequences
383
+ temperature=None,
384
+ top_k=None,
385
+ top_p=None,
386
+ #num_return_sequences=3
387
+ )
388
+ end_time_gen = time.time()
389
+ gen_duration = end_time_gen - start_time_gen
390
+ # print(f" Model generation took: {gen_duration:.4f} seconds")
391
+
392
+ if (turnCount < START_TURN):
393
+ print("Skipping turn", turnCount)
394
+ turnCount += 1
395
+ continue
396
+
397
+ # ----- UPDATED LOGIC for TopK Logits (Greedy Path) -----
398
+ # outputs.scores is a tuple of length max_new_tokens (1)
399
+ # Each element is a tensor of shape [batch_size, vocab_size] (batch_size is 1 here)
400
+ logits_for_step = outputs.scores[
401
+ 0] # Logits for the single generated token step. Shape: [1, vocab_size]
402
+
403
+ # Get the logits from the single batch item (greedy path)
404
+ logits_for_greedy_path = logits_for_step[0] # Shape: [vocab_size]
405
+
406
+ # Get the top K (MultiPV) logits and their corresponding token IDs
407
+ # Note: The highest logit corresponds to the token chosen by greedy decoding
408
+ top_k_logits_values, top_k_logits_indices = torch.topk(
409
+ logits_for_greedy_path, k=MultiPV, dim=-1
410
+ )
411
+
412
+ # Convert results to lists
413
+ top_k_logits_values = top_k_logits_values.tolist()
414
+ top_k_logits_indices = top_k_logits_indices.tolist()
415
+
416
+ # Decode the top K tokens based on logits
417
+ top_k_tokens = [tokenizer.decode(tid) for tid in top_k_logits_indices]
418
+ """
419
+ print(f"Top {MultiPV} Logits from greedy path (Token | Logit Value):")
420
+ for i in range(MultiPV):
421
+ token_str_cleaned = top_k_tokens[i].strip()
422
+ print(f" - '{token_str_cleaned}': {top_k_logits_values[i]:.4f} (ID: {top_k_logits_indices[i]})")
423
+ """
424
+ # The token actually generated by greedy decoding
425
+ greedy_selected_token_id = outputs.sequences[0, -1].item() # Last token in the sequence
426
+ greedy_selected_token_str = tokenizer.decode(greedy_selected_token_id).strip()
427
+ # This will always match top_k_tokens[0] because do_sample=False
428
+ # print(f" (Greedy search selected token: '{greedy_selected_token_str}' ID: {greedy_selected_token_id})") # Optional confirmation
429
+ # ----- END of UPDATED LOGIC -----
430
+
431
+ # Derive metrics
432
+ modelToken = reprat(top_k_tokens[0]) # Equivalent to greedy_selected_token_str
433
+ #modelToken = modelToken.replace('@','(at)')
434
+ #modelToken = f"@{modelToken[1:-1]}@"
435
+ # modelEval is the highest logit value
436
+ modelEval = round(top_k_logits_values[0], 4)
437
+ # modelEval = round(float(modelEval)*100)
438
+ # NextEval = (f"{top_k_logits_values[1]:.4f}" if MultiPV > 1 else "n.a.")
439
+ # NextEval = round(float(NextEval)*100) if MultiPV > 1 and isinstance(top_k_logits_values[1], float) else "n.a."
440
+
441
+ print("Turn ", turnCount, " now matching text word ", current_word, " ...", end='', sep='')
442
+
443
+ topNUMTvals, topNUMTindices = torch.topk(logits_for_greedy_path, k=NUM_TOKENS, dim=-1)
444
+ topNUMTvalList = topNUMTvals.tolist()
445
+ topNUMTindList = topNUMTindices.tolist()
446
+ topNUMTtokens = [reprat(tokenizer.decode(tind)) for tind in topNUMTindList]
447
+ matchingTextToken = "@@"
448
+
449
+ textTokenIndex = 0
450
+ textTokenValue = 0
451
+ for tok in topNUMTtokens:
452
+ # if tok.find(current_word) != -1:
453
+ if current_word.find("Joki") >= 0 and tok.find("J") >= 0:
454
+ print("Why doesn't", current_word, "match", tok, "at index", textTokenIndex, "?")
455
+ if borderedMatch(current_word, tok, True, True):
456
+ matchingTextToken = tok #f"@{tok.replace('@','(at)')[1:-1]}@"
457
+ textTokenValue = topNUMTvalList[textTokenIndex]
458
+ if math.isinf(textTokenValue) and textTokenValue < 0.0:
459
+ textTokenValue = MINUS_INF
460
+ else:
461
+ textTokenValue = round(textTokenValue,4)
462
+ if textTokenIndex == 0:
463
+ print("***matches top model token", modelToken, "with score ", textTokenValue)
464
+ numMatchedWords += 1
465
+ numMatchedChars += len(current_word)
466
+ else:
467
+ print("found at index", textTokenIndex, "in token", matchingTextToken, "with score ", textTokenValue, "; top is ", modelToken, modelEval)
468
+ break
469
+ textTokenIndex += 1
470
+
471
+ if textTokenIndex >= NUM_TOKENS:
472
+ textTokenValue = round(topNUMTvalList[-1], 4)
473
+ print("not found, using bottom score", textTokenValue)
474
+
475
+
476
+ NextEval = textTokenValue
477
+
478
+
479
+
480
+ # print(
481
+ # f" 9.5: Top token (greedy choice): '{modelToken}' (Evalution: {modelEval})|Logit value : {top_k_logits_values[0]:.4f}| Next best Eval: {NextEval} | Logit ")
482
+
483
+ # Build lines for this turn
484
+ current_stem = initial_prompt + " " + " ".join(processed_words[:turnCount])
485
+ lines = [
486
+ f'[PID "{promptID}"]\n',
487
+ f'[EID "{MODEL_NAME}"]\n',
488
+ f'[Turn "{turnCount}-w"]\n',
489
+ f'[TextToken "@{current_word}@"]\n',
490
+ f'[ModelToken "{modelToken}"]\n', # The model's greedy prediction
491
+ f'[TextTokenIndex "{textTokenIndex}"]\n'
492
+ f'[TextTokenValue "{textTokenValue}"]\n'
493
+ f'[Eval "{modelEval}"]\n', # The highest raw logit value
494
+ f'[PrevEval "{PrevEval}"]\n',
495
+ f'[NextEval "{NextEval}"]\n', # The second highest raw logit value
496
+ f'[Depth "{Depth}"]\n',
497
+ f'[STEM "{current_stem}"]\n',
498
+ f'[NumLegalMoves "{MultiPV}"]\n',
499
+ "---------------\n",
500
+ f"{DEPTH_RANGE}\n",
501
+ "---------------\n"
502
+ ]
503
+ for token_str, logit_val in zip(top_k_tokens, top_k_logits_values):
504
+ rep = reprat(token_str) #.replace('@', '(at)') # has ' ' or " " around it
505
+ # rep = f"@{rep[1:-1]}@" # now has @ ... @ around it
506
+ lines.append(f"{rep} {logit_val:.4f}\n")
507
+
508
+ lines.append(
509
+ "===========================================================================================================\n")
510
+ lines.append(f"[Comments]\n")
511
+ lines.append(f"[EndMove]\n\n")
512
+
513
+ # print(" Lines built.")
514
+
515
+ # 9.7 Write to file
516
+ # print(" 9.7: Writing lines to output file...")
517
+ writer.write(''.join(lines))
518
+ # print(" Write complete.\n")
519
+
520
+ # 9.8 Update state
521
+ PrevEval = modelEval
522
+
523
+ # 9.9 Status update
524
+ status_interval = min(100, num_words_to_process // 2 if num_words_to_process >= 10 else 10)
525
+ if turnCount % status_interval == 0 or turnCount == num_words_to_process:
526
+ last_time = current_time
527
+ current_time = time.time()
528
+ elapsed = current_time - start_time
529
+ elapsedLast = current_time - last_time
530
+ rate = (turnCount - 1) / elapsed if elapsed > 0 else 0
531
+ rateLast = 100.0 / elapsedLast if elapsedLast > 0 else 0
532
+ print()
533
+ print(f"Processed Turn {turnCount}. Rate: {rate:.2f} words/sec., last 100 rate: {rateLast:.2f}")
534
+
535
+ #end-for
536
+ averageCharsMatched = 0 if numMatchedWords == 0 else round(numMatchedChars/numMatchedWords, 4)
537
+ matchPercent = 0.0 if numTextTokens == 0 else round(100.0*numMatchedWords/numTextTokens, 2)
538
+ matchPercentStr = f"({matchPercent}%)"
539
+ print("Done: matched", numMatchedWords, matchPercentStr, "tokens of average length", averageCharsMatched)
540
+ print("from", numTextTokens, "tokens of average length", avgTokenLength)
541
+
542
+ else:
543
+ print("Skipping main generation loop as there are no words to process.")
544
+
545
+ hook_handle.remove()
546
+ print("Removed forward hook.")
547
+ # %% Final Stats
548
+ print("Step 10: Reporting final statistics...")
549
+ total_time = time.time() - start_time
550
+ avg_rate = (num_words_to_process / total_time) if total_time > 0 and num_words_to_process > 0 else 0
551
+ print(f" Total turns processed: {num_words_to_process}")
552
+ print(f" Total time: {total_time:.2f} seconds")
553
+ print(f" Average speed: {avg_rate:.2f} words/second")
554
+ print(f" Output written to {OUTPUT_FILE}")
555
+
556
+ # Optional: Clean up memory
557
+ print("\nCleaning up resources...")
558
+ del model
559
+ del tokenizer
560
+ del full_token_tensor
561
+ if 'outputs' in locals():
562
+ del outputs
563
+ if 'input_tensor' in locals():
564
+ del input_tensor
565
+ if 'input_tensor_dev' in locals():
566
+ del input_tensor_dev
567
+ gc.collect()
568
+ if DEVICE == 'cuda':
569
+ print("Emptying CUDA cache...")
570
+ torch.cuda.empty_cache()
571
+ print("\nScript finished.")
572
+
573
+
574
+ ### RUN MAIN ####
575
+
576
+ main(INPUT_FILE, OUTPUT_FILE, PROMPT_ID, PROMPT_TOPIC, MULTI_PV, NUM_WORDS, NUM_TOKENS, BEAM_WIDTH, ALPHA_MODE,
577
+ MODEL_TAG)
578
+