|
|
|
|
|
|
|
|
|
|
|
import itertools |
|
import sys |
|
import time |
|
from pathlib import Path |
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
import torch._dynamo.config |
|
import torch._inductor.config |
|
|
|
def device_sync(device): |
|
if "cuda" in device: |
|
torch.cuda.synchronize(device) |
|
elif ("cpu" in device) or ("mps" in device): |
|
pass |
|
else: |
|
print(f"device={device} is not yet suppported") |
|
|
|
|
|
torch._inductor.config.coordinate_descent_tuning = True |
|
torch._inductor.config.triton.unique_kernel_names = True |
|
torch._inductor.config.fx_graph_cache = True |
|
|
|
default_device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
wd = Path(__file__).parent.parent.resolve() |
|
sys.path.append(str(wd)) |
|
|
|
from model import Transformer |
|
from tokenizer import get_tokenizer |
|
|
|
def multinomial_sample_one_no_sync(probs_sort): |
|
q = torch.empty_like(probs_sort).exponential_(1) |
|
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) |
|
|
|
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): |
|
logits = logits / max(temperature, 1e-5) |
|
|
|
if top_k is not None: |
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
pivot = v.select(-1, -1).unsqueeze(-1) |
|
logits = torch.where(logits < pivot, -float("Inf"), logits) |
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
return probs |
|
|
|
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): |
|
probs = logits_to_probs(logits[0, -1], temperature, top_k) |
|
idx_next = multinomial_sample_one_no_sync(probs) |
|
return idx_next, probs |
|
|
|
def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: |
|
|
|
logits = model(x, input_pos) |
|
return sample(logits, **sampling_kwargs)[0] |
|
|
|
def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
assert input_pos.shape[-1] == 1 |
|
logits = model(x, input_pos) |
|
return sample(logits, **sampling_kwargs) |
|
|
|
def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): |
|
new_tokens, new_probs = [], [] |
|
for i in range(num_new_tokens): |
|
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): |
|
next_token, next_prob = decode_one_token( |
|
model, cur_token, input_pos, **sampling_kwargs |
|
) |
|
input_pos += 1 |
|
new_tokens.append(next_token.clone()) |
|
callback(new_tokens[-1]) |
|
new_probs.append(next_prob.clone()) |
|
cur_token = next_token.view(1, -1) |
|
|
|
return new_tokens, new_probs |
|
|
|
|
|
def model_forward(model, x, input_pos): |
|
return model(x, input_pos) |
|
|
|
def speculative_decode( |
|
model: Transformer, |
|
draft_model: Transformer, |
|
cur_token: torch.Tensor, |
|
input_pos: int, |
|
speculate_k: int, |
|
**sampling_kwargs |
|
) -> torch.Tensor: |
|
|
|
device = cur_token.device |
|
orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device) |
|
draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs) |
|
|
|
draft_tokens = torch.cat(draft_tokens) |
|
|
|
target_logits = model_forward( |
|
model, |
|
torch.cat([cur_token.view(1), draft_tokens]).view(1, -1), |
|
torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device) |
|
) |
|
target_probs = logits_to_probs(target_logits[0], **sampling_kwargs) |
|
draft_probs = torch.stack(draft_probs) |
|
|
|
|
|
|
|
p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens] |
|
q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens] |
|
accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p) |
|
rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero() |
|
|
|
if rejected_locations.shape[0] == 0: |
|
accept_length = speculate_k + 1 |
|
last_token = multinomial_sample_one_no_sync(target_probs[-1]) |
|
|
|
model_forward( |
|
draft_model, |
|
draft_tokens[-1].view(1, -1), |
|
orig_input_pos + speculate_k, |
|
) |
|
return torch.cat([draft_tokens, last_token]) |
|
else: |
|
accept_length = rejected_locations[0].item() |
|
p = draft_probs[accept_length] |
|
q = target_probs[accept_length] |
|
new = q - p |
|
new = torch.where(new > 0, new, 0.0) |
|
new = new / new.sum() |
|
next_token = multinomial_sample_one_no_sync(new) |
|
return torch.cat([draft_tokens[:accept_length], next_token]) |
|
|
|
@torch.no_grad() |
|
def generate( |
|
model: Transformer, |
|
prompt: torch.Tensor, |
|
max_new_tokens: int, |
|
*, |
|
interactive: bool, |
|
draft_model: Transformer, |
|
speculate_k: Optional[int] = 8, |
|
callback = lambda x: x, |
|
**sampling_kwargs |
|
) -> torch.Tensor: |
|
""" |
|
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. |
|
""" |
|
|
|
is_speculative = draft_model is not None |
|
|
|
T = prompt.size(0) |
|
T_new = T + max_new_tokens |
|
if interactive: |
|
max_seq_length = 350 |
|
else: |
|
max_seq_length = min(T_new, model.config.block_size) |
|
|
|
device, dtype = prompt.device, prompt.dtype |
|
max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length |
|
with torch.device(device): |
|
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) |
|
if is_speculative and draft_model is not model: |
|
draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) |
|
|
|
|
|
empty = torch.empty(T_new, dtype=dtype, device=device) |
|
empty[:T] = prompt |
|
seq = empty |
|
input_pos = torch.arange(0, T, device=device) |
|
|
|
next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs).clone() |
|
if is_speculative: |
|
prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs) |
|
seq[T] = next_token |
|
|
|
input_pos = torch.tensor([T], device=device, dtype=torch.int) |
|
accept_counts = [0] * (speculate_k + 1) |
|
|
|
if is_speculative: |
|
input_pos = input_pos.item() |
|
while input_pos < T_new - 1: |
|
cur_token = next_token.view(()) |
|
|
|
next_tokens = speculative_decode( |
|
model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs |
|
) |
|
|
|
accept_counts[len(next_tokens) - 1] += 1 |
|
num_added = min(T_new - input_pos - 1, len(next_tokens)) |
|
seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added] |
|
for i in next_tokens[: num_added,]: |
|
callback(i) |
|
input_pos = input_pos + num_added |
|
next_token = next_tokens[-1] |
|
else: |
|
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) |
|
seq[T + 1:] = torch.cat(generated_tokens) |
|
|
|
generate_stats = { |
|
'accept_counts': accept_counts |
|
} |
|
return seq, generate_stats |
|
|
|
def encode_tokens(tokenizer, string, bos=True, device=default_device): |
|
tokens = tokenizer.encode(string) |
|
if bos: |
|
tokens = [tokenizer.bos_id()] + tokens |
|
return torch.tensor(tokens, dtype=torch.int, device=device) |
|
|
|
def _load_model(checkpoint_path, device, precision, use_tp): |
|
use_cuda = 'cuda' in device |
|
with torch.device('meta'): |
|
model = Transformer.from_name(checkpoint_path.parent.name) |
|
|
|
if "int8" in str(checkpoint_path): |
|
print("Using int8 weight-only quantization!") |
|
from quantize import WeightOnlyInt8QuantHandler |
|
simple_quantizer = WeightOnlyInt8QuantHandler(model) |
|
model = simple_quantizer.convert_for_runtime() |
|
|
|
if "int4" in str(checkpoint_path): |
|
print("Using int4 weight-only quantization!") |
|
path_comps = checkpoint_path.name.split(".") |
|
groupsize = int(path_comps[-2][1:]) |
|
from quantize import WeightOnlyInt4QuantHandler |
|
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) |
|
model = simple_quantizer.convert_for_runtime() |
|
|
|
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) |
|
if "model" in checkpoint and "stories" in str(checkpoint_path): |
|
checkpoint = checkpoint["model"] |
|
model.load_state_dict(checkpoint, assign=True) |
|
|
|
if use_tp: |
|
from tp import apply_tp |
|
print("Applying tensor parallel to model ...") |
|
apply_tp(model) |
|
|
|
model = model.to(device=device, dtype=precision) |
|
return model.eval() |
|
|
|
def _get_model_size(model): |
|
model_size = 0 |
|
for name, child in model.named_children(): |
|
if not isinstance(child, torch.nn.Embedding): |
|
model_size += sum( |
|
[ |
|
p.numel() * p.dtype.itemsize |
|
for p in itertools.chain(child.parameters(), child.buffers()) |
|
] |
|
) |
|
return model_size |
|
|
|
B_INST, E_INST = "[INST]", "[/INST]" |
|
|
|
def main( |
|
prompt: str = "Hello, my name is", |
|
interactive: bool = False, |
|
num_samples: int = 5, |
|
max_new_tokens: int = 100, |
|
top_k: int = 200, |
|
temperature: float = 0.8, |
|
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), |
|
compile: bool = True, |
|
compile_prefill: bool = False, |
|
profile: Optional[Path] = None, |
|
draft_checkpoint_path: Optional[Path] = None, |
|
speculate_k: int = 5, |
|
device=default_device, |
|
) -> None: |
|
"""Generates text samples based on a pre-trained Transformer model and tokenizer. |
|
""" |
|
assert checkpoint_path.is_file(), checkpoint_path |
|
|
|
tokenizer_path = checkpoint_path.parent / "tokenizer.model" |
|
assert tokenizer_path.is_file(), str(tokenizer_path) |
|
|
|
global print |
|
from tp import maybe_init_dist |
|
rank = maybe_init_dist() |
|
use_tp = rank is not None |
|
if use_tp: |
|
if rank != 0: |
|
|
|
print = lambda *args, **kwargs: None |
|
|
|
print(f"Using device={device}") |
|
precision = torch.bfloat16 |
|
is_speculative = draft_checkpoint_path is not None |
|
is_chat = "chat" in str(checkpoint_path) |
|
|
|
print("Loading model ...") |
|
t0 = time.time() |
|
model = _load_model(checkpoint_path, device, precision, use_tp) |
|
|
|
if is_speculative: |
|
draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp) |
|
else: |
|
draft_model = None |
|
|
|
device_sync(device=device) |
|
print(f"Time to load model: {time.time() - t0:.02f} seconds") |
|
|
|
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) |
|
|
|
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) |
|
prompt_length = encoded.size(0) |
|
|
|
torch.manual_seed(1234) |
|
model_size = _get_model_size(model) |
|
if compile: |
|
if is_speculative and use_tp: |
|
torch._inductor.config.triton.cudagraph_trees = False |
|
|
|
if is_speculative: |
|
global model_forward, logits_to_prob |
|
model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True) |
|
|
|
global decode_one_token, prefill |
|
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) |
|
|
|
|
|
if compile_prefill: |
|
prefill = torch.compile(prefill, fullgraph=True, dynamic=True) |
|
|
|
|
|
aggregate_metrics = { |
|
'tokens_per_sec': [], |
|
'accept_counts': [], |
|
} |
|
start = -1 if compile else 0 |
|
|
|
for i in range(start, num_samples): |
|
device_sync(device=device) |
|
if i >= 0 and interactive: |
|
prompt = input("What is your prompt? ") |
|
if is_chat: |
|
prompt = f"{B_INST} {prompt.strip()} {E_INST}" |
|
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) |
|
|
|
if interactive and i >= 0: |
|
buffer = [] |
|
period_id = tokenizer.encode('.')[0] |
|
done_generating = False |
|
def callback(x): |
|
nonlocal done_generating |
|
if done_generating: |
|
return |
|
buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) |
|
if x.item() == tokenizer.eos_id(): |
|
done_generating = True |
|
if len(buffer) == 4 or done_generating: |
|
print(''.join(buffer), end='', flush=True) |
|
buffer.clear() |
|
|
|
else: |
|
callback = lambda x : x |
|
t0 = time.perf_counter() |
|
import contextlib |
|
if (i != num_samples - 1 or not profile) or (use_tp and rank != 0): |
|
prof = contextlib.nullcontext() |
|
else: |
|
torch.profiler._utils._init_for_cuda_graphs() |
|
prof = torch.profiler.profile() |
|
with prof: |
|
y, metrics = generate( |
|
model, |
|
encoded, |
|
max_new_tokens, |
|
draft_model=draft_model, |
|
speculate_k=speculate_k, |
|
interactive=interactive, |
|
callback=callback, |
|
temperature=temperature, |
|
top_k=top_k, |
|
) |
|
aggregate_metrics['accept_counts'].append(metrics['accept_counts']) |
|
if i == -1: |
|
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") |
|
continue |
|
if hasattr(prof, "export_chrome_trace"): |
|
if use_tp: |
|
prof.export_chrome_trace(f"{profile}_rank_{rank}.json") |
|
else: |
|
prof.export_chrome_trace(f"{profile}.json") |
|
device_sync(device=device) |
|
t = time.perf_counter() - t0 |
|
|
|
if not interactive: |
|
print(tokenizer.decode(y.tolist())) |
|
else: |
|
print() |
|
tokens_generated = y.size(0) - prompt_length |
|
tokens_sec = tokens_generated / t |
|
aggregate_metrics['tokens_per_sec'].append(tokens_sec) |
|
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") |
|
print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") |
|
print("==========") |
|
if is_speculative: |
|
counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])] |
|
acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated] |
|
print(f"Acceptance probs: {acceptance_probs}") |
|
print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}") |
|
|
|
print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") |
|
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") |
|
|
|
|
|
if __name__ == '__main__': |
|
import argparse |
|
parser = argparse.ArgumentParser(description='Your CLI description.') |
|
|
|
parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') |
|
parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') |
|
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') |
|
parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') |
|
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') |
|
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') |
|
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') |
|
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') |
|
parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') |
|
parser.add_argument('--profile', type=Path, default=None, help='Profile path.') |
|
parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.') |
|
parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.') |
|
parser.add_argument('--device', type=str, default=default_device, help='Device to use') |
|
|
|
args = parser.parse_args() |
|
main( |
|
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, |
|
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path, |
|
args.speculate_k, args.device |
|
) |
|
|