EasyMocap/myeasymocap/backbone/pare/layers/locallyconnected2d.py
2023-06-24 22:39:33 +08:00

49 lines
1.9 KiB
Python

# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import torch
import torch.nn as nn
from torch.nn.modules.utils import _pair
class LocallyConnected2d(nn.Module):
def __init__(self, in_channels, out_channels, output_size, kernel_size, stride, bias=False):
super(LocallyConnected2d, self).__init__()
output_size = _pair(output_size)
self.weight = nn.Parameter(
torch.randn(1, out_channels, in_channels, output_size[0], output_size[1], kernel_size ** 2),
requires_grad=True,
)
if bias:
self.bias = nn.Parameter(
torch.randn(1, out_channels, output_size[0], output_size[1]), requires_grad=True
)
else:
self.register_parameter('bias', None)
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
def forward(self, x):
_, c, h, w = x.size()
kh, kw = self.kernel_size
dh, dw = self.stride
x = x.unfold(2, kh, dh).unfold(3, kw, dw)
x = x.contiguous().view(*x.size()[:-2], -1)
# Sum in in_channel and kernel_size dims
out = (x.unsqueeze(1) * self.weight).sum([2, -1])
if self.bias is not None:
out += self.bias
return out