From 69462c3b2d067c7715b149c7b1aa56f95474c753 Mon Sep 17 00:00:00 2001 From: Heiru Wu Date: Tue, 28 Nov 2023 04:00:27 +0800 Subject: [PATCH] fix: fix post process --- mobilenetv2/1/model.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/mobilenetv2/1/model.py b/mobilenetv2/1/model.py index 3cfccab..1a0bdc2 100644 --- a/mobilenetv2/1/model.py +++ b/mobilenetv2/1/model.py @@ -55,12 +55,6 @@ def _image_labels(self) -> List[str]: categories.append(label.strip()) return categories - def process_model_outputs(self, output: np.array): - probabilities = torch.nn.functional.softmax(torch.from_numpy(output), dim=0) - prob, catid = torch.topk(probabilities, 1) - - return catid, prob - def ModelMetadata(self, req: ModelMetadataRequest) -> ModelMetadataResponse: resp = ModelMetadataResponse( name=req.name, @@ -103,7 +97,7 @@ async def ModelInfer(self, request: ModelInferRequest) -> ModelInferResponse: # shape=(1, batch_size, 1000) # tensor([[207], [294]]), tensor([[0.7107], [0.7309]]) - cat, score = self.process_model_outputs(out[0]) + score, cat = torch.topk(torch.from_numpy(out[0]), 1) s_out = [ bytes(f"{score[i][0]}:{self.categories[cat[i]]}", "utf-8") for i in range(cat.size(0))