Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import torch | |
import torch.nn as nn | |
from datetime import datetime | |
from torch.nn.utils import clip_grad_norm_ | |
import torch.optim as optim | |
from torch.utils.data import DataLoader | |
import json | |
dir_path = os.path.dirname(os.path.realpath(__file__)) | |
sys.path.append(os.path.join(dir_path, '..')) | |
from nets.model import Model | |
from Actor.actor import Actor | |
from dataloader import VRP_Dataset | |
from google_solver.google_model import evaluate_google_model | |
# Load params | |
with open('params.json', 'r') as f: | |
params = json.load(f) | |
# Save params into a local file for tracking | |
with open('params_saved.json', 'w') as f: | |
json.dump(params, f) | |
# Settings | |
device = params['device'] | |
run_tests = params['run_tests'] | |
save_results = params['save_results'] | |
dataset_path = params['dataset_path'] | |
# Dataset sizes | |
train_dataset_size = params['train_dataset_size'] | |
validation_dataset_size = params['validation_dataset_size'] | |
baseline_dataset_size = params['baseline_dataset_size'] | |
# Problem config | |
num_nodes = params['num_nodes'] | |
num_depots = params['num_depots'] | |
embedding_size = params['embedding_size'] | |
sample_size = params['sample_size'] | |
gradient_clipping = params['gradient_clipping'] | |
num_neighbors_encoder = params['num_neighbors_encoder'] | |
num_neighbors_action = params['num_neighbors_action'] | |
num_movers = params['num_movers'] | |
learning_rate = params['learning_rate'] | |
batch_size = params['batch_size'] | |
test_batch_size = params['test_batch_size'] | |
baseline_update_period = params['baseline_update_period'] | |
# Datasets | |
validation_dataset = VRP_Dataset(validation_dataset_size, num_nodes, num_depots, dataset_path, device) | |
baseline_dataset = VRP_Dataset(train_dataset_size, num_nodes, num_depots, dataset_path, device) | |
if params['overfit_test']: | |
train_dataset = VRP_Dataset(train_dataset_size, num_nodes, num_depots, dataset_path, device) | |
baseline_dataset = train_dataset | |
validation_dataset = train_dataset | |
# Evaluate Google solver | |
google_scores = evaluate_google_model(validation_dataset) | |
tot_google_scores = google_scores.sum().item() | |
input_size = validation_dataset.model_input_length() | |
# Models | |
model = Model(input_size=input_size, embedding_size=embedding_size, decoder_input_size=params["decoder_input_size"]) | |
actor = Actor(model=model, num_movers=num_movers, num_neighbors_encoder=num_neighbors_encoder, | |
num_neighbors_action=num_neighbors_action, device=device, normalize=False) | |
actor.train_mode() | |
baseline_model = Model(input_size=input_size, embedding_size=embedding_size, decoder_input_size=params["decoder_input_size"]) | |
baseline_actor = Actor(model=baseline_model, num_movers=num_movers, num_neighbors_encoder=num_neighbors_encoder, | |
num_neighbors_action=num_neighbors_action, device=device, normalize=False) | |
baseline_actor.greedy_search() | |
baseline_actor.load_state_dict(actor.state_dict()) | |
nn_actor = Actor(model=None, num_movers=1, num_neighbors_action=1, device=device) | |
nn_actor.nearest_neighbors() | |
optimizer = optim.Adam(params=actor.parameters(), lr=learning_rate) | |
train_batch_record = 100 | |
validation_record = 100 | |
baseline_record = None | |
# Training loop | |
for epoch in range(params['num_epochs']): | |
if not params['overfit_test']: | |
train_dataset = VRP_Dataset(train_dataset_size, num_nodes, num_depots, dataset_path, device) | |
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=train_dataset.collate) | |
for i, batch in enumerate(train_dataloader): | |
with torch.no_grad(): | |
nn_actor.nearest_neighbors() | |
nn_output = nn_actor(batch) | |
tot_nn_cost = nn_output['total_time'].sum().item() | |
baseline_actor.greedy_search() | |
baseline_cost = baseline_actor(batch)['total_time'] | |
actor.train_mode() | |
actor_output = actor(batch) | |
actor_cost, log_probs = actor_output['total_time'], actor_output['log_probs'] | |
loss = ((actor_cost - baseline_cost).detach() * log_probs).mean() | |
optimizer.zero_grad() | |
loss.backward() | |
if gradient_clipping: | |
for group in optimizer.param_groups: | |
clip_grad_norm_(group['params'], 1, norm_type=2) | |
optimizer.step() | |
tot_actor_cost = actor_cost.sum().item() | |
tot_baseline_cost = baseline_cost.sum().item() | |
actor_nn_ratio = tot_actor_cost / tot_nn_cost | |
actor_baseline_ratio = tot_actor_cost / tot_baseline_cost | |
train_batch_record = min(train_batch_record, actor_nn_ratio) | |
result = f"{epoch}, {i}, {actor_nn_ratio:.4f}, {actor_baseline_ratio:.4f}, {train_batch_record:.4f}" | |
print(result, flush=True) | |
if save_results: | |
with open('train_results.txt', 'a') as f: | |
f.write(result + '\n') | |
del batch | |
# Validation | |
if epoch % 5 == 0: | |
baseline_dataloader = DataLoader(baseline_dataset, batch_size=batch_size, collate_fn=baseline_dataset.collate) | |
tot_cost = [] | |
for batch in baseline_dataloader: | |
with torch.no_grad(): | |
actor.greedy_search() | |
actor_output = actor(batch) | |
cost = actor_output['total_time'] | |
tot_cost.append(cost) | |
del batch | |
tot_cost = torch.cat(tot_cost, dim=0) | |
if baseline_record is None or (tot_cost < baseline_record).float().mean().item() > 0.9: | |
baseline_record = tot_cost | |
baseline_actor.load_state_dict(actor.state_dict()) | |
print('\nNew baseline record\n') | |
# Test every 10 epochs | |
if (epoch % 10 == 0) and run_tests: | |
b = max(int(batch_size // sample_size**2), 1) | |
validation_dataloader = DataLoader(validation_dataset, batch_size=b, collate_fn=validation_dataset.collate) | |
tot_cost = 0 | |
tot_nn_cost = 0 | |
for batch in validation_dataloader: | |
with torch.no_grad(): | |
actor.beam_search(sample_size) | |
actor_output = actor(batch) | |
cost = actor_output['total_time'] | |
nn_actor.nearest_neighbors() | |
nn_output = nn_actor(batch) | |
nn_cost = nn_output['total_time'] | |
tot_cost += cost.sum().item() | |
tot_nn_cost += nn_cost.sum().item() | |
ratio = tot_cost / tot_nn_cost | |
validation_record = min(validation_record, ratio) | |
actor_google_ratio = tot_cost / tot_google_scores | |
print(f"\nTest results:\nActor/Google: {actor_google_ratio:.4f}, Actor/NN: {ratio:.4f}, Best NN Ratio: {validation_record:.4f}\n") | |
if save_results: | |
with open('test_results.txt', 'a') as f: | |
f.write(f"{epoch}, {actor_google_ratio:.4f}, {ratio:.4f}, {validation_record:.4f}\n") | |
print("End") | |