Skip to content

Commit

Permalink
fix: conditional loss breaking for batch size one (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
tilman151 authored Mar 14, 2024
1 parent fbad168 commit 31ce4ba
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 1 deletion.
2 changes: 2 additions & 0 deletions rul_adapt/approach/pseudo_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def generate_pseudo_labels(
values. It is recommended to clip them to zero and `max_rul` respectively before
using them to patch a reader.
The model is assumed to reside on the CPU where the calculation will be performed.
Args:
dm: The data module to generate pseudo labels for.
model: The model to use for generating the pseudo labels.
Expand Down
2 changes: 1 addition & 1 deletion rul_adapt/loss/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def compute(self) -> torch.Tensor:


def _membership(preds: torch.Tensor, fuzzy_set: Tuple[float, float]) -> torch.Tensor:
preds = preds.squeeze() if len(preds.shape) > 1 else preds
preds = preds.squeeze(-1) if preds.ndim > 1 else preds
membership = (preds >= fuzzy_set[0]) & (preds < fuzzy_set[1])

return membership
12 changes: 12 additions & 0 deletions tests/test_loss/test_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,18 @@ def test__membership():
assert torch.all(_membership(inputs, fuzzy_set) == expected)


@pytest.mark.parametrize("loss_fixture", ["cdann", "cmmd"])
def test_forward_batch_size_one(loss_fixture, request):
"""Should not fail for batch size of one."""
loss_func = request.getfixturevalue(loss_fixture)
source = torch.rand(1, 10)
source_preds = torch.zeros(1, 1)
target = torch.rand(1, 10)
target_preds = torch.zeros(1, 1)

loss_func(source, source_preds, target, target_preds)


def test_backward_cdann(cdann):
source = torch.rand(10, 10)
source_preds = torch.zeros(10, 1)
Expand Down

0 comments on commit 31ce4ba

Please sign in to comment.