From 6ba7637f4de50afcdff69896e47e7be0de764f45 Mon Sep 17 00:00:00 2001 From: Gajendra Jung Katuwal Date: Thu, 29 Mar 2018 12:15:23 -0400 Subject: [PATCH] Change num_capusle to num_conv2D_per_capsule MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change the input variable num_capusle to num_conv2D_per_capsule in the class PrimaryCaps because 8 is not the number of capsules, it is the number of conv2D units in one primary capsule. From the paper--"each primary capsule contains 8 convolutional units with a 9×9 kernel and a stride of 2" --- Capsule Network.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Capsule Network.ipynb b/Capsule Network.ipynb index f3cd682..2d8efc7 100644 --- a/Capsule Network.ipynb +++ b/Capsule Network.ipynb @@ -73,12 +73,12 @@ "outputs": [], "source": [ "class PrimaryCaps(nn.Module):\n", - " def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9):\n", + " def __init__(self, num_conv2D_per_capsule=8, in_channels=256, out_channels=32, kernel_size=9):\n", " super(PrimaryCaps, self).__init__()\n", "\n", " self.capsules = nn.ModuleList([\n", " nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0) \n", - " for _ in range(num_capsules)])\n", + " for _ in range(num_conv2D_per_capsule)])\n", " \n", " def forward(self, x):\n", " u = [capsule(x) for capsule in self.capsules]\n",