基于鸢尾花卉数据集的Fisher分类器设计

基于鸢尾花卉数据集的Fisher分类器设计

本文主要探讨Iris数据集(二维)的Fisher线性分类器的设计。
数据集下载

1. 预处理

# 导包
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import math # 数学函数
import sympy as sp # 绘图
from sklearn.model_selection import train_test_split # 拆分数据集的工具

# 导入数据
data = pd.read_excel('3-iris 数据集(2类).xls',header=None)
X = np.array(data.iloc[:,2:4]) # 截取两维
y = np.array(data.iloc[:,4])
y_c = np.unique(y) # 离散化数据
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.3,random_state=0)

2. 求取向量均值

np.set_printoptions(precision=4)
mean_vector = []  # 类别的平均值
for i in y_c:
    mean_vector.append(np.mean(X_train[y_train == i], axis=0))

结果

1:
1.445714285714285507e+00
2.457142857142857462e-01

2:
4.265714285714285126e+00
1.342857142857142749e+00

3. 计算类内离散度矩阵

S_W = np.zeros((X_train.shape[1], X_train.shape[1]))
for i in y_c:
    Xi = X_train[y_train == i] - mean_vector[i-1]
    S_W += np.mat(Xi).T * np.mat(Xi)
print('S_W:',S_W)

结果

[[9.6257 3.0683]
 [3.0683 1.9126]]

4. 计算类间离散度矩阵

S_B = np.zeros((X_train.shape[1], X_train.shape[1]))
mu = np.mean(X_train, axis=0)  # 所有样本平均值
for i in y_c:
    Ni = len(X_train[y_train == i])
    S_B += Ni * np.mat(mean_vector[i-1] - mu).T * np.mat(mean_vector[i-1] - mu)
print('S_B:',S_B)

结果

[[139.167   54.144 ]
 [ 54.144   21.0651]]

5. 计算最优投影方向w

w = (np.linalg.inv(S_W)).dot((mean_vector[0]-mean_vector[1]).T)
print(w)

结果

[-0.2253 -0.2121]

6. 计算决策面常数项

三种不同的 P ( w 2 ) / P ( w 1 ) P(w_2)/P(w_1) P(w2​)/P(w1​)

P = [1,3/7,1/9]
w_0 = []
const1 = -0.5*(((mean_vector[0]+mean_vector[1]).dot(np.linalg.inv(S_W))).dot((mean_vector[0]-mean_vector[1]).T))
for i in P:
    w_0.append(const1-math.log(i))
print(w_0)

结果

[0.8120180310967475, 1.6593158914839512, 3.009242608432967]

7. 根据不同的先验概率比绘图

fig,ax = plt.subplots(1,1)
ax.scatter(X_train[y_train == 1][:,0],X_train[y_train == 1][:,1],c='b',label='1')
ax.scatter(X_train[y_train == 2][:,0],X_train[y_train == 2][:,1],c='r',label='2')
x = sp.Symbol('x')
y = sp.Symbol('y')
X = np.array([x,y])
xx,yy = np.linspace(0,10,7),np.linspace(0,10,7)
x,y = np.meshgrid(xx,yy)
ax.contour(x,y,(0.812018031096748 - 0.225346692821696*x - 0.212130544635166*y),[0]);
ax.contour(x,y,(1.6593158914839512 - 0.225346692821696*x - 0.212130544635166*y),[0]);
ax.contour(x,y,(3.009242608432967 - 0.225346692821696*x - 0.212130544635166*y),[0]);
ax.legend()
plt.show()

图示

从下到上先验概率 P ( w 2 ) / P ( w 1 ) P(w_2)/P(w_1) P(w2​)/P(w1​)分别为1,3/7,1/9:
基于鸢尾花卉数据集的Fisher分类器设计

8. 对测试数据集进行计算错误率

scores=0
for i in range(len(X_test[:,0])):
    if ((0.812018031096748 - 0.225346692821696*X_test[i,0] - 0.212130544635166*X_test[i,1] > 0)&(y_test[i]==1))|((0.812018031096748 - 0.225346692821696*X_test[i,0] - 0.212130544635166*X_test[i,1] < 0)&(y_test[i]==2)):
        scores+=1;
print('errorRate:',1-scores/len(X_test[:,0]))

结果

errorRate: 0.0
上一篇:Fisher线性判别分析(二分类)


下一篇:统计学中假设检验有关P值的讨论