点到点和图到图带代码
点到点的代码理解
点击查看代码
def load_train_data_for_rnn(cfg, x, y, aux, scaler):
# x = {nt, nf, ngrid} = {3287,9,1399}
# y = {nt, ngrid} = {3287,1399}
# aux = {nf , nt} = {1 , 1399}
# scaler = {2, 45, 90, 1 }
nt, nf, ngrid = x.shape
# print('**************************, y.shape is ' + str(y.shape))
# print('**************************, aux.shape is ' + str(aux.shape))
# print('**************************, x.shape is ' + str(x.shape))
# print('**************************, scaler.shape is ' + str(scaler.shape))
mean, std = np.array(scaler[0]), np.array(scaler[1])
# print('**************************, mean.shape is ' + str(mean.shape))
#mean = mean.reshape(mean.shape[0],mean.shape[1]*mean.shape[2])
#std = std.reshape(std.shape[0],std.shape[1]*std.shape[2])
idx_time = np.random.randint(0, nt-cfg['seq_len']-cfg["forcast_time"], 1)[0]#在这个代码中,由于最后使用了索引 [0],所以实际上只生成一个随机整数。 1 作为参数传递给了 np.random.randint(),所以实际上只生成了一个随机整数
# print('**************************, idx_time is ' + str(idx_time))
idx_grid = np.random.randint(0, ngrid, cfg['batch_size'])#会生成 cfg['batch_size'] 个位于 0(包含)和 ngrid(不包含)之间的随机整数,数组形式,大小是64,它生成了 cfg['batch_size'] 个随机整数
# print('**************************, idx_grid is ' + str(idx_grid))
# print('**************************, idx_grid.shape is ' + str(idx_grid.shape))======idx_grid,64個(0到1399)的數組
x = np.transpose(x, (2,0,1))#x is {ngrid,nt,nf} = { 1399, 3287 , 9 }
y = np.transpose(y, (1,0))# y is {ngrid,nt} = {1399 , 3287}
aux = np.transpose(aux, (1,0)) # aux is {nt, nf}={1399,1}
# print('before select form x **************************, x is ' + str(x.shape))#========x{1399,3287,9}
x = x[idx_grid, idx_time:idx_time+cfg['seq_len']]
# print('after selcect from x **************************, x is ' + str(x.shape))=========x {64,365,9}
# print('before select form y **************************, y is ' + str(y.shape))# y is (1399, 3287)
y = y[idx_grid, idx_time+cfg['seq_len']+cfg["forcast_time"]] ##
# print('after selcect from y **************************, y is ' + str(y.shape))# y is (64,)
aux = aux[idx_grid]
y[np.isinf(y)]=np.nan#将这些选中的无穷大值替换为 NaN
mask = y == y#创建一个布尔类型的掩码(mask),判断数组 y 中的元素是否为有效值
x = x[mask]
y = y[mask]
aux = aux[mask]
x[np.isinf(x)]=np.nan
x = np.nan_to_num(x)#将数组 x 中的 NaN 值替换为 0。
return x, y, aux, mean, std

浙公网安备 33010602011771号