pytorchsampler对数据进行采样的实现-创新互联

PyTorch中还单独提供了一个sampler模块,用来对数据进行采样。常用的有随机采样器:RandomSampler,当dataloader的shuffle参数为True时,系统会自动调用这个采样器,实现打乱数据。默认的是采用SequentialSampler,它会按顺序一个一个进行采样。这里介绍另外一个很有用的采样方法: WeightedRandomSampler,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样。

成都创新互联公司是一家专注于网站设计、网站制作与策划设计,广汉网站建设哪家好?成都创新互联公司做网站,专注于网站建设十年,网设计领域的专业建站公司;建站业务涵盖:广汉等地区。广汉做网站价格咨询:028-86922220

构建WeightedRandomSampler时需提供两个参数:每个样本的权重weights、共选取的样本总数num_samples,以及一个可选参数replacement。权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部的样本数目。replacement用于指定是否可以重复选取某一个样本,默认为True,即允许在一个epoch中重复采样某一个数据。如果设为False,则当某一类的样本被全部选取完,但其样本数目仍未达到num_samples时,sampler将不会再从该类中选择数据,此时可能导致weights参数失效。

下面举例说明。

from dataSet import *
dataset = DogCat('data/dogcat/', transform=transform)

from torch.utils.data import DataLoader
# 狗的图片被取出的概率是猫的概率的两倍
# 两类图片被取出的概率与weights的绝对大小无关,只和比值有关
weights = [2 if label == 1 else 1 for data, label in dataset]

print(weights)

from torch.utils.data.sampler import WeightedRandomSampler
sampler = WeightedRandomSampler(weights,\
                num_samples=9,\
                replacement=True)
dataloader = DataLoader(dataset,
            batch_size=3,
            sampler=sampler)
for datas, labels in dataloader:
  print(labels.tolist())

当前文章:pytorchsampler对数据进行采样的实现-创新互联
当前URL:http://bzwzjz.com/article/jgjid.html

其他资讯

Copyright © 2007-2020 广东宝晨空调科技有限公司 All Rights Reserved 粤ICP备2022107769号
友情链接: 成都做网站建设公司 网站建设公司 成都营销网站建设 重庆手机网站建设 成都网站设计公司 网站建设 网站制作报价 企业网站设计 四川成都网站制作 定制级高端网站建设 成都网站建设 手机网站制作 响应式网站建设 成都模版网站建设 重庆网站制作 重庆网站制作 品牌网站建设 阿坝网站设计 成都响应式网站建设公司 成都网站制作 成都网站建设 成都网站制作