import torch from torch import nn from .modules import BasicBlock, Bottleneck class StageModule(nn.Module): def __init__(self, stage, output_branches, c, bn_momentum): super(StageModule, self).__init__() self.stage = stage self.output_branches = output_branches self.branches = nn.ModuleList() for i in range(self.stage): w = c * (2 ** i) branch = nn.Sequential( BasicBlock(w, w, bn_momentum=bn_momentum), BasicBlock(w, w, bn_momentum=bn_momentum), BasicBlock(w, w, bn_momentum=bn_momentum), BasicBlock(w, w, bn_momentum=bn_momentum), ) self.branches.append(branch) self.fuse_layers = nn.ModuleList() # for each output_branches (i.e. each branch in all cases but the very last one) for i in range(self.output_branches): self.fuse_layers.append(nn.ModuleList()) for j in range(self.stage): # for each branch if i == j: self.fuse_layers[-1].append(nn.Sequential()) # Used in place of "None" because it is callable elif i < j: self.fuse_layers[-1].append(nn.Sequential( nn.Conv2d(c * (2 ** j), c * (2 ** i), kernel_size=(1, 1), stride=(1, 1), bias=False), nn.BatchNorm2d(c * (2 ** i), eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), nn.Upsample(scale_factor=(2.0 ** (j - i)), mode='nearest'), )) elif i > j: ops = [] for k in range(i - j - 1): ops.append(nn.Sequential( nn.Conv2d(c * (2 ** j), c * (2 ** j), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False), nn.BatchNorm2d(c * (2 ** j), eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), nn.ReLU(inplace=True), )) ops.append(nn.Sequential( nn.Conv2d(c * (2 ** j), c * (2 ** i), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False), nn.BatchNorm2d(c * (2 ** i), eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), )) self.fuse_layers[-1].append(nn.Sequential(*ops)) self.relu = nn.ReLU(inplace=True) def forward(self, x): assert len(self.branches) == len(x) x = [branch(b) for branch, b in zip(self.branches, x)] x_fused = [] for i in range(len(self.fuse_layers)): for j in range(0, len(self.branches)): if j == 0: x_fused.append(self.fuse_layers[i][0](x[0])) else: x_fused[i] = x_fused[i] + self.fuse_layers[i][j](x[j]) for i in range(len(x_fused)): x_fused[i] = self.relu(x_fused[i]) return x_fused class HRNet(nn.Module): def __init__(self, c=48, nof_joints=17, bn_momentum=0.1): super(HRNet, self).__init__() # Input (stem net) self.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) self.bn1 = nn.BatchNorm2d(64, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True) self.conv2 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) self.bn2 = nn.BatchNorm2d(64, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True) self.relu = nn.ReLU(inplace=True) # Stage 1 (layer1) - First group of bottleneck (resnet) modules downsample = nn.Sequential( nn.Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False), nn.BatchNorm2d(256, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True), ) self.layer1 = nn.Sequential( Bottleneck(64, 64, downsample=downsample), Bottleneck(256, 64), Bottleneck(256, 64), Bottleneck(256, 64), ) # Fusion layer 1 (transition1) - Creation of the first two branches (one full and one half resolution) self.transition1 = nn.ModuleList([ nn.Sequential( nn.Conv2d(256, c, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(c, eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True), nn.ReLU(inplace=True), ), nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights nn.Conv2d(256, c * (2 ** 1), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False), nn.BatchNorm2d(c * (2 ** 1), eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True), nn.ReLU(inplace=True), )), ]) # Stage 2 (stage2) - Second module with 1 group of bottleneck (resnet) modules. This has 2 branches self.stage2 = nn.Sequential( StageModule(stage=2, output_branches=2, c=c, bn_momentum=bn_momentum), ) # Fusion layer 2 (transition2) - Creation of the third branch (1/4 resolution) self.transition2 = nn.ModuleList([ nn.Sequential(), # None, - Used in place of "None" because it is callable nn.Sequential(), # None, - Used in place of "None" because it is callable nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights nn.Conv2d(c * (2 ** 1), c * (2 ** 2), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False), nn.BatchNorm2d(c * (2 ** 2), eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True), nn.ReLU(inplace=True), )), # ToDo Why the new branch derives from the "upper" branch only? ]) # Stage 3 (stage3) - Third module with 4 groups of bottleneck (resnet) modules. This has 3 branches self.stage3 = nn.Sequential( StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum), StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum), StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum), StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum), ) # Fusion layer 3 (transition3) - Creation of the fourth branch (1/8 resolution) self.transition3 = nn.ModuleList([ nn.Sequential(), # None, - Used in place of "None" because it is callable nn.Sequential(), # None, - Used in place of "None" because it is callable nn.Sequential(), # None, - Used in place of "None" because it is callable nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights nn.Conv2d(c * (2 ** 2), c * (2 ** 3), kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False), nn.BatchNorm2d(c * (2 ** 3), eps=1e-05, momentum=bn_momentum, affine=True, track_running_stats=True), nn.ReLU(inplace=True), )), # ToDo Why the new branch derives from the "upper" branch only? ]) # Stage 4 (stage4) - Fourth module with 3 groups of bottleneck (resnet) modules. This has 4 branches self.stage4 = nn.Sequential( StageModule(stage=4, output_branches=4, c=c, bn_momentum=bn_momentum), StageModule(stage=4, output_branches=4, c=c, bn_momentum=bn_momentum), StageModule(stage=4, output_branches=1, c=c, bn_momentum=bn_momentum), ) # Final layer (final_layer) self.final_layer = nn.Conv2d(c, nof_joints, kernel_size=(1, 1), stride=(1, 1)) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = self.layer1(x) x = [trans(x) for trans in self.transition1] # Since now, x is a list (# == nof branches) x = self.stage2(x) # x = [trans(x[-1]) for trans in self.transition2] # New branch derives from the "upper" branch only x = [ self.transition2[0](x[0]), self.transition2[1](x[1]), self.transition2[2](x[-1]) ] # New branch derives from the "upper" branch only x = self.stage3(x) # x = [trans(x) for trans in self.transition3] # New branch derives from the "upper" branch only x = [ self.transition3[0](x[0]), self.transition3[1](x[1]), self.transition3[2](x[2]), self.transition3[3](x[-1]) ] # New branch derives from the "upper" branch only x = self.stage4(x) x = self.final_layer(x[0]) return x if __name__ == '__main__': # model = HRNet(48, 17, 0.1) model = HRNet(32, 17, 0.1) # print(model) model.load_state_dict( # torch.load('./weights/pose_hrnet_w48_384x288.pth') torch.load('./weights/pose_hrnet_w32_256x192.pth') ) print('ok!!') if torch.cuda.is_available() and False: torch.backends.cudnn.deterministic = True device = torch.device('cuda:0') else: device = torch.device('cpu') print(device) model = model.to(device) y = model(torch.ones(1, 3, 384, 288).to(device)) print(y.shape) print(torch.min(y).item(), torch.mean(y).item(), torch.max(y).item())