islr notes and exercises from An Introduction to Statistical Learning

8. Tree-based Methods

Exercise 7: Plotting test error for parameter values of random forest model in chapter 8 lab

Preparing the data

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

boston = pd.read_csv('../../datasets/Boston.csv', index_col=0)
boston = boston.reset_index(drop=True)
boston.head()
crim zn indus chas nox rm age dis rad tax ptratio black lstat medv
0 0.00632 18.0 2.31 0 0.538 6.575 65.2 4.0900 1 296 15.3 396.90 4.98 24.0
1 0.02731 0.0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 396.90 9.14 21.6
2 0.02729 0.0 7.07 0 0.469 7.185 61.1 4.9671 2 242 17.8 392.83 4.03 34.7
3 0.03237 0.0 2.18 0 0.458 6.998 45.8 6.0622 3 222 18.7 394.63 2.94 33.4
4 0.06905 0.0 2.18 0 0.458 7.147 54.2 6.0622 3 222 18.7 396.90 5.33 36.2
boston.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 506 entries, 0 to 505
Data columns (total 14 columns):
crim       506 non-null float64
zn         506 non-null float64
indus      506 non-null float64
chas       506 non-null int64
nox        506 non-null float64
rm         506 non-null float64
age        506 non-null float64
dis        506 non-null float64
rad        506 non-null int64
tax        506 non-null int64
ptratio    506 non-null float64
black      506 non-null float64
lstat      506 non-null float64
medv       506 non-null float64
dtypes: float64(11), int64(3)
memory usage: 55.4 KB

Grid search for random forest model

We’re using sklearn.ensemble.RandomForestRegressor - the parameters n_estimators, max_features correspond to the parameters ntree and mtry respectively, for the R function randomForest().

from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import GridSearchCV

n_estimators, max_features = np.arange(1, 51), np.arange(1, 14)
params = {'n_estimators': n_estimators, 'max_features': max_features}
rf_search = GridSearchCV(RandomForestRegressor(), param_grid=params, scoring='neg_mean_squared_error', cv=5)
rf_search.fit(boston.drop(columns=['medv']), boston['medv'])
/anaconda3/envs/islr/lib/python3.7/site-packages/sklearn/model_selection/_search.py:841: DeprecationWarning: The default of the `iid` parameter will change from True to False in version 0.22 and will be removed in 0.24. This will change numeric results when test-set sizes are unequal.
  DeprecationWarning)





GridSearchCV(cv=5, error_score='raise-deprecating',
       estimator=RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=None,
           max_features='auto', max_leaf_nodes=None,
           min_impurity_decrease=0.0, min_impurity_split=None,
           min_samples_leaf=1, min_samples_split=2,
           min_weight_fraction_leaf=0.0, n_estimators='warn', n_jobs=None,
           oob_score=False, random_state=None, verbose=0, warm_start=False),
       fit_params=None, iid='warn', n_jobs=None,
       param_grid={'n_estimators': array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
       18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
       35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]), 'max_features': array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13])},
       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
       scoring='neg_mean_squared_error', verbose=0)
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

rf_search_results = pd.DataFrame(rf_search.cv_results_)
rf_search_results
mean_fit_time std_fit_time mean_score_time std_score_time param_max_features param_n_estimators params split0_test_score split1_test_score split2_test_score ... mean_test_score std_test_score rank_test_score split0_train_score split1_train_score split2_train_score split3_train_score split4_train_score mean_train_score std_train_score
0 0.003891 0.000834 0.001949 0.000402 1 1 {'max_features': 1, 'n_estimators': 1} -21.991667 -85.189010 -54.552079 ... -64.012530 23.806851 648 -14.603713 -14.161111 -11.699358 -16.489481 -12.779235 -13.946580 1.634436
1 0.007109 0.000531 0.002835 0.000614 1 2 {'max_features': 1, 'n_estimators': 2} -23.077500 -40.479134 -41.309505 ... -49.661359 20.626230 646 -9.532983 -8.897994 -4.985086 -8.101160 -8.354475 -7.974340 1.573450
2 0.005619 0.001508 0.001752 0.000513 1 3 {'max_features': 1, 'n_estimators': 3} -17.650196 -41.920968 -80.304345 ... -45.140823 21.318752 642 -7.537261 -7.789177 -6.061572 -4.895781 -7.208291 -6.698416 1.077818
3 0.006459 0.001157 0.001699 0.000412 1 4 {'max_features': 1, 'n_estimators': 4} -28.594032 -42.909629 -56.985149 ... -46.777108 12.584626 643 -5.815023 -5.466340 -3.049759 -5.406645 -4.841636 -4.915881 0.983894
4 0.006406 0.001049 0.001832 0.000544 1 5 {'max_features': 1, 'n_estimators': 5} -14.134204 -33.097687 -35.283347 ... -35.554827 12.541295 621 -5.153073 -4.932412 -3.802281 -3.102355 -3.784279 -4.154880 0.770769
5 0.007777 0.001909 0.001566 0.000146 1 6 {'max_features': 1, 'n_estimators': 6} -20.835833 -33.876378 -57.023809 ... -36.389669 15.985771 626 -5.226819 -4.592859 -5.367281 -3.909785 -4.322368 -4.683823 0.547726
6 0.008977 0.002770 0.001802 0.000529 1 7 {'max_features': 1, 'n_estimators': 7} -15.424894 -35.388693 -68.895278 ... -40.056006 17.818470 638 -3.979481 -3.187319 -3.898648 -2.693557 -3.768051 -3.505411 0.491660
7 0.011081 0.001878 0.001943 0.000419 1 8 {'max_features': 1, 'n_estimators': 8} -12.485312 -21.791623 -55.193929 ... -36.435462 21.621571 627 -4.376785 -3.511429 -3.452217 -2.734615 -3.731063 -3.561222 0.527566
8 0.010636 0.002615 0.002022 0.000730 1 9 {'max_features': 1, 'n_estimators': 9} -18.031892 -31.122899 -37.199493 ... -36.325122 13.528382 625 -3.032920 -2.808331 -2.432588 -3.095907 -4.087000 -3.091350 0.549331
9 0.012460 0.002050 0.002123 0.000522 1 10 {'max_features': 1, 'n_estimators': 10} -17.399393 -33.092368 -32.873599 ... -31.259992 14.477833 582 -3.891032 -3.469017 -2.679415 -3.634116 -2.866854 -3.308087 0.460854
10 0.019396 0.002247 0.002984 0.001326 1 11 {'max_features': 1, 'n_estimators': 11} -11.739315 -22.884135 -53.257725 ... -34.255825 18.556652 612 -2.524618 -2.831083 -3.511232 -2.794624 -3.029749 -2.938261 0.328599
11 0.013650 0.002062 0.002338 0.000655 1 12 {'max_features': 1, 'n_estimators': 12} -20.765519 -29.219963 -43.909655 ... -35.303845 14.835744 620 -3.728071 -2.489337 -2.629168 -2.102650 -3.578967 -2.905638 0.636286
12 0.013520 0.002860 0.002294 0.000803 1 13 {'max_features': 1, 'n_estimators': 13} -18.235791 -41.704494 -28.564638 ... -33.328750 14.136039 602 -3.033637 -2.431979 -2.363831 -2.386989 -2.513514 -2.545990 0.249125
13 0.014610 0.002561 0.002329 0.000904 1 14 {'max_features': 1, 'n_estimators': 14} -13.594839 -20.757757 -55.323856 ... -31.949899 17.840135 589 -3.856282 -2.813479 -2.660921 -2.461808 -3.621529 -3.082804 0.552205
14 0.016239 0.003550 0.002126 0.000099 1 15 {'max_features': 1, 'n_estimators': 15} -15.587993 -37.151651 -39.838389 ... -37.000578 14.092232 629 -2.979421 -2.745065 -2.955570 -2.302839 -2.104006 -2.617380 0.353338
15 0.021542 0.007437 0.004869 0.001939 1 16 {'max_features': 1, 'n_estimators': 16} -18.313434 -37.697324 -50.224384 ... -39.177959 13.554225 636 -2.383039 -2.313836 -2.564468 -2.085711 -2.971439 -2.463699 0.296579
16 0.031045 0.007147 0.004051 0.001197 1 17 {'max_features': 1, 'n_estimators': 17} -14.477579 -38.545166 -47.593614 ... -36.993035 14.366971 628 -3.384781 -2.466205 -2.349847 -2.080384 -2.440984 -2.544440 0.441862
17 0.020159 0.003701 0.002337 0.000288 1 18 {'max_features': 1, 'n_estimators': 18} -13.351092 -29.703280 -41.840159 ... -32.092468 12.389404 591 -2.381153 -2.327830 -2.322180 -2.293605 -3.171681 -2.499290 0.337384
18 0.023308 0.001765 0.003537 0.001120 1 19 {'max_features': 1, 'n_estimators': 19} -15.882930 -31.164088 -39.008378 ... -34.578963 14.486152 616 -3.857120 -2.922748 -2.321944 -2.458541 -3.067876 -2.925646 0.542314
19 0.024203 0.002189 0.003143 0.000700 1 20 {'max_features': 1, 'n_estimators': 20} -17.061018 -37.885674 -44.421256 ... -34.235834 11.267077 611 -3.016849 -2.616610 -2.271411 -1.724738 -2.350052 -2.395932 0.424817
20 0.022982 0.002748 0.002810 0.001135 1 21 {'max_features': 1, 'n_estimators': 21} -17.608000 -29.764712 -57.100468 ... -34.166448 15.051979 610 -2.849463 -2.687234 -2.648921 -2.026325 -2.457962 -2.533981 0.282744
21 0.025036 0.002274 0.003434 0.001300 1 22 {'max_features': 1, 'n_estimators': 22} -16.536106 -26.593730 -42.356686 ... -33.135153 14.229106 600 -2.971610 -2.912313 -2.576134 -2.170917 -3.137634 -2.753722 0.343865
22 0.027406 0.005331 0.004077 0.001138 1 23 {'max_features': 1, 'n_estimators': 23} -15.404804 -26.790607 -47.250104 ... -33.781914 15.178863 605 -2.764676 -2.730681 -2.743073 -2.136093 -3.241368 -2.723178 0.350817
23 0.024246 0.004053 0.002846 0.000299 1 24 {'max_features': 1, 'n_estimators': 24} -13.008028 -36.532921 -36.781510 ... -34.381056 15.122934 614 -3.566207 -2.849794 -2.161991 -2.167744 -2.510485 -2.651244 0.523361
24 0.024699 0.001386 0.003098 0.000754 1 25 {'max_features': 1, 'n_estimators': 25} -17.054358 -37.102535 -43.966166 ... -34.298870 14.970548 613 -2.938494 -2.180039 -2.440997 -1.917603 -2.551470 -2.405720 0.345116
25 0.027446 0.005452 0.003953 0.001620 1 26 {'max_features': 1, 'n_estimators': 26} -16.987042 -34.845853 -51.564609 ... -34.968050 15.036683 618 -3.129059 -2.227953 -3.506445 -2.467797 -1.866639 -2.639579 0.597901
26 0.027815 0.003477 0.002695 0.000157 1 27 {'max_features': 1, 'n_estimators': 27} -14.762697 -28.462289 -39.627764 ... -29.153360 10.051608 567 -3.074188 -1.793045 -3.078800 -1.674099 -1.989327 -2.321892 0.624303
27 0.028738 0.006258 0.002989 0.000205 1 28 {'max_features': 1, 'n_estimators': 28} -16.580555 -38.412260 -39.808376 ... -35.772924 13.948740 622 -2.825889 -2.647394 -2.635796 -2.118854 -2.340015 -2.513589 0.251521
28 0.027386 0.003200 0.003024 0.000625 1 29 {'max_features': 1, 'n_estimators': 29} -16.863019 -29.274269 -45.390146 ... -34.144240 11.450715 609 -2.444111 -2.820349 -2.001195 -2.113187 -2.791126 -2.433993 0.336793
29 0.028264 0.004045 0.003585 0.000949 1 30 {'max_features': 1, 'n_estimators': 30} -13.082333 -26.423729 -36.927213 ... -34.760669 14.209880 617 -2.995265 -2.269125 -2.131816 -2.082101 -2.156179 -2.326897 0.339759
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
620 0.056998 0.002920 0.002737 0.000033 13 21 {'max_features': 13, 'n_estimators': 21} -8.214533 -14.830912 -22.534429 ... -22.231679 14.579162 391 -1.754569 -2.005698 -1.794253 -1.358432 -1.487365 -1.680063 0.230307
621 0.060251 0.001883 0.003669 0.001008 13 22 {'max_features': 13, 'n_estimators': 22} -7.563142 -12.218573 -24.094603 ... -21.770627 14.921198 345 -1.771726 -1.696460 -1.623146 -1.287195 -1.534081 -1.582522 0.167329
622 0.061005 0.001704 0.003103 0.000323 13 23 {'max_features': 13, 'n_estimators': 23} -8.105958 -13.928286 -18.885602 ... -20.635471 13.060317 199 -1.891957 -1.411495 -1.325187 -1.296137 -2.003337 -1.585623 0.300092
623 0.063468 0.001463 0.002937 0.000097 13 24 {'max_features': 13, 'n_estimators': 24} -7.601690 -15.970749 -21.354981 ... -22.495433 13.645004 408 -1.731364 -1.863511 -1.403066 -1.260503 -1.713268 -1.594342 0.225125
624 0.066116 0.001926 0.003166 0.000096 13 25 {'max_features': 13, 'n_estimators': 25} -8.235568 -11.828849 -21.175247 ... -21.176898 12.580003 261 -1.836137 -1.731191 -1.778248 -1.423419 -1.764215 -1.706642 0.145622
625 0.071170 0.005523 0.003244 0.000451 13 26 {'max_features': 13, 'n_estimators': 26} -8.400417 -12.670221 -23.192632 ... -21.533916 12.765942 310 -2.220595 -2.052685 -1.643289 -1.201552 -1.447342 -1.713092 0.376842
626 0.070533 0.001690 0.003245 0.000122 13 27 {'max_features': 13, 'n_estimators': 27} -8.244153 -11.947600 -19.265464 ... -21.861377 13.279718 356 -1.534273 -1.521596 -1.813661 -1.328404 -2.152099 -1.670007 0.286423
627 0.074091 0.003326 0.003271 0.000304 13 28 {'max_features': 13, 'n_estimators': 28} -8.637694 -14.060935 -25.394871 ... -22.784116 13.353693 424 -1.669801 -1.748181 -1.623929 -1.392020 -1.375407 -1.561868 0.150884
628 0.076364 0.002998 0.003175 0.000052 13 29 {'max_features': 13, 'n_estimators': 29} -8.121697 -14.794659 -19.478598 ... -21.329425 12.810383 278 -1.664764 -1.473133 -1.515179 -1.250855 -1.580039 -1.496794 0.138944
629 0.080446 0.002324 0.003275 0.000146 13 30 {'max_features': 13, 'n_estimators': 30} -8.327223 -13.938036 -18.504142 ... -21.177455 12.801173 262 -2.062341 -1.929879 -1.695927 -1.258211 -2.019631 -1.793198 0.295995
630 0.082867 0.004171 0.003718 0.000470 13 31 {'max_features': 13, 'n_estimators': 31} -7.962718 -13.002553 -23.942338 ... -22.044556 12.910633 375 -1.644527 -1.880164 -1.529467 -1.261535 -1.878561 -1.638851 0.232404
631 0.094475 0.003793 0.003925 0.000547 13 32 {'max_features': 13, 'n_estimators': 32} -8.576633 -12.183874 -21.498011 ... -21.738878 12.744928 341 -2.025376 -1.560575 -1.220283 -1.242486 -2.247338 -1.659212 0.413767
632 0.086008 0.003725 0.003554 0.000156 13 33 {'max_features': 13, 'n_estimators': 33} -8.437380 -12.298403 -20.587260 ... -22.134070 13.272654 382 -1.808651 -1.573553 -1.464411 -1.280631 -1.754965 -1.576442 0.192798
633 0.088431 0.003652 0.003595 0.000137 13 34 {'max_features': 13, 'n_estimators': 34} -7.789550 -13.488061 -21.859775 ... -21.861308 12.894215 355 -1.679846 -1.667615 -1.791897 -1.321465 -1.660280 -1.624221 0.158779
634 0.091344 0.002211 0.003711 0.000252 13 35 {'max_features': 13, 'n_estimators': 35} -8.359150 -13.595726 -22.272514 ... -22.603360 14.050540 412 -1.726925 -1.674344 -1.427541 -1.330239 -1.826045 -1.597019 0.187191
635 0.093527 0.001836 0.003673 0.000136 13 36 {'max_features': 13, 'n_estimators': 36} -8.207619 -11.886666 -20.280017 ... -21.392191 13.865106 291 -1.981380 -1.686528 -1.272086 -1.248820 -1.550914 -1.547946 0.273003
636 0.097891 0.003279 0.003842 0.000363 13 37 {'max_features': 13, 'n_estimators': 37} -8.413852 -13.636418 -18.461581 ... -21.609566 14.287871 318 -1.748631 -1.972837 -1.519343 -1.374962 -1.448768 -1.612908 0.219219
637 0.104330 0.014696 0.003811 0.000158 13 38 {'max_features': 13, 'n_estimators': 38} -8.398189 -12.856086 -20.030461 ... -21.481719 13.813047 303 -1.703596 -1.638732 -1.540393 -1.363200 -1.611878 -1.571560 0.116582
638 0.106177 0.002897 0.004407 0.000388 13 39 {'max_features': 13, 'n_estimators': 39} -8.954183 -13.575304 -21.071506 ... -21.889710 12.731358 358 -1.499969 -1.430730 -1.587401 -1.257869 -1.530888 -1.461371 0.113629
639 0.139257 0.021265 0.005298 0.000931 13 40 {'max_features': 13, 'n_estimators': 40} -8.828245 -15.956237 -19.494517 ... -22.224991 13.058084 389 -1.680124 -1.453758 -1.493468 -1.233556 -1.403723 -1.452926 0.144089
640 0.122522 0.019294 0.004169 0.000311 13 41 {'max_features': 13, 'n_estimators': 41} -8.139146 -15.458479 -21.882525 ... -21.569182 13.141029 313 -1.855412 -1.641777 -1.475673 -1.209853 -1.872287 -1.611001 0.248269
641 0.113846 0.005494 0.004137 0.000280 13 42 {'max_features': 13, 'n_estimators': 42} -8.037270 -12.633225 -26.076467 ... -22.908325 14.112196 430 -1.897440 -1.608781 -1.485819 -1.123077 -1.209502 -1.464924 0.279392
642 0.110486 0.002739 0.004201 0.000132 13 43 {'max_features': 13, 'n_estimators': 43} -7.738024 -11.768356 -23.203926 ... -22.395667 13.987355 399 -1.503985 -1.495877 -1.375177 -1.422425 -1.366992 -1.432891 0.057973
643 0.111097 0.002981 0.004338 0.000172 13 44 {'max_features': 13, 'n_estimators': 44} -7.885614 -13.315915 -20.279256 ... -21.864941 13.759772 357 -1.731070 -1.665641 -1.559994 -1.225409 -1.964766 -1.629376 0.241721
644 0.114695 0.002069 0.004231 0.000199 13 45 {'max_features': 13, 'n_estimators': 45} -8.522087 -13.834266 -19.959325 ... -21.711338 13.139111 336 -1.752770 -1.384300 -1.472556 -1.272192 -1.593057 -1.494975 0.166411
645 0.124314 0.010494 0.004609 0.000780 13 46 {'max_features': 13, 'n_estimators': 46} -9.244762 -13.533141 -20.165793 ... -21.277579 11.760141 275 -1.890338 -1.523985 -1.429138 -1.260989 -1.434726 -1.507835 0.209304
646 0.125519 0.003614 0.004486 0.000513 13 47 {'max_features': 13, 'n_estimators': 47} -7.523758 -14.241435 -23.102508 ... -22.140001 13.534116 383 -1.688069 -1.446108 -1.646988 -1.208258 -1.689223 -1.535729 0.186770
647 0.134769 0.006030 0.004768 0.000579 13 48 {'max_features': 13, 'n_estimators': 48} -7.969312 -12.855202 -22.178862 ... -21.676694 13.110346 331 -1.810569 -1.547392 -1.517920 -1.351242 -1.432992 -1.532023 0.155249
648 0.130817 0.008154 0.005051 0.000525 13 49 {'max_features': 13, 'n_estimators': 49} -8.560279 -12.427586 -21.365338 ... -21.529505 12.947968 307 -1.643209 -1.597696 -1.413012 -1.193954 -1.778295 -1.525233 0.202755
649 0.137780 0.003193 0.004835 0.000229 13 50 {'max_features': 13, 'n_estimators': 50} -8.213808 -13.037796 -18.630593 ... -21.511986 13.648508 306 -1.693073 -1.399460 -1.560297 -1.264475 -1.516853 -1.486832 0.145540

650 rows × 22 columns

from mpl_toolkits import mplot3d

n_estimators, max_features = np.arange(1, 51), np.arange(1, 14)
X, Y = np.meshgrid(max_features, n_estimators)
Z = scores = np.sqrt(-rf_search_results[['mean_test_score']].values).reshape(50, 13)

fig = plt.figure(figsize=(15, 10))
ax = plt.axes(projection='3d')
ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
                cmap='Greys', edgecolor='none')
ax.set_xlabel('n_estimators')
ax.set_ylabel('max_features')
ax.set_zlabel('rmse');
ax.view_init(60, 35)

png

rf_search.best_params_
{'max_features': 8, 'n_estimators': 19}
np.sqrt(-rf_search.best_score_)
4.198781714064752