- AI嵌入式系统:算法优化与实现
- 应忍冬 刘佩林编著
- 1317字
- 2021-11-12 17:44:38
3.1 高斯朴素贝叶斯分类器
3.1.1 原理概述
朴素贝叶斯算法是相对简单的机器学习算法,它使用贝叶斯公式,从多个统计独立的观测量计算某个随机变量Y的取不同值的可能性(注意,我们这里用大写字母表示随机变量,用小写字母表示它的具体取值)。下面通过一个例子对高斯朴素贝叶斯模型进行简要说明,关于这一模型的详细理论分析可以参考本章参考文献部分。
考虑鸢尾花的分类问题,这种花有三种类型,分别为Setosa、Versicolour和Virginica,可以用取值分别为0,1,2的变量Y表示。我们希望通过测量鸢尾花的花萼长度X来区分它的类别,这可以通过计算条件概率实现:
上述概率代表了测得一朵鸢尾花的花萼长度为x时,它属于类别y的可能性。上述概率可以用贝叶斯公式表示,即
其中概率表示对于类型是y的鸢尾花,花萼长度为x的可能性。我们通常用高斯分布来描述这个条件概率,即
即X是由Y的取值y决定的高斯随机变量,X的均值和方差分别为和。
上面的例子是从一个观测量X计算出鸢尾花属于不同类别的概率。如果有多个不同的观测量,就能够更精确地判别鸢尾花的类别。我们可以测量鸢尾花的花萼长度、花萼宽度、花瓣长度、花瓣宽度这4个属性的具体数值,分别用表示它们,我们进一步假设这4个属性相互独立(统计独立),于是可以得到从这些观测量计算Y的条件概率,即
上述观测量相互独立以及高斯分布的模型就是“高斯朴素贝叶斯模型”。
3.1.2 模型训练和推理
下面我们基于Python的机器学习软件包Scikit-Learn说明如何训练高斯朴素贝叶斯模型。这里不会涉及模型训练的数学解释,仅仅是介绍训练所使用的Python代码。
我们还是以鸢尾花卉分类问题为例。Fisher于1936年收集整理了三种鸢尾花的花萼长度、花萼宽度、花瓣长度、花瓣宽度的测量值,这些数据能够从Scikit-Learn中直接获得。数据包括了3个类别共150朵鸢尾花测量数据,每朵花的测量值包括4个数值,每个数值对应前面所给出的一个属性。Python程序通过Scikit-Learn库的API读取鸢尾花数据,具体代码如下:
from sklearn import datasets iris = datasets.load_iris()
运行之后变量iris中就存储了150朵鸢尾花测量数据和对应的花的类型数据。通过下面的命令能够分别打印出对每一朵鸢尾花的测量结果。
print(iris.data) print(iris.target)
iris.data是尺寸为150×4的矩阵,每一行对应一朵花的测量数据,iris.target是存放了150个整数元素的数组,其中元素取值0、1、2分别对应Setosa、Versicolor、Virginica这三种类型。
下面是iris.data和iris.target的数据内容片段:
iris.data: [[6.4 2.9 4.3 1.3] [6.5 3. 5.5 1.8] [5. 2.3 3.3 1. ] [6.3 3.3 6. 2.5] [5.5 2.5 4. 1.3] [5.4 3.7 1.5 0.2] … [6.7 3.1 5.6 2.4] [4.9 3.6 1.4 0.1]] iris.target: [1 2 1 2 1 0 … 1 2]
下面的代码利用加载的iris数据进行训练,得到高斯朴素贝叶斯模型:
from sklearn.naive_bayes import GaussianNB model = GaussianNB() # 构建高斯朴素贝叶斯模型 model.fit(iris.data, iris.target)
高斯朴素贝叶斯模型参数存储在变量model内,其中高斯分布的方差存放在model.sigma_中,而高斯分布的均值存放在model.theta_中。
完成模型训练后,使用下面的代码实现模型的推理,即对类别未知的数据进行分类:
y_pred = model.predict(new_data)
其中new_data是存放需要分类的花的测量数据,每一行对应一朵花的4个测量值,程序中y_pred是列向量,它的元素对应了new_data中对应行的鸢尾花分类结果。
注意,在上述训练过程中,先验概率是从训练数据中统计得到的(用每种类别在训练数据集中出现的比例作为先验概率的估计值),提供的iris训练数据中三类花的数量相同,因此先验概率。如果需要使用其他先验概率,那么可以在构建模型的时候提供先验数据作为输入参数,即
model= GaussianNB(priors)
上面代码中priors是用户提供的三类花的先验概率数组。