在机器学习多分类任务中有时候需要针对类别进行分层采样,比如说类别不均衡的数据,这时候随机采样会造成训练集、验证集、测试集中不同类别的数据比例不一样,这是会在一定程度上影响分类器的性能的,这时候就需要进行分层采样保证训练集、验证集、测试集中每一个类别的数据比例差不多持平。
下面python代码。
# 将数据按照类别进行分层划分def save_file_stratified(filename, ssdfile_dir, categories): """ 将文件分流到3个文件中 filename: 原数据地址,一个csv文件 文件内容格式: 类别\t内容 """ f_train = open('../data/usefuldata-711depart/train.txt', 'w', encoding='utf-8') f_val = open('../data/usefuldata-711depart/val.txt', 'w', encoding='utf-8') f_test = open('../data/usefuldata-711depart/test.txt', 'w', encoding='utf-8') # f_class = open('../data/usefuldata-37depart/class.txt', 'w', encoding='utf-8') dict_ssdqw = {} for ssdfile in os.listdir(ssdfile_dir): ssdfile_name = os.path.join(ssdfile_dir, ssdfile) f = open(ssdfile_name, 'r', encoding='utf-8') content_qw = '' content = f.readline() # 以下部分,因为统计整个案件基本情况他有换行,所以将多行处理在一行里面 while content: content_qw += content content_qw = content_qw.replace('\n', '') content = f.readline() ssdfile_key = str(ssdfile).replace('.txt','') dict_ssdqw[ssdfile_key] = content_qw # doc_count代表每一类数据总共有多少个 doc_count_0 = 0 doc_count_1 = 0 doc_count_2 = 0 doc_count_3 = 0 doc_count_4 = 0 doc_count_5 = 0 doc_count_6 = 0 doc_count_7 = 0 doc_count_8 = 0 doc_count_9 = 0 doc_count_10 = 0 doc_count_11 = 0 doc_count_12 = 0 temp_file = open(filename, 'r', encoding='utf-8') line = temp_file.readline() while line: line_content = line.split(',') name = line_content[0] if name in dict_ssdqw: label = line_content[1] if label == categories[0]: doc_count_0 += 1 elif label == categories[1]: doc_count_1 += 1 elif label == categories[2]: doc_count_2 += 1 elif label == categories[3]: doc_count_3 += 1 elif label == categories[4]: doc_count_4 += 1 elif label == categories[5]: doc_count_5 += 1 elif label == categories[6]: doc_count_6 += 1 elif label == categories[7]: doc_count_7 += 1 elif label == categories[8]: doc_count_8 += 1 elif label == categories[9]: doc_count_9 += 1 elif label == categories[10]: doc_count_10 += 1 elif label == categories[11]: doc_count_11 += 1 elif label == categories[12]: doc_count_12 += 1 line = temp_file.readline() temp_file.close() # 总数量 doc_count = doc_count_0 + doc_count_1 + doc_count_2 + doc_count_3 +\ doc_count_4 + doc_count_5 + doc_count_6 + doc_count_7 +\ doc_count_8 + doc_count_9 + doc_count_10 + doc_count_11 + doc_count_12 class_set = set() tag_train_0 = doc_count_0 * 70 / 100 tag_train_1 = doc_count_1 * 70 / 100 tag_train_2 = doc_count_2 * 70 / 100 tag_train_3 = doc_count_3 * 70 / 100 tag_train_4 = doc_count_4 * 70 / 100 tag_train_5 = doc_count_5 * 70 / 100 tag_train_6 = doc_count_6 * 70 / 100 tag_train_7 = doc_count_7 * 70 / 100 tag_train_8 = doc_count_8 * 70 / 100 tag_train_9 = doc_count_9 * 70 / 100 tag_train_10 = doc_count_10 * 70 / 100 tag_train_11= doc_count_11 * 70 / 100 tag_train_12 = doc_count_12 * 70 / 100 tag_val_0 = doc_count_0 * 85 / 100 tag_val_1 = doc_count_1 * 85 / 100 tag_val_2 = doc_count_2 * 85 / 100 tag_val_3 = doc_count_3 * 85 / 100 tag_val_4 = doc_count_4 * 85 / 100 tag_val_5 = doc_count_5 * 85 / 100 tag_val_6 = doc_count_6 * 85 / 100 tag_val_7 = doc_count_7 * 85 / 100 tag_val_8 = doc_count_8 * 85 / 100 tag_val_9 = doc_count_9 * 85 / 100 tag_val_10 = doc_count_10 * 85 / 100 tag_val_11 = doc_count_11 * 85 / 100 tag_val_12 = doc_count_12 * 85 / 100 # tag_test = doc_count * 70 / 100 tag_0 = 0 tag_1 = 0 tag_2 = 0 tag_3 = 0 tag_4 = 0 tag_5 = 0 tag_6 = 0 tag_7 = 0 tag_8 = 0 tag_9 = 0 tag_10 = 0 tag_11 = 0 tag_12 = 0 # 有些文书行业标记是空!!我想看看有多少条? blank_tag = 0 # 标记一下,每个类别有多少个训练集、验证集、测试集? train_class_tag = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] val_class_tag = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] test_class_tag = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] # csvfile = open(filename, 'r', encoding='utf-8') txtfile = open(filename, 'r', encoding='utf-8') process_line = txtfile.readline() while process_line: line_content = process_line.split(',') name = line_content[0] if name in dict_ssdqw: content = dict_ssdqw[name] label = line_content[1] # if label != '' and label != '其他行业': if label != '': class_set.add(label) # 对每一类进行分层采样 if label == categories[0]: tag_0 += 1 if tag_0 < tag_train_0: f_train.write(label + '\t' + content + '\n') train_class_tag[0] += 1 elif tag_0 < tag_val_0: f_val.write(label + '\t' + content + '\n') val_class_tag[0] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[0] += 1 elif label == categories[1]: tag_1 += 1 if tag_1 < tag_train_1: f_train.write(label + '\t' + content + '\n') train_class_tag[1] += 1 elif tag_1 < tag_val_1: f_val.write(label + '\t' + content + '\n') val_class_tag[1] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[1] += 1 elif label == categories[2]: tag_2 += 1 if tag_2 < tag_train_2: f_train.write(label + '\t' + content + '\n') train_class_tag[2] += 1 elif tag_2 < tag_val_2: f_val.write(label + '\t' + content + '\n') val_class_tag[2] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[2] += 1 elif label == categories[3]: tag_3 += 1 if tag_3 < tag_train_3: f_train.write(label + '\t' + content + '\n') train_class_tag[3] += 1 elif tag_3 < tag_val_3: f_val.write(label + '\t' + content + '\n') val_class_tag[3] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[3] += 1 elif label == categories[4]: tag_4 += 1 if tag_4 < tag_train_4: f_train.write(label + '\t' + content + '\n') train_class_tag[4] += 1 elif tag_4 < tag_val_4: f_val.write(label + '\t' + content + '\n') val_class_tag[4] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[4] += 1 elif label == categories[5]: tag_5 += 1 if tag_5 < tag_train_5: f_train.write(label + '\t' + content + '\n') train_class_tag[5] += 1 elif tag_5 < tag_val_5: f_val.write(label + '\t' + content + '\n') val_class_tag[5] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[5] += 1 elif label == categories[6]: tag_6 += 1 if tag_6 < tag_train_6: f_train.write(label + '\t' + content + '\n') train_class_tag[6] += 1 elif tag_6 < tag_val_6: f_val.write(label + '\t' + content + '\n') val_class_tag[6] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[6] += 1 elif label == categories[7]: tag_7 += 1 if tag_7 < tag_train_7: f_train.write(label + '\t' + content + '\n') train_class_tag[7] += 1 elif tag_7 < tag_val_7: f_val.write(label + '\t' + content + '\n') val_class_tag[7] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[7] += 1 elif label == categories[8]: tag_8 += 1 if tag_8 < tag_train_8: f_train.write(label + '\t' + content + '\n') train_class_tag[8] += 1 elif tag_8 < tag_val_8: f_val.write(label + '\t' + content + '\n') val_class_tag[8] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[8] += 1 elif label == categories[9]: tag_9 += 1 if tag_9 < tag_train_9: f_train.write(label + '\t' + content + '\n') train_class_tag[9] += 1 elif tag_9 < tag_val_9: f_val.write(label + '\t' + content + '\n') val_class_tag[9] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[9] += 1 elif label == categories[10]: tag_10 += 1 if tag_10 < tag_train_10: f_train.write(label + '\t' + content + '\n') train_class_tag[10] += 1 elif tag_10 < tag_val_10: f_val.write(label + '\t' + content + '\n') val_class_tag[10] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[10] += 1 elif label == categories[11]: tag_11 += 1 if tag_11 < tag_train_11: f_train.write(label + '\t' + content + '\n') train_class_tag[11] += 1 elif tag_11 < tag_val_11: f_val.write(label + '\t' + content + '\n') val_class_tag[11] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[11] += 1 elif label == categories[12]: tag_12 += 1 if tag_12 < tag_train_12: f_train.write(label + '\t' + content + '\n') train_class_tag[12] += 1 elif tag_12 < tag_val_12: f_val.write(label + '\t' + content + '\n') val_class_tag[12] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[12] += 1 else: blank_tag += 1 process_line = txtfile.readline() txtfile.close() print("有" + str(blank_tag) + "个文书的行业标记为空!") print("train:") print(train_class_tag) train_tag_total =0 for i_total in train_class_tag: train_tag_total += i_total train_class_tag_distribute = [] for i in train_class_tag: train_class_tag_distribute.append((i / train_tag_total) * 100) print("分布:") print(train_class_tag_distribute) print("val:") print(val_class_tag) val_tag_total = 0 for i_total in val_class_tag: val_tag_total += i_total val_class_tag_distribute = [] for i in val_class_tag: val_class_tag_distribute.append((i / val_tag_total) * 100) print("分布:") print(val_class_tag_distribute) print("test:") print(test_class_tag) test_tag_total = 0 for i_total in test_class_tag: test_tag_total += i_total test_class_tag_distribute = [] for i in test_class_tag: test_class_tag_distribute.append((i / test_tag_total) * 100) print("分布:") print(test_class_tag_distribute) f_train.close() f_test.close() f_val.close()if __name__ == '__main__': categories = [ "class1", "class2", "class3", "class4", "class5", "class6", "class7", "class8", "class9", "class10", "class11", "class12", "class13" ] save_file_stratified('../data/qwdata/shuffle-try3/classified_table_ms.txt', '../data/qwdata/ms-ygscplusssdqw',categories)
后面可以看到类别划分
这里要注意的一点是:这是我早期写的文章,需要注意的一点是,我们通常在训练集和验证集上做分层采样即可,测试集最好保持原样不要动。