# CoordConv实现

import torch
import torch.nn as nn
'''
An alternative implementation for PyTorch with auto-infering the x-y dimensions.
paper: An intriguing failing of convolutional neural networks and the CoordConv solution

https://zhuanlan.zhihu.com/p/443583240
https://blog.csdn.net/oYeZhou/article/details/116717210

'''

def __init__(self, with_r=False):
super().__init__()
self.with_r = with_r

def forward(self, ins_feat):
"""
Args:
x: shape(batch, channel, x_dim, y_dim)
"""
batch_size, _, x_dim, y_dim = ins_feat.size()
# 生成从-1到1的线性值
x_range = torch.linspace(-1, 1, ins_feat.shape[-1], device=ins_feat.device)
y_range = torch.linspace(-1, 1, ins_feat.shape[-2], device=ins_feat.device)
y, x = torch.meshgrid(y_range, x_range) # 生成二维坐标网格
y = y.expand([ins_feat.shape[0], 1, -1, -1]) # 扩充到和ins_feat相同维度
x = x.expand([ins_feat.shape[0], 1, -1, -1])
coord_feat = torch.cat([x, y], 1) # 位置特征
ret = torch.cat([ins_feat, coord_feat], 1) # concatnate一起作为下一个卷积的输入
if self.with_r:
rr = torch.sqrt(torch.pow(x - 0.5, 2) + torch.pow(y - 0.5, 2))
ret = torch.cat([ret, rr], dim=1)
return ret

class CoordConv(nn.Module):
def __init__(self, in_channels, out_channels, with_r=False, **kwargs):
super().__init__()
return ret