[ML-17]Stacking

17. Stacking

what?

  • 앙상블 학습에서 각 모델의 예측값을 가지고 새로운 메타 모델(meta learner)을 학습시켜 최종 예측 모델을 만드는 방법

  • base-level classifier를 통해 도출된 예측값을 메타 모델을 학습시키는 input data로 사용한다.

stacking

why?

  • 단일 모델을 사용하는 경우보다 더 높은 정확도를 보인다

  • 기존에 학습시킨 모델을 활용하는 것이 가능해 협업에 유리하다

why not?

  • 연산량이 많아 속도가 느리고 computational cost가 높아 현업에서 사용하기 힘들다

how?

Input

1) Data \(D = {(x_i, y_i)}_{i=1}^m\)

Step 1 for t = 1 to T(number of base-level classifiers) train \(h_t\) based on \(D\)

  • base-level classifier(ex SVM, KNN, Random Forest…)를 학습시킨다

Step 2 train meta classifier

1) for i = 1 to m(number of samples) construct new data set

  • base-level classifier의 예측값을 meta classifier를 학습시키기 위한 input data로 사용한다

2) train meta classifier

  • learn \(H\) based on \(D_h\)

Step 3 output ensemble classifier \(H\)

Code usage


from sklearn.model_selection import train_test_split 

from sklearn.metrics import accuracy_score 

from sklearn.ensemble import ExtraTreesClassifier 

from sklearn.ensemble import RandomForestClassifier 

from xgboost import XGBClassifier 

from vecstack import stacking



iris = load_iris()

X, y = iris.data, iris.target



X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.2, random_state=0)



models = [ ExtraTreesClassifier(random_state = 0, n_jobs = -1, n_estimators = 100, max_depth = 3), 

          RandomForestClassifier(random_state = 0, n_jobs = -1, n_estimators = 100, max_depth = 3), 

          XGBClassifier(seed = 0, n_jobs = -1, learning_rate = 0.1, n_estimators = 100, max_depth = 3)]



S_train, S_test = stacking(models, X_train, y_train, X_test, 

                           regression = False, metric = accuracy_score, 

                           n_folds = 4, stratified = True, shuffle = True, 

                           random_state = 0, verbose = 2)



model = XGBClassifier(seed = 0, n_jobs = -1, learning_rate = 0.1, n_estimators = 100, max_depth = 3) 



# Fit 2-nd level model 

model = model.fit(S_train, y_train) 



# Predict 

y_pred = model.predict(S_test) 



# Final prediction score 

print('Final prediction score: [%.8f]' % accuracy_score(y_test, y_pred))

Reference

핸즈온 머신러닝

stacking 코드

stacking 설명

Continue reading

Pagination


© 2017. by herbwood

Powered by aiden