%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
boston = pd.read_csv('../../datasets/Boston.csv', index_col=0)
boston = boston.reset_index(drop=True)
boston.head()
crim | zn | indus | chas | nox | rm | age | dis | rad | tax | ptratio | black | lstat | medv | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.00632 | 18.0 | 2.31 | 0 | 0.538 | 6.575 | 65.2 | 4.0900 | 1 | 296 | 15.3 | 396.90 | 4.98 | 24.0 |
1 | 0.02731 | 0.0 | 7.07 | 0 | 0.469 | 6.421 | 78.9 | 4.9671 | 2 | 242 | 17.8 | 396.90 | 9.14 | 21.6 |
2 | 0.02729 | 0.0 | 7.07 | 0 | 0.469 | 7.185 | 61.1 | 4.9671 | 2 | 242 | 17.8 | 392.83 | 4.03 | 34.7 |
3 | 0.03237 | 0.0 | 2.18 | 0 | 0.458 | 6.998 | 45.8 | 6.0622 | 3 | 222 | 18.7 | 394.63 | 2.94 | 33.4 |
4 | 0.06905 | 0.0 | 2.18 | 0 | 0.458 | 7.147 | 54.2 | 6.0622 | 3 | 222 | 18.7 | 396.90 | 5.33 | 36.2 |
boston.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 506 entries, 0 to 505
Data columns (total 14 columns):
crim 506 non-null float64
zn 506 non-null float64
indus 506 non-null float64
chas 506 non-null int64
nox 506 non-null float64
rm 506 non-null float64
age 506 non-null float64
dis 506 non-null float64
rad 506 non-null int64
tax 506 non-null int64
ptratio 506 non-null float64
black 506 non-null float64
lstat 506 non-null float64
medv 506 non-null float64
dtypes: float64(11), int64(3)
memory usage: 55.4 KB
We’re using sklearn.ensemble.RandomForestRegressor
- the parameters n_estimators
, max_features
correspond to the parameters ntree
and mtry
respectively, for the R
function randomForest()
.
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import GridSearchCV
n_estimators, max_features = np.arange(1, 51), np.arange(1, 14)
params = {'n_estimators': n_estimators, 'max_features': max_features}
rf_search = GridSearchCV(RandomForestRegressor(), param_grid=params, scoring='neg_mean_squared_error', cv=5)
rf_search.fit(boston.drop(columns=['medv']), boston['medv'])
/anaconda3/envs/islr/lib/python3.7/site-packages/sklearn/model_selection/_search.py:841: DeprecationWarning: The default of the `iid` parameter will change from True to False in version 0.22 and will be removed in 0.24. This will change numeric results when test-set sizes are unequal.
DeprecationWarning)
GridSearchCV(cv=5, error_score='raise-deprecating',
estimator=RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=None,
max_features='auto', max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, n_estimators='warn', n_jobs=None,
oob_score=False, random_state=None, verbose=0, warm_start=False),
fit_params=None, iid='warn', n_jobs=None,
param_grid={'n_estimators': array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]), 'max_features': array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])},
pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
scoring='neg_mean_squared_error', verbose=0)
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
rf_search_results = pd.DataFrame(rf_search.cv_results_)
rf_search_results
mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_max_features | param_n_estimators | params | split0_test_score | split1_test_score | split2_test_score | ... | mean_test_score | std_test_score | rank_test_score | split0_train_score | split1_train_score | split2_train_score | split3_train_score | split4_train_score | mean_train_score | std_train_score | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.003891 | 0.000834 | 0.001949 | 0.000402 | 1 | 1 | {'max_features': 1, 'n_estimators': 1} | -21.991667 | -85.189010 | -54.552079 | ... | -64.012530 | 23.806851 | 648 | -14.603713 | -14.161111 | -11.699358 | -16.489481 | -12.779235 | -13.946580 | 1.634436 |
1 | 0.007109 | 0.000531 | 0.002835 | 0.000614 | 1 | 2 | {'max_features': 1, 'n_estimators': 2} | -23.077500 | -40.479134 | -41.309505 | ... | -49.661359 | 20.626230 | 646 | -9.532983 | -8.897994 | -4.985086 | -8.101160 | -8.354475 | -7.974340 | 1.573450 |
2 | 0.005619 | 0.001508 | 0.001752 | 0.000513 | 1 | 3 | {'max_features': 1, 'n_estimators': 3} | -17.650196 | -41.920968 | -80.304345 | ... | -45.140823 | 21.318752 | 642 | -7.537261 | -7.789177 | -6.061572 | -4.895781 | -7.208291 | -6.698416 | 1.077818 |
3 | 0.006459 | 0.001157 | 0.001699 | 0.000412 | 1 | 4 | {'max_features': 1, 'n_estimators': 4} | -28.594032 | -42.909629 | -56.985149 | ... | -46.777108 | 12.584626 | 643 | -5.815023 | -5.466340 | -3.049759 | -5.406645 | -4.841636 | -4.915881 | 0.983894 |
4 | 0.006406 | 0.001049 | 0.001832 | 0.000544 | 1 | 5 | {'max_features': 1, 'n_estimators': 5} | -14.134204 | -33.097687 | -35.283347 | ... | -35.554827 | 12.541295 | 621 | -5.153073 | -4.932412 | -3.802281 | -3.102355 | -3.784279 | -4.154880 | 0.770769 |
5 | 0.007777 | 0.001909 | 0.001566 | 0.000146 | 1 | 6 | {'max_features': 1, 'n_estimators': 6} | -20.835833 | -33.876378 | -57.023809 | ... | -36.389669 | 15.985771 | 626 | -5.226819 | -4.592859 | -5.367281 | -3.909785 | -4.322368 | -4.683823 | 0.547726 |
6 | 0.008977 | 0.002770 | 0.001802 | 0.000529 | 1 | 7 | {'max_features': 1, 'n_estimators': 7} | -15.424894 | -35.388693 | -68.895278 | ... | -40.056006 | 17.818470 | 638 | -3.979481 | -3.187319 | -3.898648 | -2.693557 | -3.768051 | -3.505411 | 0.491660 |
7 | 0.011081 | 0.001878 | 0.001943 | 0.000419 | 1 | 8 | {'max_features': 1, 'n_estimators': 8} | -12.485312 | -21.791623 | -55.193929 | ... | -36.435462 | 21.621571 | 627 | -4.376785 | -3.511429 | -3.452217 | -2.734615 | -3.731063 | -3.561222 | 0.527566 |
8 | 0.010636 | 0.002615 | 0.002022 | 0.000730 | 1 | 9 | {'max_features': 1, 'n_estimators': 9} | -18.031892 | -31.122899 | -37.199493 | ... | -36.325122 | 13.528382 | 625 | -3.032920 | -2.808331 | -2.432588 | -3.095907 | -4.087000 | -3.091350 | 0.549331 |
9 | 0.012460 | 0.002050 | 0.002123 | 0.000522 | 1 | 10 | {'max_features': 1, 'n_estimators': 10} | -17.399393 | -33.092368 | -32.873599 | ... | -31.259992 | 14.477833 | 582 | -3.891032 | -3.469017 | -2.679415 | -3.634116 | -2.866854 | -3.308087 | 0.460854 |
10 | 0.019396 | 0.002247 | 0.002984 | 0.001326 | 1 | 11 | {'max_features': 1, 'n_estimators': 11} | -11.739315 | -22.884135 | -53.257725 | ... | -34.255825 | 18.556652 | 612 | -2.524618 | -2.831083 | -3.511232 | -2.794624 | -3.029749 | -2.938261 | 0.328599 |
11 | 0.013650 | 0.002062 | 0.002338 | 0.000655 | 1 | 12 | {'max_features': 1, 'n_estimators': 12} | -20.765519 | -29.219963 | -43.909655 | ... | -35.303845 | 14.835744 | 620 | -3.728071 | -2.489337 | -2.629168 | -2.102650 | -3.578967 | -2.905638 | 0.636286 |
12 | 0.013520 | 0.002860 | 0.002294 | 0.000803 | 1 | 13 | {'max_features': 1, 'n_estimators': 13} | -18.235791 | -41.704494 | -28.564638 | ... | -33.328750 | 14.136039 | 602 | -3.033637 | -2.431979 | -2.363831 | -2.386989 | -2.513514 | -2.545990 | 0.249125 |
13 | 0.014610 | 0.002561 | 0.002329 | 0.000904 | 1 | 14 | {'max_features': 1, 'n_estimators': 14} | -13.594839 | -20.757757 | -55.323856 | ... | -31.949899 | 17.840135 | 589 | -3.856282 | -2.813479 | -2.660921 | -2.461808 | -3.621529 | -3.082804 | 0.552205 |
14 | 0.016239 | 0.003550 | 0.002126 | 0.000099 | 1 | 15 | {'max_features': 1, 'n_estimators': 15} | -15.587993 | -37.151651 | -39.838389 | ... | -37.000578 | 14.092232 | 629 | -2.979421 | -2.745065 | -2.955570 | -2.302839 | -2.104006 | -2.617380 | 0.353338 |
15 | 0.021542 | 0.007437 | 0.004869 | 0.001939 | 1 | 16 | {'max_features': 1, 'n_estimators': 16} | -18.313434 | -37.697324 | -50.224384 | ... | -39.177959 | 13.554225 | 636 | -2.383039 | -2.313836 | -2.564468 | -2.085711 | -2.971439 | -2.463699 | 0.296579 |
16 | 0.031045 | 0.007147 | 0.004051 | 0.001197 | 1 | 17 | {'max_features': 1, 'n_estimators': 17} | -14.477579 | -38.545166 | -47.593614 | ... | -36.993035 | 14.366971 | 628 | -3.384781 | -2.466205 | -2.349847 | -2.080384 | -2.440984 | -2.544440 | 0.441862 |
17 | 0.020159 | 0.003701 | 0.002337 | 0.000288 | 1 | 18 | {'max_features': 1, 'n_estimators': 18} | -13.351092 | -29.703280 | -41.840159 | ... | -32.092468 | 12.389404 | 591 | -2.381153 | -2.327830 | -2.322180 | -2.293605 | -3.171681 | -2.499290 | 0.337384 |
18 | 0.023308 | 0.001765 | 0.003537 | 0.001120 | 1 | 19 | {'max_features': 1, 'n_estimators': 19} | -15.882930 | -31.164088 | -39.008378 | ... | -34.578963 | 14.486152 | 616 | -3.857120 | -2.922748 | -2.321944 | -2.458541 | -3.067876 | -2.925646 | 0.542314 |
19 | 0.024203 | 0.002189 | 0.003143 | 0.000700 | 1 | 20 | {'max_features': 1, 'n_estimators': 20} | -17.061018 | -37.885674 | -44.421256 | ... | -34.235834 | 11.267077 | 611 | -3.016849 | -2.616610 | -2.271411 | -1.724738 | -2.350052 | -2.395932 | 0.424817 |
20 | 0.022982 | 0.002748 | 0.002810 | 0.001135 | 1 | 21 | {'max_features': 1, 'n_estimators': 21} | -17.608000 | -29.764712 | -57.100468 | ... | -34.166448 | 15.051979 | 610 | -2.849463 | -2.687234 | -2.648921 | -2.026325 | -2.457962 | -2.533981 | 0.282744 |
21 | 0.025036 | 0.002274 | 0.003434 | 0.001300 | 1 | 22 | {'max_features': 1, 'n_estimators': 22} | -16.536106 | -26.593730 | -42.356686 | ... | -33.135153 | 14.229106 | 600 | -2.971610 | -2.912313 | -2.576134 | -2.170917 | -3.137634 | -2.753722 | 0.343865 |
22 | 0.027406 | 0.005331 | 0.004077 | 0.001138 | 1 | 23 | {'max_features': 1, 'n_estimators': 23} | -15.404804 | -26.790607 | -47.250104 | ... | -33.781914 | 15.178863 | 605 | -2.764676 | -2.730681 | -2.743073 | -2.136093 | -3.241368 | -2.723178 | 0.350817 |
23 | 0.024246 | 0.004053 | 0.002846 | 0.000299 | 1 | 24 | {'max_features': 1, 'n_estimators': 24} | -13.008028 | -36.532921 | -36.781510 | ... | -34.381056 | 15.122934 | 614 | -3.566207 | -2.849794 | -2.161991 | -2.167744 | -2.510485 | -2.651244 | 0.523361 |
24 | 0.024699 | 0.001386 | 0.003098 | 0.000754 | 1 | 25 | {'max_features': 1, 'n_estimators': 25} | -17.054358 | -37.102535 | -43.966166 | ... | -34.298870 | 14.970548 | 613 | -2.938494 | -2.180039 | -2.440997 | -1.917603 | -2.551470 | -2.405720 | 0.345116 |
25 | 0.027446 | 0.005452 | 0.003953 | 0.001620 | 1 | 26 | {'max_features': 1, 'n_estimators': 26} | -16.987042 | -34.845853 | -51.564609 | ... | -34.968050 | 15.036683 | 618 | -3.129059 | -2.227953 | -3.506445 | -2.467797 | -1.866639 | -2.639579 | 0.597901 |
26 | 0.027815 | 0.003477 | 0.002695 | 0.000157 | 1 | 27 | {'max_features': 1, 'n_estimators': 27} | -14.762697 | -28.462289 | -39.627764 | ... | -29.153360 | 10.051608 | 567 | -3.074188 | -1.793045 | -3.078800 | -1.674099 | -1.989327 | -2.321892 | 0.624303 |
27 | 0.028738 | 0.006258 | 0.002989 | 0.000205 | 1 | 28 | {'max_features': 1, 'n_estimators': 28} | -16.580555 | -38.412260 | -39.808376 | ... | -35.772924 | 13.948740 | 622 | -2.825889 | -2.647394 | -2.635796 | -2.118854 | -2.340015 | -2.513589 | 0.251521 |
28 | 0.027386 | 0.003200 | 0.003024 | 0.000625 | 1 | 29 | {'max_features': 1, 'n_estimators': 29} | -16.863019 | -29.274269 | -45.390146 | ... | -34.144240 | 11.450715 | 609 | -2.444111 | -2.820349 | -2.001195 | -2.113187 | -2.791126 | -2.433993 | 0.336793 |
29 | 0.028264 | 0.004045 | 0.003585 | 0.000949 | 1 | 30 | {'max_features': 1, 'n_estimators': 30} | -13.082333 | -26.423729 | -36.927213 | ... | -34.760669 | 14.209880 | 617 | -2.995265 | -2.269125 | -2.131816 | -2.082101 | -2.156179 | -2.326897 | 0.339759 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
620 | 0.056998 | 0.002920 | 0.002737 | 0.000033 | 13 | 21 | {'max_features': 13, 'n_estimators': 21} | -8.214533 | -14.830912 | -22.534429 | ... | -22.231679 | 14.579162 | 391 | -1.754569 | -2.005698 | -1.794253 | -1.358432 | -1.487365 | -1.680063 | 0.230307 |
621 | 0.060251 | 0.001883 | 0.003669 | 0.001008 | 13 | 22 | {'max_features': 13, 'n_estimators': 22} | -7.563142 | -12.218573 | -24.094603 | ... | -21.770627 | 14.921198 | 345 | -1.771726 | -1.696460 | -1.623146 | -1.287195 | -1.534081 | -1.582522 | 0.167329 |
622 | 0.061005 | 0.001704 | 0.003103 | 0.000323 | 13 | 23 | {'max_features': 13, 'n_estimators': 23} | -8.105958 | -13.928286 | -18.885602 | ... | -20.635471 | 13.060317 | 199 | -1.891957 | -1.411495 | -1.325187 | -1.296137 | -2.003337 | -1.585623 | 0.300092 |
623 | 0.063468 | 0.001463 | 0.002937 | 0.000097 | 13 | 24 | {'max_features': 13, 'n_estimators': 24} | -7.601690 | -15.970749 | -21.354981 | ... | -22.495433 | 13.645004 | 408 | -1.731364 | -1.863511 | -1.403066 | -1.260503 | -1.713268 | -1.594342 | 0.225125 |
624 | 0.066116 | 0.001926 | 0.003166 | 0.000096 | 13 | 25 | {'max_features': 13, 'n_estimators': 25} | -8.235568 | -11.828849 | -21.175247 | ... | -21.176898 | 12.580003 | 261 | -1.836137 | -1.731191 | -1.778248 | -1.423419 | -1.764215 | -1.706642 | 0.145622 |
625 | 0.071170 | 0.005523 | 0.003244 | 0.000451 | 13 | 26 | {'max_features': 13, 'n_estimators': 26} | -8.400417 | -12.670221 | -23.192632 | ... | -21.533916 | 12.765942 | 310 | -2.220595 | -2.052685 | -1.643289 | -1.201552 | -1.447342 | -1.713092 | 0.376842 |
626 | 0.070533 | 0.001690 | 0.003245 | 0.000122 | 13 | 27 | {'max_features': 13, 'n_estimators': 27} | -8.244153 | -11.947600 | -19.265464 | ... | -21.861377 | 13.279718 | 356 | -1.534273 | -1.521596 | -1.813661 | -1.328404 | -2.152099 | -1.670007 | 0.286423 |
627 | 0.074091 | 0.003326 | 0.003271 | 0.000304 | 13 | 28 | {'max_features': 13, 'n_estimators': 28} | -8.637694 | -14.060935 | -25.394871 | ... | -22.784116 | 13.353693 | 424 | -1.669801 | -1.748181 | -1.623929 | -1.392020 | -1.375407 | -1.561868 | 0.150884 |
628 | 0.076364 | 0.002998 | 0.003175 | 0.000052 | 13 | 29 | {'max_features': 13, 'n_estimators': 29} | -8.121697 | -14.794659 | -19.478598 | ... | -21.329425 | 12.810383 | 278 | -1.664764 | -1.473133 | -1.515179 | -1.250855 | -1.580039 | -1.496794 | 0.138944 |
629 | 0.080446 | 0.002324 | 0.003275 | 0.000146 | 13 | 30 | {'max_features': 13, 'n_estimators': 30} | -8.327223 | -13.938036 | -18.504142 | ... | -21.177455 | 12.801173 | 262 | -2.062341 | -1.929879 | -1.695927 | -1.258211 | -2.019631 | -1.793198 | 0.295995 |
630 | 0.082867 | 0.004171 | 0.003718 | 0.000470 | 13 | 31 | {'max_features': 13, 'n_estimators': 31} | -7.962718 | -13.002553 | -23.942338 | ... | -22.044556 | 12.910633 | 375 | -1.644527 | -1.880164 | -1.529467 | -1.261535 | -1.878561 | -1.638851 | 0.232404 |
631 | 0.094475 | 0.003793 | 0.003925 | 0.000547 | 13 | 32 | {'max_features': 13, 'n_estimators': 32} | -8.576633 | -12.183874 | -21.498011 | ... | -21.738878 | 12.744928 | 341 | -2.025376 | -1.560575 | -1.220283 | -1.242486 | -2.247338 | -1.659212 | 0.413767 |
632 | 0.086008 | 0.003725 | 0.003554 | 0.000156 | 13 | 33 | {'max_features': 13, 'n_estimators': 33} | -8.437380 | -12.298403 | -20.587260 | ... | -22.134070 | 13.272654 | 382 | -1.808651 | -1.573553 | -1.464411 | -1.280631 | -1.754965 | -1.576442 | 0.192798 |
633 | 0.088431 | 0.003652 | 0.003595 | 0.000137 | 13 | 34 | {'max_features': 13, 'n_estimators': 34} | -7.789550 | -13.488061 | -21.859775 | ... | -21.861308 | 12.894215 | 355 | -1.679846 | -1.667615 | -1.791897 | -1.321465 | -1.660280 | -1.624221 | 0.158779 |
634 | 0.091344 | 0.002211 | 0.003711 | 0.000252 | 13 | 35 | {'max_features': 13, 'n_estimators': 35} | -8.359150 | -13.595726 | -22.272514 | ... | -22.603360 | 14.050540 | 412 | -1.726925 | -1.674344 | -1.427541 | -1.330239 | -1.826045 | -1.597019 | 0.187191 |
635 | 0.093527 | 0.001836 | 0.003673 | 0.000136 | 13 | 36 | {'max_features': 13, 'n_estimators': 36} | -8.207619 | -11.886666 | -20.280017 | ... | -21.392191 | 13.865106 | 291 | -1.981380 | -1.686528 | -1.272086 | -1.248820 | -1.550914 | -1.547946 | 0.273003 |
636 | 0.097891 | 0.003279 | 0.003842 | 0.000363 | 13 | 37 | {'max_features': 13, 'n_estimators': 37} | -8.413852 | -13.636418 | -18.461581 | ... | -21.609566 | 14.287871 | 318 | -1.748631 | -1.972837 | -1.519343 | -1.374962 | -1.448768 | -1.612908 | 0.219219 |
637 | 0.104330 | 0.014696 | 0.003811 | 0.000158 | 13 | 38 | {'max_features': 13, 'n_estimators': 38} | -8.398189 | -12.856086 | -20.030461 | ... | -21.481719 | 13.813047 | 303 | -1.703596 | -1.638732 | -1.540393 | -1.363200 | -1.611878 | -1.571560 | 0.116582 |
638 | 0.106177 | 0.002897 | 0.004407 | 0.000388 | 13 | 39 | {'max_features': 13, 'n_estimators': 39} | -8.954183 | -13.575304 | -21.071506 | ... | -21.889710 | 12.731358 | 358 | -1.499969 | -1.430730 | -1.587401 | -1.257869 | -1.530888 | -1.461371 | 0.113629 |
639 | 0.139257 | 0.021265 | 0.005298 | 0.000931 | 13 | 40 | {'max_features': 13, 'n_estimators': 40} | -8.828245 | -15.956237 | -19.494517 | ... | -22.224991 | 13.058084 | 389 | -1.680124 | -1.453758 | -1.493468 | -1.233556 | -1.403723 | -1.452926 | 0.144089 |
640 | 0.122522 | 0.019294 | 0.004169 | 0.000311 | 13 | 41 | {'max_features': 13, 'n_estimators': 41} | -8.139146 | -15.458479 | -21.882525 | ... | -21.569182 | 13.141029 | 313 | -1.855412 | -1.641777 | -1.475673 | -1.209853 | -1.872287 | -1.611001 | 0.248269 |
641 | 0.113846 | 0.005494 | 0.004137 | 0.000280 | 13 | 42 | {'max_features': 13, 'n_estimators': 42} | -8.037270 | -12.633225 | -26.076467 | ... | -22.908325 | 14.112196 | 430 | -1.897440 | -1.608781 | -1.485819 | -1.123077 | -1.209502 | -1.464924 | 0.279392 |
642 | 0.110486 | 0.002739 | 0.004201 | 0.000132 | 13 | 43 | {'max_features': 13, 'n_estimators': 43} | -7.738024 | -11.768356 | -23.203926 | ... | -22.395667 | 13.987355 | 399 | -1.503985 | -1.495877 | -1.375177 | -1.422425 | -1.366992 | -1.432891 | 0.057973 |
643 | 0.111097 | 0.002981 | 0.004338 | 0.000172 | 13 | 44 | {'max_features': 13, 'n_estimators': 44} | -7.885614 | -13.315915 | -20.279256 | ... | -21.864941 | 13.759772 | 357 | -1.731070 | -1.665641 | -1.559994 | -1.225409 | -1.964766 | -1.629376 | 0.241721 |
644 | 0.114695 | 0.002069 | 0.004231 | 0.000199 | 13 | 45 | {'max_features': 13, 'n_estimators': 45} | -8.522087 | -13.834266 | -19.959325 | ... | -21.711338 | 13.139111 | 336 | -1.752770 | -1.384300 | -1.472556 | -1.272192 | -1.593057 | -1.494975 | 0.166411 |
645 | 0.124314 | 0.010494 | 0.004609 | 0.000780 | 13 | 46 | {'max_features': 13, 'n_estimators': 46} | -9.244762 | -13.533141 | -20.165793 | ... | -21.277579 | 11.760141 | 275 | -1.890338 | -1.523985 | -1.429138 | -1.260989 | -1.434726 | -1.507835 | 0.209304 |
646 | 0.125519 | 0.003614 | 0.004486 | 0.000513 | 13 | 47 | {'max_features': 13, 'n_estimators': 47} | -7.523758 | -14.241435 | -23.102508 | ... | -22.140001 | 13.534116 | 383 | -1.688069 | -1.446108 | -1.646988 | -1.208258 | -1.689223 | -1.535729 | 0.186770 |
647 | 0.134769 | 0.006030 | 0.004768 | 0.000579 | 13 | 48 | {'max_features': 13, 'n_estimators': 48} | -7.969312 | -12.855202 | -22.178862 | ... | -21.676694 | 13.110346 | 331 | -1.810569 | -1.547392 | -1.517920 | -1.351242 | -1.432992 | -1.532023 | 0.155249 |
648 | 0.130817 | 0.008154 | 0.005051 | 0.000525 | 13 | 49 | {'max_features': 13, 'n_estimators': 49} | -8.560279 | -12.427586 | -21.365338 | ... | -21.529505 | 12.947968 | 307 | -1.643209 | -1.597696 | -1.413012 | -1.193954 | -1.778295 | -1.525233 | 0.202755 |
649 | 0.137780 | 0.003193 | 0.004835 | 0.000229 | 13 | 50 | {'max_features': 13, 'n_estimators': 50} | -8.213808 | -13.037796 | -18.630593 | ... | -21.511986 | 13.648508 | 306 | -1.693073 | -1.399460 | -1.560297 | -1.264475 | -1.516853 | -1.486832 | 0.145540 |
650 rows × 22 columns
from mpl_toolkits import mplot3d
n_estimators, max_features = np.arange(1, 51), np.arange(1, 14)
X, Y = np.meshgrid(max_features, n_estimators)
Z = scores = np.sqrt(-rf_search_results[['mean_test_score']].values).reshape(50, 13)
fig = plt.figure(figsize=(15, 10))
ax = plt.axes(projection='3d')
ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
cmap='Greys', edgecolor='none')
ax.set_xlabel('n_estimators')
ax.set_ylabel('max_features')
ax.set_zlabel('rmse');
ax.view_init(60, 35)
rf_search.best_params_
{'max_features': 8, 'n_estimators': 19}
np.sqrt(-rf_search.best_score_)
4.198781714064752