-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
39 lines (31 loc) · 1.07 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""Sample Model class for track 1."""
import torch
from torch import nn
from torchvision.transforms import v2
# We use PatchcoreModel for an example. You can replace it with your model.
from anomalib.models.image.patchcore.torch_model import PatchcoreModel
class Patchcore(nn.Module):
def __init__(
self,
backbone: str = "wide_resnet50_2",
layers: list[str] = ["layer1", "layer2", "layer3"],
pre_trained: bool = True,
num_neighbors: int = 9,
) -> None:
super().__init__()
self.transform = v2.Compose(
[
v2.Resize((256, 256)),
v2.CenterCrop((224, 224)),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
],
)
self.model = PatchcoreModel(
backbone=backbone,
layers=layers,
pre_trained=pre_trained,
num_neighbors=num_neighbors,
)
def forward(self, batch: torch.Tensor):
batch = self.transform(batch)
return self.model(batch)