张量在内存中的存储(reshape/permute操作理解)
在阅读flux代码的时候,看到这段处理latent的代码有些懵逼,追根溯源就是自己对于pytorch Tensor数据组织的方式理解不透彻,因此写下这篇博客开云破雾:
1 |
|
在底层实现中,PyTorch 中的张量(Tensor)实际上是以一维的连续内存块存储的,只是通过不同的 stride 来控制数据在内存中的访问顺序。我习惯从右往左去看待张量,因此我说的层的顺序是从右往左的
1 |
|
- stride[2] = 1(最里层的stride都是1)代表最里层每个元素是紧邻的,例如0马上就接着1然后是2:
0,1,2,3,… - stride[1] = A.shape[2] = 4,代表每相邻4个元素是一组 ,例如0,1,2,3四个元素是一组,此时4,5,6,7又是新的一组:
[0,1,2,3], [4,5,6,7], … - stride[0] = A.shape[2]*A.shape[1] = 3*4,代表每相邻12个元素是一个更大的组,因此0-11这12个元素会形成最外层的大组:
[ [0,1,2,3], [4,5,6,7], [8,9,10,11] ], …
由此总结stride的规律:
- 对于有N个维度的Tensor A, $$ stride[i] = \prod_{k=i+1}^{N-1} A.shape[k] \text{ for } i < N-1 \quad\text{and}\quad stride[N-1]=1$$
我们可以把print出来的Tensor从上到下,从左往右展平成一维向量, 他的张量视图就是根据上述得到的。
理解了Tensor数据的组织方法,最重要的作用就是可以推导出reshape(),view()这俩操作后新的数据视图(注意没有permute)。
1 |
|
上述的讨论其实就是view()和reshape()操作的原理,因此对于只含view和reshape的操作,不管中间维度怎么变换,只要输入相同,且输出的维度相同最后的结果就是一样的
对于permute() 操作,他的原理和前两个不同,举一个例子:
1 |
|
现在我们彻底理解了Tensor数据组织和形状变换的原理,让我们重新回到开头flux的源码部分进行解读:
1 |
|
相关博客:https://blog.csdn.net/qq_43391414/article/details/120798955
张量在内存中的存储(reshape/permute操作理解)
https://deadsmither5.github.io/2024/12/28/张量在内存中的存储/