gather(1, action)返回的是形状(batch, 1)的tensor,作用是把tensor1的元素按照tensor2中的索引取出来,也就是取出实际执行的动作的logits。 下面详述gather的功能。 torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor 对于一个三阶张量,输出为 out[i][j][k]=input[index[i][j...
首先讨论当dim=0的情况: output = input.gather(dim=0, index) """ [[0, 9, 10, 15]] """ 当dim=0的时候,可以看作按照row来选择,其中index里的数值则表示选择了哪一行,举个栗子,原index的数值为[0.2.2.3],那么表示这四个数分别来自第0行、第2行、第2行、第3行,好了,那接下来行确定了,具体选...
首先,gather函数可以用于按索引取出指定位置的元素。在目标检测任务中,我们通常需要从一个包含所有检测框坐标的张量中,提取出指定索引的检测框坐标。具体代码如下: importtorch# 输入张量,形状为(100, 4)boxes=torch.randn(100,4)# 索引张量,形状为(10,)indices=torch.tensor([5,13,27,45,62,71,80,92,98,99...
torch.gather()的定义非常简洁: 在指定dim上,从原tensor中获取指定index的数据, 看到这个核心定义,我们很容易想到gather()的基本想法就是从完整数据中按索引取值,比如下面从列表中按索引取值: lst= [1,2,3,4,5]value= lst[2]# value = 3value= lst[2:4]# value = [3, 4] 上面的取值例子是取单个值...
在Pytorch中,gather函数用于从输入张量中选择索引对应的元素,并返回一个新的张量。要正确理解并使用gather函数,首先需要了解参数的含义和用法。gather函数的参数包括:input(输入张量),dim(指定索引的维度),index(用于选择元素的索引张量)。 使用时,可以通过如下步骤进行索引操作: ...
gather,顾名思义,聚集、集合。有点像军训的时候,排队一样,把队伍按照教官想要的顺序进行排列。 还有一个更恰当的比喻:gather的作用是根据索引查找,然后讲查找结果以张量矩阵的形式返回。 1. 拿到一个张量: importtorch a = torch.arange(15).view(3,5) ...
在PyTorch中,gather方法是一个非常有用的函数,它允许你从一个张量中按照索引取出对应的值。这个方法在深度学习中经常被用来实现诸如训练数据的采样、序列生成等操作。本文将详细介绍gather方法的作用,并通过代码示例展示它的用法。 gather方法的作用 gather方法的作用是根据给定的索引,从输入张量中取出对应位置的元素。具...
gather函数的使用:gather函数要求输入数据与索引数据具有相同的维度,如输入为二维数组,则索引也应为二维数组,但两者形状可以不同。以二维张量为例,创建一个索引数组。当dim=0时,根据索引选择行,索引中的数值表示选择的行号,以此类推,确定行后再依据原索引确定具体选择的元素。对于dim=1,操作同样...
gather函数用于从张量中提取元素。要使用gather函数,输入与索引必须具有相同的维度。例如,如果输入数据为二维,索引也应为二维,尽管它们的形状可以不同。函数输出与索引相同。以二维张量为例,创建一个索引。索引同样应为二维。当dim设置为0时,gather函数操作相当于在行上进行选择。假设索引值为[0, 2, ...
torch.gather(input,dim,index,*,sparse_grad=False,out=None)→ Tensor 这个函数我大概研究了半个小时,虽然明白了基本的运算方法,但是其具体用法还理解的不够深入,如果以后有心得的话,再和大家来交流。 根据gather函数的声明,可以看到gather函数主要由三个参数,分别是: ...