pytorch的gather函数的一些粗略的理解

pytorch的gather函数的一些理解

先给出官方文档的解释,我觉得官方的文档写的已经很清楚了,四个参数分别是input,dim,index,out,输出的tensor是以index为大小的tensor。
在这里插入图片描述
其中,这就是最关键的定义

1out[i][j][k] = tensor[index[i][j][k]][j][k] # dim=0 2out[i][j][k] = tensor[i][index[i][j][k]][k] # dim=1 3out[i][j][k] = tensor[i][j][index[i][j][k]] # dim=3 4 5

主要解释一下dim,dim=0的时候,把index的元素放入进行索引,有一点需要注意的是,参数index的tensor格式是除了第1维也就是行那一维之外,其他维的格式需与input保持一致!下面给个例子

1import torch 2 3a = torch.arange(0, 16).view(4,4) 4 5index = torch.LongTensor([[0,1,2,3]]) 6 7b = a.gather(0, index) 8print(a) 9print(index) 10print(b) 11 12#形象的理解就是在每一列的第index[]上进行索引 13for j in range(4): 14 print(a[index[0][j]][j].item()) 15-------------------------------------------------------------------- 16tensor([[ 0, 1, 2, 3], 17 [ 4, 5, 6, 7], 18 [ 8, 9, 10, 11], 19 [12, 13, 14, 15]]) 20tensor([[0, 1, 2, 3]]) 21tensor([[ 0, 5, 10, 15]]) 220 235 2410 2515 26 27

dim = 1的时候,把index的元素放入进行索引,有一点需要注意的是,参数index的tensor格式是除了第2维也就是列那一维之外,其他维的格式需与input保持一致!下面给个例子

1import torch 2 3a = torch.arange(0, 16).view(4,4) 4 5index = torch.LongTensor([[0],[1],[2],[3]]) 6 7b = a.gather(1, index) 8print(a) 9print(index) 10print(b) 11 12#形象的理解就是在每一行的第index[]列上进行索引 13for j in range(4): 14 print(a[j][index[j][0]].item()) 15-------------------------------------------------------------------- 16tensor([[ 0, 1, 2, 3], 17 [ 4, 5, 6, 7], 18 [ 8, 9, 10, 11], 19 [12, 13, 14, 15]]) 20tensor([[0], 21 [1], 22 [2], 23 [3]]) 24tensor([[ 0], 25 [ 5], 26 [10], 27 [15]]) 280 295 3010 3115 32 33

本人对矩阵的一些概念还有一些模糊不清,以上就是我的一些理解,希望有大佬可以一起交流一下,pytorch 的张量一开始很难处理清楚,还需慢慢来。

代码交流 2021