app2.py
Browse files
app2.py
ADDED
@@ -0,0 +1,589 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %% File "lifgen-hook.py" by Gishnu Madhu with editing and feature input by KWR
|
2 |
+
# Usage:
|
3 |
+
# python lifgen-hook.py -i <source-file> -o <output.lif> -pid <identifier> -pt <initial prompt> -mpv <# of scored items>
|
4 |
+
# optional: -nt <# of tokens to search for text word> -bw <width of beam search, 1 for greedy>, -nw <# of words>
|
5 |
+
# -st <word/token to start from, 1-based> -a <mode in 0...4 of treating tokens as in-bounds or matching>
|
6 |
+
# -model <model to use> #not yet implemented---change model manually for now.
|
7 |
+
#
|
8 |
+
# Qwen requires a HuggingFace token, needs, "pip install transformers --upgrade" and does a new 4.2GB download
|
9 |
+
# DeepSeek does not need an access token---just hit enter to give the empty string
|
10 |
+
#
|
11 |
+
# Example:
|
12 |
+
# python lifgen-hook.py -i YaoJokic.txt -o YaoJokicm50.lif -pid YaoTestByDeepSeek -pt "Compare Yao Ming and Nikola Jokic in NBA basketball" -mpv 50
|
13 |
+
|
14 |
+
import gradio as gr
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
import math
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import gc
|
25 |
+
import time
|
26 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
27 |
+
import os
|
28 |
+
import argparse
|
29 |
+
import sys
|
30 |
+
import re
|
31 |
+
from huggingface_hub import login
|
32 |
+
import os
|
33 |
+
from unidecode import unidecode
|
34 |
+
|
35 |
+
def capture_logits_hook(module, input_args, output):
|
36 |
+
"""
|
37 |
+
Hook function to capture the output of the lm_head layer.
|
38 |
+
The output might be a tensor or a tuple containing the tensor.
|
39 |
+
We are interested in the tensor containing logits.
|
40 |
+
"""
|
41 |
+
if isinstance(output, torch.Tensor):
|
42 |
+
logits = output
|
43 |
+
elif isinstance(output, tuple) and len(output) > 0 and isinstance(output[0], torch.Tensor):
|
44 |
+
# Common case for models returning more than just logits (e.g., past_key_values)
|
45 |
+
# We assume the first element is the logits tensor. Check model docs if unsure.
|
46 |
+
logits = output[0]
|
47 |
+
else:
|
48 |
+
# Cannot determine logits tensor, skip capture for this call
|
49 |
+
print(f"Warning: Hook captured unexpected output type: {type(output)}")
|
50 |
+
return
|
51 |
+
|
52 |
+
parser = argparse.ArgumentParser(
|
53 |
+
description="LifGenerator for CPU with Hugging face models with greedy decoding",
|
54 |
+
epilog="Help Documentation"
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"-input_file", "-i",
|
58 |
+
type=str,
|
59 |
+
help="The path to the input file."
|
60 |
+
)
|
61 |
+
|
62 |
+
parser.add_argument(
|
63 |
+
"-output_file", "-o",
|
64 |
+
type=str,
|
65 |
+
help="Name and path of output file"
|
66 |
+
)
|
67 |
+
|
68 |
+
parser.add_argument(
|
69 |
+
"-prompt_id", "-pid",
|
70 |
+
type=str,
|
71 |
+
help="Overall name of item"
|
72 |
+
)
|
73 |
+
|
74 |
+
parser.add_argument(
|
75 |
+
"-prompt_topic", "-pt",
|
76 |
+
type=str,
|
77 |
+
help="Topic given to LLM before stem words"
|
78 |
+
)
|
79 |
+
|
80 |
+
parser.add_argument(
|
81 |
+
"-multi_pv", "-mpv",
|
82 |
+
type=int,
|
83 |
+
help="Number of options to consider at each turn"
|
84 |
+
)
|
85 |
+
|
86 |
+
parser.add_argument(
|
87 |
+
"-num_words", "-nw",
|
88 |
+
type=int,
|
89 |
+
help="Cap on # of text words to iterate"
|
90 |
+
)
|
91 |
+
|
92 |
+
parser.add_argument(
|
93 |
+
"-num_tokens", "-nt",
|
94 |
+
type=int,
|
95 |
+
help="# of tokens to search for text word match"
|
96 |
+
)
|
97 |
+
|
98 |
+
parser.add_argument(
|
99 |
+
"-beam_width", "-bw",
|
100 |
+
type=int,
|
101 |
+
help="Width of beam search, 0 or 1 for greedy"
|
102 |
+
)
|
103 |
+
|
104 |
+
parser.add_argument(
|
105 |
+
"-alpha_mode", "-a",
|
106 |
+
type=int,
|
107 |
+
help="0 = all tokens, up thru 4 = alpha chars plus ' only"
|
108 |
+
)
|
109 |
+
|
110 |
+
parser.add_argument(
|
111 |
+
"-start_turn", "-st",
|
112 |
+
type=int,
|
113 |
+
help="1 by default, adds st-1 words to prompt"
|
114 |
+
)
|
115 |
+
|
116 |
+
parser.add_argument(
|
117 |
+
"-model", "-model",
|
118 |
+
type=str,
|
119 |
+
help="DS for DeepSeek, QWEN for Qwen"
|
120 |
+
)
|
121 |
+
|
122 |
+
args = parser.parse_args()
|
123 |
+
print("Welcome to the LifGenerator CPU script!")
|
124 |
+
print("This script generates lif files using a Hugging Face model and greedy decoding.")
|
125 |
+
print(f"Input file path: {args.input_file}")
|
126 |
+
print(f"Output file path: {args.output_file}")
|
127 |
+
INPUT_FILE = args.input_file if args.input_file else "Kangaroos.txt"
|
128 |
+
INPUT_FILE_STEM = INPUT_FILE.split('.')[0]
|
129 |
+
OUTPUT_FILE = args.output_file if args.output_file else (INPUT_FILE_STEM + ".lif")
|
130 |
+
PROMPT_ID = args.prompt_id if args.prompt_id else INPUT_FILE
|
131 |
+
PROMPT_TOPIC = args.prompt_topic if args.prompt_topic else INPUT_FILE
|
132 |
+
MULTI_PV = args.multi_pv if args.multi_pv else 100
|
133 |
+
NUM_WORDS = args.num_words if args.num_words else 10000
|
134 |
+
NUM_TOKENS = args.num_tokens if args.num_tokens else 10000
|
135 |
+
BEAM_WIDTH = args.beam_width if args.beam_width else 1
|
136 |
+
ALPHA_MODE = args.alpha_mode if args.alpha_mode else 0
|
137 |
+
START_TURN = args.start_turn if args.start_turn else 1
|
138 |
+
MODEL_TAG = args.model if args.model else "Qwen"
|
139 |
+
MINUS_INF = -1000.0
|
140 |
+
# main(INPUT_FILE, OUTPUT_FILE, PROMPT_ID, PROMPT_TOPIC, MULTI_PV, NUM_WORDS, NUM_TOKENS, BEAM_WIDTH, ALPHA_MODE, MODEL_TAG)
|
141 |
+
|
142 |
+
"""
|
143 |
+
Match if arg occurs in st surrounded by ends or non-alpha chars.
|
144 |
+
|
145 |
+
Intent is e.g. for "Karp" to match "Karp, R" but not "Karpov".
|
146 |
+
Whether "Karp" matches "Karp-Lipton" depends on whether hyphen is part of name.
|
147 |
+
Works even if arg itself has non-alpha characters.
|
148 |
+
Used for player and event names AND to identify tokens in command streams.
|
149 |
+
Uses C++ "isalpha" for local definition of names.
|
150 |
+
Prefer to override it to count underscore as a non-delimiting char.
|
151 |
+
Hyphen is always part of tokens but can be used to delimit place and person names,
|
152 |
+
so "Khanty" and "Khanty-Mansiysk" can both match "Khanty-Mansiysk" and
|
153 |
+
"Vachier" can match "Vachier-Lagrave".
|
154 |
+
|
155 |
+
With LLM tokens, this allows arg="abc" to match st=" abc" but not vice-versa.
|
156 |
+
However, if called with arg.strip() then vice-versa is fine.
|
157 |
+
If the token is @-@ then it will match "--" but NOT match a hyphenated word.
|
158 |
+
"""
|
159 |
+
|
160 |
+
|
161 |
+
def borderedMatch(arg, st, hyphenDelimits=False, underscoreDelimits=False):
|
162 |
+
fromPos = st.find(arg)
|
163 |
+
while fromPos != -1:
|
164 |
+
leftOK = (fromPos == 0)
|
165 |
+
if (fromPos > 0):
|
166 |
+
c = st[fromPos - 1]
|
167 |
+
if c == '-':
|
168 |
+
leftOK = hyphenDelimits
|
169 |
+
elif c == '_':
|
170 |
+
leftOK = underscoreDelimits
|
171 |
+
else:
|
172 |
+
leftOK = (not c.isalnum())
|
173 |
+
|
174 |
+
rightEdge = fromPos + len(arg)
|
175 |
+
rightOK = (rightEdge == len(st))
|
176 |
+
if (not rightOK):
|
177 |
+
d = st[rightEdge]
|
178 |
+
if d == '-':
|
179 |
+
rightOK = hyphenDelimits
|
180 |
+
elif d == '_':
|
181 |
+
rightOK = underscoreDelimits
|
182 |
+
else:
|
183 |
+
rightOK = (not d.isalnum())
|
184 |
+
|
185 |
+
if rightOK and leftOK:
|
186 |
+
return True
|
187 |
+
else: # try to find another match
|
188 |
+
fromPos = st.find(arg, fromPos + 1)
|
189 |
+
|
190 |
+
return False
|
191 |
+
|
192 |
+
|
193 |
+
def reprat(tok):
|
194 |
+
rep = unidecode(repr(tok))
|
195 |
+
return f"@{rep.replace('@','(at)')[1:-1]}@"
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
hf_token = input("Enter your Huggingface token")
|
200 |
+
# Or better:
|
201 |
+
# hf_token = os.environ.get("HUGGING_FACE_HUB_TOKEN")
|
202 |
+
|
203 |
+
if hf_token:
|
204 |
+
print("Logging in to Hugging Face Hub...")
|
205 |
+
login(token=hf_token)
|
206 |
+
else:
|
207 |
+
print("HF Token not found. Gated model download might fail.")
|
208 |
+
|
209 |
+
|
210 |
+
def main(INPUT_FILE, OUTPUT_FILE, PROMPT_ID, PROMPT_TOPIC, MULTI_PV, NUM_WORDS, NUM_TOKENS, BEAM_WIDTH, ALPHA_MODE,
|
211 |
+
MODEL_TAG):
|
212 |
+
# %% Constants and Configuration
|
213 |
+
# MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
214 |
+
# MODEL_NAME = "google/gemma-3-4b-it"
|
215 |
+
MODEL_NAME = "Qwen/Qwen3-1.7B"
|
216 |
+
# MODEL_NAME = "Qwen/Qwen3-0.6B"
|
217 |
+
# MODEL_NAME = "microsoft/Phi-4-mini-instruct"
|
218 |
+
#MODEL_NAME = "meta-llama/Llama-2-7b-hf"
|
219 |
+
#MODEL_NAME = input(f"Enter hugging face model name or press enter to default to [{MODEL_NAME}]: ") or MODEL_NAME
|
220 |
+
DEVICE = "cpu"
|
221 |
+
TORCH_DTYPE = torch.float32
|
222 |
+
DEPTH_RANGE = 1
|
223 |
+
# Ensure INPUT_FILE path is correct for your environment
|
224 |
+
# INPUT_FILE = 'feed.txt' # Assuming it's in the same directory or provide full path
|
225 |
+
# Create the input file if it doesn't exist for testing
|
226 |
+
if not os.path.exists(INPUT_FILE):
|
227 |
+
print(f"Warning: Input file '{INPUT_FILE}' not found. Creating a dummy file.")
|
228 |
+
with open(INPUT_FILE, 'w', encoding='utf-8') as f:
|
229 |
+
f.write("The quick brown fox jumps over the lazy dog")
|
230 |
+
|
231 |
+
# OUTPUT_FILE = "output.lif" # Changed output filename
|
232 |
+
MODEL_CONTEXT_WINDOW = 128_000 # Example context window, adjust if needed for the actual model
|
233 |
+
SAFETY_THRESHOLD = 2_000
|
234 |
+
MAX_INPUT_TOKENS = MODEL_CONTEXT_WINDOW - SAFETY_THRESHOLD # Max tokens per model *input slice*
|
235 |
+
|
236 |
+
# %% Load and Quantize Model & Tokenizer
|
237 |
+
print("Step 1: Loading model...")
|
238 |
+
# Add trust_remote_code=True if necessary for the specific model architecture
|
239 |
+
model = AutoModelForCausalLM.from_pretrained(
|
240 |
+
MODEL_NAME,
|
241 |
+
torch_dtype=TORCH_DTYPE,
|
242 |
+
trust_remote_code=True, # Often needed for Qwen-based models
|
243 |
+
token=hf_token
|
244 |
+
).to(DEVICE)
|
245 |
+
print(f" Model loaded to {DEVICE}.")
|
246 |
+
|
247 |
+
# print("Step 2: Applying dynamic quantization for faster CPU inference...")
|
248 |
+
# Note: Quantization might slightly affect raw logit values compared to fp32/fp16
|
249 |
+
# model = torch.quantization.quantize_dynamic(
|
250 |
+
# model,
|
251 |
+
# {torch.nn.Linear},
|
252 |
+
# dtype=torch.qint8
|
253 |
+
# )
|
254 |
+
hook_handle = model.lm_head.register_forward_hook(capture_logits_hook)
|
255 |
+
|
256 |
+
##KWR: NEW
|
257 |
+
#model.generation_config.temperature=0
|
258 |
+
#model.generation_config.top_p=1.0
|
259 |
+
|
260 |
+
model.eval()
|
261 |
+
print(" Quantization complete. Model is ready for inference.\n")
|
262 |
+
|
263 |
+
print("Step 3: Loading tokenizer...")
|
264 |
+
# Add trust_remote_code=True if necessary for the specific model architecture
|
265 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True, token=hf_token)
|
266 |
+
if tokenizer.pad_token is None:
|
267 |
+
print(" Tokenizer missing pad token; setting pad_token = eos_token")
|
268 |
+
tokenizer.pad_token = tokenizer.eos_token
|
269 |
+
# Important: Ensure model config also reflects this if needed by generation args
|
270 |
+
if hasattr(model, 'config'):
|
271 |
+
model.config.pad_token_id = tokenizer.eos_token_id
|
272 |
+
print(" Tokenizer loaded and configured.\n")
|
273 |
+
|
274 |
+
# %% User Inputs
|
275 |
+
print("Step 4: Prompting user for inputs...")
|
276 |
+
# Use default values for easier testing
|
277 |
+
promptID = PROMPT_ID # input(" Enter Prompt ID [Default: VanityTestGreedy]: ") or "VanityTestGreedy"
|
278 |
+
# MultiPV_str = input(" Enter MultiPV (top logits to show) [Default: 5]: ") or "5"
|
279 |
+
MultiPV = MULTI_PV # int(MultiPV_str) # Now only controls how many top logits to display
|
280 |
+
# LegalNumberOfMove_str = input(" Enter Max Number of moves [Default: 10]: ") or "10"
|
281 |
+
LegalNumberOfMove = NUM_WORDS # int(LegalNumberOfMove_str)
|
282 |
+
# EngineID = f"DeepSeek R1 1.5B Qwen-Distil Greedy ({DEVICE.upper()})" # Updated EngineID
|
283 |
+
EngineID = f"Qwen/Qwen3-1.7B"
|
284 |
+
# EngineID = f"Qwen/Qwen3-0.6B"
|
285 |
+
# EngineID = f"Gemma-3-4b-it ({DEVICE.upper()})" # Indicate CPU in EngineID
|
286 |
+
Depth = 1
|
287 |
+
print(" User inputs captured.\n")
|
288 |
+
|
289 |
+
# %% Pre-tokenize entire relevant input sequence
|
290 |
+
print("Step 5: Pre-tokenizing input sequence...")
|
291 |
+
initial_prompt = "Complete successive parts of a sentence given one word at a time:"
|
292 |
+
initial_prompt_ids = tokenizer.encode(initial_prompt, add_special_tokens=False)
|
293 |
+
|
294 |
+
print(f" Reading words from {INPUT_FILE}...")
|
295 |
+
lines = []
|
296 |
+
try:
|
297 |
+
with open(INPUT_FILE, 'r', encoding='utf-8') as f:
|
298 |
+
# words_from_file = f.read().split()
|
299 |
+
lines = f.readlines()
|
300 |
+
words_from_file = "".join(line.replace('\n', '') for line in lines)
|
301 |
+
wordList = re.split(r'([a-zA-Z]+|\d+)', words_from_file)
|
302 |
+
wordList = [x for x in wordList if x != ' ' and x != '']
|
303 |
+
# print("The words are:\n", words_from_file)
|
304 |
+
|
305 |
+
numChars = 0
|
306 |
+
numTextTokens = len(wordList)
|
307 |
+
for word in wordList:
|
308 |
+
numChars += len(word)
|
309 |
+
avgTokenLength = round(numChars/numTextTokens, 4)
|
310 |
+
print(f"\nFound {numTextTokens} text word/tokens with average length {avgTokenLength}.\n")
|
311 |
+
|
312 |
+
except FileNotFoundError:
|
313 |
+
print(f"Error: Input file '{INPUT_FILE}' not found. Exiting.")
|
314 |
+
exit()
|
315 |
+
|
316 |
+
all_tokens = list(initial_prompt_ids)
|
317 |
+
word_end_indices = [len(initial_prompt_ids)] # Index *after* the last token of each word (or initial prompt)
|
318 |
+
processed_words = [] # Store the actual words processed
|
319 |
+
|
320 |
+
print(" Tokenizing words and building full sequence...")
|
321 |
+
for word in wordList:
|
322 |
+
word_tokens = tokenizer.encode(" " + word, add_special_tokens=False)
|
323 |
+
all_tokens.extend(word_tokens)
|
324 |
+
word_end_indices.append(len(all_tokens))
|
325 |
+
processed_words.append(word)
|
326 |
+
|
327 |
+
full_token_tensor = torch.tensor(all_tokens, dtype=torch.long).unsqueeze(0)
|
328 |
+
print(f" Pre-tokenized {len(processed_words)} words into a sequence of {len(all_tokens)} tokens.\n")
|
329 |
+
|
330 |
+
num_words_to_process = min(len(processed_words), LegalNumberOfMove) - (START_TURN - 1)
|
331 |
+
if num_words_to_process < len(processed_words) - (START_TURN - 1):
|
332 |
+
print(f" Will process the first {num_words_to_process} words due to NUM_WORDS limit.\n")
|
333 |
+
elif num_words_to_process == 0:
|
334 |
+
print(" Warning: No words to process based on input file or limits.\n")
|
335 |
+
|
336 |
+
# %% Build file header
|
337 |
+
print("Step 8: Preparing output file header...")
|
338 |
+
header_lines = [
|
339 |
+
f'[PromptID "{promptID}"]\n',
|
340 |
+
f'[EngineID "{EngineID}"]\n',
|
341 |
+
f'[MultiPV "{MultiPV}"]\n',
|
342 |
+
f'[DepthRange "1:1"]\n\n',
|
343 |
+
] + lines + [f'\n\n']
|
344 |
+
print(" Header prepared.\n")
|
345 |
+
|
346 |
+
# %% Main Generation Loop (Using Slicing & Greedy Decoding)
|
347 |
+
print("Step 9: Entering main generation loop (using pre-tokenized slicing and greedy decoding)...\n")
|
348 |
+
PrevEval = "n.a."
|
349 |
+
start_time = time.time()
|
350 |
+
current_time = start_time
|
351 |
+
numMatchedWords = 0
|
352 |
+
numMatchedChars = 0
|
353 |
+
|
354 |
+
if num_words_to_process > 0:
|
355 |
+
if (START_TURN > 1):
|
356 |
+
OUTPUT_FILE = OUTPUT_FILE.split('.')[0]+"from"+str(START_TURN)+".lif"
|
357 |
+
with open(OUTPUT_FILE, 'w', encoding='utf-8') as writer:
|
358 |
+
print(" Writing header to output file...")
|
359 |
+
writer.write(''.join(header_lines))
|
360 |
+
print(" Header written. Starting word-by-word prediction.\n")
|
361 |
+
|
362 |
+
for turnCount in range(START_TURN, START_TURN + num_words_to_process):
|
363 |
+
current_word = processed_words[turnCount - 1].strip()
|
364 |
+
# print(f"Turn {turnCount}: Predicting after word '{current_word}'")
|
365 |
+
|
366 |
+
slice_end_index = word_end_indices[turnCount - 1]
|
367 |
+
slice_start_index = max(0, slice_end_index - MAX_INPUT_TOKENS)
|
368 |
+
# print(f" 9.1/9.2: Context slice indices: [{slice_start_index}:{slice_end_index}]")
|
369 |
+
|
370 |
+
input_tensor = full_token_tensor[:, slice_start_index:slice_end_index]
|
371 |
+
current_input_len = input_tensor.shape[1]
|
372 |
+
# print(f" 9.3: Sliced input tensor shape: {input_tensor.shape}")
|
373 |
+
|
374 |
+
input_tensor_dev = input_tensor.to(DEVICE)
|
375 |
+
|
376 |
+
start_time_gen = time.time()
|
377 |
+
# 9.4 Generate next token using GREEDY DECODING
|
378 |
+
# print(f" 9.4: Running model.generate() with {current_input_len} input tokens (Greedy Decoding)...")
|
379 |
+
with torch.no_grad():
|
380 |
+
outputs = model.generate(
|
381 |
+
input_tensor_dev,
|
382 |
+
max_new_tokens=2,
|
383 |
+
min_new_tokens=2, # Explicitly require 1 new token
|
384 |
+
output_scores=True, # Get logits
|
385 |
+
return_dict_in_generate=True, # Get dict output
|
386 |
+
do_sample=False, # Disable sampling -> Use Greedy Decoding
|
387 |
+
pad_token_id=tokenizer.pad_token_id,
|
388 |
+
num_beams=BEAM_WIDTH,
|
389 |
+
num_return_sequences=BEAM_WIDTH,
|
390 |
+
# Removed num_beams and num_return_sequences
|
391 |
+
temperature=None,
|
392 |
+
top_k=None,
|
393 |
+
top_p=None,
|
394 |
+
#num_return_sequences=3
|
395 |
+
)
|
396 |
+
end_time_gen = time.time()
|
397 |
+
gen_duration = end_time_gen - start_time_gen
|
398 |
+
# print(f" Model generation took: {gen_duration:.4f} seconds")
|
399 |
+
|
400 |
+
if (turnCount < START_TURN):
|
401 |
+
print("Skipping turn", turnCount)
|
402 |
+
turnCount += 1
|
403 |
+
continue
|
404 |
+
|
405 |
+
# ----- UPDATED LOGIC for TopK Logits (Greedy Path) -----
|
406 |
+
# outputs.scores is a tuple of length max_new_tokens (1)
|
407 |
+
# Each element is a tensor of shape [batch_size, vocab_size] (batch_size is 1 here)
|
408 |
+
logits_for_step = outputs.scores[
|
409 |
+
0] # Logits for the single generated token step. Shape: [1, vocab_size]
|
410 |
+
|
411 |
+
# Get the logits from the single batch item (greedy path)
|
412 |
+
logits_for_greedy_path = logits_for_step[0] # Shape: [vocab_size]
|
413 |
+
|
414 |
+
# Get the top K (MultiPV) logits and their corresponding token IDs
|
415 |
+
# Note: The highest logit corresponds to the token chosen by greedy decoding
|
416 |
+
top_k_logits_values, top_k_logits_indices = torch.topk(
|
417 |
+
logits_for_greedy_path, k=MultiPV, dim=-1
|
418 |
+
)
|
419 |
+
|
420 |
+
# Convert results to lists
|
421 |
+
top_k_logits_values = top_k_logits_values.tolist()
|
422 |
+
top_k_logits_indices = top_k_logits_indices.tolist()
|
423 |
+
|
424 |
+
# Decode the top K tokens based on logits
|
425 |
+
top_k_tokens = [tokenizer.decode(tid) for tid in top_k_logits_indices]
|
426 |
+
"""
|
427 |
+
print(f"Top {MultiPV} Logits from greedy path (Token | Logit Value):")
|
428 |
+
for i in range(MultiPV):
|
429 |
+
token_str_cleaned = top_k_tokens[i].strip()
|
430 |
+
print(f" - '{token_str_cleaned}': {top_k_logits_values[i]:.4f} (ID: {top_k_logits_indices[i]})")
|
431 |
+
"""
|
432 |
+
# The token actually generated by greedy decoding
|
433 |
+
greedy_selected_token_id = outputs.sequences[0, -1].item() # Last token in the sequence
|
434 |
+
greedy_selected_token_str = tokenizer.decode(greedy_selected_token_id).strip()
|
435 |
+
# This will always match top_k_tokens[0] because do_sample=False
|
436 |
+
# print(f" (Greedy search selected token: '{greedy_selected_token_str}' ID: {greedy_selected_token_id})") # Optional confirmation
|
437 |
+
# ----- END of UPDATED LOGIC -----
|
438 |
+
|
439 |
+
# Derive metrics
|
440 |
+
modelToken = reprat(top_k_tokens[0]) # Equivalent to greedy_selected_token_str
|
441 |
+
#modelToken = modelToken.replace('@','(at)')
|
442 |
+
#modelToken = f"@{modelToken[1:-1]}@"
|
443 |
+
# modelEval is the highest logit value
|
444 |
+
modelEval = round(top_k_logits_values[0], 4)
|
445 |
+
# modelEval = round(float(modelEval)*100)
|
446 |
+
# NextEval = (f"{top_k_logits_values[1]:.4f}" if MultiPV > 1 else "n.a.")
|
447 |
+
# NextEval = round(float(NextEval)*100) if MultiPV > 1 and isinstance(top_k_logits_values[1], float) else "n.a."
|
448 |
+
|
449 |
+
print("Turn ", turnCount, " now matching text word ", current_word, " ...", end='', sep='')
|
450 |
+
|
451 |
+
topNUMTvals, topNUMTindices = torch.topk(logits_for_greedy_path, k=NUM_TOKENS, dim=-1)
|
452 |
+
topNUMTvalList = topNUMTvals.tolist()
|
453 |
+
topNUMTindList = topNUMTindices.tolist()
|
454 |
+
topNUMTtokens = [reprat(tokenizer.decode(tind)) for tind in topNUMTindList]
|
455 |
+
matchingTextToken = "@@"
|
456 |
+
|
457 |
+
textTokenIndex = 0
|
458 |
+
textTokenValue = 0
|
459 |
+
for tok in topNUMTtokens:
|
460 |
+
# if tok.find(current_word) != -1:
|
461 |
+
if current_word.find("Joki") >= 0 and tok.find("J") >= 0:
|
462 |
+
print("Why doesn't", current_word, "match", tok, "at index", textTokenIndex, "?")
|
463 |
+
if borderedMatch(current_word, tok, True, True):
|
464 |
+
matchingTextToken = tok #f"@{tok.replace('@','(at)')[1:-1]}@"
|
465 |
+
textTokenValue = topNUMTvalList[textTokenIndex]
|
466 |
+
if math.isinf(textTokenValue) and textTokenValue < 0.0:
|
467 |
+
textTokenValue = MINUS_INF
|
468 |
+
else:
|
469 |
+
textTokenValue = round(textTokenValue,4)
|
470 |
+
if textTokenIndex == 0:
|
471 |
+
print("***matches top model token", modelToken, "with score ", textTokenValue)
|
472 |
+
numMatchedWords += 1
|
473 |
+
numMatchedChars += len(current_word)
|
474 |
+
else:
|
475 |
+
print("found at index", textTokenIndex, "in token", matchingTextToken, "with score ", textTokenValue, "; top is ", modelToken, modelEval)
|
476 |
+
break
|
477 |
+
textTokenIndex += 1
|
478 |
+
|
479 |
+
if textTokenIndex >= NUM_TOKENS:
|
480 |
+
textTokenValue = round(topNUMTvalList[-1], 4)
|
481 |
+
print("not found, using bottom score", textTokenValue)
|
482 |
+
|
483 |
+
|
484 |
+
NextEval = textTokenValue
|
485 |
+
|
486 |
+
|
487 |
+
|
488 |
+
# print(
|
489 |
+
# f" 9.5: Top token (greedy choice): '{modelToken}' (Evalution: {modelEval})|Logit value : {top_k_logits_values[0]:.4f}| Next best Eval: {NextEval} | Logit ")
|
490 |
+
|
491 |
+
# Build lines for this turn
|
492 |
+
current_stem = initial_prompt + " " + " ".join(processed_words[:turnCount])
|
493 |
+
lines = [
|
494 |
+
f'[PID "{promptID}"]\n',
|
495 |
+
f'[EID "{MODEL_NAME}"]\n',
|
496 |
+
f'[Turn "{turnCount}-w"]\n',
|
497 |
+
f'[TextToken "@{current_word}@"]\n',
|
498 |
+
f'[ModelToken "{modelToken}"]\n', # The model's greedy prediction
|
499 |
+
f'[TextTokenIndex "{textTokenIndex}"]\n'
|
500 |
+
f'[TextTokenValue "{textTokenValue}"]\n'
|
501 |
+
f'[Eval "{modelEval}"]\n', # The highest raw logit value
|
502 |
+
f'[PrevEval "{PrevEval}"]\n',
|
503 |
+
f'[NextEval "{NextEval}"]\n', # The second highest raw logit value
|
504 |
+
f'[Depth "{Depth}"]\n',
|
505 |
+
f'[STEM "{current_stem}"]\n',
|
506 |
+
f'[NumLegalMoves "{MultiPV}"]\n',
|
507 |
+
"---------------\n",
|
508 |
+
f"{DEPTH_RANGE}\n",
|
509 |
+
"---------------\n"
|
510 |
+
]
|
511 |
+
for token_str, logit_val in zip(top_k_tokens, top_k_logits_values):
|
512 |
+
rep = reprat(token_str) #.replace('@', '(at)') # has ' ' or " " around it
|
513 |
+
# rep = f"@{rep[1:-1]}@" # now has @ ... @ around it
|
514 |
+
lines.append(f"{rep} {logit_val:.4f}\n")
|
515 |
+
|
516 |
+
lines.append(
|
517 |
+
"===========================================================================================================\n")
|
518 |
+
lines.append(f"[Comments]\n")
|
519 |
+
lines.append(f"[EndMove]\n\n")
|
520 |
+
|
521 |
+
# print(" Lines built.")
|
522 |
+
|
523 |
+
# 9.7 Write to file
|
524 |
+
# print(" 9.7: Writing lines to output file...")
|
525 |
+
writer.write(''.join(lines))
|
526 |
+
# print(" Write complete.\n")
|
527 |
+
|
528 |
+
# 9.8 Update state
|
529 |
+
PrevEval = modelEval
|
530 |
+
|
531 |
+
# 9.9 Status update
|
532 |
+
status_interval = min(100, num_words_to_process // 2 if num_words_to_process >= 10 else 10)
|
533 |
+
if turnCount % status_interval == 0 or turnCount == num_words_to_process:
|
534 |
+
last_time = current_time
|
535 |
+
current_time = time.time()
|
536 |
+
elapsed = current_time - start_time
|
537 |
+
elapsedLast = current_time - last_time
|
538 |
+
rate = (turnCount - 1) / elapsed if elapsed > 0 else 0
|
539 |
+
rateLast = 100.0 / elapsedLast if elapsedLast > 0 else 0
|
540 |
+
print()
|
541 |
+
print(f"Processed Turn {turnCount}. Rate: {rate:.2f} words/sec., last 100 rate: {rateLast:.2f}")
|
542 |
+
|
543 |
+
#end-for
|
544 |
+
averageCharsMatched = 0 if numMatchedWords == 0 else round(numMatchedChars/numMatchedWords, 4)
|
545 |
+
matchPercent = 0.0 if numTextTokens == 0 else round(100.0*numMatchedWords/numTextTokens, 2)
|
546 |
+
matchPercentStr = f"({matchPercent}%)"
|
547 |
+
print("Done: matched", numMatchedWords, matchPercentStr, "tokens of average length", averageCharsMatched)
|
548 |
+
print("from", numTextTokens, "tokens of average length", avgTokenLength)
|
549 |
+
|
550 |
+
else:
|
551 |
+
print("Skipping main generation loop as there are no words to process.")
|
552 |
+
|
553 |
+
hook_handle.remove()
|
554 |
+
print("Removed forward hook.")
|
555 |
+
# %% Final Stats
|
556 |
+
print("Step 10: Reporting final statistics...")
|
557 |
+
total_time = time.time() - start_time
|
558 |
+
avg_rate = (num_words_to_process / total_time) if total_time > 0 and num_words_to_process > 0 else 0
|
559 |
+
print(f" Total turns processed: {num_words_to_process}")
|
560 |
+
print(f" Total time: {total_time:.2f} seconds")
|
561 |
+
print(f" Average speed: {avg_rate:.2f} words/second")
|
562 |
+
print(f" Output written to {OUTPUT_FILE}")
|
563 |
+
|
564 |
+
# Optional: Clean up memory
|
565 |
+
print("\nCleaning up resources...")
|
566 |
+
del model
|
567 |
+
del tokenizer
|
568 |
+
del full_token_tensor
|
569 |
+
if 'outputs' in locals():
|
570 |
+
del outputs
|
571 |
+
if 'input_tensor' in locals():
|
572 |
+
del input_tensor
|
573 |
+
if 'input_tensor_dev' in locals():
|
574 |
+
del input_tensor_dev
|
575 |
+
gc.collect()
|
576 |
+
if DEVICE == 'cuda':
|
577 |
+
print("Emptying CUDA cache...")
|
578 |
+
torch.cuda.empty_cache()
|
579 |
+
print("\nScript finished.")
|
580 |
+
|
581 |
+
|
582 |
+
### RUN MAIN ####
|
583 |
+
|
584 |
+
main(INPUT_FILE, OUTPUT_FILE, PROMPT_ID, PROMPT_TOPIC, MULTI_PV, NUM_WORDS, NUM_TOKENS, BEAM_WIDTH, ALPHA_MODE,
|
585 |
+
MODEL_TAG)
|
586 |
+
|
587 |
+
demo = gr.Interface(fn=main, inputs="text", outputs="text")
|
588 |
+
demo.launch()
|
589 |
+
|