-
Notifications
You must be signed in to change notification settings - Fork 4
/
problem_tsp_normal.py
87 lines (64 loc) · 2.75 KB
/
problem_tsp_normal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from torch.utils.data import Dataset
import torch
import os
import pickle
from problems.tsp.state_tsp import StateTSP
from utils.beam_search import beam_search
class TSP(object):
NAME = 'tsp'
@staticmethod
def get_costs(dataset, pi):
# Check that tours are valid, i.e. contain 0 to n -1
assert (
torch.arange(pi.size(1), out=pi.data.new()).view(1, -1).expand_as(pi) ==
pi.data.sort(1)[0]
).all(), "Invalid tour"
# Gather dataset in order of tour
d = dataset.gather(1, pi.unsqueeze(-1).expand_as(dataset))
# Length is distance (L2-norm of difference) from each next location from its prev and of last from first
return (d[:, 1:] - d[:, :-1]).norm(p=2, dim=2).sum(1) + (d[:, 0] - d[:, -1]).norm(p=2, dim=1), None
@staticmethod
def make_dataset(*args, **kwargs):
return TSPDataset(*args, **kwargs)
@staticmethod
def make_state(*args, **kwargs):
return StateTSP.initialize(*args, **kwargs)
@staticmethod
def beam_search(input, beam_size, expand_size=None,
compress_mask=False, model=None, max_calc_batch_size=4096):
assert model is not None, "Provide model"
fixed = model.precompute_fixed(input)
def propose_expansions(beam):
return model.propose_expansions(
beam, fixed, expand_size, normalize=True, max_calc_batch_size=max_calc_batch_size
)
state = TSP.make_state(
input, visited_dtype=torch.int64 if compress_mask else torch.uint8
)
return beam_search(state, beam_size, propose_expansions)
class TSPDataset(Dataset):
def __init__(self, filename=None, size=50, num_samples=1000000, offset=0, distribution=None):
super(TSPDataset, self).__init__()
self.data_set = []
if filename is not None:
assert os.path.splitext(filename)[1] == '.pkl'
with open(filename, 'rb') as f:
data = pickle.load(f)
self.data = [torch.FloatTensor(row) for row in (data[offset:offset+num_samples])]
else:
# Sample points randomly in [0, 1] square
# self.data = [torch.FloatTensor(size, 2).uniform_(0, 1) for i in range(num_samples)]
# normal distribution scaled to 0 - 1
l = []
for i in range(num_samples):
a = torch.FloatTensor(size, 2).normal_(0.5, 0.25)
amin = a.min(0, keepdim=True)
amax = a.max(0, keepdim=True)
a = (a - amin[0]) / (amax[0] - amin[0])
l.append(a)
self.data = l
self.size = len(self.data)
def __len__(self):
return self.size
def __getitem__(self, idx):
return self.data[idx]