forked from guangyaooo/MLTemplate
-
Notifications
You must be signed in to change notification settings - Fork 0
/
MultiClassSVM.py
73 lines (57 loc) · 1.96 KB
/
MultiClassSVM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import numpy as np
from utils import common
from itertools import combinations
from SupportVectorMachine.SVM import RBF_kernel,linear_kernel,SVM,auto_scale
class MultiClassSVM:
def __init__(self,C,kernel,classes=None,tol=1e-3):
self.C = C
self.kernel = kernel
self.classes = classes
self.tol=tol
self.svms = []
self.class_num = len(classes) if classes is not None else 0
def fit(self,X,y):
'''
:param X: N x d
:param y: N
:return:
'''
if self.classes is None:
self.classes = np.sort(np.unique(y))
self.class_num = len(self.classes)
for i,specified in enumerate(combinations(self.classes,2)):
print('SVM: %d %d' % specified)
data,label = common.data_filter(X,y,specified)
if self.kernel=='rbf':
sigma = auto_scale(data)
kernel = RBF_kernel(sigma)
elif self.kernel=='linear':
kernel = linear_kernel()
else:
raise NotImplemented()
svm = SVM(self.C, kernel, show_fitting_bar=True, max_iter=1000)
svm.fit(data,label,tol=self.tol)
self.svms.append(svm)
def predict(self,X):
'''
:param X: N x d
:return: N
'''
vote_res = []
for svm in self.svms:
vote_res.append(svm.predict(X).reshape((-1,1)))
vote_res = np.concatenate(vote_res,axis=1).astype(np.int)
pred = []
for row in vote_res:
pred.append(np.argmax(np.bincount(row)))
pred = np.asarray(pred)
return pred
if __name__ == '__main__':
np.random.seed(1)
from sklearn.datasets import load_iris
data, label = load_iris(return_X_y=True)
svm = MultiClassSVM(1.0, kernel='rbf')
svm.fit(data, label)
y_pred = svm.predict(data)
acc = np.sum(y_pred == label) / len(y_pred)
print('MultiClassSVM Test Acc %.4f' % acc)