概念
支持向量机(Support Vector Machine,缩写SVM)是一种监督式学习方法,广泛应用于统计分类以及回归分析,和逻辑回归同属于线性分类器。SVM计算出的决策边界与正、负样本保持了足够大的距离,因此SVM是一种大间距分类器。
SVM定义的最小化预测代价的过程为与逻辑回归相似,如果将逻辑回归的代价函数简要描述为$cost=A+\lambda B$,那么SVM就是$ cost=CA_1+B_1 $。所以SVM中的参数C可以被认为扮演了逻辑回归中$ \lambda $的角色。
在逻辑回归中,可以通过多项式扩展处理非线性分类问题。而SVM的处理方法是选择一些标记点,将样本与标记点的相似程度作为新的训练特征。距离度量的方式就称为核函数,包括线性核函数、多项式核函数、高斯核函数等。
针对参数C和核函数的选择,给出以下建议:
- 低偏差,高方差,过拟合时:减小C值。
- 高偏差,低方差,欠拟合时:增大C值。
- 当特征维度n较高,而样本规模m较小时,不宜使用核函数,否则容易引起过拟合。
- 当n较低,而m较大时,考虑使用高斯核函数,需进行特征缩放。
线性分类
绘制训练数据和决策边界的函数如下:
def plotData(X, y):
plt.scatter(X[:, 0], X[:, 1], c=y.flatten(), cmap='rainbow')
def plotBoundary(clf, X):
x_min, x_max = X[:, 0].min(), X[:, 0].max()
y_min, y_max = X[:, 1].min(), X[:, 1].max()
X, Y = np.meshgrid(np.linspace(x_min, x_max, 500),
np.linspace(y_min, y_max, 500))
Z = clf.predict(np.c_[X.ravel(), Y.ravel()]) # 预测分类
plt.contour(X, Y, Z.reshape(X.shape)) # 用绘制等高线的方法绘制决策边界
分别令C等于1和100带入模型,生成的决策边界如下图:
svc = svm.LinearSVC(C=100, max_iter=10000)
# svm.LinearSVC(C=1)
svc.fit(X, y.ravel())
plotData(X, y)
plotBoundary(svc, X)
plt.show()
可见当C比较小时,样本离决策边界的距离较大。当C比较大时距离较小,分类较严格。
非线性分类
使用高斯核函数计算非线性分类,结果如下:
clf = svm.SVC(C=1, kernel='rbf', gamma=50) # rbf即为径向基核函数(高斯核函数)
model = clf.fit(X, y.flatten())
plotData(X, y)
plotBoundary(model, X)
plt.show()
参数调优
我们实现一个简单的网格搜索寻找最优的C和$\gamma$参数。
C_values = [0.01, 0.03, 0.1, 0.3, 1, 3, 10, 30, 100]
gamma_values = [0.01, 0.03, 0.1, 0.3, 1, 3, 10, 30, 100]
best_score = 0
best_params = {'C': None, 'gamma': None}
for C in C_values:
for gamma in gamma_values:
svc = svm.SVC(C=C, gamma=gamma)
svc.fit(X, y.ravel())
score = svc.score(Xval, yval)
if score > best_score:
best_score = score
best_params['C'] = C
best_params['gamma'] = gamma
model = svm.SVC(C=best_params['C'], kernel='rbf', gamma=best_params['gamma'])
model.fit(X, y.flatten())
plotData(X, y)
plotBoundary(model, X)
plt.show()
根据最优参数计算出的决策边界如下:
垃圾邮件过滤器
本小节使用SVM建立一个垃圾邮件过滤器。我们需要将每个邮件变成一个n维的特征向量,过滤器负责判断给定邮件是否为垃圾邮件。
首先需要对邮件内容做一些基础处理,仅适用于本例中的英文邮件:
- 邮件中的所有字母转化为小写。
- 移除所有HTML标签。
- 所有URL替换为“httpaddr”。
- 所有地址替换为“emailaddr”。
- 所有“$”符号替换为“dollar”。
- 所有数字替换为“number”。
- 词干提取,所有单词还原为词源。
- 移除所有非文字类型,空格缩减。
词汇表vocab.txt存储了在实际中经常使用的单词。我们要算出处理后的电子邮件中含有多少词汇表中的单词,并得到单词的索引。存在单词的相应位置的值置为1,其余为0。下面利用已经提取好的特征向量以及相应的标签进行测试,最终的预测精度为95.3%。
X = spam_train['X']
X_test = spam_test['Xtest']
y = spam_train['y'].ravel()
y_test = spam_test['ytest'].ravel()
svc = svm.SVC(C=1, gamma='auto')
svc.fit(X, y)
print(np.round(svc.score(X_test, y_test) * 100, 2))