KWRegan commited on
Commit
db98505
·
1 Parent(s): 1a0df7f
Files changed (1) hide show
  1. app2.py +589 -0
app2.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% File "lifgen-hook.py" by Gishnu Madhu with editing and feature input by KWR
2
+ # Usage:
3
+ # python lifgen-hook.py -i <source-file> -o <output.lif> -pid <identifier> -pt <initial prompt> -mpv <# of scored items>
4
+ # optional: -nt <# of tokens to search for text word> -bw <width of beam search, 1 for greedy>, -nw <# of words>
5
+ # -st <word/token to start from, 1-based> -a <mode in 0...4 of treating tokens as in-bounds or matching>
6
+ # -model <model to use> #not yet implemented---change model manually for now.
7
+ #
8
+ # Qwen requires a HuggingFace token, needs, "pip install transformers --upgrade" and does a new 4.2GB download
9
+ # DeepSeek does not need an access token---just hit enter to give the empty string
10
+ #
11
+ # Example:
12
+ # python lifgen-hook.py -i YaoJokic.txt -o YaoJokicm50.lif -pid YaoTestByDeepSeek -pt "Compare Yao Ming and Nikola Jokic in NBA basketball" -mpv 50
13
+
14
+ import gradio as gr
15
+
16
+
17
+
18
+
19
+
20
+
21
+ import math
22
+
23
+ import torch
24
+ import gc
25
+ import time
26
+ from transformers import AutoTokenizer, AutoModelForCausalLM
27
+ import os
28
+ import argparse
29
+ import sys
30
+ import re
31
+ from huggingface_hub import login
32
+ import os
33
+ from unidecode import unidecode
34
+
35
+ def capture_logits_hook(module, input_args, output):
36
+ """
37
+ Hook function to capture the output of the lm_head layer.
38
+ The output might be a tensor or a tuple containing the tensor.
39
+ We are interested in the tensor containing logits.
40
+ """
41
+ if isinstance(output, torch.Tensor):
42
+ logits = output
43
+ elif isinstance(output, tuple) and len(output) > 0 and isinstance(output[0], torch.Tensor):
44
+ # Common case for models returning more than just logits (e.g., past_key_values)
45
+ # We assume the first element is the logits tensor. Check model docs if unsure.
46
+ logits = output[0]
47
+ else:
48
+ # Cannot determine logits tensor, skip capture for this call
49
+ print(f"Warning: Hook captured unexpected output type: {type(output)}")
50
+ return
51
+
52
+ parser = argparse.ArgumentParser(
53
+ description="LifGenerator for CPU with Hugging face models with greedy decoding",
54
+ epilog="Help Documentation"
55
+ )
56
+ parser.add_argument(
57
+ "-input_file", "-i",
58
+ type=str,
59
+ help="The path to the input file."
60
+ )
61
+
62
+ parser.add_argument(
63
+ "-output_file", "-o",
64
+ type=str,
65
+ help="Name and path of output file"
66
+ )
67
+
68
+ parser.add_argument(
69
+ "-prompt_id", "-pid",
70
+ type=str,
71
+ help="Overall name of item"
72
+ )
73
+
74
+ parser.add_argument(
75
+ "-prompt_topic", "-pt",
76
+ type=str,
77
+ help="Topic given to LLM before stem words"
78
+ )
79
+
80
+ parser.add_argument(
81
+ "-multi_pv", "-mpv",
82
+ type=int,
83
+ help="Number of options to consider at each turn"
84
+ )
85
+
86
+ parser.add_argument(
87
+ "-num_words", "-nw",
88
+ type=int,
89
+ help="Cap on # of text words to iterate"
90
+ )
91
+
92
+ parser.add_argument(
93
+ "-num_tokens", "-nt",
94
+ type=int,
95
+ help="# of tokens to search for text word match"
96
+ )
97
+
98
+ parser.add_argument(
99
+ "-beam_width", "-bw",
100
+ type=int,
101
+ help="Width of beam search, 0 or 1 for greedy"
102
+ )
103
+
104
+ parser.add_argument(
105
+ "-alpha_mode", "-a",
106
+ type=int,
107
+ help="0 = all tokens, up thru 4 = alpha chars plus ' only"
108
+ )
109
+
110
+ parser.add_argument(
111
+ "-start_turn", "-st",
112
+ type=int,
113
+ help="1 by default, adds st-1 words to prompt"
114
+ )
115
+
116
+ parser.add_argument(
117
+ "-model", "-model",
118
+ type=str,
119
+ help="DS for DeepSeek, QWEN for Qwen"
120
+ )
121
+
122
+ args = parser.parse_args()
123
+ print("Welcome to the LifGenerator CPU script!")
124
+ print("This script generates lif files using a Hugging Face model and greedy decoding.")
125
+ print(f"Input file path: {args.input_file}")
126
+ print(f"Output file path: {args.output_file}")
127
+ INPUT_FILE = args.input_file if args.input_file else "Kangaroos.txt"
128
+ INPUT_FILE_STEM = INPUT_FILE.split('.')[0]
129
+ OUTPUT_FILE = args.output_file if args.output_file else (INPUT_FILE_STEM + ".lif")
130
+ PROMPT_ID = args.prompt_id if args.prompt_id else INPUT_FILE
131
+ PROMPT_TOPIC = args.prompt_topic if args.prompt_topic else INPUT_FILE
132
+ MULTI_PV = args.multi_pv if args.multi_pv else 100
133
+ NUM_WORDS = args.num_words if args.num_words else 10000
134
+ NUM_TOKENS = args.num_tokens if args.num_tokens else 10000
135
+ BEAM_WIDTH = args.beam_width if args.beam_width else 1
136
+ ALPHA_MODE = args.alpha_mode if args.alpha_mode else 0
137
+ START_TURN = args.start_turn if args.start_turn else 1
138
+ MODEL_TAG = args.model if args.model else "Qwen"
139
+ MINUS_INF = -1000.0
140
+ # main(INPUT_FILE, OUTPUT_FILE, PROMPT_ID, PROMPT_TOPIC, MULTI_PV, NUM_WORDS, NUM_TOKENS, BEAM_WIDTH, ALPHA_MODE, MODEL_TAG)
141
+
142
+ """
143
+ Match if arg occurs in st surrounded by ends or non-alpha chars.
144
+
145
+ Intent is e.g. for "Karp" to match "Karp, R" but not "Karpov".
146
+ Whether "Karp" matches "Karp-Lipton" depends on whether hyphen is part of name.
147
+ Works even if arg itself has non-alpha characters.
148
+ Used for player and event names AND to identify tokens in command streams.
149
+ Uses C++ "isalpha" for local definition of names.
150
+ Prefer to override it to count underscore as a non-delimiting char.
151
+ Hyphen is always part of tokens but can be used to delimit place and person names,
152
+ so "Khanty" and "Khanty-Mansiysk" can both match "Khanty-Mansiysk" and
153
+ "Vachier" can match "Vachier-Lagrave".
154
+
155
+ With LLM tokens, this allows arg="abc" to match st=" abc" but not vice-versa.
156
+ However, if called with arg.strip() then vice-versa is fine.
157
+ If the token is @-@ then it will match "--" but NOT match a hyphenated word.
158
+ """
159
+
160
+
161
+ def borderedMatch(arg, st, hyphenDelimits=False, underscoreDelimits=False):
162
+ fromPos = st.find(arg)
163
+ while fromPos != -1:
164
+ leftOK = (fromPos == 0)
165
+ if (fromPos > 0):
166
+ c = st[fromPos - 1]
167
+ if c == '-':
168
+ leftOK = hyphenDelimits
169
+ elif c == '_':
170
+ leftOK = underscoreDelimits
171
+ else:
172
+ leftOK = (not c.isalnum())
173
+
174
+ rightEdge = fromPos + len(arg)
175
+ rightOK = (rightEdge == len(st))
176
+ if (not rightOK):
177
+ d = st[rightEdge]
178
+ if d == '-':
179
+ rightOK = hyphenDelimits
180
+ elif d == '_':
181
+ rightOK = underscoreDelimits
182
+ else:
183
+ rightOK = (not d.isalnum())
184
+
185
+ if rightOK and leftOK:
186
+ return True
187
+ else: # try to find another match
188
+ fromPos = st.find(arg, fromPos + 1)
189
+
190
+ return False
191
+
192
+
193
+ def reprat(tok):
194
+ rep = unidecode(repr(tok))
195
+ return f"@{rep.replace('@','(at)')[1:-1]}@"
196
+
197
+
198
+
199
+ hf_token = input("Enter your Huggingface token")
200
+ # Or better:
201
+ # hf_token = os.environ.get("HUGGING_FACE_HUB_TOKEN")
202
+
203
+ if hf_token:
204
+ print("Logging in to Hugging Face Hub...")
205
+ login(token=hf_token)
206
+ else:
207
+ print("HF Token not found. Gated model download might fail.")
208
+
209
+
210
+ def main(INPUT_FILE, OUTPUT_FILE, PROMPT_ID, PROMPT_TOPIC, MULTI_PV, NUM_WORDS, NUM_TOKENS, BEAM_WIDTH, ALPHA_MODE,
211
+ MODEL_TAG):
212
+ # %% Constants and Configuration
213
+ # MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
214
+ # MODEL_NAME = "google/gemma-3-4b-it"
215
+ MODEL_NAME = "Qwen/Qwen3-1.7B"
216
+ # MODEL_NAME = "Qwen/Qwen3-0.6B"
217
+ # MODEL_NAME = "microsoft/Phi-4-mini-instruct"
218
+ #MODEL_NAME = "meta-llama/Llama-2-7b-hf"
219
+ #MODEL_NAME = input(f"Enter hugging face model name or press enter to default to [{MODEL_NAME}]: ") or MODEL_NAME
220
+ DEVICE = "cpu"
221
+ TORCH_DTYPE = torch.float32
222
+ DEPTH_RANGE = 1
223
+ # Ensure INPUT_FILE path is correct for your environment
224
+ # INPUT_FILE = 'feed.txt' # Assuming it's in the same directory or provide full path
225
+ # Create the input file if it doesn't exist for testing
226
+ if not os.path.exists(INPUT_FILE):
227
+ print(f"Warning: Input file '{INPUT_FILE}' not found. Creating a dummy file.")
228
+ with open(INPUT_FILE, 'w', encoding='utf-8') as f:
229
+ f.write("The quick brown fox jumps over the lazy dog")
230
+
231
+ # OUTPUT_FILE = "output.lif" # Changed output filename
232
+ MODEL_CONTEXT_WINDOW = 128_000 # Example context window, adjust if needed for the actual model
233
+ SAFETY_THRESHOLD = 2_000
234
+ MAX_INPUT_TOKENS = MODEL_CONTEXT_WINDOW - SAFETY_THRESHOLD # Max tokens per model *input slice*
235
+
236
+ # %% Load and Quantize Model & Tokenizer
237
+ print("Step 1: Loading model...")
238
+ # Add trust_remote_code=True if necessary for the specific model architecture
239
+ model = AutoModelForCausalLM.from_pretrained(
240
+ MODEL_NAME,
241
+ torch_dtype=TORCH_DTYPE,
242
+ trust_remote_code=True, # Often needed for Qwen-based models
243
+ token=hf_token
244
+ ).to(DEVICE)
245
+ print(f" Model loaded to {DEVICE}.")
246
+
247
+ # print("Step 2: Applying dynamic quantization for faster CPU inference...")
248
+ # Note: Quantization might slightly affect raw logit values compared to fp32/fp16
249
+ # model = torch.quantization.quantize_dynamic(
250
+ # model,
251
+ # {torch.nn.Linear},
252
+ # dtype=torch.qint8
253
+ # )
254
+ hook_handle = model.lm_head.register_forward_hook(capture_logits_hook)
255
+
256
+ ##KWR: NEW
257
+ #model.generation_config.temperature=0
258
+ #model.generation_config.top_p=1.0
259
+
260
+ model.eval()
261
+ print(" Quantization complete. Model is ready for inference.\n")
262
+
263
+ print("Step 3: Loading tokenizer...")
264
+ # Add trust_remote_code=True if necessary for the specific model architecture
265
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True, token=hf_token)
266
+ if tokenizer.pad_token is None:
267
+ print(" Tokenizer missing pad token; setting pad_token = eos_token")
268
+ tokenizer.pad_token = tokenizer.eos_token
269
+ # Important: Ensure model config also reflects this if needed by generation args
270
+ if hasattr(model, 'config'):
271
+ model.config.pad_token_id = tokenizer.eos_token_id
272
+ print(" Tokenizer loaded and configured.\n")
273
+
274
+ # %% User Inputs
275
+ print("Step 4: Prompting user for inputs...")
276
+ # Use default values for easier testing
277
+ promptID = PROMPT_ID # input(" Enter Prompt ID [Default: VanityTestGreedy]: ") or "VanityTestGreedy"
278
+ # MultiPV_str = input(" Enter MultiPV (top logits to show) [Default: 5]: ") or "5"
279
+ MultiPV = MULTI_PV # int(MultiPV_str) # Now only controls how many top logits to display
280
+ # LegalNumberOfMove_str = input(" Enter Max Number of moves [Default: 10]: ") or "10"
281
+ LegalNumberOfMove = NUM_WORDS # int(LegalNumberOfMove_str)
282
+ # EngineID = f"DeepSeek R1 1.5B Qwen-Distil Greedy ({DEVICE.upper()})" # Updated EngineID
283
+ EngineID = f"Qwen/Qwen3-1.7B"
284
+ # EngineID = f"Qwen/Qwen3-0.6B"
285
+ # EngineID = f"Gemma-3-4b-it ({DEVICE.upper()})" # Indicate CPU in EngineID
286
+ Depth = 1
287
+ print(" User inputs captured.\n")
288
+
289
+ # %% Pre-tokenize entire relevant input sequence
290
+ print("Step 5: Pre-tokenizing input sequence...")
291
+ initial_prompt = "Complete successive parts of a sentence given one word at a time:"
292
+ initial_prompt_ids = tokenizer.encode(initial_prompt, add_special_tokens=False)
293
+
294
+ print(f" Reading words from {INPUT_FILE}...")
295
+ lines = []
296
+ try:
297
+ with open(INPUT_FILE, 'r', encoding='utf-8') as f:
298
+ # words_from_file = f.read().split()
299
+ lines = f.readlines()
300
+ words_from_file = "".join(line.replace('\n', '') for line in lines)
301
+ wordList = re.split(r'([a-zA-Z]+|\d+)', words_from_file)
302
+ wordList = [x for x in wordList if x != ' ' and x != '']
303
+ # print("The words are:\n", words_from_file)
304
+
305
+ numChars = 0
306
+ numTextTokens = len(wordList)
307
+ for word in wordList:
308
+ numChars += len(word)
309
+ avgTokenLength = round(numChars/numTextTokens, 4)
310
+ print(f"\nFound {numTextTokens} text word/tokens with average length {avgTokenLength}.\n")
311
+
312
+ except FileNotFoundError:
313
+ print(f"Error: Input file '{INPUT_FILE}' not found. Exiting.")
314
+ exit()
315
+
316
+ all_tokens = list(initial_prompt_ids)
317
+ word_end_indices = [len(initial_prompt_ids)] # Index *after* the last token of each word (or initial prompt)
318
+ processed_words = [] # Store the actual words processed
319
+
320
+ print(" Tokenizing words and building full sequence...")
321
+ for word in wordList:
322
+ word_tokens = tokenizer.encode(" " + word, add_special_tokens=False)
323
+ all_tokens.extend(word_tokens)
324
+ word_end_indices.append(len(all_tokens))
325
+ processed_words.append(word)
326
+
327
+ full_token_tensor = torch.tensor(all_tokens, dtype=torch.long).unsqueeze(0)
328
+ print(f" Pre-tokenized {len(processed_words)} words into a sequence of {len(all_tokens)} tokens.\n")
329
+
330
+ num_words_to_process = min(len(processed_words), LegalNumberOfMove) - (START_TURN - 1)
331
+ if num_words_to_process < len(processed_words) - (START_TURN - 1):
332
+ print(f" Will process the first {num_words_to_process} words due to NUM_WORDS limit.\n")
333
+ elif num_words_to_process == 0:
334
+ print(" Warning: No words to process based on input file or limits.\n")
335
+
336
+ # %% Build file header
337
+ print("Step 8: Preparing output file header...")
338
+ header_lines = [
339
+ f'[PromptID "{promptID}"]\n',
340
+ f'[EngineID "{EngineID}"]\n',
341
+ f'[MultiPV "{MultiPV}"]\n',
342
+ f'[DepthRange "1:1"]\n\n',
343
+ ] + lines + [f'\n\n']
344
+ print(" Header prepared.\n")
345
+
346
+ # %% Main Generation Loop (Using Slicing & Greedy Decoding)
347
+ print("Step 9: Entering main generation loop (using pre-tokenized slicing and greedy decoding)...\n")
348
+ PrevEval = "n.a."
349
+ start_time = time.time()
350
+ current_time = start_time
351
+ numMatchedWords = 0
352
+ numMatchedChars = 0
353
+
354
+ if num_words_to_process > 0:
355
+ if (START_TURN > 1):
356
+ OUTPUT_FILE = OUTPUT_FILE.split('.')[0]+"from"+str(START_TURN)+".lif"
357
+ with open(OUTPUT_FILE, 'w', encoding='utf-8') as writer:
358
+ print(" Writing header to output file...")
359
+ writer.write(''.join(header_lines))
360
+ print(" Header written. Starting word-by-word prediction.\n")
361
+
362
+ for turnCount in range(START_TURN, START_TURN + num_words_to_process):
363
+ current_word = processed_words[turnCount - 1].strip()
364
+ # print(f"Turn {turnCount}: Predicting after word '{current_word}'")
365
+
366
+ slice_end_index = word_end_indices[turnCount - 1]
367
+ slice_start_index = max(0, slice_end_index - MAX_INPUT_TOKENS)
368
+ # print(f" 9.1/9.2: Context slice indices: [{slice_start_index}:{slice_end_index}]")
369
+
370
+ input_tensor = full_token_tensor[:, slice_start_index:slice_end_index]
371
+ current_input_len = input_tensor.shape[1]
372
+ # print(f" 9.3: Sliced input tensor shape: {input_tensor.shape}")
373
+
374
+ input_tensor_dev = input_tensor.to(DEVICE)
375
+
376
+ start_time_gen = time.time()
377
+ # 9.4 Generate next token using GREEDY DECODING
378
+ # print(f" 9.4: Running model.generate() with {current_input_len} input tokens (Greedy Decoding)...")
379
+ with torch.no_grad():
380
+ outputs = model.generate(
381
+ input_tensor_dev,
382
+ max_new_tokens=2,
383
+ min_new_tokens=2, # Explicitly require 1 new token
384
+ output_scores=True, # Get logits
385
+ return_dict_in_generate=True, # Get dict output
386
+ do_sample=False, # Disable sampling -> Use Greedy Decoding
387
+ pad_token_id=tokenizer.pad_token_id,
388
+ num_beams=BEAM_WIDTH,
389
+ num_return_sequences=BEAM_WIDTH,
390
+ # Removed num_beams and num_return_sequences
391
+ temperature=None,
392
+ top_k=None,
393
+ top_p=None,
394
+ #num_return_sequences=3
395
+ )
396
+ end_time_gen = time.time()
397
+ gen_duration = end_time_gen - start_time_gen
398
+ # print(f" Model generation took: {gen_duration:.4f} seconds")
399
+
400
+ if (turnCount < START_TURN):
401
+ print("Skipping turn", turnCount)
402
+ turnCount += 1
403
+ continue
404
+
405
+ # ----- UPDATED LOGIC for TopK Logits (Greedy Path) -----
406
+ # outputs.scores is a tuple of length max_new_tokens (1)
407
+ # Each element is a tensor of shape [batch_size, vocab_size] (batch_size is 1 here)
408
+ logits_for_step = outputs.scores[
409
+ 0] # Logits for the single generated token step. Shape: [1, vocab_size]
410
+
411
+ # Get the logits from the single batch item (greedy path)
412
+ logits_for_greedy_path = logits_for_step[0] # Shape: [vocab_size]
413
+
414
+ # Get the top K (MultiPV) logits and their corresponding token IDs
415
+ # Note: The highest logit corresponds to the token chosen by greedy decoding
416
+ top_k_logits_values, top_k_logits_indices = torch.topk(
417
+ logits_for_greedy_path, k=MultiPV, dim=-1
418
+ )
419
+
420
+ # Convert results to lists
421
+ top_k_logits_values = top_k_logits_values.tolist()
422
+ top_k_logits_indices = top_k_logits_indices.tolist()
423
+
424
+ # Decode the top K tokens based on logits
425
+ top_k_tokens = [tokenizer.decode(tid) for tid in top_k_logits_indices]
426
+ """
427
+ print(f"Top {MultiPV} Logits from greedy path (Token | Logit Value):")
428
+ for i in range(MultiPV):
429
+ token_str_cleaned = top_k_tokens[i].strip()
430
+ print(f" - '{token_str_cleaned}': {top_k_logits_values[i]:.4f} (ID: {top_k_logits_indices[i]})")
431
+ """
432
+ # The token actually generated by greedy decoding
433
+ greedy_selected_token_id = outputs.sequences[0, -1].item() # Last token in the sequence
434
+ greedy_selected_token_str = tokenizer.decode(greedy_selected_token_id).strip()
435
+ # This will always match top_k_tokens[0] because do_sample=False
436
+ # print(f" (Greedy search selected token: '{greedy_selected_token_str}' ID: {greedy_selected_token_id})") # Optional confirmation
437
+ # ----- END of UPDATED LOGIC -----
438
+
439
+ # Derive metrics
440
+ modelToken = reprat(top_k_tokens[0]) # Equivalent to greedy_selected_token_str
441
+ #modelToken = modelToken.replace('@','(at)')
442
+ #modelToken = f"@{modelToken[1:-1]}@"
443
+ # modelEval is the highest logit value
444
+ modelEval = round(top_k_logits_values[0], 4)
445
+ # modelEval = round(float(modelEval)*100)
446
+ # NextEval = (f"{top_k_logits_values[1]:.4f}" if MultiPV > 1 else "n.a.")
447
+ # NextEval = round(float(NextEval)*100) if MultiPV > 1 and isinstance(top_k_logits_values[1], float) else "n.a."
448
+
449
+ print("Turn ", turnCount, " now matching text word ", current_word, " ...", end='', sep='')
450
+
451
+ topNUMTvals, topNUMTindices = torch.topk(logits_for_greedy_path, k=NUM_TOKENS, dim=-1)
452
+ topNUMTvalList = topNUMTvals.tolist()
453
+ topNUMTindList = topNUMTindices.tolist()
454
+ topNUMTtokens = [reprat(tokenizer.decode(tind)) for tind in topNUMTindList]
455
+ matchingTextToken = "@@"
456
+
457
+ textTokenIndex = 0
458
+ textTokenValue = 0
459
+ for tok in topNUMTtokens:
460
+ # if tok.find(current_word) != -1:
461
+ if current_word.find("Joki") >= 0 and tok.find("J") >= 0:
462
+ print("Why doesn't", current_word, "match", tok, "at index", textTokenIndex, "?")
463
+ if borderedMatch(current_word, tok, True, True):
464
+ matchingTextToken = tok #f"@{tok.replace('@','(at)')[1:-1]}@"
465
+ textTokenValue = topNUMTvalList[textTokenIndex]
466
+ if math.isinf(textTokenValue) and textTokenValue < 0.0:
467
+ textTokenValue = MINUS_INF
468
+ else:
469
+ textTokenValue = round(textTokenValue,4)
470
+ if textTokenIndex == 0:
471
+ print("***matches top model token", modelToken, "with score ", textTokenValue)
472
+ numMatchedWords += 1
473
+ numMatchedChars += len(current_word)
474
+ else:
475
+ print("found at index", textTokenIndex, "in token", matchingTextToken, "with score ", textTokenValue, "; top is ", modelToken, modelEval)
476
+ break
477
+ textTokenIndex += 1
478
+
479
+ if textTokenIndex >= NUM_TOKENS:
480
+ textTokenValue = round(topNUMTvalList[-1], 4)
481
+ print("not found, using bottom score", textTokenValue)
482
+
483
+
484
+ NextEval = textTokenValue
485
+
486
+
487
+
488
+ # print(
489
+ # f" 9.5: Top token (greedy choice): '{modelToken}' (Evalution: {modelEval})|Logit value : {top_k_logits_values[0]:.4f}| Next best Eval: {NextEval} | Logit ")
490
+
491
+ # Build lines for this turn
492
+ current_stem = initial_prompt + " " + " ".join(processed_words[:turnCount])
493
+ lines = [
494
+ f'[PID "{promptID}"]\n',
495
+ f'[EID "{MODEL_NAME}"]\n',
496
+ f'[Turn "{turnCount}-w"]\n',
497
+ f'[TextToken "@{current_word}@"]\n',
498
+ f'[ModelToken "{modelToken}"]\n', # The model's greedy prediction
499
+ f'[TextTokenIndex "{textTokenIndex}"]\n'
500
+ f'[TextTokenValue "{textTokenValue}"]\n'
501
+ f'[Eval "{modelEval}"]\n', # The highest raw logit value
502
+ f'[PrevEval "{PrevEval}"]\n',
503
+ f'[NextEval "{NextEval}"]\n', # The second highest raw logit value
504
+ f'[Depth "{Depth}"]\n',
505
+ f'[STEM "{current_stem}"]\n',
506
+ f'[NumLegalMoves "{MultiPV}"]\n',
507
+ "---------------\n",
508
+ f"{DEPTH_RANGE}\n",
509
+ "---------------\n"
510
+ ]
511
+ for token_str, logit_val in zip(top_k_tokens, top_k_logits_values):
512
+ rep = reprat(token_str) #.replace('@', '(at)') # has ' ' or " " around it
513
+ # rep = f"@{rep[1:-1]}@" # now has @ ... @ around it
514
+ lines.append(f"{rep} {logit_val:.4f}\n")
515
+
516
+ lines.append(
517
+ "===========================================================================================================\n")
518
+ lines.append(f"[Comments]\n")
519
+ lines.append(f"[EndMove]\n\n")
520
+
521
+ # print(" Lines built.")
522
+
523
+ # 9.7 Write to file
524
+ # print(" 9.7: Writing lines to output file...")
525
+ writer.write(''.join(lines))
526
+ # print(" Write complete.\n")
527
+
528
+ # 9.8 Update state
529
+ PrevEval = modelEval
530
+
531
+ # 9.9 Status update
532
+ status_interval = min(100, num_words_to_process // 2 if num_words_to_process >= 10 else 10)
533
+ if turnCount % status_interval == 0 or turnCount == num_words_to_process:
534
+ last_time = current_time
535
+ current_time = time.time()
536
+ elapsed = current_time - start_time
537
+ elapsedLast = current_time - last_time
538
+ rate = (turnCount - 1) / elapsed if elapsed > 0 else 0
539
+ rateLast = 100.0 / elapsedLast if elapsedLast > 0 else 0
540
+ print()
541
+ print(f"Processed Turn {turnCount}. Rate: {rate:.2f} words/sec., last 100 rate: {rateLast:.2f}")
542
+
543
+ #end-for
544
+ averageCharsMatched = 0 if numMatchedWords == 0 else round(numMatchedChars/numMatchedWords, 4)
545
+ matchPercent = 0.0 if numTextTokens == 0 else round(100.0*numMatchedWords/numTextTokens, 2)
546
+ matchPercentStr = f"({matchPercent}%)"
547
+ print("Done: matched", numMatchedWords, matchPercentStr, "tokens of average length", averageCharsMatched)
548
+ print("from", numTextTokens, "tokens of average length", avgTokenLength)
549
+
550
+ else:
551
+ print("Skipping main generation loop as there are no words to process.")
552
+
553
+ hook_handle.remove()
554
+ print("Removed forward hook.")
555
+ # %% Final Stats
556
+ print("Step 10: Reporting final statistics...")
557
+ total_time = time.time() - start_time
558
+ avg_rate = (num_words_to_process / total_time) if total_time > 0 and num_words_to_process > 0 else 0
559
+ print(f" Total turns processed: {num_words_to_process}")
560
+ print(f" Total time: {total_time:.2f} seconds")
561
+ print(f" Average speed: {avg_rate:.2f} words/second")
562
+ print(f" Output written to {OUTPUT_FILE}")
563
+
564
+ # Optional: Clean up memory
565
+ print("\nCleaning up resources...")
566
+ del model
567
+ del tokenizer
568
+ del full_token_tensor
569
+ if 'outputs' in locals():
570
+ del outputs
571
+ if 'input_tensor' in locals():
572
+ del input_tensor
573
+ if 'input_tensor_dev' in locals():
574
+ del input_tensor_dev
575
+ gc.collect()
576
+ if DEVICE == 'cuda':
577
+ print("Emptying CUDA cache...")
578
+ torch.cuda.empty_cache()
579
+ print("\nScript finished.")
580
+
581
+
582
+ ### RUN MAIN ####
583
+
584
+ main(INPUT_FILE, OUTPUT_FILE, PROMPT_ID, PROMPT_TOPIC, MULTI_PV, NUM_WORDS, NUM_TOKENS, BEAM_WIDTH, ALPHA_MODE,
585
+ MODEL_TAG)
586
+
587
+ demo = gr.Interface(fn=main, inputs="text", outputs="text")
588
+ demo.launch()
589
+