51工具盒子

依楼听风雨
笑看云卷云舒,淡观潮起潮落

SVM 决策函数(Decision Function)

这篇文章我想讲解的是 scikit-learn 中 SVC 的二分类、多分类场景下 ovo、ovr 决策函数的计算过程,以了解 SVC 进行推理时的逻辑。从而加深对 SVC 的理解。

决策函数公式得到决策值之后,直接判断符号,可得出类别标签。

  1. 二分类 SVC 的决策函数 {#title-0} ===========================
from sklearn.svm import SVC
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.datasets import make_classification
import numpy as np

def test(): inputs, labels = make_classification(n_samples=1000, n_features=20, n_informative=15, n_classes=2, random_state=42) estimator = SVC(kernel='rbf', gamma=0.3) estimator.fit(inputs, labels)

sample = [inputs[0]]

支持向量

svecs = estimator.support_vectors_

对偶系数

dcoef = estimator.dual_coef_

截距

icept = estimator.intercept_

参数

gamma = estimator.gamma

print('模型对偶系数:', dcoef.shape) print('模型截距参数:', icept.shape) print('模型类别列表:', estimator.classes_)

手动计算决策值

v1 = dcoef @ rbf_kernel(sample, svecs, gamma=gamma).T + icept print(v1.squeeze())

模型计算决策值

v2 = estimator.decision_function(sample) print(v2.squeeze())

if name == 'main': test()

  1. 多分类 SVC OVO 决策函数 {#title-1} ==============================

在 SVC 多分类的场景下,计算 ovo 决策结果,需要用到以下训练得到的属性值:

  1. n_support_ 表示每个分类支持向量的数量
  2. support_vectors_ 所有类别的支持向量
  3. dual_coef_ 这是是 SVC 训练得到的最重要的对偶系数,它实际等于 α * y,即:每一个样本的支持向量的朗格朗日乘子乘以该样本的标签(-1 或者 +1)
  4. intercept_ 是 ovo 中每一个模型的截距,对于 10 类别而言,会训练 (10 * (10 -1)) / 2 = 45 个二分类器,该属性中存储了这 45 个二分类器训练得到的截距。对应的顺序为:01 02 03 ...09 12 13 14 15... 19..23...

这个计算过程理解的重点是对偶系数,其结构含义如下(截图官网文档):

对偶系数矩阵的理解可能较为复杂一些,但是它是理解 ovo 决策值如何计算的极其重要的一部分。

from sklearn.svm import SVC
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.datasets import make_classification
import numpy as np

def test(): inputs, labels = make_classification(n_samples=100, n_features=20, n_informative=15, n_classes=10, random_state=42) estimator = SVC(kernel='rbf', gamma=0.3) estimator.fit(inputs, labels)

sample = [inputs[0]]

支持向量

svecs = estimator.support_vectors_

每个类别支持向量数量

nsupt = np.cumsum([0] + estimator.n_support_.tolist())

对偶系数

dcoef = estimator.dual_coef_

截距

icept = estimator.intercept_

参数

gamma = estimator.gamma

print('模型对偶系数:', dcoef.shape) print('模型截距参数:', icept.shape) print('模型类别列表:', estimator.classes_) print('类别支持数量:', nsupt)

手动计算决策值

为截距生成索引

intercept_indexes = [f'{i}{j}' for i in range(10) for j in range(i + 1, 10)]

计算输入样本与所有支持向量的加权相似度

scores = dcoef * rbf_kernel(sample, svecs, gamma=gamma)

计算每个类别的支持向量与其他类别的支持向量的相似度分数

class_scores = [] for s, e in zip(nsupt[:-1], nsupt[1:]): class_scores.append(scores[:, s:e].sum(axis=-1))

将分数展开

class_scores = np.array(class_scores) class_scores = class_scores.ravel() class_score_indexes = [f'{i}{j}' for i in range(10) for j in range(10) if i != j]

v1 = [] for flag in intercept_indexes: a_index = class_score_indexes.index(flag) b_index = class_score_indexes.index(flag[::-1]) c_index = intercept_indexes.index(flag) v1. append(class_scores[a_index] + class_scores[b_index] + icept[c_index]) print(np.array(v1))

模型计算决策值

estimator.decision_function_shape = 'ovo' v2 = estimator.decision_function(sample) print(v2.squeeze())

print('手动计算和API计算结果:', np.all(np.array(v1) == v2))

if name == 'main': test()

模型对偶系数: (9, 100)
模型截距参数: (45,)
模型类别列表: [0 1 2 3 4 5 6 7 8 9]
类别支持数量: [  0  10  20  30  41  51  60  69  79  89 100]
[ 4.26630481e-04  1.15984298e-11 -9.09090910e-02  1.45700034e-09
  1.00000000e-01  1.00000000e-01  3.20884059e-10 -1.00000000e+00
 -9.09094941e-02 -3.93785796e-04 -9.09090910e-02 -3.65702293e-04
  9.95395590e-02  9.95395438e-02 -3.93785411e-04 -1.00039376e+00
 -9.09093704e-02 -9.09090918e-02  1.54111940e-09  1.00000000e-01
  1.00000000e-01  3.12410279e-10 -1.00000000e+00 -9.09091250e-02
  9.09090909e-02  1.81818182e-01  1.81818182e-01  9.09090922e-02
 -9.09090904e-01  4.36981889e-09  9.99999994e-02  9.99999983e-02
  1.11681132e-10 -1.00000000e+00 -9.09090851e-02 -4.46671502e-09
 -1.00000006e-01 -9.99414063e-01 -1.81818177e-01 -9.99999995e-02
 -9.99414063e-01 -1.81818177e-01 -1.00000000e+00 -9.09090851e-02
  9.09090915e-01]
[ 4.26630481e-04  1.15984298e-11 -9.09090910e-02  1.45700034e-09
  1.00000000e-01  1.00000000e-01  3.20884059e-10 -1.00000000e+00
 -9.09094941e-02 -3.93785796e-04 -9.09090910e-02 -3.65702293e-04
  9.95395590e-02  9.95395438e-02 -3.93785411e-04 -1.00039376e+00
 -9.09093704e-02 -9.09090918e-02  1.54111940e-09  1.00000000e-01
  1.00000000e-01  3.12410279e-10 -1.00000000e+00 -9.09091250e-02
  9.09090909e-02  1.81818182e-01  1.81818182e-01  9.09090922e-02
 -9.09090904e-01  4.36981889e-09  9.99999994e-02  9.99999983e-02
  1.11681132e-10 -1.00000000e+00 -9.09090851e-02 -4.46671502e-09
 -1.00000006e-01 -9.99414063e-01 -1.81818177e-01 -9.99999995e-02
 -9.99414063e-01 -1.81818177e-01 -1.00000000e+00 -9.09090851e-02
  9.09090915e-01]
手动计算和API计算结果: True
  1. 多分类 SVC OVR 决策函数 {#title-2} ==============================

在 SVC 中,当我们把 decision_function_shape 设置 ovr 时,实际内部仍然会先计算 ovo 决策值,然后再由 ovo 的决策值转换为 ovr 的决策值。ovr 中,每一个类别都对应了预测分数,我们最后将其归为分数最大的类别标签即可。

from sklearn.svm import SVC
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.datasets import make_classification
import numpy as np

def test(): inputs, labels = make_classification(n_samples=100, n_features=20, n_informative=15, n_classes=10, random_state=42) estimator = SVC(kernel='rbf', gamma=0.3) estimator.fit(inputs, labels)

sample = [inputs[0]]

支持向量

svecs = estimator.support_vectors_

每个类别支持向量数量

nsupt = np.cumsum([0] + estimator.n_support_.tolist())

对偶系数

dcoef = estimator.dual_coef_

截距

icept = estimator.intercept_

参数

gamma = estimator.gamma

print('模型对偶系数:', dcoef.shape) print('模型截距参数:', icept.shape) print('模型类别列表:', estimator.classes_) print('类别支持数量:', nsupt)

手动计算决策值

estimator.decision_function_shape = 'ovo' ovo = estimator.decision_function(sample)

ovr 分数是由 ovo 转换得到

from sklearn.utils.multiclass import _ovr_decision_function

第一个参数:每一个分类器预测的类别

第二个参数:预测为 +1 类别的分数或概率

第三个参数:类别的数量

v1 = ovr_decision_function(ovo < 0, -ovo, len(estimator.classes)) print(v1.squeeze())

模型计算决策值

estimator.decision_function_shape = 'ovr' v2 = estimator.decision_function(sample) print(v2.squeeze())

if name == 'main': test()

模型对偶系数: (9, 100)
模型截距参数: (45,)
模型类别列表: [0 1 2 3 4 5 6 7 8 9]
类别支持数量: [  0  10  20  30  41  51  60  69  79  89 100]
[ 5.83489857  1.83461706  4.83489581  7.97222223  3.83489343 -0.21688867
  0.78311133  2.83489581  9.29938003  6.97222241]
[ 5.83489857  1.83461706  4.83489581  7.97222223  3.83489343 -0.21688867
  0.78311133  2.83489581  9.29938003  6.97222241]
赞(3)
未经允许不得转载:工具盒子 » SVM 决策函数(Decision Function)