Skip to content

Commit

Permalink
Convert tensors to numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
ulupo committed Dec 2, 2023
1 parent 9b88dd7 commit 73e6239
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 43 deletions.
2 changes: 1 addition & 1 deletion diffpass/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@
'diffpass/train.py'),
'diffpass.train.InformationAndMirrortree.soft_': ( 'train.html#informationandmirrortree.soft_',
'diffpass/train.py'),
'diffpass.train._dcc': ('train.html#_dcc', 'diffpass/train.py'),
'diffpass.train._dccn': ('train.html#_dccn', 'diffpass/train.py'),
'diffpass.train.apply_hard_permutation_batch_to_similarity': ( 'train.html#apply_hard_permutation_batch_to_similarity',
'diffpass/train.py'),
'diffpass.train.compute_num_correct_matchings': ( 'train.html#compute_num_correct_matchings',
Expand Down
42 changes: 21 additions & 21 deletions diffpass/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def apply_hard_permutation_batch_to_similarity(
return torch.gather(x_permuted_rows, -1, index)


def _dcc(x: torch.Tensor) -> torch.Tensor:
return x.detach().clone().cpu()
def _dccn(x: torch.Tensor) -> torch.Tensor:
return x.detach().clone().cpu().numpy()

# %% ../nbs/train.ipynb 5
@dataclass
Expand Down Expand Up @@ -243,16 +243,16 @@ def _fit(
loss_info = epoch_results["loss_info"]
perms = epoch_results["perms"]
results.log_alphas.append(
[_dcc(log_alpha) for log_alpha in self.permutation.log_alphas]
[_dccn(log_alpha) for log_alpha in self.permutation.log_alphas]
)
results.hard_perms.append(
[
_dcc(perms_this_group).argmax(-1).to(torch.int16)
_dccn(perms_this_group).argmax(-1).to(torch.int16)
for perms_this_group in perms
]
)
results.hard_losses[self.information_measure].append(
_dcc(loss_info)
_dccn(loss_info)
)

# Soft pass
Expand All @@ -262,10 +262,10 @@ def _fit(
loss_info = epoch_results["loss_info"]
perms = epoch_results["perms"]
results.soft_perms.append(
[_dcc(perms_this_group) for perms_this_group in perms]
[_dccn(perms_this_group) for perms_this_group in perms]
)
results.soft_losses[self.information_measure].append(
_dcc(loss_info)
_dccn(loss_info)
)

loss = loss_info.sum()
Expand Down Expand Up @@ -649,18 +649,18 @@ def _fit(
x_perm_hard = epoch_results["x_perm"]
perms = epoch_results["perms"]
results.log_alphas.append(
[_dcc(log_alpha) for log_alpha in self.permutation.log_alphas]
[_dccn(log_alpha) for log_alpha in self.permutation.log_alphas]
)
results.hard_perms.append(
[
_dcc(perms_this_group).argmax(-1).to(torch.int16)
_dccn(perms_this_group).argmax(-1).to(torch.int16)
for perms_this_group in perms
]
)
results.hard_losses[self.information_measure].append(
_dcc(loss_info)
_dccn(loss_info)
)
results.hard_losses["BestHits"].append(_dcc(loss_bh))
results.hard_losses["BestHits"].append(_dccn(loss_bh))

# Soft pass
if i < epochs or compute_final_soft:
Expand All @@ -670,12 +670,12 @@ def _fit(
loss_bh = epoch_results["loss_bh"]
perms = epoch_results["perms"]
results.soft_perms.append(
[_dcc(perms_this_group) for perms_this_group in perms]
[_dccn(perms_this_group) for perms_this_group in perms]
)
results.soft_losses[self.information_measure].append(
_dcc(loss_info)
_dccn(loss_info)
)
results.soft_losses["BestHits"].append(_dcc(loss_bh))
results.soft_losses["BestHits"].append(_dccn(loss_bh))

loss = (
self.information_loss.weight * loss_info
Expand Down Expand Up @@ -1020,18 +1020,18 @@ def _fit(
x_perm_hard = epoch_results["x_perm"]
perms = epoch_results["perms"]
results.log_alphas.append(
[_dcc(log_alpha) for log_alpha in self.permutation.log_alphas]
[_dccn(log_alpha) for log_alpha in self.permutation.log_alphas]
)
results.hard_perms.append(
[
_dcc(perms_this_group).argmax(-1).to(torch.int16)
_dccn(perms_this_group).argmax(-1).to(torch.int16)
for perms_this_group in perms
]
)
results.hard_losses[self.information_measure].append(
_dcc(loss_info)
_dccn(loss_info)
)
results.hard_losses["Mirrortree"].append(_dcc(loss_mt))
results.hard_losses["Mirrortree"].append(_dccn(loss_mt))

# Soft pass
if i < epochs or compute_final_soft:
Expand All @@ -1041,12 +1041,12 @@ def _fit(
loss_mt = epoch_results["loss_mt"]
perms = epoch_results["perms"]
results.soft_perms.append(
[_dcc(perms_this_group) for perms_this_group in perms]
[_dccn(perms_this_group) for perms_this_group in perms]
)
results.soft_losses[self.information_measure].append(
_dcc(loss_info)
_dccn(loss_info)
)
results.soft_losses["Mirrortree"].append(_dcc(loss_mt))
results.soft_losses["Mirrortree"].append(_dccn(loss_mt))

loss = (
self.information_loss.weight * loss_info
Expand Down
42 changes: 21 additions & 21 deletions nbs/train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@
" return torch.gather(x_permuted_rows, -1, index)\n",
"\n",
"\n",
"def _dcc(x: torch.Tensor) -> torch.Tensor:\n",
" return x.detach().clone().cpu()"
"def _dccn(x: torch.Tensor) -> torch.Tensor:\n",
" return x.detach().clone().cpu().numpy()"
]
},
{
Expand Down Expand Up @@ -315,16 +315,16 @@
" loss_info = epoch_results[\"loss_info\"]\n",
" perms = epoch_results[\"perms\"]\n",
" results.log_alphas.append(\n",
" [_dcc(log_alpha) for log_alpha in self.permutation.log_alphas]\n",
" [_dccn(log_alpha) for log_alpha in self.permutation.log_alphas]\n",
" )\n",
" results.hard_perms.append(\n",
" [\n",
" _dcc(perms_this_group).argmax(-1).to(torch.int16)\n",
" _dccn(perms_this_group).argmax(-1).to(torch.int16)\n",
" for perms_this_group in perms\n",
" ]\n",
" )\n",
" results.hard_losses[self.information_measure].append(\n",
" _dcc(loss_info)\n",
" _dccn(loss_info)\n",
" )\n",
"\n",
" # Soft pass\n",
Expand All @@ -334,10 +334,10 @@
" loss_info = epoch_results[\"loss_info\"]\n",
" perms = epoch_results[\"perms\"]\n",
" results.soft_perms.append(\n",
" [_dcc(perms_this_group) for perms_this_group in perms]\n",
" [_dccn(perms_this_group) for perms_this_group in perms]\n",
" )\n",
" results.soft_losses[self.information_measure].append(\n",
" _dcc(loss_info)\n",
" _dccn(loss_info)\n",
" )\n",
"\n",
" loss = loss_info.sum()\n",
Expand Down Expand Up @@ -729,18 +729,18 @@
" x_perm_hard = epoch_results[\"x_perm\"]\n",
" perms = epoch_results[\"perms\"]\n",
" results.log_alphas.append(\n",
" [_dcc(log_alpha) for log_alpha in self.permutation.log_alphas]\n",
" [_dccn(log_alpha) for log_alpha in self.permutation.log_alphas]\n",
" )\n",
" results.hard_perms.append(\n",
" [\n",
" _dcc(perms_this_group).argmax(-1).to(torch.int16)\n",
" _dccn(perms_this_group).argmax(-1).to(torch.int16)\n",
" for perms_this_group in perms\n",
" ]\n",
" )\n",
" results.hard_losses[self.information_measure].append(\n",
" _dcc(loss_info)\n",
" _dccn(loss_info)\n",
" )\n",
" results.hard_losses[\"BestHits\"].append(_dcc(loss_bh))\n",
" results.hard_losses[\"BestHits\"].append(_dccn(loss_bh))\n",
"\n",
" # Soft pass\n",
" if i < epochs or compute_final_soft:\n",
Expand All @@ -750,12 +750,12 @@
" loss_bh = epoch_results[\"loss_bh\"]\n",
" perms = epoch_results[\"perms\"]\n",
" results.soft_perms.append(\n",
" [_dcc(perms_this_group) for perms_this_group in perms]\n",
" [_dccn(perms_this_group) for perms_this_group in perms]\n",
" )\n",
" results.soft_losses[self.information_measure].append(\n",
" _dcc(loss_info)\n",
" _dccn(loss_info)\n",
" )\n",
" results.soft_losses[\"BestHits\"].append(_dcc(loss_bh))\n",
" results.soft_losses[\"BestHits\"].append(_dccn(loss_bh))\n",
"\n",
" loss = (\n",
" self.information_loss.weight * loss_info\n",
Expand Down Expand Up @@ -1108,18 +1108,18 @@
" x_perm_hard = epoch_results[\"x_perm\"]\n",
" perms = epoch_results[\"perms\"]\n",
" results.log_alphas.append(\n",
" [_dcc(log_alpha) for log_alpha in self.permutation.log_alphas]\n",
" [_dccn(log_alpha) for log_alpha in self.permutation.log_alphas]\n",
" )\n",
" results.hard_perms.append(\n",
" [\n",
" _dcc(perms_this_group).argmax(-1).to(torch.int16)\n",
" _dccn(perms_this_group).argmax(-1).to(torch.int16)\n",
" for perms_this_group in perms\n",
" ]\n",
" )\n",
" results.hard_losses[self.information_measure].append(\n",
" _dcc(loss_info)\n",
" _dccn(loss_info)\n",
" )\n",
" results.hard_losses[\"Mirrortree\"].append(_dcc(loss_mt))\n",
" results.hard_losses[\"Mirrortree\"].append(_dccn(loss_mt))\n",
"\n",
" # Soft pass\n",
" if i < epochs or compute_final_soft:\n",
Expand All @@ -1129,12 +1129,12 @@
" loss_mt = epoch_results[\"loss_mt\"]\n",
" perms = epoch_results[\"perms\"]\n",
" results.soft_perms.append(\n",
" [_dcc(perms_this_group) for perms_this_group in perms]\n",
" [_dccn(perms_this_group) for perms_this_group in perms]\n",
" )\n",
" results.soft_losses[self.information_measure].append(\n",
" _dcc(loss_info)\n",
" _dccn(loss_info)\n",
" )\n",
" results.soft_losses[\"Mirrortree\"].append(_dcc(loss_mt))\n",
" results.soft_losses[\"Mirrortree\"].append(_dccn(loss_mt))\n",
"\n",
" loss = (\n",
" self.information_loss.weight * loss_info\n",
Expand Down

0 comments on commit 73e6239

Please sign in to comment.