博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
python 多分类任务中按照类别分层采样
阅读量:4920 次
发布时间:2019-06-11

本文共 13381 字,大约阅读时间需要 44 分钟。

 

在机器学习多分类任务中有时候需要针对类别进行分层采样,比如说类别不均衡的数据,这时候随机采样会造成训练集、验证集、测试集中不同类别的数据比例不一样,这是会在一定程度上影响分类器的性能的,这时候就需要进行分层采样保证训练集、验证集、测试集中每一个类别的数据比例差不多持平。

下面python代码。

# 将数据按照类别进行分层划分def save_file_stratified(filename, ssdfile_dir, categories):    """    将文件分流到3个文件中    filename: 原数据地址,一个csv文件    文件内容格式:  类别\t内容    """    f_train = open('../data/usefuldata-711depart/train.txt', 'w', encoding='utf-8')    f_val = open('../data/usefuldata-711depart/val.txt', 'w', encoding='utf-8')    f_test = open('../data/usefuldata-711depart/test.txt', 'w', encoding='utf-8')    # f_class = open('../data/usefuldata-37depart/class.txt', 'w', encoding='utf-8')    dict_ssdqw = {}    for ssdfile in os.listdir(ssdfile_dir):        ssdfile_name = os.path.join(ssdfile_dir, ssdfile)        f = open(ssdfile_name, 'r', encoding='utf-8')        content_qw = ''        content = f.readline()        # 以下部分,因为统计整个案件基本情况他有换行,所以将多行处理在一行里面        while content:            content_qw += content            content_qw = content_qw.replace('\n', '')            content = f.readline()        ssdfile_key = str(ssdfile).replace('.txt','')        dict_ssdqw[ssdfile_key] = content_qw    # doc_count代表每一类数据总共有多少个    doc_count_0 = 0    doc_count_1 = 0    doc_count_2 = 0    doc_count_3 = 0    doc_count_4 = 0    doc_count_5 = 0    doc_count_6 = 0    doc_count_7 = 0    doc_count_8 = 0    doc_count_9 = 0    doc_count_10 = 0    doc_count_11 = 0    doc_count_12 = 0    temp_file = open(filename, 'r', encoding='utf-8')    line = temp_file.readline()    while line:        line_content = line.split(',')        name = line_content[0]        if name in dict_ssdqw:            label = line_content[1]            if label == categories[0]:                doc_count_0 += 1            elif label == categories[1]:                doc_count_1 += 1            elif label == categories[2]:                doc_count_2 += 1            elif label == categories[3]:                doc_count_3 += 1            elif label == categories[4]:                doc_count_4 += 1            elif label == categories[5]:                doc_count_5 += 1            elif label == categories[6]:                doc_count_6 += 1            elif label == categories[7]:                doc_count_7 += 1            elif label == categories[8]:                doc_count_8 += 1            elif label == categories[9]:                doc_count_9 += 1            elif label == categories[10]:                doc_count_10 += 1            elif label == categories[11]:                doc_count_11 += 1            elif label == categories[12]:                doc_count_12 += 1        line = temp_file.readline()    temp_file.close()    # 总数量    doc_count = doc_count_0 + doc_count_1 + doc_count_2 + doc_count_3 +\        doc_count_4 + doc_count_5 + doc_count_6 + doc_count_7 +\        doc_count_8 + doc_count_9 + doc_count_10 + doc_count_11 + doc_count_12    class_set = set()    tag_train_0 = doc_count_0 * 70 / 100    tag_train_1 = doc_count_1 * 70 / 100    tag_train_2 = doc_count_2 * 70 / 100    tag_train_3 = doc_count_3 * 70 / 100    tag_train_4 = doc_count_4 * 70 / 100    tag_train_5 = doc_count_5 * 70 / 100    tag_train_6 = doc_count_6 * 70 / 100    tag_train_7 = doc_count_7 * 70 / 100    tag_train_8 = doc_count_8 * 70 / 100    tag_train_9 = doc_count_9 * 70 / 100    tag_train_10 = doc_count_10 * 70 / 100    tag_train_11= doc_count_11 * 70 / 100    tag_train_12 = doc_count_12 * 70 / 100    tag_val_0 = doc_count_0 * 85 / 100    tag_val_1 = doc_count_1 * 85 / 100    tag_val_2 = doc_count_2 * 85 / 100    tag_val_3 = doc_count_3 * 85 / 100    tag_val_4 = doc_count_4 * 85 / 100    tag_val_5 = doc_count_5 * 85 / 100    tag_val_6 = doc_count_6 * 85 / 100    tag_val_7 = doc_count_7 * 85 / 100    tag_val_8 = doc_count_8 * 85 / 100    tag_val_9 = doc_count_9 * 85 / 100    tag_val_10 = doc_count_10 * 85 / 100    tag_val_11 = doc_count_11 * 85 / 100    tag_val_12 = doc_count_12 * 85 / 100    # tag_test = doc_count * 70 / 100    tag_0 = 0    tag_1 = 0    tag_2 = 0    tag_3 = 0    tag_4 = 0    tag_5 = 0    tag_6 = 0    tag_7 = 0    tag_8 = 0    tag_9 = 0    tag_10 = 0    tag_11 = 0    tag_12 = 0    # 有些文书行业标记是空!!我想看看有多少条?    blank_tag = 0    # 标记一下,每个类别有多少个训练集、验证集、测试集?    train_class_tag = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]    val_class_tag = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]    test_class_tag = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]    # csvfile = open(filename, 'r', encoding='utf-8')    txtfile = open(filename, 'r', encoding='utf-8')    process_line = txtfile.readline()    while process_line:        line_content = process_line.split(',')        name = line_content[0]        if name in dict_ssdqw:            content = dict_ssdqw[name]            label = line_content[1]            # if label != '' and label != '其他行业':            if label != '':                class_set.add(label)                # 对每一类进行分层采样                if label == categories[0]:                    tag_0 += 1                    if tag_0 < tag_train_0:                        f_train.write(label + '\t' + content + '\n')                        train_class_tag[0] += 1                    elif tag_0 < tag_val_0:                        f_val.write(label + '\t' + content + '\n')                        val_class_tag[0] += 1                    else:                        f_test.write(label + '\t' + content + '\n')                        test_class_tag[0] += 1                elif label == categories[1]:                    tag_1 += 1                    if tag_1 < tag_train_1:                        f_train.write(label + '\t' + content + '\n')                        train_class_tag[1] += 1                    elif tag_1 < tag_val_1:                        f_val.write(label + '\t' + content + '\n')                        val_class_tag[1] += 1                    else:                        f_test.write(label + '\t' + content + '\n')                        test_class_tag[1] += 1                elif label == categories[2]:                    tag_2 += 1                    if tag_2 < tag_train_2:                        f_train.write(label + '\t' + content + '\n')                        train_class_tag[2] += 1                    elif tag_2 < tag_val_2:                        f_val.write(label + '\t' + content + '\n')                        val_class_tag[2] += 1                    else:                        f_test.write(label + '\t' + content + '\n')                        test_class_tag[2] += 1                elif label == categories[3]:                    tag_3 += 1                    if tag_3 < tag_train_3:                        f_train.write(label + '\t' + content + '\n')                        train_class_tag[3] += 1                    elif tag_3 < tag_val_3:                        f_val.write(label + '\t' + content + '\n')                        val_class_tag[3] += 1                    else:                        f_test.write(label + '\t' + content + '\n')                        test_class_tag[3] += 1                elif label == categories[4]:                    tag_4 += 1                    if tag_4 < tag_train_4:                        f_train.write(label + '\t' + content + '\n')                        train_class_tag[4] += 1                    elif tag_4 < tag_val_4:                        f_val.write(label + '\t' + content + '\n')                        val_class_tag[4] += 1                    else:                        f_test.write(label + '\t' + content + '\n')                        test_class_tag[4] += 1                elif label == categories[5]:                    tag_5 += 1                    if tag_5 < tag_train_5:                        f_train.write(label + '\t' + content + '\n')                        train_class_tag[5] += 1                    elif tag_5 < tag_val_5:                        f_val.write(label + '\t' + content + '\n')                        val_class_tag[5] += 1                    else:                        f_test.write(label + '\t' + content + '\n')                        test_class_tag[5] += 1                elif label == categories[6]:                    tag_6 += 1                    if tag_6 < tag_train_6:                        f_train.write(label + '\t' + content + '\n')                        train_class_tag[6] += 1                    elif tag_6 < tag_val_6:                        f_val.write(label + '\t' + content + '\n')                        val_class_tag[6] += 1                    else:                        f_test.write(label + '\t' + content + '\n')                        test_class_tag[6] += 1                elif label == categories[7]:                    tag_7 += 1                    if tag_7 < tag_train_7:                        f_train.write(label + '\t' + content + '\n')                        train_class_tag[7] += 1                    elif tag_7 < tag_val_7:                        f_val.write(label + '\t' + content + '\n')                        val_class_tag[7] += 1                    else:                        f_test.write(label + '\t' + content + '\n')                        test_class_tag[7] += 1                elif label == categories[8]:                    tag_8 += 1                    if tag_8 < tag_train_8:                        f_train.write(label + '\t' + content + '\n')                        train_class_tag[8] += 1                    elif tag_8 < tag_val_8:                        f_val.write(label + '\t' + content + '\n')                        val_class_tag[8] += 1                    else:                        f_test.write(label + '\t' + content + '\n')                        test_class_tag[8] += 1                elif label == categories[9]:                    tag_9 += 1                    if tag_9 < tag_train_9:                        f_train.write(label + '\t' + content + '\n')                        train_class_tag[9] += 1                    elif tag_9 < tag_val_9:                        f_val.write(label + '\t' + content + '\n')                        val_class_tag[9] += 1                    else:                        f_test.write(label + '\t' + content + '\n')                        test_class_tag[9] += 1                elif label == categories[10]:                    tag_10 += 1                    if tag_10 < tag_train_10:                        f_train.write(label + '\t' + content + '\n')                        train_class_tag[10] += 1                    elif tag_10 < tag_val_10:                        f_val.write(label + '\t' + content + '\n')                        val_class_tag[10] += 1                    else:                        f_test.write(label + '\t' + content + '\n')                        test_class_tag[10] += 1                elif label == categories[11]:                    tag_11 += 1                    if tag_11 < tag_train_11:                        f_train.write(label + '\t' + content + '\n')                        train_class_tag[11] += 1                    elif tag_11 < tag_val_11:                        f_val.write(label + '\t' + content + '\n')                        val_class_tag[11] += 1                    else:                        f_test.write(label + '\t' + content + '\n')                        test_class_tag[11] += 1                elif label == categories[12]:                    tag_12 += 1                    if tag_12 < tag_train_12:                        f_train.write(label + '\t' + content + '\n')                        train_class_tag[12] += 1                    elif tag_12 < tag_val_12:                        f_val.write(label + '\t' + content + '\n')                        val_class_tag[12] += 1                    else:                        f_test.write(label + '\t' + content + '\n')                        test_class_tag[12] += 1            else:                blank_tag += 1        process_line = txtfile.readline()    txtfile.close()    print("有" + str(blank_tag) + "个文书的行业标记为空!")    print("train:")    print(train_class_tag)    train_tag_total =0    for i_total in train_class_tag:        train_tag_total += i_total    train_class_tag_distribute = []    for i in train_class_tag:        train_class_tag_distribute.append((i / train_tag_total) * 100)    print("分布:")    print(train_class_tag_distribute)    print("val:")    print(val_class_tag)    val_tag_total = 0    for i_total in val_class_tag:        val_tag_total += i_total    val_class_tag_distribute = []    for i in val_class_tag:        val_class_tag_distribute.append((i / val_tag_total) * 100)    print("分布:")    print(val_class_tag_distribute)    print("test:")    print(test_class_tag)    test_tag_total = 0    for i_total in test_class_tag:        test_tag_total += i_total    test_class_tag_distribute = []    for i in test_class_tag:        test_class_tag_distribute.append((i / test_tag_total) * 100)    print("分布:")    print(test_class_tag_distribute)    f_train.close()    f_test.close()    f_val.close()if __name__ == '__main__':    categories = [        "class1",        "class2",        "class3",        "class4",        "class5",        "class6",        "class7",        "class8",        "class9",        "class10",        "class11",        "class12",        "class13"    ]    save_file_stratified('../data/qwdata/shuffle-try3/classified_table_ms.txt', '../data/qwdata/ms-ygscplusssdqw',categories)
View Code

 

后面可以看到类别划分

 


 

这里要注意的一点是:这是我早期写的文章,需要注意的一点是,我们通常在训练集和验证集上做分层采样即可,测试集最好保持原样不要动。

转载于:https://www.cnblogs.com/zhouxiaosong/p/11113959.html

你可能感兴趣的文章
浏览器内核引擎
查看>>
SqlServer中怎么删除重复的记录(表中没有id)
查看>>
操作系统基础知识之————单线程(Thread)与多线程的区别
查看>>
PAT 1022 Digital Library[map使用]
查看>>
由于目标计算机积极拒绝,无法连接。
查看>>
hive常用命令
查看>>
Nmap使用教程 - 一
查看>>
java深入解析
查看>>
js返回上一页并刷新的几种方法
查看>>
POJ 3320 Jessica's Reading Problem 尺取法
查看>>
Unity Json 之三
查看>>
linux java -jar startup.sh
查看>>
DDD的思考
查看>>
类型转换及返回json对象的问题
查看>>
模拟题 找出不能拼凑的最小数
查看>>
ivew实现table的编辑保存追加删除
查看>>
poj 1904(强连通分量+输入输出外挂)
查看>>
Ubuntu重启关机命令使用
查看>>
第5章 不要让线程成为脱缰的野马(Keeping your Threads on Leash) ---干净的终止一个线程...
查看>>
shell $() vs ${}
查看>>