Skip to content

Commit

Permalink
Typing improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
kisonho committed Jan 12, 2024
1 parent aa9e061 commit b8edcde
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 6 deletions.
3 changes: 1 addition & 2 deletions magnet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from . import data, losses, networks, nn
from .managers import MonaiTargetingManager as MonaiManager, TargetingManager as Manager
from .nn import MAGNET, MAGNET2
from .version import VERSION

try:
from .builder import build, build_v1, build_v2, build_v2_unet
except ImportError:
build = build_v1 = build_v2 = build_v2_unet = NotImplemented

VERSION = "2.1"
12 changes: 10 additions & 2 deletions magnet/nn/shared.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from torchmanager_core import torch
from torchmanager_core.typing import Any, Generic, Module, NamedTuple, Optional, Union
from torchmanager_core.typing import Any, Generic, Module, NamedTuple, Optional, Union, overload

from .fusion import Fusion

Expand Down Expand Up @@ -82,6 +82,14 @@ def forward(self, x_in: Union[torch.Tensor, list[torch.Tensor]], *args: Any, **k
assert not isinstance(self.target, list), "Multiple targets are assigned but only one modality is given to the input."
return self.target_modules[self.target](x_in, *args, **kwargs)

@overload
def fuse(self, preds: list[torch.Tensor]) -> torch.Tensor:
...

@overload
def fuse(self, preds: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
...

def fuse(self, preds: list[Any]) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
# check predictions
assert len(preds) > 0, "Results must contain at least one predictions"
Expand Down Expand Up @@ -151,7 +159,7 @@ def __init__(self, *encoders: Module, fusion: Optional[Fusion] = None, decoder:
self.decoder = decoder
self.return_features = return_features

def forward(self, x_in: Union[torch.Tensor, list[torch.Tensor], dict[str, Any]], *args: Any, **kwargs: Any) -> Any:
def forward(self, x_in: Union[torch.Tensor, list[torch.Tensor], dict[str, Any]], *args: Any, **kwargs: Any) -> Union[FeaturedData, torch.Tensor]:
# unpack inputs
if isinstance(x_in, dict):
x = x_in['image']
Expand Down
3 changes: 3 additions & 0 deletions magnet/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from torchmanager_core import Version

VERSION = Version("2.1.1")
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "magms"
version = "2.1"
description = "Modality-Agnostic Learning for Medical Image Segmentation Using Multi-modality Self-distillation (v2.1)"
version = "2.1.1"
description = "Modality-Agnostic Learning for Medical Image Segmentation Using Multi-modality Self-distillation (v2.1.1)"
authors = ["Qisheng He <Qisheng.He@wayne.edu>"]
repository = "https://github.com/kisonho/magnet.git"

Expand Down

0 comments on commit b8edcde

Please sign in to comment.