torch.gather

函数定义

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

Gathers values along an axis specified by dim.

对于一个3-D的张量,输出按照以下公式被指定为:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

函数参数

  • input (Tensor) – the source tensor

  • dim (int) – the axis along which to index

  • index (LongTensor) – the indices of elements to gather

  • sparse_grad (bool, optional) – If True, gradient w.r.t. input will be a sparse tensor.
  • out (Tensoroptional) – the destination tensor

函数参数说明

  • 参数input和参数index必须拥有相同数量的维度,并且要求index.size(d) <= input.size(d)对于所有的维度d != dim。
  • out将会拥有和index一样的形状。
  • 参数input和参数index不能彼此进行广播

例子

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1,  1],
        [ 4,  3]])
上一篇:MySQL的删除语句


下一篇:MySQL的插入语句