49 lines
1.9 KiB
Python
49 lines
1.9 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
|
# holder of all proprietary rights on this computer program.
|
|
# You can only use this computer program if you have closed
|
|
# a license agreement with MPG or you get the right to use the computer
|
|
# program from someone who is authorized to grant you that right.
|
|
# Any use of the computer program without a valid license is prohibited and
|
|
# liable to prosecution.
|
|
#
|
|
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
|
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
|
# for Intelligent Systems. All rights reserved.
|
|
#
|
|
# Contact: ps-license@tuebingen.mpg.de
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn.modules.utils import _pair
|
|
|
|
|
|
class LocallyConnected2d(nn.Module):
|
|
def __init__(self, in_channels, out_channels, output_size, kernel_size, stride, bias=False):
|
|
super(LocallyConnected2d, self).__init__()
|
|
output_size = _pair(output_size)
|
|
self.weight = nn.Parameter(
|
|
torch.randn(1, out_channels, in_channels, output_size[0], output_size[1], kernel_size ** 2),
|
|
requires_grad=True,
|
|
)
|
|
if bias:
|
|
self.bias = nn.Parameter(
|
|
torch.randn(1, out_channels, output_size[0], output_size[1]), requires_grad=True
|
|
)
|
|
else:
|
|
self.register_parameter('bias', None)
|
|
self.kernel_size = _pair(kernel_size)
|
|
self.stride = _pair(stride)
|
|
|
|
def forward(self, x):
|
|
_, c, h, w = x.size()
|
|
kh, kw = self.kernel_size
|
|
dh, dw = self.stride
|
|
x = x.unfold(2, kh, dh).unfold(3, kw, dw)
|
|
x = x.contiguous().view(*x.size()[:-2], -1)
|
|
# Sum in in_channel and kernel_size dims
|
|
out = (x.unsqueeze(1) * self.weight).sum([2, -1])
|
|
if self.bias is not None:
|
|
out += self.bias
|
|
return out |