这篇文章我想讲解的是 scikit-learn 中 SVC 的二分类、多分类场景下 ovo、ovr 决策函数的计算过程,以了解 SVC 进行推理时的逻辑。从而加深对 SVC 的理解。
决策函数公式得到决策值之后,直接判断符号,可得出类别标签。
- 二分类 SVC 的决策函数 {#title-0} ===========================
from sklearn.svm import SVC
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.datasets import make_classification
import numpy as np
def test():
inputs, labels = make_classification(n_samples=1000, n_features=20, n_informative=15, n_classes=2, random_state=42)
estimator = SVC(kernel='rbf', gamma=0.3)
estimator.fit(inputs, labels)
sample = [inputs[0]]
支持向量
svecs = estimator.support_vectors_
对偶系数
dcoef = estimator.dual_coef_
截距
icept = estimator.intercept_
参数
gamma = estimator.gamma
print('模型对偶系数:', dcoef.shape)
print('模型截距参数:', icept.shape)
print('模型类别列表:', estimator.classes_)
手动计算决策值
v1 = dcoef @ rbf_kernel(sample, svecs, gamma=gamma).T + icept
print(v1.squeeze())
模型计算决策值
v2 = estimator.decision_function(sample)
print(v2.squeeze())
if name == 'main':
test()
- 多分类 SVC OVO 决策函数 {#title-1} ==============================
在 SVC 多分类的场景下,计算 ovo 决策结果,需要用到以下训练得到的属性值:
- n_support_ 表示每个分类支持向量的数量
- support_vectors_ 所有类别的支持向量
- dual_coef_ 这是是 SVC 训练得到的最重要的对偶系数,它实际等于 α * y,即:每一个样本的支持向量的朗格朗日乘子乘以该样本的标签(-1 或者 +1)
- intercept_ 是 ovo 中每一个模型的截距,对于 10 类别而言,会训练 (10 * (10 -1)) / 2 = 45 个二分类器,该属性中存储了这 45 个二分类器训练得到的截距。对应的顺序为:01 02 03 ...09 12 13 14 15... 19..23...
这个计算过程理解的重点是对偶系数,其结构含义如下(截图官网文档):
对偶系数矩阵的理解可能较为复杂一些,但是它是理解 ovo 决策值如何计算的极其重要的一部分。
from sklearn.svm import SVC
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.datasets import make_classification
import numpy as np
def test():
inputs, labels = make_classification(n_samples=100, n_features=20, n_informative=15, n_classes=10, random_state=42)
estimator = SVC(kernel='rbf', gamma=0.3)
estimator.fit(inputs, labels)
sample = [inputs[0]]
支持向量
svecs = estimator.support_vectors_
每个类别支持向量数量
nsupt = np.cumsum([0] + estimator.n_support_.tolist())
对偶系数
dcoef = estimator.dual_coef_
截距
icept = estimator.intercept_
参数
gamma = estimator.gamma
print('模型对偶系数:', dcoef.shape)
print('模型截距参数:', icept.shape)
print('模型类别列表:', estimator.classes_)
print('类别支持数量:', nsupt)
手动计算决策值
为截距生成索引
intercept_indexes = [f'{i}{j}' for i in range(10) for j in range(i + 1, 10)]
计算输入样本与所有支持向量的加权相似度
scores = dcoef * rbf_kernel(sample, svecs, gamma=gamma)
计算每个类别的支持向量与其他类别的支持向量的相似度分数
class_scores = []
for s, e in zip(nsupt[:-1], nsupt[1:]):
class_scores.append(scores[:, s:e].sum(axis=-1))
将分数展开
class_scores = np.array(class_scores)
class_scores = class_scores.ravel()
class_score_indexes = [f'{i}{j}' for i in range(10) for j in range(10) if i != j]
v1 = []
for flag in intercept_indexes:
a_index = class_score_indexes.index(flag)
b_index = class_score_indexes.index(flag[::-1])
c_index = intercept_indexes.index(flag)
v1. append(class_scores[a_index] + class_scores[b_index] + icept[c_index])
print(np.array(v1))
模型计算决策值
estimator.decision_function_shape = 'ovo'
v2 = estimator.decision_function(sample)
print(v2.squeeze())
print('手动计算和API计算结果:', np.all(np.array(v1) == v2))
if name == 'main':
test()
模型对偶系数: (9, 100)
模型截距参数: (45,)
模型类别列表: [0 1 2 3 4 5 6 7 8 9]
类别支持数量: [ 0 10 20 30 41 51 60 69 79 89 100]
[ 4.26630481e-04 1.15984298e-11 -9.09090910e-02 1.45700034e-09
1.00000000e-01 1.00000000e-01 3.20884059e-10 -1.00000000e+00
-9.09094941e-02 -3.93785796e-04 -9.09090910e-02 -3.65702293e-04
9.95395590e-02 9.95395438e-02 -3.93785411e-04 -1.00039376e+00
-9.09093704e-02 -9.09090918e-02 1.54111940e-09 1.00000000e-01
1.00000000e-01 3.12410279e-10 -1.00000000e+00 -9.09091250e-02
9.09090909e-02 1.81818182e-01 1.81818182e-01 9.09090922e-02
-9.09090904e-01 4.36981889e-09 9.99999994e-02 9.99999983e-02
1.11681132e-10 -1.00000000e+00 -9.09090851e-02 -4.46671502e-09
-1.00000006e-01 -9.99414063e-01 -1.81818177e-01 -9.99999995e-02
-9.99414063e-01 -1.81818177e-01 -1.00000000e+00 -9.09090851e-02
9.09090915e-01]
[ 4.26630481e-04 1.15984298e-11 -9.09090910e-02 1.45700034e-09
1.00000000e-01 1.00000000e-01 3.20884059e-10 -1.00000000e+00
-9.09094941e-02 -3.93785796e-04 -9.09090910e-02 -3.65702293e-04
9.95395590e-02 9.95395438e-02 -3.93785411e-04 -1.00039376e+00
-9.09093704e-02 -9.09090918e-02 1.54111940e-09 1.00000000e-01
1.00000000e-01 3.12410279e-10 -1.00000000e+00 -9.09091250e-02
9.09090909e-02 1.81818182e-01 1.81818182e-01 9.09090922e-02
-9.09090904e-01 4.36981889e-09 9.99999994e-02 9.99999983e-02
1.11681132e-10 -1.00000000e+00 -9.09090851e-02 -4.46671502e-09
-1.00000006e-01 -9.99414063e-01 -1.81818177e-01 -9.99999995e-02
-9.99414063e-01 -1.81818177e-01 -1.00000000e+00 -9.09090851e-02
9.09090915e-01]
手动计算和API计算结果: True
- 多分类 SVC OVR 决策函数 {#title-2} ==============================
在 SVC 中,当我们把 decision_function_shape 设置 ovr 时,实际内部仍然会先计算 ovo 决策值,然后再由 ovo 的决策值转换为 ovr 的决策值。ovr 中,每一个类别都对应了预测分数,我们最后将其归为分数最大的类别标签即可。
from sklearn.svm import SVC
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.datasets import make_classification
import numpy as np
def test():
inputs, labels = make_classification(n_samples=100, n_features=20, n_informative=15, n_classes=10, random_state=42)
estimator = SVC(kernel='rbf', gamma=0.3)
estimator.fit(inputs, labels)
sample = [inputs[0]]
支持向量
svecs = estimator.support_vectors_
每个类别支持向量数量
nsupt = np.cumsum([0] + estimator.n_support_.tolist())
对偶系数
dcoef = estimator.dual_coef_
截距
icept = estimator.intercept_
参数
gamma = estimator.gamma
print('模型对偶系数:', dcoef.shape)
print('模型截距参数:', icept.shape)
print('模型类别列表:', estimator.classes_)
print('类别支持数量:', nsupt)
手动计算决策值
estimator.decision_function_shape = 'ovo'
ovo = estimator.decision_function(sample)
ovr 分数是由 ovo 转换得到
from sklearn.utils.multiclass import _ovr_decision_function
第一个参数:每一个分类器预测的类别
第二个参数:预测为 +1 类别的分数或概率
第三个参数:类别的数量
v1 = ovr_decision_function(ovo < 0, -ovo, len(estimator.classes))
print(v1.squeeze())
模型计算决策值
estimator.decision_function_shape = 'ovr'
v2 = estimator.decision_function(sample)
print(v2.squeeze())
if name == 'main':
test()
模型对偶系数: (9, 100)
模型截距参数: (45,)
模型类别列表: [0 1 2 3 4 5 6 7 8 9]
类别支持数量: [ 0 10 20 30 41 51 60 69 79 89 100]
[ 5.83489857 1.83461706 4.83489581 7.97222223 3.83489343 -0.21688867
0.78311133 2.83489581 9.29938003 6.97222241]
[ 5.83489857 1.83461706 4.83489581 7.97222223 3.83489343 -0.21688867
0.78311133 2.83489581 9.29938003 6.97222241]