mirror of
https://github.com/donnemartin/data-science-ipython-notebooks.git
synced 2024-03-22 13:30:56 +08:00
1055 lines
256 KiB
Plaintext
1055 lines
256 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Validation and Model Selection"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Credits: Forked from [PyCon 2015 Scikit-learn Tutorial](https://github.com/jakevdp/sklearn_pycon2015) by Jake VanderPlas\n",
|
||
|
"\n",
|
||
|
"In this section, we'll look at *model evaluation* and the tuning of *hyperparameters*, which are parameters that define the model."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from __future__ import print_function, division\n",
|
||
|
"\n",
|
||
|
"%matplotlib inline\n",
|
||
|
"import numpy as np\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"\n",
|
||
|
"# Use seaborn for plotting defaults\n",
|
||
|
"import seaborn as sns; sns.set()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Validating Models\n",
|
||
|
"\n",
|
||
|
"One of the most important pieces of machine learning is **model validation**: that is, checking how well your model fits a given dataset. But there are some pitfalls you need to watch out for.\n",
|
||
|
"\n",
|
||
|
"Consider the digits example we've been looking at previously. How might we check how well our model fits the data?"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.datasets import load_digits\n",
|
||
|
"digits = load_digits()\n",
|
||
|
"X = digits.data\n",
|
||
|
"y = digits.target"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Let's fit a K-neighbors classifier"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
|
||
|
" metric_params=None, n_neighbors=1, p=2, weights='uniform')"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 3,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from sklearn.neighbors import KNeighborsClassifier\n",
|
||
|
"knn = KNeighborsClassifier(n_neighbors=1)\n",
|
||
|
"knn.fit(X, y)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Now we'll use this classifier to *predict* labels for the data"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"y_pred = knn.predict(X)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Finally, we can check how well our prediction did:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"1797 / 1797 correct\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(\"{0} / {1} correct\".format(np.sum(y == y_pred), len(y)))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"It seems we have a perfect classifier!\n",
|
||
|
"\n",
|
||
|
"**Question: what's wrong with this?**"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Validation Sets\n",
|
||
|
"\n",
|
||
|
"Above we made the mistake of testing our data on the same set of data that was used for training. **This is not generally a good idea**. If we optimize our estimator this way, we will tend to **over-fit** the data: that is, we learn the noise.\n",
|
||
|
"\n",
|
||
|
"A better way to test a model is to use a hold-out set which doesn't enter the training. We've seen this before using scikit-learn's train/test split utility:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"((1347, 64), (450, 64))"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from sklearn.cross_validation import train_test_split\n",
|
||
|
"X_train, X_test, y_train, y_test = train_test_split(X, y)\n",
|
||
|
"X_train.shape, X_test.shape"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Now we train on the training data, and validate on the test data:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"438 / 450 correct\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"knn = KNeighborsClassifier(n_neighbors=1)\n",
|
||
|
"knn.fit(X_train, y_train)\n",
|
||
|
"y_pred = knn.predict(X_test)\n",
|
||
|
"print(\"{0} / {1} correct\".format(np.sum(y_test == y_pred), len(y_test)))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"This gives us a more reliable estimate of how our model is doing.\n",
|
||
|
"\n",
|
||
|
"The metric we're using here, comparing the number of matches to the total number of samples, is known as the **accuracy score**, and can be computed using the following routine:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"0.97333333333333338"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 8,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from sklearn.metrics import accuracy_score\n",
|
||
|
"accuracy_score(y_test, y_pred)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"This can also be computed directly from the ``model.score`` method:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"0.97333333333333338"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 9,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"knn.score(X_test, y_test)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Using this, we can ask how this changes as we change the model parameters, in this case the number of neighbors:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"1 0.973333333333\n",
|
||
|
"5 0.982222222222\n",
|
||
|
"10 0.971111111111\n",
|
||
|
"20 0.955555555556\n",
|
||
|
"30 0.96\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"for n_neighbors in [1, 5, 10, 20, 30]:\n",
|
||
|
" knn = KNeighborsClassifier(n_neighbors)\n",
|
||
|
" knn.fit(X_train, y_train)\n",
|
||
|
" print(n_neighbors, knn.score(X_test, y_test))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"We see that in this case, a small number of neighbors seems to be the best option."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Cross-Validation\n",
|
||
|
"\n",
|
||
|
"One problem with validation sets is that you \"lose\" some of the data. Above, we've only used 3/4 of the data for the training, and used 1/4 for the validation. Another option is to use **2-fold cross-validation**, where we split the sample in half and perform the validation twice:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"((898, 64), (899, 64))"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 11,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"X1, X2, y1, y2 = train_test_split(X, y, test_size=0.5, random_state=0)\n",
|
||
|
"X1.shape, X2.shape"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"0.983296213808\n",
|
||
|
"0.982202447164\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(KNeighborsClassifier(1).fit(X2, y2).score(X1, y1))\n",
|
||
|
"print(KNeighborsClassifier(1).fit(X1, y1).score(X2, y2))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Thus a two-fold cross-validation gives us two estimates of the score for that parameter.\n",
|
||
|
"\n",
|
||
|
"Because this is a bit of a pain to do by hand, scikit-learn has a utility routine to help:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"0.97614938602520218"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 13,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from sklearn.cross_validation import cross_val_score\n",
|
||
|
"cv = cross_val_score(KNeighborsClassifier(1), X, y, cv=10)\n",
|
||
|
"cv.mean()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### K-fold Cross-Validation\n",
|
||
|
"\n",
|
||
|
"Here we've used 2-fold cross-validation. This is just one specialization of $K$-fold cross-validation, where we split the data into $K$ chunks and perform $K$ fits, where each chunk gets a turn as the validation set.\n",
|
||
|
"We can do this by changing the ``cv`` parameter above. Let's do 10-fold cross-validation:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 14,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"array([ 0.93513514, 0.99453552, 0.97237569, 0.98888889, 0.96089385,\n",
|
||
|
" 0.98882682, 0.99441341, 0.98876404, 0.97175141, 0.96590909])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 14,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"cross_val_score(KNeighborsClassifier(1), X, y, cv=10)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"This gives us an even better idea of how well our model is doing."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Overfitting, Underfitting and Model Selection"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Now that we've gone over the basics of validation, and cross-validation, it's time to go into even more depth regarding model selection.\n",
|
||
|
"\n",
|
||
|
"The issues associated with validation and \n",
|
||
|
"cross-validation are some of the most important\n",
|
||
|
"aspects of the practice of machine learning. Selecting the optimal model\n",
|
||
|
"for your data is vital, and is a piece of the problem that is not often\n",
|
||
|
"appreciated by machine learning practitioners.\n",
|
||
|
"\n",
|
||
|
"Of core importance is the following question:\n",
|
||
|
"\n",
|
||
|
"**If our estimator is underperforming, how should we move forward?**\n",
|
||
|
"\n",
|
||
|
"- Use simpler or more complicated model?\n",
|
||
|
"- Add more features to each observed data point?\n",
|
||
|
"- Add more training samples?\n",
|
||
|
"\n",
|
||
|
"The answer is often counter-intuitive. In particular, **Sometimes using a\n",
|
||
|
"more complicated model will give _worse_ results.** Also, **Sometimes adding\n",
|
||
|
"training data will not improve your results.** The ability to determine\n",
|
||
|
"what steps will improve your model is what separates the successful machine\n",
|
||
|
"learning practitioners from the unsuccessful."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Illustration of the Bias-Variance Tradeoff\n",
|
||
|
"\n",
|
||
|
"For this section, we'll work with a simple 1D regression problem. This will help us to\n",
|
||
|
"easily visualize the data and the model, and the results generalize easily to higher-dimensional\n",
|
||
|
"datasets. We'll explore a simple **linear regression** problem.\n",
|
||
|
"This can be accomplished within scikit-learn with the `sklearn.linear_model` module.\n",
|
||
|
"\n",
|
||
|
"We'll create a simple nonlinear function that we'd like to fit"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 15,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def test_func(x, err=0.5):\n",
|
||
|
" y = 10 - 1. / (x + 0.1)\n",
|
||
|
" if err > 0:\n",
|
||
|
" y = np.random.normal(y, err)\n",
|
||
|
" return y"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Now let's create a realization of this dataset:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 16,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def make_data(N=40, error=1.0, random_seed=1):\n",
|
||
|
" # randomly sample the data\n",
|
||
|
" np.random.seed(1)\n",
|
||
|
" X = np.random.random(N)[:, np.newaxis]\n",
|
||
|
" y = test_func(X.ravel(), error)\n",
|
||
|
" \n",
|
||
|
" return X, y"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 17,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAFVCAYAAAA+OJwpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAF2xJREFUeJzt3X+Q3Hddx/HnlaZdLNfCwCKoKCLyUQcZRrFwAZJ2kFgI\nLQmQyRpbuA4/W+jUwnDQooyDMtMpAkIECgVykA5eptiLZIIxiOUinQNGBCkjvLE46IyjEiolEbsN\nbdc/dq/ZXHN7t3v747Pf7/Mx0+nu3X1vP+/udV/fz4/v5zvRaDSQJEmjdcaoGyBJkgxkSZKyYCBL\nkpQBA1mSpAwYyJIkZcBAliQpA2eu5YdSSs8Ero+IC1NKTwfeD9wP3Au8PCK+P8A2SpJUeKv2kFNK\nM8BNwNmtL/0Z8IaIuBC4FXjL4JonSVI5rGXI+k7gJcBE63ktIr7RerwBuGcQDZMkqUxWDeSIuBW4\nr+35fwGklDYCrwfeO7DWSZJUEmuaQ14upbQTuA54YUTctdrPNxqNxsTExGo/JklSUXQdel0Hckrp\nUuA1wAUR8cM1tWpigqNHj3f7UoVQrU6Wtnawfusvb/1lrh2sv1qd7PqYbi57aqSUzgDeBzwCuDWl\ndFtK6Y+6flVJknSKNfWQI+J7wMbW00cPrDWSJJWUG4NIkpQBA1mSpAwYyJIkZcBAliQpAwayJEkZ\nMJAlScqAgSxJUgYMZEmSMmAgS5KUAQNZkqQMGMiSJGXAQJYkKQMGsiRJGTCQJUnKgIEsSVIGDGRJ\nkjJgIEuSlAEDWZKkDJw56gZIUj/U63Xm5o4AUKttolKpjLhFUncMZEljr16vs3PnPIuLlwMwP7+H\nffu2G8oaKw5ZSxp7c3NHWmG8AdjA4uL0g71laVwYyJIkZcBAljT2arVNTE3tAU4AJ5iamqVW2zTq\nZkldcQ5Z0tirVCrs27edubkDANRqzh9r/BjIkgqhUqkwPb1l1M2QeuaQtSRJGTCQJUnKgIEsSVIG\nDGRJkjJgIEuSlAEDWZKkDBjIkiRlwECWJCkDbgwiSSPiLSPVzkCWpBHwlpFaziFrSRoBbxmp5dbU\nQ04pPRO4PiIuTCk9GZgFHgC+Cbw+IhqDa6IkScW3ag85pTQD3ASc3frSe4DrImITMAG8eHDNk6Ri\n8paRWm4tPeQ7gZcAe1vPfyMilsZV/hrYAuwfQNskqbC8ZaSWWzWQI+LWlNIT27400fb4f4Hz+t0o\nSSqD3G8Z6Srw4epllfUDbY8ngbvXclC1OtnDSxVDmWsH67f+8tY/zrXX63V27LiFhYXLADh4cC+H\nDu3qKpTHuf5R6CWQv5ZS2hwRC8ALgM+v5aCjR4/38FLjr1qdLG3tYP3WX976x7322dnDrTDeAMDC\nwqXs3n1gzT36ca9/vXo5GekmkJdWUr8JuCmldBbwz8Cnu35VSZJ0ijUFckR8D9jYevwvwAWDa5Ik\nadRqtU3Mz+9hcXEaoLUKfPtoG1Vw7tQlSXoIV4EPn4EsSQXR71XRua8CLxoDWZIKwL2xx597WUtS\nAbg39vgzkCVJyoCBLGls1Ot1ZmcPMzt7mHq9PurmZMW9scefc8iSxoJzpJ25Knr8GchSiY3TXsWn\nzpHSmiNd+85RZeCq6PFmIEslZY9TyotzyFJJjduqXOdIVXT2kCWNBedITxqnqQatnYEsldQ47lXs\nHKlTDUVmIEslZY9zPLm4rbgMZKnE7HGOqzpwuPX4uaNsiPrIRV2SCqEsm4Zs23Y+5577PmALsIVz\nz30/27adP+pmqQ8MZGlIyhIYo7A0rzozcwkzM5ewc+d8Yf8b79//FY4dewtLq+OPHZth//6vjLpZ\n6gMDWRqCMgXGKOR2CVe9XufGGw968qWuGMjSEOQWGBqcpZOvK67YMpCTL6/HLi4DWdLYyymkBn3y\ntbQ6/oYbDnDDDQe85KlAXGUtDcE4XvM7Tsp2CZer44tpotFoDON1GkePHh/G62SnWp2krLWD9bfX\nX8bdlcr4/p/cuGMaaJ58lbEXW8b3vl21OjnR7TH2kKUhsVczGLmd6Cz11g8e/BzHj9cL31tX/xjI\n0pDlFiDjLNdtJCuVCq973dZS9xDVPRd1SUPk5U/9VdbV617TXkwGsjREZQ0Q9Y8ndcVlIEsaW4O6\n3CnnHqgndcXlHLJKI4e5Wy9/6q9BXO6U67y0is/LngbMpf951L/8Q3ZqajgfsqerP4cTg2HJ5f3v\nxuzsYWZmLmHp9oZwghtu6P72hoOqfVSXVXX7dzuO730/edmTtIKc7iHr5U9aj1FsguKowXA4h6yB\nWJqDu/HGg9nNwUmd5LQN50qWTuqmp7cMJRSdtx4Oe8jqu1END3fi3K3WqmzbcCofziEPWBnnUfo1\nB9dvo5i7LeP7367M9Rep9l7mrYtUfy+cQ5Y6cO5W6o2jBsNhIKvvHB6WiscT2sEzkNV37WfTk5MV\ntm71bFp5KNMlZxo/BrIGYulsuuzzSMqHl+4odz0FckrpDOCjwFOAB4BXR0T0s2GS1E85XYsunU6v\n1yFvAc6JiOcA7wDe2b8mSdLa5bzvtNSNXgP5HuC8lNIEcB7NK+glaai6ufPROGz4oXLrdQ75dqAC\nfBt4NHBx31okFZQLivqvm2FoL91R7noN5Bng9oh4W0rp54C/Syk9NSJW7ClXq5M9vtT4K3PtYP3V\n6iT1ep0dO25hYeEyAA4e3MuhQ7tKEQiDfP8nJx/6329ystLhNSd585tfOrD2LOfffrnr71avgXwO\ncKz1+Ic0T08f1umAsq60LfsqY+tv1j87e7gVxs2e3MLCpezeXfwFRYN+/7duPZ+pqVOved+6dXsW\nf3P+7Vt/t3oN5HcBe1JKf0/zE+baiLinx98lDZzDxcXkMLSKpKdAjoi7Abde0ljI4fpTdy8bHHeQ\nUlG4MYgKL4frT+3JSVqNgSwNybj25Bzul4aj1+uQpbHh9aerW2lzjW6u85W0PvaQVXgOF3fWaY49\nh+F+qSwMZJXCuA4XD4OhK+XBIWtJK3K4Xxoee8jSEOW4QKrTJVkO90vDYyBLywwqNHO4Hvp0Vgtd\nh/ul4TCQpTb9CM32QL/qqq0Pfj3nuVpDVxo955ClNqeG5oZWaB5Z8/HLLxO66KJPeZmQpDUxkKU+\nWh7oCwuXnjL87QIpSStxyFpqM8g9p10gJamTiUajMYzXaZT1Nlzegmz86l/Poq6Tc9DTAGzefDN7\n915c2uAdx/e/X8pcO1h/tTo50e0x9pClZdazwGl5L/iqq3Zx/PhP+tk8SQVlIEt91h7olUrFQJa0\nJi7qkiQpAwayJEkZMJAlScqAgSxJUgYMZEmSMmAgS5KUAQNZkqQMGMiSJGXAQJYkKQPu1KXCWs+e\n1JI0bAayCunkTR4uB2B+fg/79nl3JUn5cshahbT8vsSLi9MP9pYlKUcGsiRJGTCQVUi12iampvYA\nJ4ATTE3NUqttGnWzJGlFziGrkJbfl7hWc/5YUt4MZBVW+32JJSl3DllLkpQBA1mSpAwYyJIkZcBA\nliQpAwayJEkZ6HmVdUrpWuBimlsh/XlEfKJvrZIkqWR66iGnlC4ApiJiI3AB8KQ+tkmSpNLptYe8\nBbgjpbQfOBd4c/+aJElS+Uw0Go2uD0op3QQ8AXgRzd7xZyLiVzoc0v2LSJI0via6PaDXHvIPgG9F\nxH3Ad1JK9ZTSYyLiBysdcPTo8R5farxVq5OlqH2lew+Xpf6VWH956y9z7WD91epk18f0GshfBK4G\n3pNS+hngHOCuHn+Xxpz3Hpak9etpUVdEHAS+llL6CvAZ4MqIcFi6pLz3sCStX8+XPUXEW/rZEOVj\npeFnSdLguDGITrE0/DwzcwkzM5ewc+c89Xq94zHee1iS1s/bL+oUpw4/0xp+PtDxNoad7j1cr9eZ\nnT3c+rq9bUlaiYGsvjjdvYfr9To7dtzCwsJlgIu9JKkTh6x1in4OP8/NHWmFsYu9JGk19pB1ik7D\nz5KkwTGQ9RCnG37uRa2
|
||
|
"text/plain": [
|
||
|
"<matplotlib.figure.Figure at 0x10935d710>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"X, y = make_data(40, error=1)\n",
|
||
|
"plt.scatter(X.ravel(), y);"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Now say we want to perform a regression on this data. Let's use the built-in linear regression function to compute a fit:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 18,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAFeCAYAAABU/2zqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl4lNdh7/GvVoZF7GLVsHjh2BizSjbCZjNI8RI7Jolj\nxbVbcpM2S5ubJmmcJl1unyZt09w2aeq2SZomJnXaQJ0GO37oJRJghO3ItsS+Hhu8MBIgiVUsGm0z\n94/3FQwChDTM8s7M7/M8fqxllnMYmN+7nPm9WeFwGBEREUmu7GQPQERERBTIIiIinqBAFhER8QAF\nsoiIiAcokEVERDxAgSwiIuIBuckegIgknjGmGHjeWjs12WO5FmPM/wIetdY+cpXfTQf+I+JHOcAM\n4MPW2heMMd8EHgPagFeAL1lr2xIwbJGoKZBFxFOMMSOBvwaeBDZd7TbW2n3AnIj7/D2wyw3jTwAP\nAPOsteeMMd8Fvgl8Je6DF7kBCmRJK8aYJcDfAA3AHcAF4P8A/xswwH9ba7/k3vZh4E+AfPd2f2St\nfd0YMxb4ITAGGAe8D3zMWttsjHkPeBZYBkwC1lhrv3qVcXwW+DTQDgSBT1tr9xtjFgHPAF3A6zjB\nsQSYCjxjrb0zYh7PWGvv7MN4XgdmAl8D6tzHnwTkAauttX8TMaY/BM4Ae3v5M1wAfAsYDISAv7DW\nrjPGrAQ+CQxyH+OnwKfc709ba5cZY/4MqAA6gbeAP7DWNhpjNgMngNuA77tz+Iy19qGrDOExnNfv\nj4Cr/b7neBcCH8HZQwaYC7xorT3nfv8CsAoFsnicziFLOioGvmGtvR1oxAmqB3HeqH/fGDPOGHMr\n8FfAA9bauTjh+UtjzCDgceA1a+0Ca+1NOGH9lPvYYWCwtXYRsAD4vDFmcuSTG2NygO8CH7DW3gX8\nK3CPMWYA8AvgD93nrAEmu4/Zm+uNZ7e1drq19kXgOeAn1tpi4G6gzBjzmDFmNs6GyUJ3TOev9kTG\nmBHAT4AnrbXzgA8B3zfG+N2bTAcWW2vvA7Iivl/m7pneDxRba2cBe3CCsHucJ621d1hr/8la+9I1\nwhhr7Q+ttd/A2ZDpi78Dvh4RwG8CjxhjRrqvxceB8X18LJGkUSBLOnrXWrvT/foQsMla22mtPQG0\nAKOAMpw36U3GmO3Az3D2Wm+21v4j8Lox5kvGmO/j7HkNjnj8FwGstUeAJmBk5JNba7uA54EaY8wz\nOHuTPwHuBILW2pfd2z3n/q5XfRjPKwDGmMHAYuAb7pxqgCJgFnAf8GtrbZN7nx9e4+lK3T+XF93H\nWIezl3wnTqjuigg+enx/P87GQKv7/T8Cy4wxeZHjjCV3b36UtfY/u3/m/rm+ALwMbAZ24BypEPE0\nHbKWdNRz8U7nVW6TDWy01lZ0/8AYMwmoN8b8LVAC/BjnHGYuzt5gt9aIr8M9fgeAtfYpd+FRGfBV\nnEO9X7/KbbuDoufj5EeM63rj6Q7EHPf/pdbaoHvf0e54f4/LN8C7eo7ZlQ3st9bOj3j+iThHGp6M\neK6ez91936we30eOted9Y+FxnEPnFxljhgM/tdZ+0/3+A8DBODy3SExpD1kyURgn2MqNMQbAGHM/\nzp6UDygH/sFa+x9AM06o5lzjsa5gjBltjDmMc4j2e8Cf4ZzjtUCrMeaD7u0eAgrduzUDk4wxhcaY\nLODRiIfs03istS0455O/7D7+MJy90keAKne+E92br7zG8N8AbnXPdWOMmQkcoG+HfH8NfMI97A/O\neftqa233RscVGy4xsBjY2ONndwMvGGNyjTH5wNM4R0BEPE2BLOmo5znZK87Ruqt0fw9YbYzZAXwD\neNhaewH4S+DvjDGv4yxA+gVwS1+f3Fp7HGdV70ZjTB3OIrNPuYeyPwx81RizDfgo0BExnh/iLMqq\nAY5EjLs/43kCmG+M2YUTrj+31v7cWrsHJ5g2GmNqcQ55X+3PpRlngdS33T+XnwFPWWsD7u0j79Pz\n+x8DG4A3jTH7gNnAb/W4PQDGmEeMMeuuMYdrPT7GmO3GmLkRP7oFeK/HHH4NrAd24pzHfhPnnL6I\np2Xp8osiyWOMaQWMtfZwssciIsnVp3PIxpi7gW9Za5e6qzX/EeccVBvw2xELRUSkf7RFLCJAHw5Z\nG2OeBn4EDHB/9A84ny1cCvwSZ8GKiETBWjtIe8ciAn07h3wQ57xX94KMCmvtLvfrPC5fcSoiIiJR\nuG4gW2t/ScTHRqy1x+Di5/9+Hy2WEBERuWFRfQ7ZGPM4zmcqH3TLFnoVDofDWVnx+MSDiIiIJ/U7\n9PodyMaYJ3E+LrLEWnuqT6PKyqK5+Wx/nyotFBYWZOzcQfPX/DN3/pk8d9D8CwsL+n2f/nwOOWyM\nyQa+BwzB6f192RjzF/1+VhEREblMn/aQrbXv4RTpg9MDLCIiIjGkpi4REREPUCCLiIh4gAJZRETE\nAxTIIiIiHqBAFhER8QAFsoiIiAcokEVERDxAgSwiIuIBCmQREREPiOriEiIiIl7T0dnF6/sa2bi1\nnkEDcnn6ibnJHlK/KJBFRCSlnTnfzubtDby8rZ6WCx1kZ2WxbF5RsofVbwpkERFJSYGmc1TVBnh9\n3zE6u8IMGpDLA3dPYtm8IkYO9SV7eP2mQBYRkZQRCofZdegEVbUB9r/vXAF47IiBLC/2c8+d4/Dl\np26spe7IRUQkY7S1d/Hq7qNsqAvQeKoVgNsnj6CsxM/Mm0eRnZWV5BHeOAWyiIh41smWIBu31lO9\n4wgX2jrJzcni3jvHU1bixz9mSLKHF1MKZBER8ZxDR85QVRug7kAzoXCYoYPy+NC9U1kyZyLDBucn\ne3hxoUAWERFP6AqF2GqbqaoLcKihBYCiwiGUlRQxf/pY8nJzkjzC+FIgi4hIUl0IdrBl51E2bg1w\noqUNgFk3j6K8xM9tk0eQlQbnh/tCgSwiaSEYDLJ69RYAKioW4fOl3sdeMk3jqQtsqK3n1d1Haevo\nIj8vm/vmTmR5sZ9xIwcle3gJp0AWkZQXDAZ5/PG11NR8AoC1a59lzZoVCmUPCofDHDh8mqraADsP\nHicMjCgYwCP3TGHR7AkM9uUle4hJo0AWkZS3evUWN4ydN/OampWsXv0SK1eWJ3dgclFHZ4g39jVS\nVRcg0HQOgJsmDKW8xM/caYXk5ujSCgpkERGJmxa31nLT9gZazreTnZVFyW1jKC/xc/PEYckenqco\nkEUk5VVULGLt2mepqVkJQGnpKioqViR3UBnuvaMtrKk8wOt7G+nsCjFwQC733z2JZXOLGDVMpxKu\nRoEsIinP5/OxZs0KVq9+CYCKCp0/ToZQOMzuQyeoqguw7z2n1nLMiIGUpUGtZSLoT0dE0oLP59M5\n4yRpa+/iN3uOUlVXz7GTFwCYectolsyawMxb0qPWMhEUyCIiEpWTLUE2bqtny44jnA86tZb33DmO\nsmI/82ZMoLn5bLKHmFIUyCIi0i/vHGmhsvbwxVrLgkF5PHLPFJbOmciwIQOSPbyUpUAWEZHr6gqF\n2PbWcapqAxxsOANAUeFgyor9zL8j/WstE0GBLCIi13Sp1rKeEy1BAGa6tZa3Z1CtZSIokEVE5AqN\npy6woc6ttWx3ai2Xzp3I8nlFjB81ONnDS0sKZBERAZxaS3v4NJU9ai0fXjCFRbMmMGRg5tZaJoIC\nWUQkw3V0hnhzfyNVtQEOu7WWU8c7tZbzjGotE0WBLCKSoVouOLWWL29r4Mz5drKyoLi71nLCUJ0f\nTjAFsohIkiTrkpH1zeeoqg1Qc7HWMocP3OVn2bwiRg8bmJAxyJUUyCIiSZDoS0aGwmH2vHOCqtoA\ne7trLYcPZHlxEffcOZ6BAxQHyaZXQEQkCRJ1yci29i5+s/cYG+oCHD3h1FreNmk4ZSV+Zt08muxs\nHZb2ij4FsjHmbuBb1tqlxphbgFVACNgD/L61Nhy/IYqISH+dOtvGxq31VO9o4Hywk5zsLO6ZMY6y\nEj+TxhYke3hyFdcNZGP
|
||
|
"text/plain": [
|
||
|
"<matplotlib.figure.Figure at 0x1096035c0>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"X_test = np.linspace(-0.1, 1.1, 500)[:, None]\n",
|
||
|
"\n",
|
||
|
"from sklearn.linear_model import LinearRegression\n",
|
||
|
"from sklearn.metrics import mean_squared_error\n",
|
||
|
"model = LinearRegression()\n",
|
||
|
"model.fit(X, y)\n",
|
||
|
"y_test = model.predict(X_test)\n",
|
||
|
"\n",
|
||
|
"plt.scatter(X.ravel(), y)\n",
|
||
|
"plt.plot(X_test.ravel(), y_test)\n",
|
||
|
"plt.title(\"mean squared error: {0:.3g}\".format(mean_squared_error(model.predict(X), y)));"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"We have fit a straight line to the data, but clearly this model is not a good choice. We say that this model is **biased**, or that it **under-fits** the data.\n",
|
||
|
"\n",
|
||
|
"Let's try to improve this by creating a more complicated model. We can do this by adding degrees of freedom, and computing a polynomial regression over the inputs. Scikit-learn makes this easy with the ``PolynomialFeatures`` preprocessor, which can be pipelined with a linear regression.\n",
|
||
|
"\n",
|
||
|
"Let's make a convenience routine to do this:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 19,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.preprocessing import PolynomialFeatures\n",
|
||
|
"from sklearn.linear_model import LinearRegression\n",
|
||
|
"from sklearn.pipeline import make_pipeline\n",
|
||
|
"\n",
|
||
|
"def PolynomialRegression(degree=2, **kwargs):\n",
|
||
|
" return make_pipeline(PolynomialFeatures(degree),\n",
|
||
|
" LinearRegression(**kwargs))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Now we'll use this to fit a quadratic curve to the data."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 20,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAFeCAYAAABU/2zqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl8VPW9//HXZB0SEiAhEPadLzuIuAQVRMUNF3AjWm2x\n7a2trbXW1v7a+7v39/vd2/u4vUt769W2Wm+FVluDGyIFKYoYFUEF2Zcv+xYghLAlIZP1/P6YCYYQ\nQjLMZM7MvJ+PBw+yzJn5fDPJeZ/v+X7P93gcx0FEREQiKyHSBYiIiIgCWURExBUUyCIiIi6gQBYR\nEXEBBbKIiIgLKJBFRERcICnSBYhI+zPGTABes9YOiHQtTRljhgAvAllAOfBVa61t5nFTgP/Evx87\nDnzfWrs+8L0ngYeBWqAEeMRau6t9WiASHPWQRcRt/gz8xlo7Evg/wBtNH2CM6QS8DjxhrR0L/AiY\nb4xJMcbcAHwduNJaOw54E5jdbtWLBEk9ZIkpxphrgX8FioCRwGn8O/XvAwZ4w1r7w8Bjbwf+HkgJ\nPO5H1tqVxpjuwPNANyAX2AvcZ60tMcbswb9zvx7oC8y11v6kmTq+AzwCVAM+/D20LcaYScAzQB2w\nErgFuBYYADxjrR3dqB3PWGtHt6KelcAY4KfAqsDz9wWSgQJr7b82qukHwElgUws/w4nAL4B0oB74\nv9bahcaYWcA3gLTAc/wR+Gbg8xPW2uuNMf8A5OPvmW4DvmetLTbGfACUAsOA3wXa8G1r7bQmr90L\nMNbaAgBr7WJjzO+MMZdYa9c0eugQ4KS19sPA41YZYxwgDzgUeO7ywGNXA+e8RyJuox6yxKIJwD9b\na4cDxfiD6lZgPPBdY0xu4LTovwC3WGvH4w/PN40xacBMYLm1dqK1diD+sH4o8NwOkG6tnQRMBB4z\nxvRr/OLGmETgv4CbrLWXA78HrjLGpOLv1f0g8JorgH6B52zJherZYK0dYa2dD7wEvGitnQBcAUw1\nxtxrjBmH/8DkmkBNFc29kDGmC/7TxQ9aay8F7gR+Z4zpE3jICGCytfY6wNPo8+uNMQ8DNwMTAr3W\njcCcRnUes9aOtNY+a61d0DSMA/oAB5t87QDQq8nXtgEZxpjrAnVfC/QGcq21m6y1HwW+nor/4OLV\n5tor4iYKZIlFu6216wIf7wTet9bWWmtLgVNANjAV6AG8b4xZA7yMv9c6yFr738BKY8wPjTG/A0bh\n7y02mA9grT0IHME/1nmGtbYOeA1YYYx5Bn9v8kVgNOCz1i4LPO6lwPda1Ip6GsInHZgM/HOgTSvw\nh9RY4Drgb9baI4Ftnj/Py+UFfi7zA8+xEH8veTT+UF3fqOdJk89vxn8wUBn4/L+B640xyY3rvIDz\n7ZPqGn9irT0FzAD+T6DOW4BP8Z+RAMAYkwMswf+e/6wVry0SUTplLbGoqsnntc08JgFYaq3Nb/iC\nMaYvcMAY82/AZcAfgPfx/514Gm1b2ehjp8n3ALDWPmSMGYE/+H+C/1Tvz5p5bEOANH2elEZ1Xaie\nhkBMDPyfZ631BbbtGqj3W5wddmcFXCMJwBZr7ZWNXr8X/jMNDzZ6raav3bCtp8nnjWttum1z9uE/\nLd9YL/y95DOMMR78p6wnN/paEbAj8PEY/AdOb+IfitCi/eJ66iFLPHLwB9uNxhgDYIy5GVgLeIEb\ngV9ba/+Mf4buVL4MuwsyxnQ1xuzDf4r2aeAf8I/xWqDSGHNb4HHTgJzAZiVAX2NMTiBspjd6ylbV\nE+g1rgSeDDx/J/y90juAdwPtbTj1O+s85X8KDAmMdTcE21b8veYL+RvwcOC0P/jH7QuttQ0HHecc\nuDTThgPATmPMzMDr3wTUWWs3NPPwd4wxlwYe902gxFq7wRgzGFgG/D9r7ZMKY4kWCmSJRU13wOfs\nkK21m/H3GguMMWuBfwZut9aeBv4J+E9jzEr8E5BeBwa39sWttUeBnwNLjTGr8E8y+2bgVPZdwE+M\nMV8A9wA1jep5Hv+krBX4x1Eb6m5LPQ8AVxpj1uMP11esta9YazcCTwVq+hz/Ke/mfi4lwN3Avwd+\nLi8DD1lr9wce33ibpp//AXgP+MwYsxkYB3ylyeMBMMbcYYxZeJ425APfNsZswP++3NtouzXGmPGB\nkH0AeMEYswn/OHvDQcxP8B9YPR54/BpjzIrzvJaIa3h0+0WRyDHGVOKfVbwv0rWISGS1agzZGHMF\n8Atr7ZTAbM3/xj8GVYX/ov0jLT6BiJyPjohFBGjFKWtjzFPAC0Bq4Eu/xn9t4RT8EyZ0fZ9IkKy1\naeodiwi0bgx5B/5xr4YJGfkNy9PhX3igstmtREREpNUuGMjW2jdpdNmItfYwnFnN57v4F0AQERGR\nixDUdciBSxJ+BtwaWGyhRY7jOB7PBa94EBERiRVtDr02B7Ix5kH8l4tca6093qqqPB5KSsra+lIx\nIScnI27bDmq/2h+/7Y/ntoPan5OT0eZt2nIdsmOMSQCeBjriX/d3mTHm/7b5VUVEROQsreohW2v3\n4F9IH/zrAIuIiEgIaaUuERERF1Agi4iIuIACWURExAUUyCIiIi6gQBYREXEBBbKIiIgLKJBFRERc\nQIEsIiLiAgpkERERF1Agi4iIuIACWURExAUUyCIiIi6gQBYREXEBBbKIiIgLKJBFRERcQIEsIiLi\nAgpkERERF1Agi4iIuEBSpAsQEQkFn89HQcGHAOTnT8Lr9Ua4IpG2USCLSNTz+XzMnDmPFSseBmDe\nvNnMnTtDoSxRRaesRSTqFRR8GAjjZCCZFStmnekti0QLBbKIiIgLKJBFJOrl508iL282UA1Uk5c3\nh/z8SZEuS6RNNIYsIlHP6/Uyd+4MCgoWAJCfr/FjiT4KZBGJCV6vl1mzbox0GSJB0ylrERERF1Ag\ni4iIuIACWURExAUUyCIiIi6gQBYREXEBBbKIiIgLKJBFRERcQIEsIiLiAloYREQkQnTLSGlMgSwi\nEgG6ZaQ0pVPWIiIRoFtGSlOt6iEbY64AfmGtnWKMGQzMAeqBjcB3rbVO+EoUERGJfRfsIRtjngJe\nAFIDX/oV8DNr7STAA9wZvvJERGKTbhkpTbWmh7wDuAt4KfD5eGttw3mVd4AbgbfCUJuISMzSLSOl\nqQsGsrX2TWNM/0Zf8jT6uBzoFOqiRETigdtvGalZ4O0rmFnW9Y0+zgBOtGajnJyMIF4qNsRz20Ht\nV/vjt/3R3Hafz8e9975GYeFDACxc+BKLFz/QplCO5vZHQjCBvMYYM9laWwjcAixtzUYlJWVBvFT0\ny8nJiNu2g9qv9sdv+6O97XPmLAmEcTIAhYUP8swzC1rdo4/29l+sYA5G2hLIDTOpnwReMMakAJuB\n19v8qiIiInKWVgWytXYPMDHw8Xbg2vCVJCIikZafP4l582azYsUsgMAs8BmRLSrGaaUuERE5h2aB\ntz8FsohIjAj1rGi3zwKPNQpkEZEYoLWxo5/WshYRiQFaGzv6KZBFRERcQIEsIlHD5/MxZ84S5sxZ\ngs/ni3Q5rqK1saOfxpBFJCpojLRlmhUd/RTIInEsmtYqPnuMlMAYaetXjooHmhUd3RTIInFKPU4R\nd1Egi8SpaOtxtmXlqNq6enzVdfiqa6mqrvN/XFNHXZ1Dfb1DveP/v67e/7/HA4mJCSQmeEhI8JAU\n+D8lKRFvSiLe1ES8KUl4UxJJStTUGwkPBbKIRIeEJH79u5uYt+CvVNXB0BETeP3DvZRX1lB2OvCv\nspqKyhpq65wLP1+QkhITSPcmkZGWTEZaChlpyWSmpZCRnkKn9BSyMlPJzvSSkdkhbDVE01CDtJ4C\nWSROuW2tYsdxOFVRzeFjpyk+XknxsdMcPnaakhM+jp3ycbqqNvDIRAC2r9x/1vYdUhPJ6JBCdncv\nHVISSU1JIjU50LtNTiQ1OZHERA+JCQkkJHjO9IYTPOA4UBfoMdfV1/t7z3UOVbWB3nWVv7fd0Ouu\n8NVSeqqKAyUVLbapY4f
|
||
|
"text/plain": [
|
||
|
"<matplotlib.figure.Figure at 0x109451b00>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model = PolynomialRegression(2)\n",
|
||
|
"model.fit(X, y)\n",
|
||
|
"y_test = model.predict(X_test)\n",
|
||
|
"\n",
|
||
|
"plt.scatter(X.ravel(), y)\n",
|
||
|
"plt.plot(X_test.ravel(), y_test)\n",
|
||
|
"plt.title(\"mean squared error: {0:.3g}\".format(mean_squared_error(model.predict(X), y)));"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"This reduces the mean squared error, and makes a much better fit. What happens if we use an even higher-degree polynomial?"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 21,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAFeCAYAAABU/2zqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xd8XHed7//XmVEZdVnFsiz3dtzj2E5xip0CISGFGAgW\nEIJDWWB3WRbYXxaWbXcX7u5m97L3LrtLWUIMIUROICYYpzgNO8WOU5zE9bj3pmb1kTQz5/fHzMiy\nrDIazTmjkd7PxyNgSXPOfI9mdD7z+XybYds2IiIiklyeZDdAREREFJBFRESGBQVkERGRYUABWURE\nZBhQQBYRERkGFJBFRESGgbRkN0BE3GGa5lLgCcuypia7LT2ZpjkT+BlQBDQD91mWZfXyuBuBB4F0\noA34M8uy3jRN81vAqm4PHQvkWpZVYJpmAXAW2NPt539uWdYmZ65GJD4KyCIyHDwKfN+yrCrTNG8F\nfgPM7/4A0zQzgCrgFsuy3jNN83bgEWC2ZVn/DPxz5HEFwDbgc5FDrwY2WZb1IXcuRSQ+CsiS8kzT\nvAH4J+AkMA9oBf4O+DPABH5jWdY3Io+9E/gOkBF53F9YlrXVNM0y4MeEM6txwFHgE5ZlVZumeQR4\nGLgZmASstSzrL3tpx1eALwEdgB/4kmVZe0zTXA78AAgCW4HbgBuAqcAPLMta0O06fmBZ1oIY2rMV\nWAh8G3grcv5JhDPHKsuy/qlbm/4caAB29fM7vIZwQMsBQsDfW5a1wTTN1cDngezIOX4OfCHy9XnL\nsm42TfNvgEogAOwD/tSyrLOmaf4BqAVmAz+MXMOXLcu6vcdzVwCmZVlVAJZlPWua5g9N07zcsqzt\n0cdZltVhmuZ4y7KCpmkawHSgppfL+T/A05ZlPRf5+hqgyDTNVyLX9xPLsn7U1+9CJFnUhywjxVLg\nHy3LmkO4PPlt4MPAYuBPTNMcFymLfg+4zbKsxYSD55OmaWYTLne+ZlnWNZZlTSMcrD8TObcN5FiW\ntZzwzf2rpmlO7v7kpml6gX8HPmRZ1pXAT4BrTdPMBH5NuES6GNgCTI6csz8DtWeHZVlzLct6inCW\n+DPLspYCVwEfNE3zHtM0FxH+YHJ9pE0tvT2RaZpjCJeL77UsawnwEeCHpmlOjDxkLrDCsqybAKPb\n1zebpnk/cCuw1LKsy4CdwJpu7ayzLGueZVn/aVnW+p7BOGIicKrH904AFT0fGAnGZZGf/wvwrz2u\nZV6k/X/b7dudwO+A5cAdwNdN0/xIb78LkWRShiwjxWHLst6L/Psg4ewtANSaptkIFAMrgHLgJdM0\no8cFgemWZf2HaZrXm6b5DWAm4XLp1m7nfwrAsqxTpmmeI9zXeTT6w0igeALYYprmBmAj8CvCHwj8\nlmW9HHncI6Zp/sdAFxNDe14BME0zJ3JdY0zT/MfIz3KAywgHuucsyzoX+f6Pgd4C4rLI7+Wpbr+X\nELCAcFB937Ks5m6P7/71rYQ/DLRFvv4P4DumaaZ3b+cA+koMgr1907Kss0CFaZqXAy+apnmVZVn7\nIz/+GuEqQ1O3x3+32+GnTNP8MbCSyGsqMlwoIMtI0d7j60Avj/EAL1qWVRn9hmmak4ATpmn+C3AF\n8BDwEuG/DaPbsW3d/m33+BkAlmV9xjTNucAHgb8kXOr9q14e29HHeTK6tWug9kQDojfy/8ssy/JH\nji2JtPePuDjY9RrgIo/ZY1nW1d2ev4JwpeHebs/V87mjxxo9vu7e1p7H9uYY4bJ8dxWEs+Aupmnm\nAzdblrUOwLKs7aZpvkf4w8r+SJXio4Q/BHU/7qvAby3LOt6tjR2IDDMqWctoYRMObLeYkTQwMnjo\nXcAH3AL8X8uyHgWqCQdVbx/nuoRpmiWmaR4jXKL9f8DfEO7jtYA20zTviDzudqA0clg1MMk0zdJI\nn+jd3U4ZU3ssy2oknDl/M3L+AsJZ6V3A85HrjZZ+V/fR/DeAmZG+bkzTXAjsJZw1D+Q54P5I2R/C\n/fabLMuKBrxLPrj0cg0ngIOmaa6KPP+HgKBlWTt6PDQEPBTp746Wp2dH2g/hjL7esqxjPY67Fvj/\nIscUER7stTaGaxNxlQKyjBQ9+2Qv6aO1LGs34ayxyjTNd4F/BO60LKsV+Afg30zT3Ep4ANKvgRmx\nPrllWTXAdwmXUN8iPMjsC5ZlBQlnbX9pmuY7wMcJ92lG2/NjwoOythDuR422ezDt+RRwtWma7xMO\nTo9ZlvWYZVk7gQcibXqTcCm7t99LNfAx4MHI7+WXwGciGaXd45ieXz8EvABsM01zN7AI+HSPxwNg\nmuZdkXJ+byqBL5umuYPw63JPt+O2m6a5OFImvxv4v6Zpbo889ycty4r2P88ADvdy7j8FJpimuZPw\n7/m/Lct6sY92iCSNoe0XRdxlmmYb4VHFPTM5ERnFYsqQTdO8yjTNl3t871Omab7uTLNERjR9ChaR\nSww4qMs0zQfoMbAjMrrxc30eJCJ9siwre+BHichoE0uGfIBwH5gBYJpmMeG5nH9ODAM2REREZGAD\nBmTLsp4kMoXENE0P4YEU3yC26QwiIiISg8GOsl5CeCTjD4HHgLmmaX5/oIPs8MixUfPfL5/dY9/5\nzafsHQdrkt4W/af/9N/o+++h3+207/zmU/a+Y/VDOk9tQ5t95zefsv/tl2/bgP2PD71h3/nNp+yW\nts6kX2MK/Ddog1oYxLKsN4ks+B5ZOrAqukZwfwzDoLq6aaCHjRhtreEpmOfrWwFG1bX3VFqap+vX\n9Se7GUmTrOtvbgmvk9PU2EZ1dfzrP9U3hc/T3t5JdXUTHR3h9XZqa5tpzez/vHrt8wZ9zGAy5J4R\n3+jle9KNfjkikgyhUPju4/EkZpiP7mXuiOmjk2VZRwgvqt/v9yTMMCJ/BJrjLSJJEIwEZG+CArK4\nQyt1OSD6J6BwLCLJkOiArAWk3KGA7IRogpzcVojIKJWokrWhBNtVCsgO6HoPKyKLSBJcyJB1i08l\nerWcEPlYaSsii0gSONWHrNK1sxSQHeBRJ7KIJFHCStaJaIzETAHZQSEFZBFJgmAwBCRyUFdCTiMD\nUEB2QNe0J6XIIpIEQTtBJWuN6nKVArIDusKx4rGIJEGiFwYRdyggO0HTnkQkiYLBBM9DTshZZCAK\nyA4wFJFFJImCto3HMLp1n8VH+bW7FJAdpGlPIpIMoZCd2HJ1pP9NdzRnKSA7QGO6RCSZgkE7MeXq\nPk6hsV7OUEB2gOKxiCRTMJSggCyuUkB2QLTfRqOsRSQZQnZiS9a6lblDAdkJXbsv6m0sIu4LBkMJ\nyZCVY7tLAdkBehOLSDIFQzZebyIHdSXuVNI3BWQHdJWsk9wOERmdQpFpT0M11GlTMjgKyA5SyVpE\nkiFho6zFVQrIDtDfgYgkU7hknbjbe8/UwlDHnCMUkJ0QKfOElCGLSBKEQokpWYu7FJAdoIVBRCSZ\nEj0PWd1v7lBAdoBH85BFJImCCVo6U0m2uxSQHaSStYgkQyjR057EFQrIDtCnShFJFtu2Cdk2Xt2I\nUo4CsgM8GtQlIkkSDIXvOwkpWff4Wrc0ZykgO0hvXhFxWygSkBNZsr7kXqbk2xFpsTzINM2rgH+2\nLOtG0zQXAf8BBIF24D7Lss452MaUo+kGIpIs0Qw5MSVr3cvcNGCGbJrmA8D/AJmRb/1f4E8ty7oR\neBL4S+eal5qifwcqWYuI2xJZshZ3xVKyPgB8lAsflSoty3o/8u90oM2JhqW0rt2ektsMERl9LpSs\nh94jqWKfuwZ8xSzLehIIdPv6DIBpmtcAfwL8u2OtS1FdJWtFZBFxWVfJWhlyyompD7kn0zRXAX8F\nfNiyrNpYjiktzYvnqVJSwalGAHJywlX+0XTtvdH16/pHM7evP+gJ51k52RlDfu5WfycA6RleSkvz\nyMjwAlBSkosvY+DwMdpf+8EadEA2TfNe4I+AGyzLqo/1uOrqpsE+VcpqavQD0NgU/v/RdO09lZbm\n6fp1/cluRtIk4/rP1bY
|
||
|
"text/plain": [
|
||
|
"<matplotlib.figure.Figure at 0x1093c27b8>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model = PolynomialRegression(30)\n",
|
||
|
"model.fit(X, y)\n",
|
||
|
"y_test = model.predict(X_test)\n",
|
||
|
"\n",
|
||
|
"plt.scatter(X.ravel(), y)\n",
|
||
|
"plt.plot(X_test.ravel(), y_test)\n",
|
||
|
"plt.title(\"mean squared error: {0:.3g}\".format(mean_squared_error(model.predict(X), y)))\n",
|
||
|
"plt.ylim(-4, 14);"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"When we increase the degree to this extent, it's clear that the resulting fit is no longer reflecting the true underlying distribution, but is more sensitive to the noise in the training data. For this reason, we call it a **high-variance model**, and we say that it **over-fits** the data."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Just for fun, let's use IPython's interact capability (only in IPython 2.0+) to explore this interactively:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 22,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAFeCAYAAABU/2zqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl8XOV97/HPaB1blmXZGnnfbT1gsLElNkMwUBKyENIQ\nkqCk0JgmTZukudkaUpLbe9umS5o2aRN6S9M0wSlpK5o0TsIlJQtQSLmGFslgs/3k3cZga7RZkq3R\nOvePcyTLwpI14xnNmZnv+/XihWY0y/PMyOd7znPO73lC8XgcERERyayCTDdAREREFMgiIiKBoEAW\nEREJAAWyiIhIACiQRUREAkCBLCIiEgBFmW6AiEw/59ylwPfMbGWm2zKec+4y4K+BmUAh8Odm9k9n\nedzNwDbgsH9XHLjGzE465/4N2AD0+L971Mw+k+62i5wPBbKIBIZzLgR8H7jTzB51zi0GmpxzT5vZ\n3nEPvwr4CzP70lle6kqgzsyOpbnJIimjQJac4py7Dvgz4ChwEXAK+N/A/wAc8G9m9mn/sTcDXwBK\n/Mf9rpk95ZybD3wDqAYWAIeA95pZ1Dl3ELgPuAFYBjxgZp87Szs+AvwW0A/EgN8ys5ecc1uAe4Ah\n4CngrcB1wErgHjNbP6Yf95jZ+im05ym8o8G7gWf8118GFAMNZvZnY9r0SeAE8MIkn+FVwJeAMmAY\n+AMze8g5txX4IN6R6wngO8CH/NudZnaDc+73gXpgEGgGfsfMjjvn/gNoAy4A7vX78NtmdtO4ty/x\n3+9RADM76pxrBRYDZwvkfufcu/G+vy+Y2S+dcyuBcuAbzrkVQCPwGTPrmKjPIkGgc8iSiy4Fvmhm\nFwLH8YLqbUAt8DHn3ALn3FrgT4C3mlktXnj+wDk3E7gNeNLMrjKzVXgb+zv8144DZWa2BS8QPu6c\nWz72zZ1zhcBfAW82s8uBvweuds6V4h39fdJ/zx3Acv81J3Ou9uw2s3Vm9iPgfuDbZnYpcAXwJufc\ne5xzG/F2TK7x23TybG/knKsEvg3cbmZ1wK8C9zrnlvoPWQdca2a/AoTG3L7BOXcn8BbgUjO7BHge\nb0h5pJ3tZnaRmf2NmT14ljDGzPrM7L4x7fkw3o7BU2dpbivwN35f7wa2+0fUEeDnwIeBTXjD1t+e\n6MMVCQodIUsuOmBmz/k/78M7ehsE2pxzXcA84FpgIfCoc27keUPAajP7unPuGufcp4G1wMWcGQg/\nAjCzV51zLcBcvCM+/PuHnHPfA3Y45x4Cfgb8M94OQczMHvMfd79z7uvn6swU2vNLAOdcmd+vSufc\nF/3flQGXAEuBn5pZi3//N4DXBSKw2f9cfjTmcxkG1uOF6i4z6xnz+LG334K3M9Dr3/468AXnXPHY\ndk6Vc+738EY23mxmfeN/b2a3jvn5Sefc/wPeZGbbgFvHvM4fAMecc0X+34FIICmQJReN33ifbSNc\nADxiZvUjdzjnlgGvOOf+HLgM+BbwKN6/k9CY5/aO+Tk+7ncAmNkdzrl1wJuAz+EN9X7+LI/tn+B1\nSsa061ztGQnEQv//m80s5j+3ym/vhzlzRGxofJt9BcBLZnblmPdfjDfScPuY9xr/3iPPDY27Pbat\n4597Vv5Iwja84e0rzezwWR5TAXzMzP503Pv1O+feAMw1sx+PuX+YifssEggaspZ8FMcLthudfxjo\nnHsL8CwQBm4E/tq/sjeKF6qFE7zW6zjnqpxzh/GGaL8G/D7eOV4Dep1zb/cfdxPe8Cr++yxzzkX8\nC5veOeYlp9QeM+vCO3L+jP/6FXhHpe/AG8K90Q9XgK0TNP9pYK1/rhvn3AbgZbyj5nP5KXCnP+wP\n3tHt42Y2stPxuh2XCXwP7xzw1WcLY18P8FHn3Lv8dm7C22l52H/u1/3hd4DP4l1RrpV0JNAUyJKL\nxm94X7chNrMX8Y4aG5xzzwJfBG42s1PAHwF/6Zx7Cu8CpO8Da6b65mbWCvwx8Ihz7hm8i8w+ZGZD\nwLuAzznnmoB3AwNj2vMNvIuydgCvjml3Iu15P3Clc24XXrj+i5n9i5k9D9zlt+m/8Yayz/a5RPGG\ne7/sfy7fBe4wsyP+48c+Z/ztbwG/AP7LOfcisBH4tXGPB8A59w5/OP8MzrmrgbcDq4EnnXM7/f/e\n5P9+p3Ou1v8sfxX4XefcbrxzxO81s3Yz+3e8C9uedM69jHfB3O9M8HmJBEZIyy+KZI5zrhdwkxwJ\nikiemNIRsnPuCufcY+Pue79/EYWIJE97xCICTOGiLufcXYy7mMM/X/MbaWyXSF4ws5nnfpSI5IOp\nHCHvxTvvFQJwzs3Dq9/8JFO/SENEREQmcc5ANrMf4JeNOOcK8C7c+DRTLGEQERGRc0u0DrkO7+rO\ne/HKQ9Y55746MhXhROLxeDwU0sG0iIjkjYRDL6FANrP/xpslCH+6wIZzhTFAKBQiGu1OtG05IRIp\nz9u+g/qv/qv/+dr/fO47eP1PVCJ1yOOvBg2d5T4RERFJwpSOkM3sIN5E+pPeJyIiIsnRTF0iIiIB\noEAWEREJAAWyiIhIACiQRUREAkCBLCIiEgAKZBERkQBQIIuIiASAAllERCQAFMgiIiIBoEAWEREJ\nAAWyiIhIACiQRUREAkCBLCIiEgAKZBERkQBQIIuIiASAAllERCQAFMgiIiIBoEAWEREJAAWyiIhI\nACiQRUREAkCBLCIiEgAKZBERkQBQIIuIiASAAllERCQAFMgiIiIBoEAWEREJgKKpPMg5dwXwJTO7\n3jm3Efg6MAT0Ab9uZi1pbKOIiEjOO+cRsnPuLuCbQKl/118Dv2Nm1wM/AD6XvuaJiIjkh6kMWe8F\n3gWE/Nv1ZrbL/7kY6E1Hw0RERPLJOQPZzH4ADI65fQzAOXcV8DHgr9LWOhERkTwxpXPI4znnbgM+\nD7zNzNqm8pxIpDyZt8oJ+dx3UP/Vf/U/X+Vz35ORcCA7524HPgxcZ2YdU31eNNqd6FvlhEikPG/7\nDuq/+q/+52v/87nvkNzOSCJlT3HnXAHwNWAW8APn3GPOuT9I+F1FRETkDFM6Qjazg8BV/s15aWuN\niIhIntLEICIiIgGgQBYREQkABbKIiEgAKJBFREQCQIEsIiISAApkERGRAFAgi4iIBIACWUREJACS\nmstaREQkiIaGh9n7ygnicbhgeWWmm5MQBbKIiGS1gcFhXjrUTqNF2bmnlZ7eAUqKCrj3M9cSCoXO\n/QIBoUAWEZGsE+sfZPf+dpqaozy3t5VY/xAAs8tKuG7TYt6wfmFWhTEokEVEJEv09A7w3N5WGi3K\n8wfaGRwaBqCqIsyWSxZR5yKsXlRBQUF2BfEIBbKIZLVYLEZDwxMA1NdvIRwOZ7hFkkod3X3s3BOl\n0aLY4U6G43EAFleVUVsToc5FWFo9K+uOhs9GgSwiWWN8+ALcdtt2duy4E4Dt2+/jgQduUShnuZaO\nUzQ1t9LY3MK+o12j969cOJs6F6G2JsKCuTMz2ML0UCCLSFaIxWKvC9+3v322f7sYgB07ttLQ8CBb\nt96YwZZKouLxOEejJ2ls9o6EX4n2ABAKwQXL5lBb44Xw3Nm5vaOlQBaRrNDQ8MTrwnfBgj/PbKMk\nacPxOAde66LJojQ2R2np6AWgqDDEhtXzqKuJcMnaKmbPLMlwS6ePAllEstall67h2LH72LFjKwCb\nN2+jvv6WzDZKJjQ0PEzz4U4am73ypI7uPgBKiwu59IJq6moibFg9jxml+RlN+dlrEck69fVb2L79\nzPC9445buOMOaGh40H+Mzh8HzcDgEC8c7KDJojy716sRBigLF3H1+gXU1VSzbkUlJcWFGW5p5imQ\nRSQrhMNhHnjglrOGr84ZB0tv3yC/3HmUx545zK79bfT5NcIVs0q4vnYxdTURapbOoahQszePpUAW\nkawRDocVvgHVfaqfZ/e20mRRXjjYMVojHJkTpm6TF8IrF82mIAfKk9JFgSwiWU11yJnT3hVj555W\nGq0FO9KJXyLMkkgZ12x
|
||
|
"text/plain": [
|
||
|
"<matplotlib.figure.Figure at 0x109ee9e48>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from IPython.html.widgets import interact\n",
|
||
|
"\n",
|
||
|
"def plot_fit(degree=1, Npts=50):\n",
|
||
|
" X, y = make_data(Npts, error=1)\n",
|
||
|
" X_test = np.linspace(-0.1, 1.1, 500)[:, None]\n",
|
||
|
" \n",
|
||
|
" model = PolynomialRegression(degree=degree)\n",
|
||
|
" model.fit(X, y)\n",
|
||
|
" y_test = model.predict(X_test)\n",
|
||
|
"\n",
|
||
|
" plt.scatter(X.ravel(), y)\n",
|
||
|
" plt.plot(X_test.ravel(), y_test)\n",
|
||
|
" plt.ylim(-4, 14)\n",
|
||
|
" plt.title(\"mean squared error: {0:.2f}\".format(mean_squared_error(model.predict(X), y)))\n",
|
||
|
" \n",
|
||
|
"interact(plot_fit, degree=[1, 30], Npts=[2, 100]);"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Detecting Over-fitting with Validation Curves\n",
|
||
|
"\n",
|
||
|
"Clearly, computing the error on the training data is not enough (we saw this previously). As above, we can use **cross-validation** to get a better handle on how the model fit is working.\n",
|
||
|
"\n",
|
||
|
"Let's do this here, again using the ``validation_curve`` utility. To make things more clear, we'll use a slightly larger dataset:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 23,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeMAAAFVCAYAAADc5IdQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAH9FJREFUeJzt3X+Q3Gd92PH3ORwshcNk6FHoFJJLIU+bSTtM68o6J0gx\nThyMLEXCUF1cy75MbQImHo+TmQPs4ukAnoLSmLpu6wABny0POY+D5aIIFCeGSA2jmCENDZm0T+v2\nOp3pDKC4GKspiw64/rG30t7qbne/u/vd5/vj/ZrxWCfd3j7P97u3n32e5/N8nqn19XUkSVI6l6Ru\ngCRJdWcwliQpMYOxJEmJGYwlSUrMYCxJUmIGY0mSEntBv28IIVwOfDjGeGUI4Q3Avwa+D3wXuDHG\n+M2c2yhJUqX1HBmHEJaATwAv2virfwX8SozxSuBx4D35Nk+SpOrrN039DPBWYGrj64UY459t/Hka\n+E5eDZMkqS56BuMY4+PA9zq+/jpACOEK4N3AR3NtnSRJNdB3zbhbCOEgcCfwlhjjs/2+f319fX1q\naqrft0mSVBWZg16mYBxCuAF4B/AzMcZvDdSiqSnOnDmbtV2VMTs7U9v+17nvYP/tf337X+e+Q6v/\nWQ26tWk9hHAJcB/wUuDxEMIXQwj/PPMzSpKkTfqOjGOM/xO4YuPLV+TaGkmSasiiH5IkJWYwliQp\nMYOxJEmJGYwlSUrMYCxJUmIGY0mSEjMYS5KUmMFYkqTEDMaSJCVmMJYkKTGDsSRJiRmMJUlKzGAs\nSVJiBmNJkhLre4SiJKl4ms0mKyunAFhY2EWj0UjcIo3CYCxJJdNsNjl48CinT/8SAEePPsijjx4w\nIJeY09SSVDIrK6c2AvE0MM3p04vnR8kqJ4OxJEmJGYwlqWQWFnYxP/8gcA44x/z8MgsLu1I3SyNw\nzViSSqbRaPDoowdYWTkGwMKC68VlZzCWpBJqNBosLl6duhkaE4OxJFWQW5/KxWAsSRXj1qfyMYFL\nkirGrU/l48hYkro4xatJc2QsSR3aU7xLS/tYWtrHwYNHaTabqZuViVufyseRsSR12DzFy8YU77FS\nZS679al8DMaSVEFufSoXp6mlGms2mywvP8ny8pOlm4rNi1O8SsGRsVRTbn/ZmlO8SsFgLNVUFdZG\n8+IUrybNaWpJkhIzGEs15dqoVBxOU0s15dqoVBwGY6nGXBuVisFpakmSEjMYS5KUmMFYkqTEDMaS\nJCVmApcklYxHPFaPwViSEhkmqG4uY9rkgQfu4eab38ChQ1cZlEvMaWpJSmDYc5MvlDH9PvAYq6sf\n5K67rivlucu6wGAsSQlsrg0+vVEb/FSGn/AUcGiEx6tI+k5ThxAuBz4cY7wyhPA6YBn4AfDnwLtj\njOv5NlFSnbk+utnCwi6OHn2Q06dfmbopGqOeI+MQwhLwCeBFG391L3BnjHEXMAX8Qr7Nk1Rng07l\nlvFc5mFrg7fLmN5zz3eYm7s38+NVTP1Gxs8AbwWObHz9D2KM7XmQzwNXA0/k1DZJNTfIMY9lPZd5\nlNrgjUaDW265lkOHmtYWr4iewTjG+HgI4Uc7/mqq48//F7g0j0ZJ0qAGPZe5iNPdo9YGt7Z4dWTd\n2vSDjj/PAM8N8qDZ2ZmMT1Mtde5/nfsO9n/U/t922x6OHz/CyZM3ALB79yPcdtv1mwLpzMzFQXVm\nprHpuZvNJm9/+2OcPHkIgOPHj3DixPW5B+Q63/86930YWYPxn4YQdscYTwLX0Ern6+vMmbOZG1YV\ns7Mzte1/nfsO9n9c/T9yZG/HVOxezp5d4+zZtfP/vmfPDubnH+T06UUA5ueX2bPnwKbnXl5+ciMQ\nt0bPJ0/ewP33Xzx6Hqc63/869x2G+yAyaDBuZ0z/GvCJEMILgb8AfifzM0ojKOJUo/LVbyp23Ocy\nF/01Nmj7it4PbTa1vp77zqT1un9Cqmv/x9337kSd+fliJ+rU+d5Dsfp/4bWzCLRGz1u9dsb5Gsuj\n/4O2L/XvSpHufQqzszNT/b9rM4t+qDRGL5Kg1FJtQWqPng8fPsbhw8e2DUxFf40N2r6i90MXsza1\npIlIvQVpkMzjtbW1nv8u5cWRsUpj2CIJKoaij9aazSaf/eyzwEO0X2M7d36qUK+xQX8H/F0pH0fG\nKo1xJ+pInVZWTvH00++gdQDD7wNr7N17aaFeY4P+Dvi7Uj4GY5XKsEUOzCxN70JN5UWAjdHagbE+\nx3jucwPYA5xjevrYOJs3FoP+DlgQpFwMxqq81GuVasl7tDbqfZ7Eh4Ui8INpMRmMVXmDlktU/vIc\nrY16n+swtZv3B9N2oJ+ZabBnz47KXb88GYwlaUPVp3bz/GCaem9z2ZlNrcozs7S6Ovct79+/w/uc\nUNGz5YvOkbEqrw7Tj3W01ZTrQw9dw2OPPc5XvvLfuOyy1yVuYfHUZV28jCyHmbM6l4Wrc9/B/ufd\n/+XlJ1la2kdrJNYETrBv31f5xjdmN7YopZ0qLer9zyuBa9CSo3UwTDlMR8aSSqE7iHT8C/Bp4BCf\n/eweWkU7vg80TNbbQl7r4p0zUK0ErnoG4mEZjCUV3nZT0q0p11cCF45HhBtpFe3Yk6axNdYO9EWd\nFSgyE7gkFd5WyUFPPPFlHn30AAcOfG2LR6xhEpfKxJGxVAF1LeTQaDS477538vWvX0hK2rnzU+zd\neynT08eGStar67VUWgZjqeSqXmGs2WyytrbG3Nz7WV29E2hsygK+OFv+uqH7XvVrqeIyGEslV+UK\nY5uD43XMzd3LzTe/lkOHNgfIcSUlVflajsLZgvwZjCVlMsk35u7guLp6B9PTxy56ToNFfpwtmAwT\nuKSSm2SFsfYb89LSPpaW9nHw4FGazWYuz5WiTVWp1tZZmWzYa9H+Gbff/pucPv2LWFkrX46MpZKb\nZIWxSU/jDlIxKmubeo2iq1CtbRwj2c0/Yx+tvds30DpeUnkwGEsVUNUDDsYdHAcJVGW/luP4wNT9\nM1p7tz8HvMUSmjlxmlrSwFJM47aD4+Li1VsG4ixtqs9hBk3g+MZ/g01Td05tr62du+jfDxz4GocP\nH3O9OCeOjCUNrIjTuEVsU0r79+/gQx+6j+effw8AL3vZR9i//6aej+meMdi585NcfvnHefrpW4DW\n8sB9972z1tc1bx4UkbM6l4Wrc9+h2v0fJHu5s/9FyXae5GEGqe7/5gM0AM5x+HDvaeqtHnPPPY8z\nPT29MUqeYnp6euB7V+XX/iA8KEJS7rImCBVpa4yj6MG1g29R7l3VuWYsTcA4tpoURdZ11zzWaUe5\nnv3WoMuqfU3W1s6xc+cnybKuv926e33W2NNzZCzlLMvIsCjTuUVWpJF2UXRfk8sv//j5aeZBRv/t\nGYMjRz7DV77yDJdd9vpJNFsdHBlLXcY9ih10dFHEghpbyZpRPe4MbEdrF+u+Jk8/fQvT09OZR/+/\n+7tnOXr0vdx113UcPHiU/ft3VKIIShk4MpY6pBx1FbEu8lYj9e511/37r+k5mq/bOm37ms3MNNiz\nZ0dh+9p9b7d6/T3xxLFa3buUDMZShzwC4iBVpIqo1weT9rrroB9exllIo8jXs/t6zM9P5sNc1muy\n1X279tqZLb+37EVQysJpailn7ZHh4cPHehZNKFpd5EGmg1NMGQ96PVNINYWe9Zps1U6YKtTrr24c\nGUsdxj3qypKQVbfp3FE4WrvYqNdkenra119CFv3IWZ03v5e17+PKaJ6ZmeZNb3p405TlQw9dwxNP\nfHnknz1O2/V3kAIZvb6nrPd/FJMsKjKKvNtZx3vfaZiiHwbjnNX5RVnnvgN85jOneNe7rqazqtHc\n3N2srn4QmNx6Yi/91jgH+WDy3HPP8Z73LAPwkY8s8vKXvxyY/P0vyrawsiZwjbOddf/dtwKXVHCr\nq/MUKVu6X8Jav6nPZrPJTTd9ntOn3wvA17+e5gNGkfYet69Z0QOSU/3FYgKXlJPFxas2JcTMzX0U\neFPiVo1XUfb8FqUd0rA
|
||
|
"text/plain": [
|
||
|
"<matplotlib.figure.Figure at 0x109edfba8>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"X, y = make_data(120, error=1.0)\n",
|
||
|
"plt.scatter(X, y);"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 24,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.learning_curve import validation_curve\n",
|
||
|
"\n",
|
||
|
"def rms_error(model, X, y):\n",
|
||
|
" y_pred = model.predict(X)\n",
|
||
|
" return np.sqrt(np.mean((y - y_pred) ** 2))\n",
|
||
|
"\n",
|
||
|
"degree = np.arange(0, 18)\n",
|
||
|
"val_train, val_test = validation_curve(PolynomialRegression(), X, y,\n",
|
||
|
" 'polynomialfeatures__degree', degree, cv=7,\n",
|
||
|
" scoring=rms_error)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Now let's plot the validation curves:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 25,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfQAAAFgCAYAAABNIYvfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3XecZGd95/vPiZW7uzpNjppRjTIS0UILEsE2NtiYe8EB\nX4EBAwsY2yAjhFYSxmDtlVkEFgYjDItYkm0ul12EcVpjkkVSnpGmRhpNT+5cHSqe9OwfVR2qq6rD\nzFR3V/Xv/XrVVPU5Vaef0zPT3/OE8zyaUgohhBBCtDZ9rQsghBBCiPMngS6EEEK0AQl0IYQQog1I\noAshhBBtQAJdCCGEaAMS6EIIIUQbMNe6AMvleb7KZPJrXYzzlkxGkfNYH9rhHKA9zqMdzgHkPNaT\ndjgHgL6+hLbc97ZMDd00jbUuwgUh57F+tMM5QHucRzucA8h5rCftcA4r1TKB7vneWhdBCCGEWLda\nJtAd313rIgghhBDrVgsFurPWRRBCCCHWrdYJ9ECa3IUQQohGWibQAxXgSagLIYQQdbVMoAO4gfSj\nCyGEEPW0VKA7MtJdCCGEqKulAt2TGroQQqxrjuNw//3fXPb7v/Od+/nhD7/fcP+XvvQFnnzy0IUo\nWttrmZniQJrchRBivRsbG+Vb3/qfvPKVr17W+1/xilcuuv93f/eNF6BUG0OLBbo0uQshxHL93b89\nzc8OD6/4c4ah4fuq7r7nHujndS/Z1/CzX/zi5xkYeIYvfOFvCIKAxx9/lGKxwPvffzvf+c79pNNP\nMjk5yb59+/nAB+7gc5/7DD09vezatZsvfek+bNvizJnTvPSlv8iNN76Jj3zkg7zsZb/E2NgoDzzw\nI0qlEmfOnOL1r38Dr3jFK3niiYPcffddRKMxurqShEIhPvCBO2bLc+LEce68808xDBOlFHfc8WH6\n+vq5++67ePLJJ/A8lze/+W1cd92Lueeeu3n88UcBePnLf5nXvva3+MhHPsjU1CRTU1PcddfH+fKX\n7+Oxxx4hCAJ+8zd/hxtueBnf+Mbf84//+G10XefAgUv5oz+6acU/8wuhpQJdKYXru1iGtdZFEUII\nUccb3vBmnnnmKG9841v4/OfvZc+evbz73e8ln8/R0dHB3Xf/FUEQcOONv8no6AiaNjdV+dDQIF/8\n4tdwHIdXv/qXufHGN83u1zSNXC7Hxz52D6dOneTmm/+YV7zilXz0o3dy++0fZvfuPdx776cYHR2p\nKs/Pf/5TLr30Cv7zf/4DHnvsEbLZLE8++QSTk5N89rP3MT09zd/+7ZfRdYPBwTPce+8X8DyPd7zj\nLTz72c9B0zSe/ezn8brX/TYPPPAjzp49w6c+9TeUSiXe/vbf47nPfQHf+c63eO97b+HAgUv45je/\nju/7GMbqTz3bUoEO5Vq6BLoQQiztdS/Zt2htupG+vgQjI9Pn9D2Vqq7Z79ixCwDbDpHJZPjgB28l\nEomSz+fxvOpW14suughd1wmHw4RCoZpj799/caV8/ThOebKxsbFRdu/eA8BVV13N//7f/1z1mVe+\n8tf58pfv473vfTfxeIy3ve2dnDx5nMsvvxKARCLBW97ydr7ylf/BVVddDYBpmlx22RUcO3YMgJ07\ny+fwzDNPk04f5g/+4G0A+L7P2bNnuOWWO/ja177EmTOnufzyK2t+BqulpQbFgfSjCyHEeqZpOkEQ\nAOVwn6lh//jHP2J4eJAPfvAjvPWt78BxSnWCb/GFxebX5mf0929iYKAcvAcPPlaz/wc/+B5XXXU1\nn/jEp7j++pfypS/dx+7dezh8uDzQLpvNctNN72b37j089tgjAHiex8GDj7Jjx46q77tr1x6uuebZ\n3HPPZ7j77r/ihhtexrZt2/nWt77JTTfdwic/eS9HjqQ5dOjxZf60LqwWrKFLoAshxHrV3d2N57l8\n+tP3EAqFZsPw0ksv5777Pse73/12urt7uPTSy2ebx+c3q8+pDe96+9/73vdz550fIhKJYFkWvb39\nVZ85cOASPvKRD2JZFr7v84d/+F7270/x85//lHe84y34vs+b3vRWnv/8X+Dhhx/k7W9/E67r8tKX\nvpyLLz5Q9X2vu+5FPPzwg7zznb9PoZDnRS+6gWg0ykUXXcQ73/kWotEYfX39XHrp5RfiR7li2lo1\nDazUYHZEnR4axdRNtsY3r3Vxztn5NGWtJ+1wHu1wDtAe59EO5wByHmvhG9/4e17ykpfT1dXFZz/7\naSzL4o1vfEtLncNiVrIeesvV0D3lVTXjCCGE2Li6u7t5z3veSSQSJR6Pc+utf7rWRVozLRfoqPLA\nOFsGxgkhxIZ3/fUv5frrX7rWxVgXWm5QHEg/uhBCCLGQBLoQQgjRBloz0GWRFiGEEKJKawa61NCF\nEEKIKi0Z6J7yCFSw1sUQQghxHv7gD97GiRMDDVdcu/HG31z089/73ncZHR1lfHyM//bf/t9mFbNl\ntN4od5gd6R4y7LUuiRBCiPOiLbniWiNf//rX2LNnDzt37ua97735Aper9bRmoFNudpdAF0KIxr7x\n9P08PLzyaUgNXcMP6k86dnX/FbxmX+MAvvXWP+G1r/1tnvWsazh8+Anuu+9z3Hbbh7jzzj8jl8sy\nOjrCa17zWl796v+78gk1u+Lar//6a7jrrj/n6NGn6O/fRC6XA8pzqH/ykx/H9wMmJye46ab3MzU1\nxVNPHeHDH/4gt932IT784Tv4zGf+Oz/72Y/57Gf/mng8SiQS45Zb7uDIkcN8+ctfrFnJbb7PfOav\neOSRB/E8n+uvfwmvf/0bOHToIPfc8zGCIKCvr4/bb/8wx48f4+Mf/yi6rmPbIW6++VaCIODmm/+Y\nzs4ufuEXXsjzn38tn/jER1FK0dnZyS233I7juNxxxy0opXAch5tuumV2bvoLpXUD3XdBbkUXQoh1\n5VWv+g2+8537edazruHb3/4Wv/Zrv8GpUyd52ct+iRe/+AZGR0d417veNi/Q56ZW/f73v0upVOTe\ne7/AxMQEv/Vb5TXVjx07xrve9Ufs3buPf/mXf+Tb3/4WN998K/v3X8yf/MkHMM25KLvrrjv59Kc/\nxyWX7OFTn/os9933Oa699rq6K7nN96//+k/cc8+99PT08A//8C0A/uIv/pwPfejP2blzN9/+9v/i\n+PFj3HXXn/P+99/Gvn37+eEPv8c999zNu971R4yPj/P5z38Z0zR561vfyK23fpBdu3Zz//3/ky9/\n+YtcccWVdHZ28V/+y58yMHCMYrFwwX/2rRvoMjBOCCEW9Zp9r1y0Nt3I+Uyb+rznvYBPfeoTTE1N\n8dhjj/Ce97yP0dER/u7vvsr3v/9vRKNxfN+v+9kTJ45z4MClAHR1dbFrV3kVtd7ePr7whc8RCoXI\n53PEYvG6n5+YmCAWi9Hb2wvAVVc9i3vv/RTXXnvdkiu53X77n/HpT/8l4+NjvOAF1wKQyYyxc+du\nAH71V38NgNHREfbt2w/AlVdezV//9ScB2LJl6+yFxYkTA3z0o3cC5YVeduzYyQte8EJOnjzJLbe8\nF9M0ufHGN6/sB7sMLTkoDsp96EIIIdYXXde54YaX8dGP3smLXnQ9mqbxta99icsvv4Lbbvszbrjh\npagGg5p3794zu2La1NQUJ0+eAOATn/gob37z27j11g+yd+++2VXadH1uZTcoXwTkcjnGxkYBeOSR\nh2aXPl1sJTfXdfnud/+VP/3TP+cv//Kv+c537mdwcJCenj5OnToJwFe+8kW+//1/p7e3j6NHn549\n/szysLo+F6c7duzitts+xD33fIa3ve2dvPCF5UVdenp6+djHPsmNN76Je+/9q5X+aJfUsjV0P/AJ\nVICutew1iRBCtKVf+ZVX8Vu/9Ru84x3/PwAvfOGL+PjH/4Lvf//f2bNnL9FoFNetbmXVNI3/9J+u\n56GHHuT3f/8N9Pb20d3dA8Av/dIruO22m+nv38SBA5fOBvbll1/JRz5yB3/yJx+Ybba/+eZbufXW\n92HbJpFIjFtv/SBHjz6
|
||
|
"text/plain": [
|
||
|
"<matplotlib.figure.Figure at 0x10a3f2518>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"def plot_with_err(x, data, **kwargs):\n",
|
||
|
" mu, std = data.mean(1), data.std(1)\n",
|
||
|
" lines = plt.plot(x, mu, '-', **kwargs)\n",
|
||
|
" plt.fill_between(x, mu - std, mu + std, edgecolor='none',\n",
|
||
|
" facecolor=lines[0].get_color(), alpha=0.2)\n",
|
||
|
"\n",
|
||
|
"plot_with_err(degree, val_train, label='training scores')\n",
|
||
|
"plot_with_err(degree, val_test, label='validation scores')\n",
|
||
|
"plt.xlabel('degree'); plt.ylabel('rms error')\n",
|
||
|
"plt.legend();"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Notice the trend here, which is common for this type of plot.\n",
|
||
|
"\n",
|
||
|
"1. For a small model complexity, the training error and validation error are very similar. This indicates that the model is **under-fitting** the data: it doesn't have enough complexity to represent the data. Another way of putting it is that this is a **high-bias** model.\n",
|
||
|
"\n",
|
||
|
"2. As the model complexity grows, the training and validation scores diverge. This indicates that the model is **over-fitting** the data: it has so much flexibility, that it fits the noise rather than the underlying trend. Another way of putting it is that this is a **high-variance** model.\n",
|
||
|
"\n",
|
||
|
"3. Note that the training score (nearly) always improves with model complexity. This is because a more complicated model can fit the noise better, so the model improves. The validation data generally has a sweet spot, which here is around 5 terms.\n",
|
||
|
"\n",
|
||
|
"Here's our best-fit model according to the cross-validation:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 26,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAFVCAYAAAA+OJwpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xd8W/W9//GXPOU9Ejt7mBBOFjMhicMKlJ0wAoWkzHTQ\nC73t7Xo0nb+2ty1t4d7b3pb2dtFiCLeYthAumwAtCQ1JIEDI5JuEDLJjO962bNnW7w/ZwXY8JFnS\nOZLez8cj4HGO9PnqyPqc73b5fD5ERETEXkl2ByAiIiJKyCIiIo6ghCwiIuIASsgiIiIOoIQsIiLi\nAErIIiIiDpASyEGWZc0BfmqMudiyrLOAXwLtQAtwhzHmWARjFBERiXuD1pAty1oG/AFI7/zRfwOf\nN8ZcDDwJfD1y4YmIiCSGQJqsdwE3AK7O75cYYzZ1fp0KNEciMBERkUQyaEI2xjwJtHX7/giAZVnz\ngH8Ffh6x6ERERBJEQH3IvVmWtRj4FnC1MaZqsON9Pp/P5XINdpiIiEi8CDrpBZ2QLcu6DfgsMN8Y\nUx1QVC4XFRX1wT5VXCgqyknYsoPKr/Kr/Ila/kQuO/jLH6xgpj35LMtKAn4BZANPWpb1D8uyvh/0\ns4qIiEgPAdWQjTF7gXmd3w6LWDQiIiIJSguDiIiIOIASsoiIiAMoIYuIiDiAErKIiIgDKCGLiIg4\ngBKyiIiIAyghi4iIOEBIS2eKiIj9PB4P5eWrAViy5ELcbrfNEclQKCGLiMQgj8fD4sUrWLv2kwCs\nWPEQjz++SEk5hqnJWkQkBpWXr+5MxqlAKmvXLj1RW5bYpIQsIiLiAErIIiIxaMmSCyktfQhoBVop\nLS1jyZILexzj8XgoK1tJWdlKPB6PLXFK4NSHLCISg9xuN48/vojy8mcAWLKkZ/+x+phjjxKyiEiM\ncrvdLF16eZ+/69nHTGcf8zP9Hi/2U0IWEemDphRJtKkPWUSkl67m3mXLrmXZsmtZvHhFzPXBBtLH\nLM6iGrKISC/x0Nw7WB+zOI8SskiCU9Ns/Bqoj1mcR03WIgksHppmI0HNvWIH1ZBFElg8NM1Ggpp7\nxQ5KyCIifVBzr0SbmqxFEpiaZkWcQzVkkQSmplkR51BCFklwapoVcQYlZBERG4Uy7UxT1eKTErKI\niE1C2QCi9zlPPPF7rr12GKmpqUrOMS6ghGxZ1hzgp8aYiy3LOhUoAzqALcC/GmN8kQtRRCQ+hTLt\nrOc5HtavT2f9+hsB7egU6wYdZW1Z1jLgD0B6549+BnzLGHMh4AKui1x4IiKB7eubmHv/vgrciT85\np3Ym9NU2xyShCmTa0y7gBvzJF+AcY0zXFX8BuDQSgYmIQGCricXqimOhTDvreY43ClFKtAyakI0x\nTwJt3X7k6vZ1A5AX7qBERLr0bKLtuxYYyDHgvFp017Sz++9/hvvvfyag5ubu59x7bzNz5/4RzSOP\nD6EM6uro9nUOUBPISUVFOSE8VXxI5LJD+Mvv/1B9FYClSz/m+P4yXf+hlT8n5+Trm5Pj7vG4gRzj\n8Xi46aa/smrV7QA899xyXnzxlj7fP+F8jw1e/hy+9rUbg3pMjye1s8xuXnjhKsrLXwZg6dI7esRq\n999Kor/3gxVKQn7XsqyLjDGrgKvwd2IMqqKiPoSnin1FRTkJW3YIf/l7jzB95BFnD2LR9R96+Rcs\nmE1p6UOsXbsUgNLSMhYsWNTjcQM5pqxsZWcy9g+gWrXqNh544OQBVOF8j0Xi+g8UX329l/p676DH\nRYPe+8HfjASzdGbXSOqvAv9uWdYb+BP634J+VpEQBdo0Kc4VbLNxIM26oTT99mf58lcd/R4L9G9A\nfyuxJ6AasjFmLzCv8+udwPzIhSQi8SqUebcQ2Gpigx2zZMmFrFjRsxa9ZMmik+J78MGNQHBNyCLh\noM0lJKYMZTMEpw3oSUTRqLX1d50DqUWXl69mz55vA8vpeo+VlPzcUQOlAv0b0MYhsUcrdUlMCXUz\nhFBrZhJbBrvOga3b7QZuAV4GvHzmM+Md9T4J9G9goOMitfRm98f9whcWhOUxE4nL54vKIlu+RO3c\n18AGZ5S/rGwly5ZdS9eAHmjl/vsHXhEpHJxSfrv0Lv9HCXMp4G82DueN0VCvc7jjc+L1733TUloa\nnpvT3o970UXLWb78GkfdzERTUVGOa/CjelINWUSiJhLbPXavlXm9Q1soIxG2owxluc5QHnfVqtvC\n8riJRAlZEkIgA3okOsK53WPPWpmHCRP+nYkT97B371cBKCn5Oddff4tt8YkEQ4O6JCGEc1qM2Kf3\ngK2PamXtwF/Zt+8n7N37edLTvwQ8w549n+POO1/QIL5uIjXYq/fjXnTRoxpEFiT1IUeYE/uQoknl\nV/nDVf6++j4XLszh29/+OLASuJzufcf+QVkLiNZ4gb449fpHa1BX1yIliUh9yCIJLN43re+r73Ph\nwic7V+gqDutzRfu17Ojw0d7hA3z4fB+twpTkcpGS7MLlCvqzfUCRapbv/rhdK4dJ4JSQReJAvE/r\n8ng8rF27Df9H1hX4pyZBamoqjz++iOXLX+HBB3/Gnj1fBiA3937q6v6Nj5pkAx8vMJTXssXbzvE6\nD9X1LdQ1tUJSBYeP1VPf7KWhqZWGZi/Nre20ettp8bbT0tpOi7eDtvaOfh/T5YL01GTSUpNJT00i\nPTWFnMxUcrPSyM1MIzcrlfzsdIryMyguyCAvKy3sCRzi/4bPCZSQReJApEbO9ieaH84fJchvdP7k\nYeBmSkvLT4yCvuuuhdx0Uw1f//p9APzgB0t44YWXO+ML7sZkoNey0ePl6PFmjlU3UVnr4Xh9C8fr\nPByva6G63kOjp23gBwfSUpJIT0smPTWZ/Oz0E4k2Odk/pMfV+R8XLjo6Omhp66C1td3/f287Dc3N\nHKho6P/xU5Mozs9g9PAsJozIYdyIbMaPyCE3My3g16C3eL/hcwolZBEJSrQ/nHsnSLiDRYvu4xe/\nuLvHQhd33vnCiaR95MjAMQ10Q9HWAfkja8jK95CZ30hWQT3vViXz9n+v7jfhpqclU5iTzsRRuRTm\npFOQk05eVhqjR+bi87aTnZlKTmYaWe4UUpKHPpa2rb2DusZW6ppaqWts5Xh9CxXVzRyraaaiupmj\n1U0cqGjkze3HTpxTmJvO5LH5nDYuH2tcPqOGZQ5ak/5o4Nw21q79EtG64UtUSsgicSCa07qiXRvv\nS2nptB5JNJiYum4o1r91BzmFDTz/z+e57uapHK32cKCigZqGFM6/5Y0e5zS1J1GUk8akMXmMKMik\nuCCDonw3hTluCnPTyUhP6TO5RWpQV0pyEoW5bgpz+946cvHiFWzcupi84jqmnbOa2ReM48OjDazf\ndpT1244CkJuVxhmThnHmpOFMLynAnZbS5+P4X9dr8bdM3EZXd4GEnxKySByI5wUthnqz0d7RwdHj\nzRysbOTAsQbeeHs37iklXDX3Jfw5NI9X3zkEQEFOOjNOKaQ4L50Pdx3AnQI3X38uo4pySU6KjVmi\n/puTTwCraa6DI7uu4cpZL/PlL1zGkeNNmP017Piwhm17j/PPTYf556bDpCS7mDaxkHMmF7Lzne0k\nJ4HX23pSywQ8D1ytefwRooQsEieitaDFRwlyCfB3SkrWcf31/zroeaH2Owdys9EV08Yti8kZXs/0\nc16jOb+E7/3pTQ5XNdLW3n16ZxKpbi/HDwyjviqXuspMbr/5Pe6642Iy3akfHXbF1IDicxqvtxV4\nHLiz8ycP4/Vm4XK5GDUsi1HDsph/1hg6fD72HK7jvV1VvLerkk0fVLHpgyravOkc3TUCb9U6oBl/\nQvYAL3LmmSu5+eZmbr89fm74nETzkCPMqfMQo0Xlj8/y19TUcMUVf2bPnq8A/a+H3FX+cK+fXNfU\nyp6D1Ty7chONbS4y8vM4XNWEp7W9x3FpqUmMGZ7FmOHZjCnKYmxRNsNzU/jsp56L2Hra3UX7+ns8\nHj7/+V/z9NPfpvuc7Hv
|
||
|
"text/plain": [
|
||
|
"<matplotlib.figure.Figure at 0x10a2f5ba8>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model = PolynomialRegression(4).fit(X, y)\n",
|
||
|
"plt.scatter(X, y)\n",
|
||
|
"plt.plot(X_test, model.predict(X_test));"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Detecting Data Sufficiency with Learning Curves\n",
|
||
|
"\n",
|
||
|
"As you might guess, the exact turning-point of the tradeoff between bias and variance is highly dependent on the number of training points used. Here we'll illustrate the use of *learning curves*, which display this property.\n",
|
||
|
"\n",
|
||
|
"The idea is to plot the mean-squared-error for the training and test set as a function of *Number of Training Points*"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 27,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.learning_curve import learning_curve\n",
|
||
|
"\n",
|
||
|
"def plot_learning_curve(degree=3):\n",
|
||
|
" train_sizes = np.linspace(0.05, 1, 20)\n",
|
||
|
" N_train, val_train, val_test = learning_curve(PolynomialRegression(degree),\n",
|
||
|
" X, y, train_sizes, cv=5,\n",
|
||
|
" scoring=rms_error)\n",
|
||
|
" plot_with_err(N_train, val_train, label='training scores')\n",
|
||
|
" plot_with_err(N_train, val_test, label='validation scores')\n",
|
||
|
" plt.xlabel('Training Set Size'); plt.ylabel('rms error')\n",
|
||
|
" plt.ylim(0, 3)\n",
|
||
|
" plt.xlim(5, 80)\n",
|
||
|
" plt.legend()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Let's see what the learning curves look like for a linear model:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 28,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfUAAAFkCAYAAAA5cqL3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3XmAJGVh//93VfU903PsHLsLLLvAQi3IvQgEkUPwazQY\nj0SNMaJRI0ZjNGpE5Asag+H7M0Zj8Igg/EAlmkjQBIiamKiAXwG5zy1Y2OVednZ2dq4+6/j+Ud09\nPbMzuz273dM93Z+XDtNV1V391OxMf57nqarnMYIgQERERJY/s9kFEBERkfpQqIuIiLQJhbqIiEib\nUKiLiIi0CYW6iIhIm1Coi4iItIlIo3Zs27YFXAUcAQTABxzHeaRq++uBSwAXuMZxnG81qiwiIiKd\noJEt9fMA33Gc04H/DXy+vMG27SjwJeDVwJnA+23bHm5gWURERNpew0LdcZx/Ay4oLa4Dxqo2Hwls\ndhxn3HGcInA7cEajyiIiItIJGtb9DuA4jmfb9rXAm4Dfr9rUA4xXLU8CvY0si4iISLtraKgDOI7z\nbtu2LwTutG37SMdxsoSBnq56WprZLfnd7JweD5KxOFEzQsRqeLFFRERagbGYJzfyQrl3Agc5jnM5\nkAV8wgvmADYBh9u23Q9ME3a9/+2e9rdl2wuVx4ZhEDEjYcCbESJGpLJsmVZdj2NoKM3IyGRd99nK\nOul4O+lYobOOt5OOFTrreDvpWCE83sVoZJP3BuBa27Z/CUSBjwBvsm2723Gcq2zb/hjwU8Lz+lc7\njvNirTsOgoCiV6ToFXfbZhgmUdMiYkaJmFYp7KNEDKvugS8iItJKGhbqpW72t+1h+83AzfV+3yDw\nKXg+hXkC3zTMsGVf1covfzcN3bIvIiLLW0ednPYDn4JXoOAVdtu2ItlPd7SrCaUSERGpDzVPSzLF\nTLOLICIisl8U6iU5L4/ne80uhoiIyD5TqJcFkHVzzS6FiIjIPlOoV8m46oIXEZHlS6FeRV3wIiKy\nnCnUq6kLXkRkvxQKBW6++Uc1P//HP76Z22+/dcHt3/3utTz22CMLbpfZOuqWtlpMuxm6Y7q1TUSW\nv3/5n838ZtP2uu7z5RuGeeur1i+4fXR0Bzfd9G+cd94ba9rfa1973h63/9EfvXsxxet4CvU58qUu\neI0+JyKyeN/+9jVs3foU1177LXzf56GHHiCXy/KpT13Kj398M47zGOPj46xffzif/vRnuPrqbzIw\nMMjatev47nevIxaL8sILz3POOf+L889/D5///Gc599zXMDq6g1//+leAx5YtW3nHO97Fa197Ho8+\n+jBf/vIXSKW66OvrJx6P8+lPf6ZSnmeeeZrLL/8rLCtCEAR85jOXMTQ0zJe//AUee+xRXLfIe997\nAaeffiZXXPFlHnroAQBe/erf5i1v+QM+//nPMjExzsTEBF/4wt9z/fXX8eCD9+P7Pm972x9y9tnn\ncuONP+AnP7kF0zTZsOEoPvrRTzTpp69Q310AGTdLOtbd7JKIiOyXt75q/R5b1Y3wrne9l6eeepJ3\nv/t9XHPNlRxyyKH8+Z9/nExmmp6eHr785a/h+z7nn/82duwYwTBm5it56aVtfPvb36dQKPDGN/42\n55//nsp2wzCYnp7mO9+5lvvue5QLL/wLXvva8/jiFy/n0ksvY926Q7jyyq+zY8fIrPLcffddHHXU\nMfzpn36YBx+8n6mpKR577FHGx8e56qrrmJyc5J//+XpM02Lbthe48sprcV2XD37wfWzceBKGYbBx\n48m89a1v59e//hUvvvgCX//6t8jn83zgA3/My19+Kj/+8U18/OMXsWHDkfzoRzfgeR6W1ZyGoUJ9\nHgp1EZF9EwTBrOU1a9YCEIvFGRsb47OfvZhkMkUmk8F13VnPPeywwzBNk0QiQTwe323fhx9+BABD\nQ8MUCuHIoKOjO1i37hAAjjvuBP77v/9z1mvOO+8NXH/9dXz8439Od3cXF1zwIZ599mmOPvpYANLp\nNO973wf4p3/6DscddwIAkUiEl73sGLZs2QLAwQeHx/DUU5txnE18+MMXAOB5Hi+++AIXXfQZvv/9\n7/LCC89z9NHH7vYzWEq6UG4eeV0FLyKyTwzDxPd9IAz4ckv7jjt+xfbt2/jsZz/P+9//QQqF/Dzh\nt+dZRqtb9WXDwyvZujUM34cffnC37bfd9kuOO+4EvvKVr3PWWefw3e9ex7p1h7BpU3jx3dTUFJ/4\nxJ+zbt0hPPjg/QC4rsvDDz/AmjVrZr3v2rWHcOKJG7niim/y5S9/jbPPPpcDDzyIm276EZ/4xEV8\n9atX8vjjDo888lCNP636U0t9PuqCFxHZJytWrMB1i3zjG1cQj8crgXjUUUdz3XVX8+d//gFWrBjg\nqKOOrnSVV3exz9g9wOfb/vGPf4rLL/8cyWSSaDTK4ODwrNds2HAkn//8Z4lGo3iex0c+8nEOP9zm\n7rvv4oMffB+e5/Ge97yfU075Le677x4+8IH3UCwWOeecV3PEERtmve/pp5/Bfffdw4c+9CdksxnO\nOONsUqkUhx12GB/60PtIpboYGhrmqKOOrsePcp8YzewmWIx7nnxsSQsaj8RZmRrqyLl7O+V4O+lY\nobOOt5OOFTrreOce6403/oBXverV9PX1cdVV3yAajfLud7+viSWsr6Gh9J67L+ZQS30B6oIXEWl9\nK1as4GMf+xDJZIru7m4uvvivml2kplKoL6TUBQ99zS6JiIgs4KyzzuGss85pdjFahi6U2wNNxyoi\nIsuJQn0P8n4BV13wIiKyTCjU9ySA6cJ0s0shIiJSE4X6XkwX1AUvIiLLg0J9L3JuQVfBi4g0wIc/\nfAHPPLN1wZnazj//bXt8/S9/+XN27NjBzp2j/N3f/X+NKuayoqvf9yrQQDQisizduPlm7tte39HN\nThg+hjev3/PMaotj7HWmtoXccMP3OeSQQzj44HV8/OMX1rFMy5dCvQbTxYxCXUSkBhdf/Je85S1v\n5/jjT2TTpke57rqrueSSz3H55X/N9PQUO3aM8OY3v4U3vvH3S68IKjO1veENb+YLX/gbnnzyCYaH\nVzI9HV7T9NRTm/nqV/8ez/OZnp7gox/9JBMTEzzxxONcdtlnueSSz3HZZZ/hm9/8//nNb+7gqqv+\nkVgsRm9vLxdd9Bkef3wT11//7d1mgKv2zW9+jfvvvwfX9TjrrFfxjne8i0ceeZgrrvgSvu8zNDTE\npZdextNPb+Hv//6LmKZJLBbnwgsvxvd9LrzwL+jt7eO3fusVnHLKaXzlK18kCIJSGS6lUCjymc9c\nRBAEFAoFPvGJiypj2deTQr0GBa+A67tETP24RGT5ePP68+rcqt6717/+Tfz4xzdz/PEncsstN/G7\nv/smnnvuWc499zWceebZ7Ngxwp/92QVVoT4zDOutt/6cfD7HlVdey65du/iDPwjnZN+yZQt/9mcf\n5dBD13Pnnb/klltu4sILL+bww4/gL//y00QiM5/NX/jC5XzjG1czODjID37wfa677mpOO+30eWeA\nq/azn/2UK664koGBAf7jP24C4G//9m/43Of+hoMPXsctt/w7Tz+9hS984W/41KcuYf36w7n99l9y\nxRVf5s/+7KPs3LmTa665nkgkwvvf/24uvvizrF27jptv/jeuv/7bHHPMsfT29vG///dfsXXrFnK5\nbEN+/kqpGmXcLD2xdLOLISLS0k4++VS+/vWvMDExwYMP3s/HPvZJduwY4V/+5Xvceuv/kEp143nz\nX6f0zDNPs2HDUQD09fWxdm04+9rg4BDXXns18XgczysQiew+gxvArl276OrqYnBwEIDjjjueK6/8\nOqeddvpeZ4C79NK/5hvf+Ad27hzl1FNPA2BsbJSDD14HwO/8zu8CsGPHCOvXHw7AsceewD/+41cB\nWL36gErl4plntvLFL14OhJPDrFlzMKee+gqeffZZLrro40QiEc4//72L+8HWSBfK1ShTbEytSkSk\nnZimydlnn8sXv3g5Z5x
|
||
|
"text/plain": [
|
||
|
"<matplotlib.figure.Figure at 0x10a42d400>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"plot_learning_curve(1)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"This shows a typical learning curve: for very few training points, there is a large separation between the training and test error, which indicates **over-fitting**. Given the same model, for a large number of training points, the training and testing errors converge, which indicates potential **under-fitting**.\n",
|
||
|
"\n",
|
||
|
"As you add more data points, the training error will never increase, and the testing error will never decrease (why do you think this is?)\n",
|
||
|
"\n",
|
||
|
"It is easy to see that, in this plot, if you'd like to reduce the MSE down to the nominal value of 1.0 (which is the magnitude of the scatter we put in when constructing the data), then adding more samples will *never* get you there. For $d=1$, the two curves have converged and cannot move lower. What about for a larger value of $d$?"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 29,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfUAAAFkCAYAAAA5cqL3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3XmcZFV9///X3Wqv7q6e7llgVga4gCibCypBQHGFRE3U\nqIkxasRojIkaEfiCxmjIjxiNwSVCMGJEjRo0AcQtIqhxF0S2i8AMszBL90x3177c5fdHLV3VS3VV\nd+31eT4e0FV1b1WdW91T73vOPYvieR5CCCGE6H9qtwsghBBCiNaQUBdCCCEGhIS6EEIIMSAk1IUQ\nQogBIaEuhBBCDAgJdSGEEGJA6O16YdM0NeB64ETAA95iWdb9VdsvBq4EbOAzlmX9W7vKIoQQQgyD\ndtbULwJcy7LOAf4f8KHyBtM0DeAjwIXAc4A3m6a5vo1lEUIIIQZe20Ldsqz/Bi4p3d0OzFRtPhl4\nxLKsOcuyCsAPgXPbVRYhhBBiGLSt+R3AsizHNM3PAi8D/qBq0wgwV3U/AYy2syxCCCHEoGtrqANY\nlvV60zQvBX5qmubJlmVlKAZ6tGq3KLU1+UUentrl6Vrn+vV99f5v8Isn7uWvn/kmNkQmVtw/6o8w\nGV7XgZIJIYQYIkozO7ezo9wfA5sty7oayAAuxQ5zAA8BJ5imGQNSFJve/7He6yXi2XYVdUkRtXjO\nMZU+gq8QWnH/jO5C2tfuYrXd5GSUqalEt4vREcN0rDBcxztMxwrDdbzDdKxQPN5mtLPq+1XgdNM0\n7wS+CbwDeJlpmn9Wuo7+TuBbwP8BN1iWdaCNZWnaeCAGwHTqaEP7u57TzuIIIYQQK2pbTb3UzP6q\nOttvBW5t1/uvVTnUp9JHYWzl/R3PbXOJhBBCiPpk8plljPlGUBWVqQZr6o7U1IUQQnSZhPoyNFVj\nzDfCdPpIY0/wwHEl2IUQQnSPhHod44EY6UKWdCHT0P5SWxdCCNFNEup1lK+rH83WHW1XIaEuhBCi\nmyTU6xgPFHvIHc01GOqudJYTQgjRPRLqdczX1Gcb2t/27HYWRwghhKhLQr2OppvfpaYuhBhy+Xye\nW2/9esP73377rfzwh3ctu/3zn/8sDz54/7LbRa22TxPbz0J6kIDu54hcUxdC9KEvf+8Rfv7Q4Za+\n5tNOWs8rLzh+2e1Hjkxzyy3/zUUXvbSh13vRiy6qu/2P/uj1zRRv6Emo16EoCpPhdeyPH8T1XFSl\nfsOGzConhBh2n/vcZ9i9+zE++9l/w3VdfvObX5PNZnjve6/i9ttvxbIeZG5ujuOPP4HLL38fN9zw\nadatm2Dbtu18/vM34vMZPPHEfp773Ofzute9gQ996P0873kv4MiRaX784x8BDrt27ea1r/0TXvSi\ni3jggfv46EevIRQKMzYWw+/3c/nl76uUZ8+ex7n66r9F03Q8z+N97/sgk5Pr+ehHr+HBBx/Atgu8\n8Y2XcM45z+Haaz/Kb37zawAuvPCFvOIVf8iHPvR+4vE54vE411zzz9x0043ce+89uK7Lq171Gs4/\n/3ncfPNX+OY3b0NVVU466RT+6q/e3aVPX0J9RZOhcfbOPcFsLl7pOLccmVVOCNFLXnnB8XVr1e3w\nJ3/yRh577FFe//o38ZnPXMeOHcfxl3/5LtLpFCMjI3z0o5/AdV1e97pXMT09haLMr1dy6NBBPve5\nL5HP53npS1/I6173hsp2RVFIpVL8x398lrvvfoBLL/1rXvSii/jwh6/mqqs+yPbtO7juuk8yPT1V\nU55f/OJnnHLKk/nzP3879957D8lkkgcffIC5uTmuv/5GEokE//mfN6GqGgcPPsF1130W27Z561vf\nxFlnPRVFUTjrrKfzyle+mh//+EccOPAEn/zkv5HL5XjLW/6Upz3tbG6//Rbe9a7LOOmkk/n617+K\n4zhomtbRz71MQn0Fk+FxoHhdfeVQl5q6EGK4eZ5Xc3/Llm0A+Hx+ZmZmeP/7ryAYDJFOp7Ht2s7F\nO3fuRFVVAoEAfr9/0WufcMKJAExOriefzwPF5v7t23cAcNppZ/C///vtmudcdNHvcdNNN/Kud/0l\nkUiYSy55G3v3Ps6ppz4FgGg0ypve9Ba+8IX/4LTTzgBA13We9KQns2vXLgC2bi0ew2OPPYJlPcTb\n334JAI7jcODAE1x22fv40pc+zxNP7OfUU5+y6DPoJOkot4KJ0Hyor0hmlRNCDDlFUXFLnYY9z6vU\ntH/ykx9x+PBB3v/+D/HmN7+VfD63RPjVX2W0ulZftn79BnbvLobvfffdu2j7D35wJ6eddgYf+9gn\nOe+85/L5z9/I9u07eOihYue7ZDLJu9/9l2zfvoN7770HANu2ue++X7Nly5aa9922bQdnnnkW1177\naT760U9w/vnP49hjN3PLLV/n3e++jI9//Doeftji/vt/0+Cn1XpSU19BeY30Ziag0ehOs4sQQnTb\n+Pg4tl3gU5+6Fr/fXwnEU045lRtvvIG//Mu3MD6+jlNOObXSVF7dxD5vcYAvtf1d73ovV1/9AYLB\nIIZhMDGxvuY5J510Mh/60PsxDAPHcXjHO97FCSeY/OIXP+Otb30TjuPwhje8mWc845ncffcvectb\n3kChUOC5z72QE088qeZ9zznnXO6++5e87W1/RiaT5txzzycUCrFz507e9rY3EQqFmZxczymnnNqK\nj3JVlG42EzTjl48+2JWCRkZ8XPm9D7MlcgyvOekPVtx/MrSOoB7sQMnaY5jWKh6mY4XhOt5hOlYY\nruNdeKw33/wVLrjgQsbGxrj++k9hGAavf/2buljC1pqcjNZvvlhAauorMDSdMf9IwxPQyFh1IYTo\nnPHxcd75zrcRDIaIRCJcccXfdrtIXSWh3oBxf4zH4o+TtXME9MWdN6pJZzkhhOic8857Lued99xu\nF6NnSEe5BjQzs5yEuhBCiG6RUG9AM6FuS+93IYQQXSKh3oDKam0NhLrMKieEEKJbJNQbUK6pH2lg\nCVaZVU4IIUS3SKg3IGKE8alGQz3g5Zq6EEI05u1vv4Q9e3Yvu1Lb6173qrrPv/POO5ienubo0SP8\n0z/9f+0qZl+R3u8NUBSF8UCMqcyRlRd2Kc0qp6kyAY0QortufuRW7j7c2tnNzlj/ZF5+fP2V1Zqj\nrLhS23K++tUvsWPHDrZu3c673nVpC8vUvyTUGzQeiHEwfZh4PsGYf7TuvjKrnBBiWF1xxd/wile8\nmtNPP5OHHnqAG2+8gSuv/ABXX/13pFJJpqenePnLX8FLX1qezMurrNT2e7/3cq655u959NHfsn79\nBlKpFFCcc/3jH/9nHMcllYrzV3/1HuLxOL/97cN88IPv58orP8AHP/g+Pv3pf+fnP/8J11//r/h8\nPkZHR7nssvfx8MMPcdNNn1u0Aly1T3/6E9xzzy+xbYfzzruA1772T7j//vu49tqP4Louk5OTXHXV\nB3n88V388z9/GFVV8fn8XHrpFbiuy6WX/jWjo2M885nP5hnPeBYf+9iH8TyvVIaryOcLvO99l+F5\nHvl8nne/+7LKXPatJKHeoOrOco2EuhBCdNvLj7+oxbXqlV188cu4/fZbOf30M7nttlv43d99Gfv2\n7eV5z3sBz3nO+UxPT/EXf3FJVajPT8N61113kMtlue66zzI7O8sf/mFxTfZdu3bxF3/xVxx33PH8\n9Kd3ctttt3DppVdwwgkn8jd/czm6Ph9l11xzNZ/61A1MTEzwla98iRtvvIFnPeucJVeAq/bd736L\na6+9jnXr1vGNb9wCwD/+49/zgQ/8PVu3bue22/6Hxx/fxTXX/D3vfe+VHH/8Cfzwh3dy7bUf5S/+\n4q84evQon/nMTei6zpvf/HquuOL9bNu2nVtv/W9uuulzPPnJT2F0dIz/9//+lt27d5HNZtry+Uuo\nN2h+WNssx9XPdJlVTggxtJ7+9LP55Cc/Rjwe59577+Gd73wP09NTfPnLX+Suu75HKBTBcZau+OzZ\n8zgnnXQKAGNjY2zbVlx
|
||
|
"text/plain": [
|
||
|
"<matplotlib.figure.Figure at 0x10a51a9b0>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"plot_learning_curve(3)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Here we see that by adding more model complexity, we've managed to lower the level of convergence to an rms error of 1.0!\n",
|
||
|
"\n",
|
||
|
"What if we get even more complex?"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 30,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfUAAAFkCAYAAAA5cqL3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3XmcXFWd///XvXVr7+ru6i0LCUnIciEEkgACAoOsrqiI\nI24jOm447qOOCHwBx2WYHzqig8sI4ogjghuggMCIIoijyJIECHCzkH3tfan9Lr8/blV1ddLdqe6u\nvT/Px6PTVXWrbp3b3an3Pcs9R3EcByGEEELUP7XaBRBCCCFEaUioCyGEEA1CQl0IIYRoEBLqQggh\nRIOQUBdCCCEahIS6EEII0SC0cu1Y13UPcAuwAnCAjxiGsbFg+xuBawAT+KFhGD8oV1mEEEKI2aCc\nNfWLANswjLOA/wd8NbdB13Uv8A3gQuBVwId1Xe8qY1mEEEKIhle2UDcM49fA5dm7i4H+gs3HAVsM\nwxg0DCMDPA6cXa6yCCGEELNB2ZrfAQzDsHRd/xHwFuDvCzY1A4MF94eBlnKWRQghhGh0ZR8oZxjG\n+3D71W/RdT2YfXgQiBQ8LcLYmvxhTNNycPvm5asOv37x/H3OpT/7J2f9vheqXhb5ki/5kq86+pqS\ncg6Uew+wwDCM64EEYDNawJeA5bquR4EYbtP71ybbX39/vFxFnVRnZ4Tu7uGqvHc1lOt4Y/E0AIOD\ncbq12vh5yu+2cc2mY4XZdbyz6VjBPd6pKGdN/ZfAGl3XHwUeBD4FvEXX9Q9l+9E/AzwE/B9wq2EY\n+8pYFiGEEKLhla2mbhhGAnj7JNvvA+4r1/uL2qJkv0+5LUkIIUTRZPIZIYQQokFIqAshhBANQkJd\nCCGEaBAS6qLCpFddCCHKRUJdVIQif2pCCFF28kkrhBBCNAgJdVFRjiPN70I0snQ6zX333VP08x94\n4D4ef/yxCbf/5Cc/4sUXN064XYxV1rnfhRBCVM/P/7CFJ186WNJ9vuLYLi49b9mE23t7e7j33l9z\n0UUXF7W/173uokm3/8M/vG8qxZv1JNSFEEKUzI9//EO2b3+ZH/3oB9i2zXPPbSCZTPCFL1zLAw/c\nh2G8yODgIMuWLeeqq67j1lu/T3t7B4sWLeYnP7kNn8/L3r17OP/8V3PZZe/nq1/9Ihdc8Bp6e3v4\ny1/+DFhs27add7/7vbzudRfxwgvPc+ONNxAKhWltjeL3+7nqquvy5dm5cwfXX/+veDwajuNw3XVf\nobOzixtvvIEXX3wB08zwgQ9czllnvYqbbrqR557bAMCFF76Wt73tHXz1q19kaGiQoaEhbrjhm9x+\n+208++x6bNvm7W9/F+eeewF33fULHnzwflRV5dhjV/LpT3+uSj99CXVRIcqRnyKEKLFLz1s2aa26\nHN773g/w8stbed/7PsgPf3gzS5Ycwyc/+Vni8RjNzc3ceON3sG2byy57Oz093SjK6KfDgQP7+fGP\n7ySdTnPxxa/lssven9+uKAqxWIz/+Z8fsW7dC1xxxT/zutddxNe/fj3XXvsVFi9ews03f5eenu4x\n5Xnqqb+xcuUJ/NM/fYJnn13PyMgIL774AoODg9xyy20MDw/zs5/djqp62L9/Lzff/CNM0+SjH/0g\nJ598CoqicPLJp3Lppe/kL3/5M/v27eW73/0BqVSKj3zkH3nFK07ngQfu5bOfvZJjjz2Oe+75JZZl\n4fF4Kvpzz5FQF0IIUTKHjptZuHARAD6fn/7+fr74xasJBkPE43FM0xzz3KVLl6KqKoFAAL/ff9i+\nly9fAUBnZxfptLtIVG9vD4sXLwFg9eq1/P73/zvmNRdd9GZuv/02PvvZT9LUFObyyz/Grl07WLXq\nRAAikQgf/OBH+OlP/4fVq9cCoGkaxx9/Atu2bQPg6KPdY3j55S0Yxkt84hOXA2BZFvv27eXKK6/j\nzjt/wt69e1i16sSqjh2SgXKiIgrPxoUQjUtRVGzbBtyAz/3f/+tf/8zBg/v54he/yoc//FHS6dQ4\n4Tf558R4nyNdXXPYvt0N3+eff/aw7X/606OsXr2Wb33ru5xzzvn85Ce3sXjxEl56yR18NzIywuc+\n90kWL17Cs8+uB8A0TZ5/fgMLFy4c876LFi3hpJNO5qabvs+NN36Hc8+9gKOOWsC9997D5z53Jd/+\n9s1s2mSwceNzRf60Sk9q6kIIIUqmra0N08zwve/dhN/vzwfiypWruO22W/nkJz9CW1s7K1euyjeV\nFzaxjzo8wMfb/tnPfoHrr/8SwWAQr9dLR0fXmNcce+xxfPWrX8Tr9WJZFp/61GdZvlznqaf+xkc/\n+kEsy+L97/8wp532State5qPfOT9ZDIZzj//QlasOHbM+5511tmsW/c0H/vYh0gk4px99rmEQiGW\nLl3Kxz72QUKhMJ2dXaxcuaoUP8ppUerlEqPu7uGqFHQ2rt1bjuN9aPsf+M3LD/KRE9/HCR0rS77/\n6ZDfbeOaTccKs+t4Dz3Wu+76BeeddyGtra3ccsv38Hq9vO99H6xiCUurszMypWZOqakLIYSoW21t\nbXzmMx8jGAzR1NTE1Vf/a7WLVFUS6qJCpE9dCFF655xzPuecc361i1EzZKCcEEII0SAk1IUQQogG\nIaEuKkKuaBNCiPKTUBdCCCEahIS6EEKIqvjEJy5n587tE67Udtllb5/09Y8++gg9PT309fXyH//x\n/5WrmHVFRr+LiqqXeRGEaAR3bbmPdQdLO7vZ2q4TuGTZ5CurTY1yxJXaJvLLX97JkiVLOProxXz2\ns1eUsEz1S0JdVIQil7QJMStcffW/8La3vZM1a07ipZde4LbbbuWaa77E9dd/mVhshJ6ebi655G1c\nfPHfZ1/h5Fdqe/ObL+GGG/6NrVs309U1h1gsBrhzrn/729/EsmxisSE+/enPMzQ0xObNm/jKV77I\nNdd8ia985Tq+//3/5skn/8ott/wXPp+PlpYWrrzyOjZteonbb//xYSvAFfr+97/D+vVPY5oW55xz\nHu9+93vZuPF5brrpG9i2TWdnJ9de+xV27NjGN7/5dVRVxefzc8UVV2PbNldc8c+0tLTyyleeyWmn\nncG3vvV1HMfJluFa0ukM1113JY7jkE6n+dznrszPZV9KEupCCNGgLll2UYlr1Uf2xje+hQceuI81\na07i/vvv5U1vegu7d+/iggtew6tedS49Pd18/OOXF4T66DSsjz32CKlUkptv/hEDAwO84x3umuzb\ntm3j4x//NMccs4wnnniU+++/lyuuuJrly1fwL/9yFZo2GmU33HA93/verXR0dPCLX9zJbbfdyhln\nnDXuCnCFHn74IW666Wba29v57W/vBeBrX/s3vvSlf+Pooxdz//2/YceObdxww7/xhS9cw7Jly3n8\n8Ue56aYb+fjHP01fXx8//OHtaJrGhz/8Pq6++ossWrSY++77Nbff/mNOOOFEWlpa+X//71/Zvn0b\nyWSiLD9/CXVRUdL4LkRjO/XU0/nud7/F0NAQzz67ns985vP09HTz85/fwWOP/YFQqAnLssZ97c6d\nOzj2WHca6dbWVhYtcldf6+jo5Ec/uhW/349lpdG0w1dwAxgYGCAcDtPR0QHA6tVruPnm73LGGWcd\ncQW4a6/9Mt/73n/S19fL6aefAUB/fy9HH70YgDe84U0A9PR0s2zZcgBOPHEt//Vf3wZg3rz5+ZOL\nnTu38/WvXw+4i8MsXHg0p59+Jrt27eLKKz+LpmlcdtkHpvaDLZIMlBMVIc3vQswOqqpy7rkX8PWv\nX8/ZZ5+DoijceedPWLXqBK655suce+75OI497msXL16SX2ltaGiIXbt2AvCtb32dD3zgcq6++ous\nWLEiPzZHVUdXhAP3RCAWi9Hb2wPA+vXP5JdNnWxWy0wmwyOPPMy//uu/8Z//+V888MB97N+/n/b2\nTnbv3gXAT3/6Yx577I90dHSydeuW/P5zS8uq6micLly4iGuu+RI33fR9Lr/8Y5x5prsQTHt7B9/4\nxre57LL3c/PN35nqj7Y
|
||
|
"text/plain": [
|
||
|
"<matplotlib.figure.Figure at 0x109465e48>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"plot_learning_curve(10)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"For an even more complex model, we still converge, but the convergence only happens for *large* amounts of training data.\n",
|
||
|
"\n",
|
||
|
"So we see the following:\n",
|
||
|
"\n",
|
||
|
"- you can **cause the lines to converge** by adding more points or by simplifying the model.\n",
|
||
|
"- you can **bring the convergence error down** only by increasing the complexity of the model.\n",
|
||
|
"\n",
|
||
|
"Thus these curves can give you hints about how you might improve a sub-optimal model. If the curves are already close together, you need more model complexity. If the curves are far apart, you might also improve the model by adding more data.\n",
|
||
|
"\n",
|
||
|
"To make this more concrete, imagine some telescope data in which the results are not robust enough. You must think about whether to spend your valuable telescope time observing *more objects* to get a larger training set, or *more attributes of each object* in order to improve the model. The answer to this question has real consequences, and can be addressed using these metrics."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Summary\n",
|
||
|
"\n",
|
||
|
"We've gone over several useful tools for model validation\n",
|
||
|
"\n",
|
||
|
"- The **Training Score** shows how well a model fits the data it was trained on. This is not a good indication of model effectiveness\n",
|
||
|
"- The **Validation Score** shows how well a model fits hold-out data. The most effective method is some form of cross-validation, where multiple hold-out sets are used.\n",
|
||
|
"- **Validation Curves** are a plot of validation score and training score as a function of **model complexity**:\n",
|
||
|
" + when the two curves are close, it indicates *underfitting*\n",
|
||
|
" + when the two curves are separated, it indicates *overfitting*\n",
|
||
|
" + the \"sweet spot\" is in the middle\n",
|
||
|
"- **Learning Curves** are a plot of the validation score and training score as a function of **Number of training samples**\n",
|
||
|
" + when the curves are close, it indicates *underfitting*, and adding more data will not generally improve the estimator.\n",
|
||
|
" + when the curves are far apart, it indicates *overfitting*, and adding more data may increase the effectiveness of the model.\n",
|
||
|
" \n",
|
||
|
"These tools are powerful means of evaluating your model on your data."
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 2",
|
||
|
"language": "python",
|
||
|
"name": "python2"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 2
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython2",
|
||
|
"version": "2.7.9"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 0
|
||
|
}
|