如何使用 PyTorch torch.max()
在本文中,我们将了解如何使用 PyTorch torch.max()
函数。
正如大家所料,这是一个非常简单的功能,但有趣的是,它的功能比想象的要多。
让我们通过一些简单的例子来看看如何使用这个函数。
注意
:在撰写本文时,使用的 PyTorch 版本是 PyTorch 1.5.0
PyTorch torch.max() - 基本语法
要使用 PyTorch torch.max()
,首先导入 torch。
import torch
现在,此函数返回 Tensor 中元素的最大值。
PyTorch torch.max() 的默认行为
默认行为是返回单个元素和一个索引,对应于全局最大元素。
max_element = torch.max(input_tensor)
下面是一个例子:
p = torch.randn([2, 3])
print(p)
max_element = torch.max(p)
print(max_element)
输出
tensor([[-0.0665, 2.7976, 0.9753],
[ 0.0688, -1.0376, 1.4443]])
tensor(2.7976)
事实上,这给了我们 Tensor 中的全局最大元素!
沿维度使用 torch.max()
但是,大家可能希望获得沿特定维度的最大值,作为张量,而不是单个元素。
要指定维度(轴 - 在 numpy 中),还有另一个可选的关键字参数,称为 dim
这代表了我们取最大值的方向。
这将返回一个元组 max_elements
和 max_indices
。
- max_elements -> Tensor的所有最大元素。
- max_indices -> 对应于最大元素的索引。
max_elements, max_indices = torch.max(input_tensor, dim)
这将返回一个 Tensor,它具有沿维度 dim
的最大元素。
现在让我们看一些例子。
p = torch.randn([2, 3])
print(p)
# 沿 dim = 0 (axis = 0) 获取最大值
max_elements, max_idxs = torch.max(p, dim=0)
print(max_elements)
print(max_idxs)
输出如下所示
tensor([[-0.0665, 2.7976, 0.9753],
[ 0.0688, -1.0376, 1.4443]])
tensor([0.0688, 2.7976, 1.4443])
tensor([1, 0, 1])
如你所见,我们找到了沿维度 0 的最大值(沿列的最大值)。
此外,我们得到与元素对应的索引。 例如,0.0688 在第 0 列的索引为 1
同样,如果要沿行查找最大值,请使用 dim=1
。
# 沿 dim = 1(axis = 1)获取最大值
max_elements, max_idxs = torch.max(p, dim=1)
print(max_elements)
print(max_idxs)
输出如下所示
tensor([2.7976, 1.4443])
tensor([1, 2])
实际上,我们得到了沿行的最大元素,以及相应的索引(沿行)。
使用 torch.max() 进行比较
我们还可以使用 torch.max() 来获取两个 Tensor 之间的最大值。
output_tensor = torch.max(a, b)
在这里,a 和 b 必须具有相同的维度,或者必须是“可广播的” Tensor。
这是一个比较两个具有相同维度的Tensor的简单示例。
p = torch.randn([2, 3])
q = torch.randn([2, 3])
print("p =", p)
print("q =",q)
# 比较 p 和 q 的元素并得到最大值
max_elements = torch.max(p, q)
print(max_elements)
结果输出如下所示
p = tensor([[-0.0665, 2.7976, 0.9753],
[ 0.0688, -1.0376, 1.4443]])
q = tensor([[-0.0678, 0.2042, 0.8254],
[-0.1530, 0.0581, -0.3694]])
tensor([[-0.0665, 2.7976, 0.9753],
[ 0.0688, 0.0581, 1.4443]])
实际上,我们得到的输出 Tensor 在 p 和 q 之间具有最大元素。
总结
在本文中,我们学习了如何使用 torch.max()
函数来找出 Tensor 的最大元素。
我们还使用这个函数来比较两个 Tensor 并获得其中的最大值。
相关文章
Pandas DataFrame DataFrame.shift() 函数
发布时间:2024/04/24 浏览次数:133 分类:Python
-
DataFrame.shift() 函数是将 DataFrame 的索引按指定的周期数进行移位。
Python pandas.pivot_table() 函数
发布时间:2024/04/24 浏览次数:82 分类:Python
-
Python Pandas pivot_table()函数通过对数据进行汇总,避免了数据的重复。
Pandas read_csv()函数
发布时间:2024/04/24 浏览次数:254 分类:Python
-
Pandas read_csv()函数将指定的逗号分隔值(csv)文件读取到 DataFrame 中。
Pandas 多列合并
发布时间:2024/04/24 浏览次数:628 分类:Python
-
本教程介绍了如何在 Pandas 中使用 DataFrame.merge()方法合并两个 DataFrames。
Pandas loc vs iloc
发布时间:2024/04/24 浏览次数:837 分类:Python
-
本教程介绍了如何使用 Python 中的 loc 和 iloc 从 Pandas DataFrame 中过滤数据。
在 Python 中将 Pandas 系列的日期时间转换为字符串
发布时间:2024/04/24 浏览次数:894 分类:Python
-
了解如何在 Python 中将 Pandas 系列日期时间转换为字符串