迹忆客 专注技术分享

当前位置:主页 > 学无止境 > 编程语言 > MATLAB >

在 MATLAB 中查找 K 最近邻

作者:迹忆客 最近更新:2023/04/22 浏览次数:

本教程将讨论使用 MATLAB 中的 knnsearch() 函数查找 k 最近邻。


在 MATLAB 中使用 knnsearch() 查找 K 最近邻

KNN,也称为k近邻,是一种分类算法,用于寻找数据集中某个点的k近邻。 例如,如果我们有一个包含医院病人数据的数据集,我们想找到一个可以猜出年龄和体重的人。

我们可以在KNN算法中传入医院所有在场病人的年龄和体重,以及我们猜测的要找人的年龄和体重,它会返回与未知人最接近的病人数据。 我们可以使用 MATLAB 的 knnsearch() 函数来完成上述任务。

我们必须将已知患者的年龄和体重作为 knnsearch() 函数中的第一个参数传递,并将未知人员的年龄和体重作为第二个参数传递。 该函数将从最接近我们未知人物的数据集中返回索引或行号。

例如,让我们使用存储在 MATLAB 中的医院数据集,根据年龄和体重搜索一个未知的人。 请参阅下面的代码。

clc

load hospital;
X_data = [hospital.Age hospital.Weight];
Y_data = [30 162];
Ind = knnsearch(X_data,Y_data);
hospital(Ind,:)

输出:

ans =

               LastName             Sex     Age    Weight    Smoker    BloodPressure      Trials
    HLE-603    {'HERNANDEZ'}        Male    36     166       false     120          83    {1×2 double}

上述代码中,医院数据集包含了100位患者的姓名、性别、年龄、体重、血压、吸烟信息。 要查看数据集的内容,我们可以通过在工作区窗口中双击它来打开它。

在这个例子中,我们只使用了年龄和体重参数,因为我们只知道关于未知人的这些信息,但我们也可以使用其他参数。

KNN 算法只返回一个最近邻,但我们也可以使用 K 参数设置最近邻的数量并定义最近邻的数量。

我们还可以使用 NSMethod 参数设置用于查找最近邻居的方法,然后定义方法名称,如 euclidean、cityblock 或 chebyshev。

我们还可以更改用于查找点之间距离的方法,默认情况下使用 Distance 参数将其设置为欧几里得; 之后,我们可以定义方法的名称,如 seuclidean、cosine 或 cityblock。

默认情况下,KNN 算法使用 50 个点作为叶节点,但我们也可以使用 BucketSize 参数更改它并传递点数。 KNN 算法根据给定的数据进行聚类,如果我们增加桶的大小,就会有更少的聚类和更多的点。

knnseach() 函数返回的索引默认排序。 但是,我们也可以通过使用 SortIndices 参数关闭排序过程来获得索引的原始顺序; 之后,我们需要传递 false。

例如,让我们更改上面讨论的属性并查看结果。 请参阅下面的代码。

clc

load hospital;
X_data = [hospital.Age hospital.Weight hospital.Smoker];
Y_data = [30 162 true];
Ind = knnsearch(X_data,Y_data,'K',2,'NSMethod','euclidean','Distance','chebychev','SortIndices',false);
hospital(Ind,:)

输出:

ans =

               LastName             Sex
    HLE-603    {'HERNANDEZ'}        Male
    VRH-620    {'MITCHELL' }        Male


               Age    Weight    Smoker
    HLE-603    36     166       false
    VRH-620    39     164       true


               BloodPressure
    HLE-603    120          83
    VRH-620    128          92


               Trials
    HLE-603    {1×2 double}
    VRH-620    {1×0 double}

在上面的代码中,我们从数据集中包含了另一个参数 Smoker,考虑我们是否知道未知的人是吸烟者。 我们可以在输出中看到,现在有两个患者接近未知人的数据。

在上面的例子中,我们只检查了一个人的最近邻居,但我们也可以找到多人的最近邻居。 上述属性可能会根据数据集更改结果。

knnsearch() 函数查找 k 个最近的点,但是如果我们想找到距给定点特定距离内的所有最近点,我们可以使用 MATLAB 中的 rangesearch() 函数。 检查此链接以获取有关 rangesearch() 函数的更多详细信息。

使用 knnsearch() 函数的问题在于,根据运行代码的机器,在大型数据集中需要花费一些时间。 但在机器学习中,我们希望我们的代码非常快,所以我们将这个过程分为训练和测试。

我们在训练过程中在给定的数据集上训练一个模型,这需要一些时间。 我们保存训练好的模型,当我们想从输入预测输出时,我们可以使用预训练模型在几秒钟内预测结果。

要使用 KNN 分类器训练模型,我们可以使用 fitcknn() 函数训练模型,然后我们可以使用 predict() 函数预测新输入的输出。

例如,让我们使用花数据集使用 KNN 分类器训练模型,然后使用 predict() 函数预测花类。 请参阅下面的代码。

clc

load fisheriris
X_data = meas;
Y_data = species;
MyModel = fitcknn(X_data,Y_data,'NumNeighbors',6,'Standardize',1)
X_new = 1;
class_name = predict(MyModel,X_new)

输出:

MyModel =

  ClassificationKNN
             ResponseName: 'Y'
    CategoricalPredictors: []
               ClassNames: {'setosa'  'versicolor'  'virginica'}
           ScoreTransform: 'none'
          NumObservations: 150
                 Distance: 'euclidean'
             NumNeighbors: 6


  Properties, Methods


class_name =

  1×1 cell array

    {'versicolor'}

X_data 包含上述代码中 150 朵鸢尾花的花瓣测量值,Y_data 包含这 150 朵鸢尾花对应的鸢尾花或类名。 正如我们在输出中看到的,该模型包括三个类名和每个类的 150 个观测值,用于查找距离的方法是欧几里得,邻居数为 6。

我们使用 predict() 函数使用新观察值预测类名,即 1。我们还可以通过创建观察值的列向量来使用多个观察值。

我们还可以更改用于查找最近邻居的方法、用于查找点之间距离的方法以及桶大小,就像我们在 knnsearch() 函数中更改它们的方式一样。

我们还可以从 predict() 函数获得另外两个输出:预测分数和类名的预期成本。 我们还可以使用 save 命令保存我们训练的模型,并在 MATLAB 中使用 load 命令随时加载它。

save 命令将创建一个 .mat 文件,其中包含我们在 MATLAB 当前目录中训练的模型,如果我们想将其加载回来,.mat 文件应该存在于 MATLAB 使用的当前目录中。

保存和加载命令的基本语法如下。

save model_name
load model_name

fitcknn() 函数的第一个输入参数是一个包含观察结果的表,第二个参数包含我们要预测的类名。 它应该是分类、字符串、逻辑、数字、元胞或字符数组。

转载请发邮件至 1244347461@qq.com 进行申请,经作者同意之后,转载请以链接形式注明出处

本文地址:

相关文章

如何在 Matplotlib Pyplot 中显示网格

发布时间:2024/02/04 浏览次数:142 分类:Python

本文演示了如何在 Python Matplotlib 中在一个图上画一个网格。使用 grid()函数来绘制网格,并解释了如何改变网格颜色和线条类型。

如何在 Matplotlib 中画一条任意线

发布时间:2024/02/04 浏览次数:166 分类:Python

本教程讲解了我们如何在 Matplotlib 中使用 matplotlib.pyplot.plot()、matplotlib.pyplot.vlines()、matplotlib.pyplot.hlines()方法和 matplotlib.collection.LineCollection 绘制任意线条。

Matplotlib 中的叠加条形图

发布时间:2024/02/04 浏览次数:182 分类:Python

本教程展示了如何使用 plt.bar()方法将某些数据集的条形图堆叠在另一个数据集上。我们在 Matplotlib 中使用 matplotlib.pyplot.bar()方法生成条形图。

设置 Matplotlib 网格间隔

发布时间:2024/02/04 浏览次数:250 分类:Python

本教程将介绍我们如何在 Matplotlib 绘图中设置网格间距,并对主要网格和次要网格应用不同的样式。

扫一扫阅读全部技术教程

社交账号
  • https://www.github.com/onmpw
  • qq:1244347461

最新推荐

教程更新

热门标签

扫码一下
查看教程更方便