脚本之家,脚本语言编程技术及教程分享平台!
分类导航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服务器之家 - 脚本之家 - Python - PyTorch中permute的用法详解

PyTorch中permute的用法详解

2021-09-13 00:41一只想入门却未入门的程 Python

今天小编就为大家分享一篇PyTorch中permute的用法详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

?
1
permute(dims)

将tensor的维度换位。

参数:参数是一系列的整数,代表原来张量的维度。比如三维就有0,1,2这些dimension。

例:

?
1
2
3
4
5
6
7
import torch
import numpy as np
a=np.array([[[1,2,3],[4,5,6]]])
unpermuted=torch.tensor(a)
print(unpermuted.size()) # ——> torch.Size([1, 2, 3])
permuted=unpermuted.permute(2,0,1)
print(permuted.size()) # ——> torch.Size([3, 1, 2])

再比如图片img的size比如是(28,28,3)就可以利用img.permute(2,0,1)得到一个size为(3,28,28)的tensor。

利用这个函数permute(1,3,2)可以把Tensor([[[1,2,3],[4,5,6]]]) 转换成

?
1
2
3
tensor([[[1., 4.],
[2., 5.],
[3., 6.]]])

如果使用view(1,3,2),可以得到

?
1
2
3
tensor([[[1., 2.],
[3., 4.],
[5., 6.]]])

以上这篇PyTorch中permute的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。

原文链接:https://blog.csdn.net/qq_40231500/article/details/90606872

延伸 · 阅读

精彩推荐