219 lines
9.6 KiB
Python
219 lines
9.6 KiB
Python
import torch
|
|
from torch import nn
|
|
from .modules import BasicBlock, Bottleneck
|
|
|
|
|
|
class StageModule(nn.Module):
|
|
def __init__(self, stage, output_branches, c, bn_momentum):
|
|
super(StageModule, self).__init__()
|
|
self.stage = stage
|
|
self.output_branches = output_branches
|
|
|
|
self.branches = nn.ModuleList()
|
|
for i in range(self.stage):
|
|
w = c * (2 ** i)
|
|
branch = nn.Sequential(
|
|
BasicBlock(w, w, bn_momentum=bn_momentum),
|
|
BasicBlock(w, w, bn_momentum=bn_momentum),
|
|
BasicBlock(w, w, bn_momentum=bn_momentum),
|
|
BasicBlock(w, w, bn_momentum=bn_momentum),
|
|
)
|
|
self.branches.append(branch)
|
|
|
|
self.fuse_layers = nn.ModuleList()
|
|
# for each output_branches (i.e. each branch in all cases but the very last one)
|
|
for i in range(self.output_branches):
|
|
self.fuse_layers.append(nn.ModuleList())
|
|
for j in range(self.stage): # for each branch
|
|
if i == j:
|
|
self.fuse_layers[-1].append(nn.Sequential()) # Used in place of "None" because it is callable
|
|
elif i < j:
|
|
self.fuse_layers[-1].append(nn.Sequential(
|
|
nn.Conv2d(c * (2 ** j), c * (2 ** i), kernel_size=(1, 1), stride=(1, 1), bias=False),
|
|
nn.BatchNorm2d(c * (2 ** i), eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
|
|
nn.Upsample(scale_factor=(2.0 ** (j - i)), mode='nearest'),
|
|
))
|
|
elif i > j:
|
|
ops = []
|
|
for k in range(i - j - 1):
|
|
ops.append(nn.Sequential(
|
|
nn.Conv2d(c * (2 ** j), c * (2 ** j), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1),
|
|
bias=False),
|
|
nn.BatchNorm2d(c * (2 ** j), eps=1e-05, momentum=0.1, affine=True,
|
|
track_running_stats=True),
|
|
nn.ReLU(inplace=True),
|
|
))
|
|
ops.append(nn.Sequential(
|
|
nn.Conv2d(c * (2 ** j), c * (2 ** i), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1),
|
|
bias=False),
|
|
nn.BatchNorm2d(c * (2 ** i), eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
|
|
))
|
|
self.fuse_layers[-1].append(nn.Sequential(*ops))
|
|
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
def forward(self, x):
|
|
assert len(self.branches) == len(x)
|
|
|
|
x = [branch(b) for branch, b in zip(self.branches, x)]
|
|
|
|
x_fused = []
|
|
for i in range(len(self.fuse_layers)):
|
|
for j in range(0, len(self.branches)):
|
|
if j == 0:
|
|
x_fused.append(self.fuse_layers[i][0](x[0]))
|
|
else:
|
|
x_fused[i] = x_fused[i] + self.fuse_layers[i][j](x[j])
|
|
|
|
for i in range(len(x_fused)):
|
|
x_fused[i] = self.relu(x_fused[i])
|
|
|
|
return x_fused
|
|
|
|
|
|
class HRNet(nn.Module):
|
|
def __init__(self, c=48, nof_joints=17, bn_momentum=0.1):
|
|
super(HRNet, self).__init__()
|
|
|
|
# Input (stem net)
|
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
|
|
self.bn1 = nn.BatchNorm2d(64, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True)
|
|
self.conv2 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
|
|
self.bn2 = nn.BatchNorm2d(64, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
# Stage 1 (layer1) - First group of bottleneck (resnet) modules
|
|
downsample = nn.Sequential(
|
|
nn.Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False),
|
|
nn.BatchNorm2d(256, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
|
|
)
|
|
self.layer1 = nn.Sequential(
|
|
Bottleneck(64, 64, downsample=downsample),
|
|
Bottleneck(256, 64),
|
|
Bottleneck(256, 64),
|
|
Bottleneck(256, 64),
|
|
)
|
|
|
|
# Fusion layer 1 (transition1) - Creation of the first two branches (one full and one half resolution)
|
|
self.transition1 = nn.ModuleList([
|
|
nn.Sequential(
|
|
nn.Conv2d(256, c, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
|
|
nn.BatchNorm2d(c, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
|
|
nn.ReLU(inplace=True),
|
|
),
|
|
nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights
|
|
nn.Conv2d(256, c * (2 ** 1), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
|
|
nn.BatchNorm2d(c * (2 ** 1), eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
|
|
nn.ReLU(inplace=True),
|
|
)),
|
|
])
|
|
|
|
# Stage 2 (stage2) - Second module with 1 group of bottleneck (resnet) modules. This has 2 branches
|
|
self.stage2 = nn.Sequential(
|
|
StageModule(stage=2, output_branches=2, c=c, bn_momentum=bn_momentum),
|
|
)
|
|
|
|
# Fusion layer 2 (transition2) - Creation of the third branch (1/4 resolution)
|
|
self.transition2 = nn.ModuleList([
|
|
nn.Sequential(), # None, - Used in place of "None" because it is callable
|
|
nn.Sequential(), # None, - Used in place of "None" because it is callable
|
|
nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights
|
|
nn.Conv2d(c * (2 ** 1), c * (2 ** 2), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
|
|
nn.BatchNorm2d(c * (2 ** 2), eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
|
|
nn.ReLU(inplace=True),
|
|
)), # ToDo Why the new branch derives from the "upper" branch only?
|
|
])
|
|
|
|
# Stage 3 (stage3) - Third module with 4 groups of bottleneck (resnet) modules. This has 3 branches
|
|
self.stage3 = nn.Sequential(
|
|
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
|
|
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
|
|
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
|
|
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
|
|
)
|
|
|
|
# Fusion layer 3 (transition3) - Creation of the fourth branch (1/8 resolution)
|
|
self.transition3 = nn.ModuleList([
|
|
nn.Sequential(), # None, - Used in place of "None" because it is callable
|
|
nn.Sequential(), # None, - Used in place of "None" because it is callable
|
|
nn.Sequential(), # None, - Used in place of "None" because it is callable
|
|
nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights
|
|
nn.Conv2d(c * (2 ** 2), c * (2 ** 3), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
|
|
nn.BatchNorm2d(c * (2 ** 3), eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
|
|
nn.ReLU(inplace=True),
|
|
)), # ToDo Why the new branch derives from the "upper" branch only?
|
|
])
|
|
|
|
# Stage 4 (stage4) - Fourth module with 3 groups of bottleneck (resnet) modules. This has 4 branches
|
|
self.stage4 = nn.Sequential(
|
|
StageModule(stage=4, output_branches=4, c=c, bn_momentum=bn_momentum),
|
|
StageModule(stage=4, output_branches=4, c=c, bn_momentum=bn_momentum),
|
|
StageModule(stage=4, output_branches=1, c=c, bn_momentum=bn_momentum),
|
|
)
|
|
|
|
# Final layer (final_layer)
|
|
self.final_layer = nn.Conv2d(c, nof_joints, kernel_size=(1, 1), stride=(1, 1))
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
x = self.conv2(x)
|
|
x = self.bn2(x)
|
|
x = self.relu(x)
|
|
|
|
x = self.layer1(x)
|
|
x = [trans(x) for trans in self.transition1] # Since now, x is a list (# == nof branches)
|
|
|
|
x = self.stage2(x)
|
|
# x = [trans(x[-1]) for trans in self.transition2] # New branch derives from the "upper" branch only
|
|
x = [
|
|
self.transition2[0](x[0]),
|
|
self.transition2[1](x[1]),
|
|
self.transition2[2](x[-1])
|
|
] # New branch derives from the "upper" branch only
|
|
|
|
x = self.stage3(x)
|
|
# x = [trans(x) for trans in self.transition3] # New branch derives from the "upper" branch only
|
|
x = [
|
|
self.transition3[0](x[0]),
|
|
self.transition3[1](x[1]),
|
|
self.transition3[2](x[2]),
|
|
self.transition3[3](x[-1])
|
|
] # New branch derives from the "upper" branch only
|
|
|
|
x = self.stage4(x)
|
|
|
|
x = self.final_layer(x[0])
|
|
|
|
return {
|
|
'output': x
|
|
}
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# model = HRNet(48, 17, 0.1)
|
|
model = HRNet(32, 17, 0.1)
|
|
|
|
# print(model)
|
|
|
|
model.load_state_dict(
|
|
# torch.load('./weights/pose_hrnet_w48_384x288.pth')
|
|
torch.load('./weights/pose_hrnet_w32_256x192.pth')
|
|
)
|
|
print('ok!!')
|
|
|
|
if torch.cuda.is_available() and False:
|
|
torch.backends.cudnn.deterministic = True
|
|
device = torch.device('cuda:0')
|
|
else:
|
|
device = torch.device('cpu')
|
|
|
|
print(device)
|
|
|
|
model = model.to(device)
|
|
|
|
y = model(torch.ones(1, 3, 384, 288).to(device))
|
|
print(y.shape)
|
|
print(torch.min(y).item(), torch.mean(y).item(), torch.max(y).item())
|