Pytorch中torch.gather和torch.scatter函数理解

torch.gather()

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

参数解释:

  • input (Tensor) – the source tensor

  • dim (int) – the axis along which to index

  • index (LongTensor) – the indices of elements to gather

  • sparse_grad (bool,optional) – If True, gradient w.r.t. input will be a sparse tensor.

  • out (Tensor, optional) – the destination tensor

示例1:

t = torch.tensor([[1,2],[3,4]])
torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
tensor([[ 1,  1],
        [ 4,  3]])

解释:

gather的意思是聚集和取,即从input这个张量中取元素,而index则对应所取元素的下标。如果dim=0,那么index中的数值表示行坐标,如果dim=1,那么index中的数值表示列坐标。另外,index的shape和output的shape应该要一致。

以上述示例来说就是:index的第一行对应输出的第一行,其元素[0,0]就是从t中的第一行的下标为0的位置取其元素

示例2:

t = torch.tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
                  [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
                  [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])
torch.gather(t, 0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]))
tensor([[0.3992, 0.2908, 0.9044, 0.4850, 0.6004],
        [0.5735, 0.9006, 0.6797, 0.4152, 0.1732]])

torch.scatter()

torch.scatter_(input, dim, index, src) → Tensor

参数解释:

  • dim (int) – the axis along which to index

  • index (LongTensor) – the indices of elements to scatter, can be either empty or of the same dimensionality as src. When empty, the operation returns self unchanged.

  • src (Tensor or float) – the source element(s) to scatter.

  • reduce (str, optional) – reduction operation to apply, can be either ‘add’ or ‘multiply’.

示例1:

x = torch.rand(2, 5)
x
tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
        [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
        [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
        [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])

解释:

scatter可以理解为gather的反操作,即用src中的元素去替换input中的元素,而index中的数值则对应input元素的下标。如果dim=0,那么index中的数值表示横坐标,如果dim=1,那么index中的数值表示纵坐标。另外,output的shape和input的shape是一致的。

src = torch.arange(1, 11).reshape((2, 5))
src
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
------------------------------------------------
index = torch.tensor([[0, 1, 2, 0]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])
------------------------------------------------
index = torch.tensor([[0, 1, 2], [0, 1, 4]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
        [6, 7, 0, 0, 8],
        [0, 0, 0, 0, 0]])
------------------------------------------------
torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='multiply')
tensor([[2.0000, 2.0000, 2.4600, 2.0000],
        [2.0000, 2.0000, 2.0000, 2.4600]])
------------------------------------------------
torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='add')
tensor([[2.0000, 2.0000, 3.2300, 2.0000],
        [2.0000, 2.0000, 2.0000, 3.2300]])

参考链接:

https://zhuanlan.zhihu.com/p/187401278
https://www.cnblogs.com/dogecheng/p/11938009.html
https://wmathor.com/index.php/archives/1457/

上一篇:tf.gather,tf.gather_nd,tf.boolean_mask


下一篇:CSS父级边框塌陷问题