def split_train_test(text_df, size=0.8):
"""
分割训练集和测试集
"""
# 为保证每个类中的数据能在训练集中和测试集中的比例相同,所以需要依次对每个类进行处理
train_text_df = pd.DataFrame()
test_text_df = pd.DataFrame()
labels = [0, 1, 2, 3]
for label in labels:
# 找出label的记录
text_df_w_label = text_df[text_df['label'] == label]
# 重新设置索引,保证每个类的记录是从0开始索引,方便之后的拆分
text_df_w_label = text_df_w_label.reset_index()
# 默认按80%训练集,20%测试集分割
# 这里为了简化操作,取前80%放到训练集中,后20%放到测试集中
# 当然也可以随机拆分80%,20%(尝试实现下DataFrame中的随机拆分)
# 该类数据的行数
n_lines = text_df_w_label.shape[0]
split_line_no = math.floor(n_lines * size)
text_df_w_label_train = text_df_w_label.iloc[:split_line_no, :]
text_df_w_label_test = text_df_w_label.iloc[split_line_no:, :]
# 放入整体训练集,测试集中
train_text_df = train_text_df.append(text_df_w_label_train)
test_text_df = test_text_df.append(text_df_w_label_test)
train_text_df = train_text_df.reset_index()
test_text_df = test_text_df.reset_index()
return train_text_df, test_text_df