pytorch中的广播语义

pytorch中的广播语义

目录

1、什么是广播语义?

2、广播语义的规则

3、不符合广播语义的例子

4、符合广播语义的例子

pytorch的广播语义(broadcasting semantics),和numpy的很像,所以可以先看看numpy的文档:

1、什么是广播语义?

官方文档有这样一个解释:

In short, if a PyTorch operation supports broadcast, then its Tensor arguments can be automatically expanded to be of equal sizes (without making copies of the data).

这句话的意思大概是:简单的说,如果一个pytorch操作支持广播,那么它的Tensor参数可以自动的扩展为相同的尺寸(不需要复制数据)。

按照我的理解,应该是指算法计算过程中,不同的Tensor如果size不同,但是符合一定的规则,那么可以自动的进行维度扩展,来实现Tensor的计算。在维度扩展的过程中,并不是真的把维度小的Tensor复制为和维度大的Tensor相同,因为这样太浪费内存了。

2、广播语义的规则

首先来看标准的情况,两个Tensor的size相同,则可以直接计算:

x = torch.empty((4, 2, 3)) y = torch.empty((4, 2, 3))  print((x+y).size()) 

输出:

torch.Size([4, 2, 3]) 

但是,如果两个Tensor的维度并不相同,pytorch也是可以根据下面的两个法则进行计算:

(1)Each tensor has at least one dimension.

(2)When iterating over the dimension sizes, starting at the trailing dimension, the dimension sizes must either be equal, one of them is 1, or one of them does not exist.

每个Tensor至少有一个维度。

迭代标注尺寸时,从后面的标注开始

第一个规则要求每个参与计算的Tensor至少有一个维度,第二个规则是指在维度迭代时,从最后一个维度开始,可以有三种情况:

维度相等

其中一个维度是1

其中一个维度不存在

3、不符合广播语义的例子 x = torch.empty((0, )) y = torch.empty((2, 3))  print((x + y).size())

输出:

RuntimeError: The size of tensor a (0) must match  the size of tensor b (3) at non-singleton dimension 1 

这里,不满足第一个规则“每个参与计算的Tensor至少有一个维度”。

x = torch.empty(5, 2, 4, 1)  y = torch.empty(3, 1, 1)  print((x + y).size())

输出:

RuntimeError: The size of tensor a (2) must match 
the size of tensor b (3) at non-singleton dimension 1 

这里,不满足第二个规则,因为从最后的维度开始迭代的过程中,倒数第三个维度:x是2,y是3。这并不符合第二条规则的三种情况,所以不能使用广播语义。

4、符合广播语义的例子 x = torch.empty(5, 3, 4, 1)  y = torch.empty(3, 1, 1)  print((x + y).size()) 

输出:

torch.Size([5, 3, 4, 1]) 

x是四维的,y是三维的,从最后一个维度开始迭代:

最后一维:x是1,y是1,满足规则二 

倒数第二维:x是4,y是1,满足规则二 

倒数第三维:x是3,y是3,满足规则一

倒数第四维:x是5,y是0,满足规则一 

 到此这篇关于pytorch中的广播语义的文章就介绍到这了,更多相关pytorch广播语义内容请搜索易知道(ezd.cc)以前的文章或继续浏览下面的相关文章希望大家以后多多支持易知道(ezd.cc)!

推荐阅读