Skip to content

Commit

Permalink
The model definitions are now updated to match the currently used mod…
Browse files Browse the repository at this point in the history
…els as mentioned in issue #1
  • Loading branch information
emilianavt committed Apr 30, 2020
1 parent b43f281 commit 8145e4b
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 142 deletions.
2 changes: 1 addition & 1 deletion Unity/OpenSeeShowPoints.cs
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ void Update () {
lineRenderers[i].enabled = true;
lineRenderers[i].widthMultiplier = lineWidth;
lineRenderers[i].material = lineMaterial;
lineRenderers[i].material.SetColor("_Color", color);
lineRenderers[i].startColor = color;
lineRenderers[i].endColor = color;
lineRenderers[i].material.SetColor("_Color", color);
lineRenderers[i].SetPosition(0, gameObjects[a].transform.position);
lineRenderers[i].SetPosition(1, gameObjects[b].transform.position);
}
Expand Down
311 changes: 170 additions & 141 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,164 +1,164 @@
# This file is not used by the tracking application and currently outdated
import torch
import torch.nn as nn
import geffnet.mobilenetv3 # geffnet.mobilenetv3._gen_mobilenet_v3 needs to be patched to return the parameters instead of instantiating the network
from geffnet.efficientnet_builder import round_channels

class DSConv2d(nn.Module):
def __init__(self, in_planes, out_planes, kernels_per_layer=4, groups=1):
def __init__(self, in_planes, out_planes, kernels_per_layer=4, groups=1, old=0):
super(DSConv2d, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_planes, in_planes * kernels_per_layer, kernel_size=3, padding=1, groups=in_planes),
nn.Conv2d(in_planes * kernels_per_layer, out_planes, kernel_size=1, groups=groups)
)
if old == 2:
self.conv = nn.Sequential(
nn.Conv2d(in_planes, in_planes * kernels_per_layer, kernel_size=3, padding=1, groups=in_planes),
nn.Conv2d(in_planes * kernels_per_layer, out_planes, kernel_size=1, groups=groups)
)
elif old == 1:
self.conv = nn.Sequential(
nn.Conv2d(in_planes, in_planes * kernels_per_layer, kernel_size=3, padding=1, groups=in_planes, bias=False),
nn.BatchNorm2d(in_planes * kernels_per_layer),
nn.Conv2d(in_planes * kernels_per_layer, out_planes, kernel_size=1, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU6(inplace=True)
)
else:
self.conv = nn.Sequential(
nn.Conv2d(in_planes, in_planes * kernels_per_layer, kernel_size=3, padding=1, groups=in_planes, bias=False),
nn.BatchNorm2d(in_planes * kernels_per_layer),
nn.ReLU6(inplace=True),
nn.Conv2d(in_planes * kernels_per_layer, out_planes, kernel_size=1, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU6(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x

class UNetUp(nn.Module):
def __init__(self, in_channels, residual_in_channels, out_channels, size):
def __init__(self, in_channels, residual_in_channels, out_channels, size, old=0):
super(UNetUp, self).__init__()
self.up = nn.Upsample(size=size, mode='bilinear', align_corners=True)
self.conv = DSConv2d(in_channels + residual_in_channels, out_channels, 1, 1)
self.conv = DSConv2d(in_channels + residual_in_channels, out_channels, 1, 1, old=old)
def forward(self, x1, x2):
x1 = self.up(x1)
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x

# Copy of torchvision ShuffleNetV2 for type annotation
def channel_shuffle(x, groups: int):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups

# reshape
x = x.view(batchsize, groups,
channels_per_group, height, width)

x = torch.transpose(x, 1, 2).contiguous()

# flatten
x = x.view(batchsize, -1, height, width)

return x

class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride):
super(InvertedResidual, self).__init__()

if not (1 <= stride <= 3):
raise ValueError('illegal stride value')
self.stride = stride

branch_features = oup // 2
assert (self.stride != 1) or (inp == branch_features << 1)

self.branch1 = nn.Sequential(
self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(inp),
nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)

self.branch2 = nn.Sequential(
nn.Conv2d(inp if (self.stride > 1) else branch_features,
branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(branch_features),
nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)

@staticmethod
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)

# This is the gaze tracking model
class OpenSeeFaceGaze(geffnet.mobilenetv3.MobileNetV3):
def __init__(self):
kwargs = geffnet.mobilenetv3._gen_mobilenet_v3(['small'])
super(OpenSeeFaceGaze, self).__init__(**kwargs)
self.up1 = UNetUp(576, 48, 64, (2,2), old=2)
self.up2 = UNetUp(64, 24, 32, (4,4), old=2)
self.up3 = UNetUp(32, 16, 15, (8,8), old=2)
self.group = DSConv2d(15, 3, kernels_per_layer=4, groups=3, old=2)
def _forward_impl(self, x):
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
r1 = None
r2 = None
r3 = None
for i, feature in enumerate(self.blocks):
x = feature(x)
if i == 3:
r3 = x
if i == 1:
r2 = x
if i == 0:
r1 = x
x = self.up1(x, r3)
x = self.up2(x, r2)
x = self.up3(x, r1)
x = self.group(x)
return x
def forward(self, x):
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)

out = channel_shuffle(out, 2)

return out

class ShuffleNetV2(nn.Module):
def __init__(self, stages_repeats, stages_out_channels, num_classes=1000):
super(ShuffleNetV2, self).__init__()

if len(stages_repeats) != 3:
raise ValueError('expected stages_repeats as list of 3 positive ints')
if len(stages_out_channels) != 5:
raise ValueError('expected stages_out_channels as list of 5 positive ints')
self._stage_out_channels = stages_out_channels

input_channels = 3
output_channels = self._stage_out_channels[0]
self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
input_channels = output_channels

self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
for name, repeats, output_channels in zip(
stage_names, stages_repeats, self._stage_out_channels[1:]):
seq = [InvertedResidual(input_channels, output_channels, 2)]
for i in range(repeats - 1):
seq.append(InvertedResidual(output_channels, output_channels, 1))
setattr(self, name, nn.Sequential(*seq))
input_channels = output_channels

output_channels = self._stage_out_channels[-1]
self.conv5 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)

self.fc = nn.Linear(output_channels, num_classes)
return self._forward_impl(x)

# This is the face detection model. Because the landmark model is very robust, it gets away with predicting very rough bounding boxes. It is fully convolutional and can be made to run on different resolutions. It was trained on 224x224 crops and the most reasonable results can be found in the range of 224x224 to 640x640.
class OpenSeeFaceDetect(geffnet.mobilenetv3.MobileNetV3):
def __init__(self, size="large", channel_multiplier=0.1):
kwargs = geffnet.mobilenetv3._gen_mobilenet_v3([size], channel_multiplier=channel_multiplier)
super(OpenSeeFaceDetect, self).__init__(**kwargs)
if size == "large":
self.up1 = UNetUp(round_channels(960, channel_multiplier), round_channels(112, channel_multiplier), 256, (14,14), old=1)
self.up2 = UNetUp(256, round_channels(40, channel_multiplier), 128, (28,28), old=1)
self.up3 = UNetUp(128, round_channels(24, channel_multiplier), 64, (56,56), old=1)
self.group = DSConv2d(64, 2, kernels_per_layer=4, groups=2, old=1)
self.r1_i = 1
self.r2_i = 2
self.r3_i = 4
elif size == "small":
self.up1 = UNetUp(round_channels(576, channel_multiplier), round_channels(40, channel_multiplier), 256, (14,14), old=1)
self.up2 = UNetUp(256, round_channels(24, channel_multiplier), 128, (28,28), old=1)
self.up3 = UNetUp(128, round_channels(16, channel_multiplier), 64, (56,56), old=1)
self.group = DSConv2d(64, 2, kernels_per_layer=4, groups=2, old=1)
self.r1_i = 0
self.r2_i = 1
self.r3_i = 2
self.maxpool = nn.MaxPool2d(kernel_size=3, dilation=1, stride=1, padding=1)
def _forward_impl(self, x):
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
r2 = None
r3 = None
for i, feature in enumerate(self.blocks):
x = feature(x)
if i == self.r3_i:
r3 = x
if i == self.r2_i:
r2 = x
if i == self.r1_i:
r1 = x
x = self.up1(x, r3)
x = self.up2(x, r2)
x = self.up3(x, r1)
x = self.group(x)
x2 = self.maxpool(x)
return x, x2
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.conv5(x)
x = x.mean([2, 3]) # globalpool
x = self.fc(x)
return x

# Facial landmark detection model
class OpenSeeNetSNV2(ShuffleNetV2):
def __init__(self):
super(OpenSeeNetSNV2, self).__init__([4, 8, 4], [24, 116, 232, 464, 1024])
self.up1 = UNetUp(1024, 232, 256, (14,14))
self.up2 = UNetUp(256, 116, 198 * 1, (28,28))
self.group = DSConv2d(198 * 1, 198, kernels_per_layer=4, groups=3)
return self._forward_impl(x)

# Landmark detection model
# Models:
# 0: "small", 0.5
# 1: "small", 1.0
# 2: "large", 0.75
# 3: "large", 1.0
class OpenSeeFaceLandmarks(geffnet.mobilenetv3.MobileNetV3):
def __init__(self, size="large", channel_multiplier=1.0):
kwargs = geffnet.mobilenetv3._gen_mobilenet_v3([size], channel_multiplier=channel_multiplier)
super(OpenSeeFaceLandmarks, self).__init__(**kwargs)
if size == "large":
self.up1 = UNetUp(round_channels(960, channel_multiplier), round_channels(112, channel_multiplier), 256, (14,14))
self.up2 = UNetUp(256, round_channels(40, channel_multiplier), 198 * 1, (28,28))
self.group = DSConv2d(198 * 1, 198, kernels_per_layer=4, groups=3)
self.r2_i = 2
self.r3_i = 4
elif size == "small":
self.up1 = UNetUp(round_channels(576, channel_multiplier), round_channels(40, channel_multiplier), 256, (14,14))
self.up2 = UNetUp(256, round_channels(24, channel_multiplier), 198 * 1, (28,28))
self.group = DSConv2d(198 * 1, 198, kernels_per_layer=4, groups=3)
self.r2_i = 1
self.r3_i = 2
def _forward_impl(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.stage2(x)
r2 = x
x = self.stage3(x)
r3 = x
x = self.stage4(x)
x = self.conv5(x)

x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
r2 = None
r3 = None
for i, feature in enumerate(self.blocks):
x = feature(x)
if i == self.r3_i:
r3 = x
if i == self.r2_i:
r2 = x
x = self.up1(x, r3)
x = self.up2(x, r2)
x = self.group(x)
return x

def forward(self, x):
return self._forward_impl(x)

Expand Down Expand Up @@ -196,12 +196,15 @@ def AdapWingLoss(pre_hm, gt_hm):
if first:
first_mask = dilated
first = False
dilated[17:27] *= 1.2
dilated[17] *= 1.3
dilated[18] *= 1.4
dilated[25] *= 1.4
dilated[26] *= 1.3
dilated[36:47] *= 2.5
# These weights varied between training runs and model sizes
dilated[17:27][dilated[17:27] > 1] *= 1.4
dilated[17, dilated[17] > 1] *= 1.6
dilated[18, dilated[18] > 1] *= 1.8
dilated[25, dilated[25] > 1] *= 1.8
dilated[26, dilated[26] > 1] *= 1.6
dilated[36:48][dilated[36:48] > 1] *= 2.8
# Used for a very small model
#dilated[[37,38,40,41,43,44,46,47]][dilated[[37,38,40,41,43,44,46,47]] > 1] *= 20.8
mask[i] = torch.cat([dilated, dilated, dilated], 0)

diff_hm = torch.abs(gt_hm - pre_hm)
Expand All @@ -216,5 +219,31 @@ def AdapWingLoss(pre_hm, gt_hm):

return first_mask.detach(), mean_loss

# The fast model runs on grayscale 112x112 using the ShuffleNet V2 0.5x configuration with input_channels=1. It was trained using AWL, but without the additional weighting for certain features
.


# Checkpoint test
if __name__== "__main__":
print("Checking gaze model")
m=OpenSeeFaceGaze()
ckpt = torch.load("gaze.pth")
m.load_state_dict(ckpt)
print("Checking detection model")
m=OpenSeeFaceDetect()
ckpt = torch.load("detection.pth")
m.load_state_dict(ckpt)
print("Checking lm_model0 model")
m=OpenSeeFaceLandmarks("small", 0.5)
ckpt = torch.load("lm_model0.pth")
m.load_state_dict(ckpt)
print("Checking lm_model1 model")
m=OpenSeeFaceLandmarks("small", 1.0)
ckpt = torch.load("lm_model1.pth")
m.load_state_dict(ckpt)
print("Checking lm_model2 model")
m=OpenSeeFaceLandmarks("large", 0.75)
ckpt = torch.load("lm_model2.pth")
m.load_state_dict(ckpt)
print("Checking lm_model3 model")
m=OpenSeeFaceLandmarks("large", 1.0)
ckpt = torch.load("lm_model3.pth")
m.load_state_dict(ckpt)

0 comments on commit 8145e4b

Please sign in to comment.