File size: 3,701 Bytes
f6d6286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
from ..hellaswag import render_example, iterate_examples, get_most_likely_row
import torch.distributed as dist 
from torch.distributed import init_process_group, destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
import os
from ..ModelGPT2 import GPT,log_file

ddp = int(os.environ.get('RANK', -1)) != -1 #will be True if ddp run
if ddp:
    assert torch.cuda.is_available()
    init_process_group(backend='nccl')
    ddp_rank = int(os.environ['RANK'])
    ddp_local_rank = int(os.environ['LOCAL_RANK'])
    ddp_world_size = int(os.environ['WORLD_SIZE'])
    device = f"cuda:{ddp_local_rank}"
    torch.cuda.set_device(device)
    master_process = ddp_rank == 0 #this is the process doing checkpoint,logging,etc
else:
    ddp_rank = 0
    ddp_local_rank = 0
    ddp_world_size = 1
    master_process = True
    #attempt to autodetect the device
    device = 'cpu'
    if torch.cuda.is_available():
        device = 'cuda'
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        device = "mps"  #for mac users use apple silicon cpu which allready have gpu.mps is backend for apple silicon
    print(f"Using device: {device}")
# device = "cpu" #OVERRIDE

device_type = "cuda" if device.startswith("cuda") else "cpu"

torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1337)


#Creating model by loading the model weights
checkpoint_path = '../log/model_final.pt' 
if master_process:
        print(f"Loading checkpoint from {checkpoint_path}")

checkpoint = torch.load(checkpoint_path, map_location=device)

# Extract config and create model
model_config = checkpoint['config']
model_config.vocab_size = 50304 #for computational effciency(power of 2)
model = GPT(model_config)
# Load model state dict
model.load_state_dict(checkpoint['model'])
model = DDP(model, device_ids=[ddp_local_rank])
model.to(device)


def evaluate_hellaswag(model, device, device_type, ddp, ddp_rank, ddp_world_size, log_file, master_process):
         
    num_correct_norm = 0
    num_total = 0

    for i, example in enumerate(iterate_examples("val")):
        # only process example where i % ddp_world_size ==ddp_rank#this is for proper managemnt of which part is deal by which gpu
        if ddp:
            if i % ddp_world_size != ddp_rank:
                continue
        #rendering example into tokens and labels
        _, tokens, mask, label = render_example(example)
        tokens = tokens.to(device)
        mask = mask.to(device)
        #get the logits
        with torch.no_grad():
            with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
                logits, loss = model(tokens)
            pred_norm = get_most_likely_row(tokens, mask, logits)
        num_total += 1
        num_correct_norm += int(pred_norm == label)
    #reduce the stats accross all process
    if ddp:
        num_total = torch.tensor(num_total, dtype=torch.long, device=device)
        num_correct_norm = torch.tensor(num_correct_norm, dtype=torch.long, device=device)
        dist.all_reduce(num_total, op=dist.ReduceOp.SUM)
        dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM)
        num_total = num_total.item()
        num_correct_norm = num_correct_norm.item()
    acc_norm = num_correct_norm / num_total #accuracy of hellaswag
    if master_process:
        print(f"HellaSwag accuracy: {num_correct_norm}/{num_total}={acc_norm:.4f}")
        with open(log_file, "a") as f:
            f.write(f"Final Hellaswag accuracy: {acc_norm:.4f}\n")   

evaluate_hellaswag(model, device, device_type, ddp, ddp_rank, ddp_world_size, log_file, master_process)
if ddp:
    destroy_process_group()