本文介绍了尺寸为M<;32的火炬张量分度错误?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着跟版网的小编来一起学习吧!
问题描述
我正在尝试通过索引矩阵访问pytorch张量,但我最近发现这段代码找不到无法工作的原因。
下面的代码分为两部分。前半部分被证明是有效的,而后半部分是错误的。我看不出原因。有没有人能解释一下这件事?import torch
import numpy as np
a = torch.rand(32, 16)
m, n = a.shape
xx, yy = np.meshgrid(np.arange(m), np.arange(m))
result = a[xx] # WORKS for a torch.tensor of size M >= 32. It doesn't work otherwise.
a = torch.rand(16, 16)
m, n = a.shape
xx, yy = np.meshgrid(np.arange(m), np.arange(m))
result = a[xx] # IndexError: too many indices for tensor of dimension 2
如果我更改a = np.random.rand(16, 16)
,它也可以正常工作。
推荐答案
首先,让我快速了解一下如何使用一个数值数组和另一个张量来索引张量。
示例:这是我们要索引的目标张量
numpy_indices = torch.tensor([[0, 1, 2, 7],
[0, 1, 2, 3]]) # numpy array
tensor_indices = torch.tensor([[0, 1, 2, 7],
[0, 1, 2, 3]]) # 2D tensor
t = torch.tensor([[1, 2, 3, 4], # targeted tensor
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24],
[25, 26, 27, 28],
[29, 30, 31, 32]])
numpy_result = t[numpy_indices]
tensor_result = t[tensor_indices]
使用2D数值数组编制索引:索引的读取方式类似于成对(x,y)张量[行,列],例如
t[0,0], t[1,1], t[2,2], and t[7,3]
。print(numpy_result) # tensor([ 1, 6, 11, 32])
使用2D张量进行索引:以行的方式遍历索引张量,每个值都是目标张量中一行的索引。 例如
[ [t[0],t[1],t[2],[7]] , [[0],[1],[2],[3]] ]
参见下例,索引后的tensor_result
的新形状为(tensor_indices.shape[0],tensor_indices.shape[1],t.shape[1])=(2,4,4)
。print(tensor_result) # tensor([[[ 1, 2, 3, 4], # [ 5, 6, 7, 8], # [ 9, 10, 11, 12], # [29, 30, 31, 32]], # [[ 1, 2, 3, 4], # [ 5, 6, 7, 8], # [ 9, 10, 11, 12], # [ 13, 14, 15, 16]]])
如果您尝试在numpy_indices
中添加第三行,您将收到相同的错误,因为索引将由3D表示,例如,(0,0,0).(7,3,3)。
indices = np.array([[0, 1, 2, 7],
[0, 1, 2, 3],
[0, 1, 2, 3]])
print(numpy_result) # IndexError: too many indices for tensor of dimension 2
但是,张量索引不是这种情况,形状将更大(3,4,4)。
最后,如您所见,这两种索引类型的输出完全不同。要解决您的问题,您可以使用xx = torch.tensor(xx).long() # convert a numpy array to a tensor
高级索引(NUMPY_INDEX>;3行)的情况如何,因为您的情况仍然不明确且未解决,您可以检查1、2、3。
这篇关于尺寸为M<;32的火炬张量分度错误?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持跟版网!
本站部分内容来源互联网,如果有图片或者内容侵犯了您的权益,请联系我们,我们会在确认后第一时间进行删除!