如何使用 PyTorch torch.max()
在本文中,我们将了解如何使用 PyTorch torch.max()
函数。
正如大家所料,这是一个非常简单的功能,但有趣的是,它的功能比想象的要多。
让我们通过一些简单的例子来看看如何使用这个函数。
注意
:在撰写本文时,使用的 PyTorch 版本是 PyTorch 1.5.0
PyTorch torch.max() - 基本语法
要使用 PyTorch torch.max()
,首先导入 torch。
import torch
现在,此函数返回 Tensor 中元素的最大值。
PyTorch torch.max() 的默认行为
默认行为是返回单个元素和一个索引,对应于全局最大元素。
max_element = torch.max(input_tensor)
下面是一个例子:
p = torch.randn([2, 3])
print(p)
max_element = torch.max(p)
print(max_element)
输出
tensor([[-0.0665, 2.7976, 0.9753],
[ 0.0688, -1.0376, 1.4443]])
tensor(2.7976)
事实上,这给了我们 Tensor 中的全局最大元素!
沿维度使用 torch.max()
但是,大家可能希望获得沿特定维度的最大值,作为张量,而不是单个元素。
要指定维度(轴 - 在 numpy 中),还有另一个可选的关键字参数,称为 dim
这代表了我们取最大值的方向。
这将返回一个元组 max_elements
和 max_indices
。
- max_elements -> Tensor的所有最大元素。
- max_indices -> 对应于最大元素的索引。
max_elements, max_indices = torch.max(input_tensor, dim)
这将返回一个 Tensor,它具有沿维度 dim
的最大元素。
现在让我们看一些例子。
p = torch.randn([2, 3])
print(p)
# 沿 dim = 0 (axis = 0) 获取最大值
max_elements, max_idxs = torch.max(p, dim=0)
print(max_elements)
print(max_idxs)
输出如下所示
tensor([[-0.0665, 2.7976, 0.9753],
[ 0.0688, -1.0376, 1.4443]])
tensor([0.0688, 2.7976, 1.4443])
tensor([1, 0, 1])
如你所见,我们找到了沿维度 0 的最大值(沿列的最大值)。
此外,我们得到与元素对应的索引。 例如,0.0688 在第 0 列的索引为 1
同样,如果要沿行查找最大值,请使用 dim=1
。
# 沿 dim = 1(axis = 1)获取最大值
max_elements, max_idxs = torch.max(p, dim=1)
print(max_elements)
print(max_idxs)
输出如下所示
tensor([2.7976, 1.4443])
tensor([1, 2])
实际上,我们得到了沿行的最大元素,以及相应的索引(沿行)。
使用 torch.max() 进行比较
我们还可以使用 torch.max() 来获取两个 Tensor 之间的最大值。
output_tensor = torch.max(a, b)
在这里,a 和 b 必须具有相同的维度,或者必须是“可广播的” Tensor。
这是一个比较两个具有相同维度的Tensor的简单示例。
p = torch.randn([2, 3])
q = torch.randn([2, 3])
print("p =", p)
print("q =",q)
# 比较 p 和 q 的元素并得到最大值
max_elements = torch.max(p, q)
print(max_elements)
结果输出如下所示
p = tensor([[-0.0665, 2.7976, 0.9753],
[ 0.0688, -1.0376, 1.4443]])
q = tensor([[-0.0678, 0.2042, 0.8254],
[-0.1530, 0.0581, -0.3694]])
tensor([[-0.0665, 2.7976, 0.9753],
[ 0.0688, 0.0581, 1.4443]])
实际上,我们得到的输出 Tensor 在 p 和 q 之间具有最大元素。
总结
在本文中,我们学习了如何使用 torch.max()
函数来找出 Tensor 的最大元素。
我们还使用这个函数来比较两个 Tensor 并获得其中的最大值。
相关文章
Python 中的第一类函数
发布时间:2023/04/25 浏览次数:113 分类:Python
-
第一类函数是被语言视为对象或变量的函数。 我们可以将它们分配给变量或将它们作为对象传递给其他函数。Python 支持第一类函数的功能。
Python 函数参数类型
发布时间:2023/04/25 浏览次数:140 分类:Python
-
在这篇 Python 文章中,我们将学习 Python 中使用的函数参数类型。 我们还将学习如何编写不带参数的 Python 函数。
Python 生成器中的 send 函数
发布时间:2023/04/25 浏览次数:111 分类:Python
-
本教程将介绍如何在 Python 中使用生成器的 send() 函数。我们可以创建一个像迭代器一样运行的函数,并且可以通过 Python 生成器函数在 for 循环中使用。
Python Functools 偏函数
发布时间:2023/04/25 浏览次数:80 分类:Python
-
本文介绍了我们如何使用分部函数,该函数随 functools 库一起提供,并附有示例。 这显示了调用时如何传递属性和部分函数。
Python main() 函数中的参数
发布时间:2023/04/25 浏览次数:157 分类:Python
-
在本教程结束时,我们应该了解Python 中在 main() 中使用参数是否是一种好的做法。
Python 中的内置 identity 函数
发布时间:2023/04/25 浏览次数:88 分类:Python
-
identity 函数只是一个返回其参数的函数。 当我们定义一个恒等函数并赋值时,它会返回该值。在本教程结束时,我们将了解 Python 是否具有内置的 identity 函数。
在 Python 中拟合阶跃函数
发布时间:2023/04/25 浏览次数:177 分类:Python
-
阶跃函数是带有看起来像一系列步骤的图形的方法。 它们由一系列中间有间隔的水平线段组成,也可以称为阶梯函数。本文给出了阶跃函数的简单演示。
在 Python 中创建双向链表
发布时间:2023/04/25 浏览次数:54 分类:Python
-
双向链表是指由称为节点的顺序链接的记录集组成的链接数据结构。 每个节点包含一个前一个指针、一个下一个指针和一个数据字段。
将 Python 类对象序列化为 JSON
发布时间:2023/04/25 浏览次数:152 分类:Python
-
本教程介绍序列化过程。 它还说明了我们如何使用 toJSON() 方法使 JSON 类可序列化,并包装 JSON 以转储到其类中。