pytorch中masked_fill报错怎么办-创新互联

小编给大家分享一下pytorch中masked_fill报错怎么办,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!

创新互联专注于西平企业网站建设,响应式网站开发,商城建设。西平网站建设公司,为西平等地区提供建站服务。全流程定制开发,专业设计,全程项目跟踪,创新互联专业和态度为您提供的服务

如下所示:

import torch.nn.functional as F
import numpy as np
a = torch.Tensor([1,2,3,4])
a = a.masked_fill(mask = torch.ByteTensor([1,1,0,0]), value=-np.inf)
 
print(a)
b = F.softmax(a)

print(b)

tensor([-inf, -inf, 3., 4.])
d:/pycharmdaima/star-transformer/ceshi.py:8: UserWarning: Implicit dimension choice for softmax has been deprecated. Change
the call to include dim=X as an argument.
b = F.softmax(a)
tensor([0.0000, 0.0000, 0.2689, 0.7311])

容易报错:

Expected object of scalar type Byte but got scalar type Long for argument #2 'mask'

原因,

mask = torch.LongTensor()

解决方法:

mask = torch.ByteTensor()

在mask值为1的位置处用value填充。mask的元素个数需和本tensor相同,但尺寸可以不同

以上是“pytorch中masked_fill报错怎么办”这篇文章的所有内容,感谢各位的阅读!相信大家都有了一定的了解,希望分享的内容对大家有所帮助,如果还想学习更多知识,欢迎关注创新互联行业资讯频道!


标题名称:pytorch中masked_fill报错怎么办-创新互联
分享网址:http://bzwzjz.com/article/dhedch.html

其他资讯

Copyright © 2007-2020 广东宝晨空调科技有限公司 All Rights Reserved 粤ICP备2022107769号
友情链接: 泸州网站建设 手机网站制作 成都网站建设 企业网站设计 重庆网站建设 成都网站建设 成都网站建设公司 成都网站设计 成都品牌网站建设 成都响应式网站建设 企业手机网站建设 成都网站建设 高端网站建设 网站建设费用 成都网站制作公司 重庆手机网站建设 商城网站建设 成都网站设计公司 定制网站设计 专业网站设计 企业网站设计 成都网站建设公司