diff --git a/dsm/utilities.py b/dsm/utilities.py index bc25cca..dcf3994 100644 --- a/dsm/utilities.py +++ b/dsm/utilities.py @@ -50,7 +50,7 @@ def get_optimizer(model, lr): ' is not implemented') def pretrain_dsm(model, t_train, e_train, t_valid, e_valid, - n_iter=10000, lr=1e-2, thres=1e-4): + n_iter=10000, lr=1e-2): premodel = DeepSurvivalMachinesTorch(1, 1, dist=model.dist, @@ -59,7 +59,7 @@ def pretrain_dsm(model, t_train, e_train, t_valid, e_valid, optimizer = get_optimizer(model, lr) - oldcost = float('inf') + best_loss = float('inf') patience = 0 costs = [] for _ in tqdm(range(n_iter)): @@ -74,15 +74,18 @@ def pretrain_dsm(model, t_train, e_train, t_valid, e_valid, valid_loss = 0 for r in range(model.risks): valid_loss += unconditional_loss(premodel, t_valid, e_valid, str(r+1)) - valid_loss = valid_loss.detach().cpu().numpy() - costs.append(valid_loss) - #print(valid_loss) - if np.abs(costs[-1] - oldcost) < thres: + valid_loss = valid_loss.item() + + if best_loss < valid_loss: patience += 1 - if patience == 3: + if patience == 2: break - oldcost = costs[-1] - + else: + patience = 0 + best_loss = valid_loss + best_params = deepcopy(premodel.state_dict()) + + premodel.load_state_dict(best_params) return premodel def _reshape_tensor_with_nans(data): @@ -129,8 +132,7 @@ def train_dsm(model, t_valid_, e_valid_, n_iter=10000, - lr=1e-2, - thres=1e-4) + lr=1e-2) for r in range(model.risks): model.shape[str(r+1)].data.fill_(float(premodel.shape[str(r+1)])) @@ -140,7 +142,7 @@ def train_dsm(model, optimizer = get_optimizer(model, lr) patience = 0 - oldcost = float('inf') + best_loss = float('inf') nbatches = int(x_train.shape[0]/bs)+1 @@ -166,7 +168,6 @@ def train_dsm(model, _reshape_tensor_with_nans(eb), elbo=elbo, risk=str(r+1)) - #print ("Train Loss:", float(loss)) loss.backward() optimizer.step() @@ -179,30 +180,17 @@ def train_dsm(model, elbo=False, risk=str(r+1)) - valid_loss = valid_loss.detach().cpu().numpy() - costs.append(float(valid_loss)) - dics.append(deepcopy(model.state_dict())) - - if costs[-1] >= oldcost: + valid_loss = valid_loss.item() + if valid_loss > best_loss: if patience == 2: - minm = np.argmin(costs) - model.load_state_dict(dics[minm]) - - del dics - gc.collect() - - return model, i + break else: patience += 1 else: patience = 0 - - oldcost = costs[-1] - - minm = np.argmin(costs) - model.load_state_dict(dics[minm]) - - del dics - gc.collect() + best_state = deepcopy(model.state_dict()) + best_loss = valid_loss + + model.load_state_dict(best_state) return model, i