在 Python 中计算马氏距离
本教程将介绍在 Python 中求两个 NumPy 数组之间的马氏距离的方法。
使用 Python 中 scipy.spatial.distance
库中的 cdist()
函数计算马氏距离
马氏距离是点与分布之间距离的度量。如果我们想找到两个数组之间的马氏距离,我们可以使用 Python 中 scipy.spatial.distance
库中的 cdist()
函数。cdist()
函数 计算两个集合之间的距离。我们可以在输入参数中指定 mahalanobis
来查找 Mahalanobis 距离。请参考以下代码示例。
import numpy as np
from scipy.spatial.distance import cdist
x = np.array([[[1, 2, 3], [3, 4, 5], [5, 6, 7]], [[5, 6, 7], [7, 8, 9], [9, 0, 1]]])
i, j, k = x.shape
xx = x.reshape(i, j * k).T
y = np.array([[[8, 7, 6], [6, 5, 4], [4, 3, 2]], [[4, 3, 2], [2, 1, 0], [0, 1, 2]]])
yy = y.reshape(i, j * k).T
results = cdist(xx, yy, "mahalanobis")
results = np.diag(results)
print(results)
输出:
[3.63263583 2.59094773 1.97370848 1.97370848 2.177978 3.04256456
3.04256456 1.54080605 2.58298363]
我们使用上述代码中的 cdist()
函数计算并存储了数组 x
和 y
之间的马氏距离。我们首先使用 np.array()
函数创建了两个数组。然后我们重新调整两个数组的形状并将转置保存在新数组 xx
和 yy
中。然后我们将这些新数组传递给 cdist()
函数,并在参数中使用 cdist(xx,yy,'mahalanobis')
指定 mahalanobis
。
在 Python 中使用 numpy.einsum()
方法计算马氏距离
我们还可以使用 numpy.einsum()
方法 计算两个数组之间的马氏距离。numpy.einsum()
方法用于评估输入参数的爱因斯坦求和约定。
import numpy as np
x = np.array([[[1, 2, 3], [3, 4, 5], [5, 6, 7]], [[5, 6, 7], [7, 8, 9], [9, 0, 1]]])
i, j, k = x.shape
xx = x.reshape(i, j * k).T
y = np.array([[[8, 7, 6], [6, 5, 4], [4, 3, 2]], [[4, 3, 2], [2, 1, 0], [0, 1, 2]]])
yy = y.reshape(i, j * k).T
X = np.vstack([xx, yy])
V = np.cov(X.T)
VI = np.linalg.inv(V)
delta = xx - yy
results = np.sqrt(np.einsum("nj,jk,nk->n", delta, VI, delta))
print(results)
输出:
[3.63263583 2.59094773 1.97370848 1.97370848 2.177978 3.04256456
3.04256456 1.54080605 2.58298363]
我们将数组传递给 np.vstack()
函数并将值存储在 X
中。之后,我们将 X
的转置传递给 np.cov()
函数并将结果存储在 V
中。然后我们计算了矩阵 V
的乘法逆矩阵,并将结果存储在 VI
中。我们计算了 xx
和 yy
之间的差异,并将结果存储在 delta
中。最后,我们使用 results = np.sqrt(np.einsum('nj,jk,nk->n', delta, VI, delta))
计算并存储了 x
和 y
之间的马氏距离。
相关文章
Pandas DataFrame DataFrame.shift() 函数
发布时间:2024/04/24 浏览次数:133 分类:Python
-
DataFrame.shift() 函数是将 DataFrame 的索引按指定的周期数进行移位。
Python pandas.pivot_table() 函数
发布时间:2024/04/24 浏览次数:82 分类:Python
-
Python Pandas pivot_table()函数通过对数据进行汇总,避免了数据的重复。
Pandas read_csv()函数
发布时间:2024/04/24 浏览次数:254 分类:Python
-
Pandas read_csv()函数将指定的逗号分隔值(csv)文件读取到 DataFrame 中。
Pandas 多列合并
发布时间:2024/04/24 浏览次数:628 分类:Python
-
本教程介绍了如何在 Pandas 中使用 DataFrame.merge()方法合并两个 DataFrames。
Pandas loc vs iloc
发布时间:2024/04/24 浏览次数:837 分类:Python
-
本教程介绍了如何使用 Python 中的 loc 和 iloc 从 Pandas DataFrame 中过滤数据。
在 Python 中将 Pandas 系列的日期时间转换为字符串
发布时间:2024/04/24 浏览次数:894 分类:Python
-
了解如何在 Python 中将 Pandas 系列日期时间转换为字符串