## pytorch计算图扩大，反传变慢问题debug

# 为了更加集成，给定两个角度，生成compact的倾斜图片
class Compact_Homo(nn.Module):
def __init__(self, device):
super(Compact_Homo, self).__init__()
# 假设内参数K为单位矩阵
self.d = 5  # 表示物体到光心的距离
self.device = device
def forward(self, alpha, beta, size, d):
# alpha: N, beta: N, size: N*C*W*H
# pdb.set_trace()
if d is not None:
self.d = d
B = alpha.shape[0]
# 表示图像的尺寸
if size is None:
size = (B, 3, 1024, 1024)
N, C, H, W = size
N = B
Rotx = torch.zeros(B, 3, 3).to(self.device).clone()
ones = torch.ones(B,).to(self.device).clone()

# pdb.set_trace()
Rotx[:, 0, 0] =  ones
Rotx[:,1, 1] = torch.cos(beta).squeeze(1)
Rotx[:,1, 2] = -torch.sin(beta).squeeze(1)
Rotx[:,2, 1] = torch.sin(beta).squeeze(1)
Rotx[:,2, 2] = torch.cos(beta).squeeze(1)

Roty = torch.zeros(B, 3, 3).to(self.device).clone()
ones = torch.ones(B,).to(self.device).clone()
Roty[:,1,1] = ones.clone()
Roty[:,0,0] = torch.cos(alpha).squeeze(1)
Roty[:,0,2] = torch.sin(alpha).squeeze(1)
Roty[:,2,0] = -torch.sin(alpha).squeeze(1)
Roty[:,2,2] = torch.cos(alpha).squeeze(1)

# 以下过程构造homo
R = torch.bmm(Rotx, Roty)
R_1 = torch.inverse(R).clone()  # 版本不一样，需要的shape也不一样
t = torch.zeros(B,3).to(self.device)
# pdb.set_trace()
t[:,2] = d.squeeze(1).clone() # 平移向量
R_1[:,:,2] = t.clone()  # 将第三列赋值
temp_homo = R_1.clone()
homo = torch.inverse(R_1).clone()

# -------------------
# 以下过程构造单位圆，求解其center以及其scale
C = torch.zeros(B, 3, 3).to(self.device).clone()
C[:,0,0] = torch.tensor(1.)
C[:,1,1] = torch.tensor(1.)
C[:,2,2] = torch.tensor(-1.)
C2 = torch.bmm(torch.inverse(torch.transpose(temp_homo,1,2)), C)
C2_ = torch.bmm(C2, torch.inverse(temp_homo))

C3 = torch.inverse(C2_)  # 对偶形式

a = C3[:,0,0]
b = C3[:,0,2]+C3[:,2,0]
c = C3[:,2,2]

right_x = (-b-torch.sqrt(b.mul(b)-4*a.mul(c)))/(2*a)
left_x = (-b+torch.sqrt(b.mul(b)-4*a.mul(c)))/(2*a)
right_x = -1./right_x
left_x = -1./left_x

width = right_x-left_x
center_x = (right_x+left_x)/2

a_ = C3[:,1,1]
b_ = C3[:,1,2]+C3[:,2,1]
c_ = C3[:,2,2]

bottom_y = (-b_-torch.sqrt(b_.mul(b_)-4*a_.mul(c_)))/(2*a_)
top_y = (-b_+torch.sqrt(b_.mul(b_)-4*a_.mul(c_)))/(2*a_)
bottom_y = -1./bottom_y
top_y = -1./top_y

height = bottom_y-top_y
center_y = (top_y+bottom_y)/2
scale = torch.max(width, height)

#---------------------
# 根据求解得到的homo，中心点以及产生compact的grid
# size = (1, 3, 1024, 1024)
N, C, H, W = size
N=B

base_grid = torch.zeros(N, H, W, 2).to(self.device)
linear_points = torch.linspace(-1, 1, W).to(self.device) if W > 1 else torch.Tensor([-1]).to(self.device)
base_grid[:, :, :, 0] = torch.ger(torch.ones(H).to(self.device), linear_points).expand_as(base_grid[:, :, :, 0])
linear_points = torch.linspace(-1, 1, H).to(self.device) if H > 1 else torch.Tensor([-1]).to(self.device)
base_grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W).to(self.device)).expand_as(base_grid[:, :, :, 1])
base_grid = base_grid.view(N, H * W, 2)

# 对center和scale进行变换
center_x = center_x.unsqueeze(1)
center_y = center_y.unsqueeze(1)
center = torch.cat((center_x,center_y), 1).unsqueeze(1).repeat(1,W*H,1)
scale = scale.unsqueeze(1).repeat(1,H*W).unsqueeze(2).repeat(1,1,2)

base_grid = base_grid*scale/2
base_grid = base_grid+center

# 将homo进行扩展，方便运算
h = homo.unsqueeze(1).repeat(1, W*H, 1, 1)

temp1 = (h[:, :, 0, 0] * base_grid[:, :, 0] + h[:, :, 0, 1] * base_grid[:, :, 1] + h[:, :, 0, 2])
temp2 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])
u1 = temp1 / temp2

temp3 = (h[:, :, 1, 0] * base_grid[:, :, 0] + h[:, :, 1, 1] * base_grid[:, :, 1] + h[:, :, 1, 2])
temp4 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])
v1 = temp3 / temp4

grid1 = u1.view(N, H, W, 1)
grid2 = v1.view(N, H, W, 1)

grid = torch.cat((grid1, grid2), 3)
return grid


        temp1 = (h[:, :, 0, 0] * base_grid[:, :, 0] + h[:, :, 0, 1] * base_grid[:, :, 1] + h[:, :, 0, 2])
temp2 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])
u1 = temp1 / temp2

temp3 = (h[:, :, 1, 0] * base_grid[:, :, 0] + h[:, :, 1, 1] * base_grid[:, :, 1] + h[:, :, 1, 2])
temp4 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])


        h = homo
# temp1 = (h[:, :, 0, 0] * base_grid[:, :, 0] + h[:, :, 0, 1] * base_grid[:, :, 1] + h[:, :, 0, 2])
# temp2 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])
temp1 = (h[:, 0, 0] * base_grid[:, :, 0] + h[:, 0, 1] * base_grid[:, :, 1] + h[:, 0, 2])
temp2 = (h[:, 2, 0] * base_grid[:, :, 0] + h[:, 2, 1] * base_grid[:, :, 1] + h[:, 2, 2])
u1 = temp1 / temp2

# temp3 = (h[:, :, 1, 0] * base_grid[:, :, 0] + h[:, :, 1, 1] * base_grid[:, :, 1] + h[:, :, 1, 2])
# temp4 = (h[:, :, 2, 0] * base_grid[:, :, 0] + h[:, :, 2, 1] * base_grid[:, :, 1] + h[:, :, 2, 2])
temp3 = (h[:, 1, 0] * base_grid[:, :, 0] + h[:, 1, 1] * base_grid[:, :, 1] + h[:, 1, 2])
temp4 = (h[:, 2, 0] * base_grid[:, :, 0] + h[:, 2, 1] * base_grid[:, :, 1] + h[:, 2, 2])
v1 = temp3 / temp4

        # 对center和scale进行变换
center_x = center_x.unsqueeze(1)
center_y = center_y.unsqueeze(1)
# center = torch.cat((center_x,center_y), 1).unsqueeze(1).repeat(1,W*H,1)
# scale = scale.unsqueeze(1).repeat(1,H*W).unsqueeze(2).repeat(1,1,2)
center = torch.cat((center_x,center_y), 1)
scale = scale
base_grid = base_grid*scale/2.
base_grid = base_grid+center