np.random.seed(1)
def combineGenerator(x_l,x_a,x_v,batch_size):
index = 0
while True:
start = index % x_l.shape[0]
end = start + batch_size
################ l_mask ################
bat_l = x_l[start:end]
l_mask1 = get_mask(bat_l, mask_rate=0.1, mask_lenth=1)
l_mask2 = get_mask(bat_l, mask_rate=0.1, mask_lenth=2)
l_mask = l_mask1 + l_mask2
l_mask = np.where(l_mask >= 1, 1, 0)
l_output = l_mask * bat_l
l_output = l_output.astype(np.float32)
################ a_mask ################
bat_a = x_a[start:end]
a_mask1 = get_mask(bat_a, mask_rate=0.1, mask_lenth=1)
a_mask2 = get_mask(bat_a, mask_rate=0.1, mask_lenth=2)
a_mask = a_mask1 + a_mask2
a_mask = np.where(a_mask >= 1, 1, 0)
a_output = a_mask * bat_a
a_output = a_output.astype(np.float32)
################ v_mask ################
bat_v = x_v[start:end]
v_mask1 = get_mask(bat_v, mask_rate=0.1, mask_lenth=1)
v_mask2 = get_mask(bat_v, mask_rate=0.1, mask_lenth=2)
v_mask = v_mask1 + v_mask2
v_mask = np.where(v_mask >= 1, 1, 0)
v_output = v_mask * bat_v
v_output = v_output.astype(np.float32)
# bat_y = y[start:end] * (1-mask_mat[:,:,:,np.newaxis])
index += batch_size
# yield (bat_x_mask, bat_y_mask)
yield [bat_l,bat_a,bat_v,l_mask,a_mask,v_mask],{'mult__model':l_output,'mult__model_1':a_output,'mult__model_2':v_output}
train_generator = combineGenerator(l_train,a_train,v_train,batch_size=bat)
test_generator = combineGenerator(l_val,a_val,v_val, batch_size=bat)
v_input = tf.keras.layers.Input(shape=(500,35))
a_input = tf.keras.layers.Input(shape=(500,74))
l_input = tf.keras.layers.Input(shape=(50,300))
v_mask_input = tf.keras.layers.Input(shape=(500, 1))
a_mask_input = tf.keras.layers.Input(shape=(500, 1))
l_mask_input = tf.keras.layers.Input(shape=(50, 1))
out_l,out_a,out_v,out = model(l_input,a_input,v_input,l_mask =l_mask_input,a_mask =a_mask_input,v_mask =v_mask_input)
pre_model = keras.models.Model(inputs = [l_input,a_input,v_input,l_mask_input,a_mask_input,v_mask_input],outputs = [out_l,out_a,out_v])
opt = tf.keras.optimizers.Adam(lr=lr_rate,clipvalue=1.)
history2 = LossHistory_early_stop(which_test, epochs, bat, lr_rate,)
# early_stoping = EarlyStopping(monitor='val_loss', patience=patience, restore_best_weights=True, mode='min')
early_stoping = EarlyStopping(monitor='val_weighted_accuracy', patience=patience, restore_best_weights=True, mode='max')
pre_model.compile(loss={'mult__model':tf.losses.MSE,'mult__model_1':tf.losses.MSE,'mult__model_2':tf.losses.MSE}, optimizer=opt, loss_weights=[100,0.5,30])
# my_model.compile(loss=tf.losses.MSE, optimizer=opt, metrics=[weighted_accuracy])
# my_model.compile(loss=tf.losses.MSE, optimizer=opt)
pre_model.fit(train_generator,validation_data=test_generator,steps_per_epoch=v_train.shape[0]//bat, validation_steps=v_val.shape[0]//bat,
epochs=epochs, batch_size=bat,callbacks=[early_stoping,history2])