From 73e6239207085be3f108ffb9498ad2b5a881aa42 Mon Sep 17 00:00:00 2001 From: Umberto Lupo Date: Sat, 2 Dec 2023 13:05:36 +0100 Subject: [PATCH] Convert tensors to numpy --- diffpass/_modidx.py | 2 +- diffpass/train.py | 42 +++++++++++++++++++++--------------------- nbs/train.ipynb | 42 +++++++++++++++++++++--------------------- 3 files changed, 43 insertions(+), 43 deletions(-) diff --git a/diffpass/_modidx.py b/diffpass/_modidx.py index 6e7468f..1fd7a78 100644 --- a/diffpass/_modidx.py +++ b/diffpass/_modidx.py @@ -194,7 +194,7 @@ 'diffpass/train.py'), 'diffpass.train.InformationAndMirrortree.soft_': ( 'train.html#informationandmirrortree.soft_', 'diffpass/train.py'), - 'diffpass.train._dcc': ('train.html#_dcc', 'diffpass/train.py'), + 'diffpass.train._dccn': ('train.html#_dccn', 'diffpass/train.py'), 'diffpass.train.apply_hard_permutation_batch_to_similarity': ( 'train.html#apply_hard_permutation_batch_to_similarity', 'diffpass/train.py'), 'diffpass.train.compute_num_correct_matchings': ( 'train.html#compute_num_correct_matchings', diff --git a/diffpass/train.py b/diffpass/train.py index 7bce135..828ed1a 100644 --- a/diffpass/train.py +++ b/diffpass/train.py @@ -69,8 +69,8 @@ def apply_hard_permutation_batch_to_similarity( return torch.gather(x_permuted_rows, -1, index) -def _dcc(x: torch.Tensor) -> torch.Tensor: - return x.detach().clone().cpu() +def _dccn(x: torch.Tensor) -> torch.Tensor: + return x.detach().clone().cpu().numpy() # %% ../nbs/train.ipynb 5 @dataclass @@ -243,16 +243,16 @@ def _fit( loss_info = epoch_results["loss_info"] perms = epoch_results["perms"] results.log_alphas.append( - [_dcc(log_alpha) for log_alpha in self.permutation.log_alphas] + [_dccn(log_alpha) for log_alpha in self.permutation.log_alphas] ) results.hard_perms.append( [ - _dcc(perms_this_group).argmax(-1).to(torch.int16) + _dccn(perms_this_group).argmax(-1).to(torch.int16) for perms_this_group in perms ] ) results.hard_losses[self.information_measure].append( - _dcc(loss_info) + _dccn(loss_info) ) # Soft pass @@ -262,10 +262,10 @@ def _fit( loss_info = epoch_results["loss_info"] perms = epoch_results["perms"] results.soft_perms.append( - [_dcc(perms_this_group) for perms_this_group in perms] + [_dccn(perms_this_group) for perms_this_group in perms] ) results.soft_losses[self.information_measure].append( - _dcc(loss_info) + _dccn(loss_info) ) loss = loss_info.sum() @@ -649,18 +649,18 @@ def _fit( x_perm_hard = epoch_results["x_perm"] perms = epoch_results["perms"] results.log_alphas.append( - [_dcc(log_alpha) for log_alpha in self.permutation.log_alphas] + [_dccn(log_alpha) for log_alpha in self.permutation.log_alphas] ) results.hard_perms.append( [ - _dcc(perms_this_group).argmax(-1).to(torch.int16) + _dccn(perms_this_group).argmax(-1).to(torch.int16) for perms_this_group in perms ] ) results.hard_losses[self.information_measure].append( - _dcc(loss_info) + _dccn(loss_info) ) - results.hard_losses["BestHits"].append(_dcc(loss_bh)) + results.hard_losses["BestHits"].append(_dccn(loss_bh)) # Soft pass if i < epochs or compute_final_soft: @@ -670,12 +670,12 @@ def _fit( loss_bh = epoch_results["loss_bh"] perms = epoch_results["perms"] results.soft_perms.append( - [_dcc(perms_this_group) for perms_this_group in perms] + [_dccn(perms_this_group) for perms_this_group in perms] ) results.soft_losses[self.information_measure].append( - _dcc(loss_info) + _dccn(loss_info) ) - results.soft_losses["BestHits"].append(_dcc(loss_bh)) + results.soft_losses["BestHits"].append(_dccn(loss_bh)) loss = ( self.information_loss.weight * loss_info @@ -1020,18 +1020,18 @@ def _fit( x_perm_hard = epoch_results["x_perm"] perms = epoch_results["perms"] results.log_alphas.append( - [_dcc(log_alpha) for log_alpha in self.permutation.log_alphas] + [_dccn(log_alpha) for log_alpha in self.permutation.log_alphas] ) results.hard_perms.append( [ - _dcc(perms_this_group).argmax(-1).to(torch.int16) + _dccn(perms_this_group).argmax(-1).to(torch.int16) for perms_this_group in perms ] ) results.hard_losses[self.information_measure].append( - _dcc(loss_info) + _dccn(loss_info) ) - results.hard_losses["Mirrortree"].append(_dcc(loss_mt)) + results.hard_losses["Mirrortree"].append(_dccn(loss_mt)) # Soft pass if i < epochs or compute_final_soft: @@ -1041,12 +1041,12 @@ def _fit( loss_mt = epoch_results["loss_mt"] perms = epoch_results["perms"] results.soft_perms.append( - [_dcc(perms_this_group) for perms_this_group in perms] + [_dccn(perms_this_group) for perms_this_group in perms] ) results.soft_losses[self.information_measure].append( - _dcc(loss_info) + _dccn(loss_info) ) - results.soft_losses["Mirrortree"].append(_dcc(loss_mt)) + results.soft_losses["Mirrortree"].append(_dccn(loss_mt)) loss = ( self.information_loss.weight * loss_info diff --git a/nbs/train.ipynb b/nbs/train.ipynb index a76b211..c8e5ea5 100644 --- a/nbs/train.ipynb +++ b/nbs/train.ipynb @@ -98,8 +98,8 @@ " return torch.gather(x_permuted_rows, -1, index)\n", "\n", "\n", - "def _dcc(x: torch.Tensor) -> torch.Tensor:\n", - " return x.detach().clone().cpu()" + "def _dccn(x: torch.Tensor) -> torch.Tensor:\n", + " return x.detach().clone().cpu().numpy()" ] }, { @@ -315,16 +315,16 @@ " loss_info = epoch_results[\"loss_info\"]\n", " perms = epoch_results[\"perms\"]\n", " results.log_alphas.append(\n", - " [_dcc(log_alpha) for log_alpha in self.permutation.log_alphas]\n", + " [_dccn(log_alpha) for log_alpha in self.permutation.log_alphas]\n", " )\n", " results.hard_perms.append(\n", " [\n", - " _dcc(perms_this_group).argmax(-1).to(torch.int16)\n", + " _dccn(perms_this_group).argmax(-1).to(torch.int16)\n", " for perms_this_group in perms\n", " ]\n", " )\n", " results.hard_losses[self.information_measure].append(\n", - " _dcc(loss_info)\n", + " _dccn(loss_info)\n", " )\n", "\n", " # Soft pass\n", @@ -334,10 +334,10 @@ " loss_info = epoch_results[\"loss_info\"]\n", " perms = epoch_results[\"perms\"]\n", " results.soft_perms.append(\n", - " [_dcc(perms_this_group) for perms_this_group in perms]\n", + " [_dccn(perms_this_group) for perms_this_group in perms]\n", " )\n", " results.soft_losses[self.information_measure].append(\n", - " _dcc(loss_info)\n", + " _dccn(loss_info)\n", " )\n", "\n", " loss = loss_info.sum()\n", @@ -729,18 +729,18 @@ " x_perm_hard = epoch_results[\"x_perm\"]\n", " perms = epoch_results[\"perms\"]\n", " results.log_alphas.append(\n", - " [_dcc(log_alpha) for log_alpha in self.permutation.log_alphas]\n", + " [_dccn(log_alpha) for log_alpha in self.permutation.log_alphas]\n", " )\n", " results.hard_perms.append(\n", " [\n", - " _dcc(perms_this_group).argmax(-1).to(torch.int16)\n", + " _dccn(perms_this_group).argmax(-1).to(torch.int16)\n", " for perms_this_group in perms\n", " ]\n", " )\n", " results.hard_losses[self.information_measure].append(\n", - " _dcc(loss_info)\n", + " _dccn(loss_info)\n", " )\n", - " results.hard_losses[\"BestHits\"].append(_dcc(loss_bh))\n", + " results.hard_losses[\"BestHits\"].append(_dccn(loss_bh))\n", "\n", " # Soft pass\n", " if i < epochs or compute_final_soft:\n", @@ -750,12 +750,12 @@ " loss_bh = epoch_results[\"loss_bh\"]\n", " perms = epoch_results[\"perms\"]\n", " results.soft_perms.append(\n", - " [_dcc(perms_this_group) for perms_this_group in perms]\n", + " [_dccn(perms_this_group) for perms_this_group in perms]\n", " )\n", " results.soft_losses[self.information_measure].append(\n", - " _dcc(loss_info)\n", + " _dccn(loss_info)\n", " )\n", - " results.soft_losses[\"BestHits\"].append(_dcc(loss_bh))\n", + " results.soft_losses[\"BestHits\"].append(_dccn(loss_bh))\n", "\n", " loss = (\n", " self.information_loss.weight * loss_info\n", @@ -1108,18 +1108,18 @@ " x_perm_hard = epoch_results[\"x_perm\"]\n", " perms = epoch_results[\"perms\"]\n", " results.log_alphas.append(\n", - " [_dcc(log_alpha) for log_alpha in self.permutation.log_alphas]\n", + " [_dccn(log_alpha) for log_alpha in self.permutation.log_alphas]\n", " )\n", " results.hard_perms.append(\n", " [\n", - " _dcc(perms_this_group).argmax(-1).to(torch.int16)\n", + " _dccn(perms_this_group).argmax(-1).to(torch.int16)\n", " for perms_this_group in perms\n", " ]\n", " )\n", " results.hard_losses[self.information_measure].append(\n", - " _dcc(loss_info)\n", + " _dccn(loss_info)\n", " )\n", - " results.hard_losses[\"Mirrortree\"].append(_dcc(loss_mt))\n", + " results.hard_losses[\"Mirrortree\"].append(_dccn(loss_mt))\n", "\n", " # Soft pass\n", " if i < epochs or compute_final_soft:\n", @@ -1129,12 +1129,12 @@ " loss_mt = epoch_results[\"loss_mt\"]\n", " perms = epoch_results[\"perms\"]\n", " results.soft_perms.append(\n", - " [_dcc(perms_this_group) for perms_this_group in perms]\n", + " [_dccn(perms_this_group) for perms_this_group in perms]\n", " )\n", " results.soft_losses[self.information_measure].append(\n", - " _dcc(loss_info)\n", + " _dccn(loss_info)\n", " )\n", - " results.soft_losses[\"Mirrortree\"].append(_dcc(loss_mt))\n", + " results.soft_losses[\"Mirrortree\"].append(_dccn(loss_mt))\n", "\n", " loss = (\n", " self.information_loss.weight * loss_info\n",