2023-06-19 16:39:27 +08:00

219 lines
9.6 KiB

import torch
from torch import nn
from .modules import BasicBlock, Bottleneck
class StageModule(nn.Module):
def __init__(self, stage, output_branches, c, bn_momentum):
super(StageModule, self).__init__()
self.stage = stage
self.output_branches = output_branches
self.branches = nn.ModuleList()
for i in range(self.stage):
w = c * (2 ** i)
branch = nn.Sequential(
BasicBlock(w, w, bn_momentum=bn_momentum),
BasicBlock(w, w, bn_momentum=bn_momentum),
BasicBlock(w, w, bn_momentum=bn_momentum),
BasicBlock(w, w, bn_momentum=bn_momentum),
self.fuse_layers = nn.ModuleList()
# for each output_branches (i.e. each branch in all cases but the very last one)
for i in range(self.output_branches):
for j in range(self.stage): # for each branch
if i == j:
self.fuse_layers[-1].append(nn.Sequential()) # Used in place of "None" because it is callable
elif i < j:
nn.Conv2d(c * (2 ** j), c * (2 ** i), kernel_size=(1, 1), stride=(1, 1), bias=False),
nn.BatchNorm2d(c * (2 ** i), eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
nn.Upsample(scale_factor=(2.0 ** (j - i)), mode='nearest'),
elif i > j:
ops = []
for k in range(i - j - 1):
nn.Conv2d(c * (2 ** j), c * (2 ** j), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1),
nn.BatchNorm2d(c * (2 ** j), eps=1e-05, momentum=0.1, affine=True,
nn.Conv2d(c * (2 ** j), c * (2 ** i), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1),
nn.BatchNorm2d(c * (2 ** i), eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
assert len(self.branches) == len(x)
x = [branch(b) for branch, b in zip(self.branches, x)]
x_fused = []
for i in range(len(self.fuse_layers)):
for j in range(0, len(self.branches)):
if j == 0:
x_fused[i] = x_fused[i] + self.fuse_layers[i][j](x[j])
for i in range(len(x_fused)):
x_fused[i] = self.relu(x_fused[i])
return x_fused
class HRNet(nn.Module):
def __init__(self, c=48, nof_joints=17, bn_momentum=0.1):
super(HRNet, self).__init__()
# Input (stem net)
self.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
self.bn1 = nn.BatchNorm2d(64, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True)
self.conv2 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
self.bn2 = nn.BatchNorm2d(64, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True)
self.relu = nn.ReLU(inplace=True)
# Stage 1 (layer1) - First group of bottleneck (resnet) modules
downsample = nn.Sequential(
nn.Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False),
nn.BatchNorm2d(256, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
self.layer1 = nn.Sequential(
Bottleneck(64, 64, downsample=downsample),
Bottleneck(256, 64),
Bottleneck(256, 64),
Bottleneck(256, 64),
# Fusion layer 1 (transition1) - Creation of the first two branches (one full and one half resolution)
self.transition1 = nn.ModuleList([
nn.Conv2d(256, c, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
nn.BatchNorm2d(c, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights
nn.Conv2d(256, c * (2 ** 1), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
nn.BatchNorm2d(c * (2 ** 1), eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
# Stage 2 (stage2) - Second module with 1 group of bottleneck (resnet) modules. This has 2 branches
self.stage2 = nn.Sequential(
StageModule(stage=2, output_branches=2, c=c, bn_momentum=bn_momentum),
# Fusion layer 2 (transition2) - Creation of the third branch (1/4 resolution)
self.transition2 = nn.ModuleList([
nn.Sequential(), # None, - Used in place of "None" because it is callable
nn.Sequential(), # None, - Used in place of "None" because it is callable
nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights
nn.Conv2d(c * (2 ** 1), c * (2 ** 2), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
nn.BatchNorm2d(c * (2 ** 2), eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
)), # ToDo Why the new branch derives from the "upper" branch only?
# Stage 3 (stage3) - Third module with 4 groups of bottleneck (resnet) modules. This has 3 branches
self.stage3 = nn.Sequential(
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
# Fusion layer 3 (transition3) - Creation of the fourth branch (1/8 resolution)
self.transition3 = nn.ModuleList([
nn.Sequential(), # None, - Used in place of "None" because it is callable
nn.Sequential(), # None, - Used in place of "None" because it is callable
nn.Sequential(), # None, - Used in place of "None" because it is callable
nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights
nn.Conv2d(c * (2 ** 2), c * (2 ** 3), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
nn.BatchNorm2d(c * (2 ** 3), eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True),
)), # ToDo Why the new branch derives from the "upper" branch only?
# Stage 4 (stage4) - Fourth module with 3 groups of bottleneck (resnet) modules. This has 4 branches
self.stage4 = nn.Sequential(
StageModule(stage=4, output_branches=4, c=c, bn_momentum=bn_momentum),
StageModule(stage=4, output_branches=4, c=c, bn_momentum=bn_momentum),
StageModule(stage=4, output_branches=1, c=c, bn_momentum=bn_momentum),
# Final layer (final_layer)
self.final_layer = nn.Conv2d(c, nof_joints, kernel_size=(1, 1), stride=(1, 1))
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.layer1(x)
x = [trans(x) for trans in self.transition1] # Since now, x is a list (# == nof branches)
x = self.stage2(x)
# x = [trans(x[-1]) for trans in self.transition2] # New branch derives from the "upper" branch only
x = [
] # New branch derives from the "upper" branch only
x = self.stage3(x)
# x = [trans(x) for trans in self.transition3] # New branch derives from the "upper" branch only
x = [
] # New branch derives from the "upper" branch only
x = self.stage4(x)
x = self.final_layer(x[0])
return {
'output': x
if __name__ == '__main__':
# model = HRNet(48, 17, 0.1)
model = HRNet(32, 17, 0.1)
# print(model)
# torch.load('./weights/pose_hrnet_w48_384x288.pth')
if torch.cuda.is_available() and False:
torch.backends.cudnn.deterministic = True
device = torch.device('cuda:0')
device = torch.device('cpu')
model = model.to(device)
y = model(torch.ones(1, 3, 384, 288).to(device))
print(torch.min(y).item(), torch.mean(y).item(), torch.max(y).item())