Hike News
Hike News

機器學習-模型的選擇與調校參數

交叉驗證(cross validation)

  • 目的:為了讓被評估的模型更加準確可信
  • 交叉驗證通常搭配網格搜索一起使用
  • 將訓練數據分成n等分,以下圖為例:讓其中一等分當作驗證集 其他則為訓練集,總共驗證5次(組),每次更換不同的驗證集,得到5組模型的結果,求出準確率的平均值作為最終結果,又稱5折交叉驗證

交叉驗證

網格搜索(超參數搜索)

  • 通常情況下,有很多參數需要手動指定(如:K Neighbors Classifier的K值),稱之為超參數
  • 調整參數的用意是希望model所表現的預測效果越好
  • 手動過程繁雜,所以需要對模型預設幾種超參數組合,且每組超參數都採用交叉驗證來進行評估,最後選出最優的參數組合建立模型
  • 當同一個演算法,超參數不止一個時會進行交叉測試
  • 使用sklearn.model_selection.GridSearchCV (網格搜索並同時交叉驗證)

GridSearchCV

  • sklearn.model_selection.GridSearchCV(estimator, param_grid, cv=None)
  • estimator : 估計器object
  • param_grid : 估計器參數(dict),如:{"n_neighbors":[1,3,5]}
  • cv : 指定幾折交叉驗證
  • 返回的就是一個新的做好驗證的estimator object,仍可以調用fit, predict, score等方法
  • 可調用結果分析方法:
    • bestscore : 在交叉驗證中驗證的最好結果
    • bestestimator : 最好的參數模型
    • cvresults : 每次交叉驗證後的驗證集準確率結果和訓練集準確率結果

Example

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from sklearn.model_selection import GridSearchCV,train_test_split
from sklearn.datasets import load_iris
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score


iris_data = load_iris()

def GridSearchCV_test():
knn = KNeighborsClassifier()
K_value = {"n_neighbors":[3,5,7,10]}
GS_CV = GridSearchCV(knn, param_grid=K_value, cv=3)

x_train, x_test, y_train, y_test = train_test_split(iris_data.data,iris_data.target)
GS_CV.fit(x_train,y_train)
y_predict = GS_CV.predict(x_test)
score = accuracy_score(y_test,y_predict)
print("準確率:",score)
print("在交叉驗證中驗證的最好結果:",GS_CV.best_score_)
print("最好的參數模型:\n",GS_CV.best_estimator_)
print("每次交叉驗證後的驗證集準確率結果和訓練集準確率結果:\n",GS_CV.cv_results_)

if __name__ == '__main__':
GridSearchCV_test()

結果

1
2
3
4
5
6
7
8
9
10
11
準確率: 0.9473684210526315
在交叉驗證中驗證的最好結果: 0.9910714285714286
最好的參數模型:
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
metric_params=None, n_jobs=1, n_neighbors=10, p=2,
weights='uniform')
每次交叉驗證後的驗證集準確率結果和訓練集準確率結果:
{'mean_fit_time': array([0.00023023, 0.0001897 , 0.00018748, 0.00018843]), 'std_fit_time': array([5.71920846e-05, 8.99132768e-07, 1.38109105e-06, 4.49566384e-07]), 'mean_score_time': array([0.00042526, 0.00037464, 0.00037185, 0.00038997]), 'std_score_time': array([7.77299011e-05, 4.85630548e-06, 8.77806426e-07, 7.62525420e-06]), 'param_n_neighbors': masked_array(data=[3, 5, 7, 10],
mask=[False, False, False, False],
fill_value='?',
dtype=object), 'params': [{'n_neighbors': 3}, {'n_neighbors': 5}, {'n_neighbors': 7}, {'n_neighbors': 10}], 'split0_test_score': array([0.97368421, 0.97368421, 0.97368421, 0.97368421]), 'split1_test_score': array([1. , 0.97297297, 0.97297297, 1. ]), 'split2_test_score': array([0.97297297, 0.97297297, 0.97297297, 1. ]), 'mean_test_score': array([0.98214286, 0.97321429, 0.97321429, 0.99107143]), 'std_test_score': array([0.01254582, 0.00033675, 0.00033675, 0.01245966]), 'rank_test_score': array([2, 3, 3, 1], dtype=int32), 'split0_train_score': array([1. , 1. , 1. , 0.98648649]), 'split1_train_score': array([0.98666667, 0.98666667, 0.98666667, 0.98666667]), 'split2_train_score': array([0.98666667, 0.98666667, 0.98666667, 0.98666667]), 'mean_train_score': array([0.99111111, 0.99111111, 0.99111111, 0.98660661]), 'std_train_score': array([6.28539361e-03, 6.28539361e-03, 6.28539361e-03, 8.49377515e-05])}