Add TensorFlow nearest neighbor notebook.

pull/20/head^2
Donne Martin 2015-12-28 07:54:11 -05:00
parent 47c1483f46
commit 60ebead762
2 changed files with 387 additions and 0 deletions

View File

@ -97,6 +97,7 @@ IPython Notebook(s) demonstrating deep learning functionality.
| [tsf-basics](http://nbviewer.ipython.org/github/donnemartin/data-science-ipython-notebooks/blob/master/deep-learning/tensor-flow-examples/1_intro/basic_operations.ipynb) | Learn basic operations in TensorFlow, a library for various kinds of perceptual and language understanding tasks from Google. |
| [tsf-linear](http://nbviewer.ipython.org/github/donnemartin/data-science-ipython-notebooks/blob/master/deep-learning/tensor-flow-examples/2_basic_classifiers/linear_regression.ipynb) | Implement linear regression in TensorFlow. |
| [tsf-logistic](http://nbviewer.ipython.org/github/donnemartin/data-science-ipython-notebooks/blob/master/deep-learning/tensor-flow-examples/2_basic_classifiers/logistic_regression.ipynb) | Implement logistic regression in TensorFlow. |
| [tsf-nn](http://nbviewer.ipython.org/github/donnemartin/data-science-ipython-notebooks/blob/master/deep-learning/tensor-flow-examples/2_basic_classifiers/nearest_neighbor.ipynb) | Implement nearest neighboars in TensorFlow. |
### tensor-flow-exercises

View File

@ -0,0 +1,386 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Nearest Neighbor in TensorFlow\n",
"\n",
"Credits: Forked from [TensorFlow-Examples](https://github.com/aymericdamien/TensorFlow-Examples) by Aymeric Damien\n",
"\n",
"## Setup\n",
"\n",
"Refer to the [setup instructions](http://nbviewer.ipython.org/github/donnemartin/data-science-ipython-notebooks/blob/master/deep-learning/tensor-flow-examples/Setup_TensorFlow.md)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting /tmp/data/train-images-idx3-ubyte.gz\n",
"Extracting /tmp/data/train-labels-idx1-ubyte.gz\n",
"Extracting /tmp/data/t10k-images-idx3-ubyte.gz\n",
"Extracting /tmp/data/t10k-labels-idx1-ubyte.gz\n"
]
}
],
"source": [
"# Import MINST data\n",
"import input_data\n",
"mnist = input_data.read_data_sets(\"/tmp/data/\", one_hot=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# In this example, we limit mnist data\n",
"Xtr, Ytr = mnist.train.next_batch(5000) #5000 for training (nn candidates)\n",
"Xte, Yte = mnist.test.next_batch(200) #200 for testing"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Reshape images to 1D\n",
"Xtr = np.reshape(Xtr, newshape=(-1, 28*28))\n",
"Xte = np.reshape(Xte, newshape=(-1, 28*28))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# tf Graph Input\n",
"xtr = tf.placeholder(\"float\", [None, 784])\n",
"xte = tf.placeholder(\"float\", [784])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Nearest Neighbor calculation using L1 Distance\n",
"# Calculate L1 Distance\n",
"distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.neg(xte))), reduction_indices=1)\n",
"# Predict: Get min distance index (Nearest neighbor)\n",
"pred = tf.arg_min(distance, 0)\n",
"\n",
"accuracy = 0."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Initializing the variables\n",
"init = tf.initialize_all_variables()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test 0 Prediction: 7 True Class: 7\n",
"Test 1 Prediction: 2 True Class: 2\n",
"Test 2 Prediction: 1 True Class: 1\n",
"Test 3 Prediction: 0 True Class: 0\n",
"Test 4 Prediction: 4 True Class: 4\n",
"Test 5 Prediction: 1 True Class: 1\n",
"Test 6 Prediction: 4 True Class: 4\n",
"Test 7 Prediction: 9 True Class: 9\n",
"Test 8 Prediction: 8 True Class: 5\n",
"Test 9 Prediction: 9 True Class: 9\n",
"Test 10 Prediction: 0 True Class: 0\n",
"Test 11 Prediction: 0 True Class: 6\n",
"Test 12 Prediction: 9 True Class: 9\n",
"Test 13 Prediction: 0 True Class: 0\n",
"Test 14 Prediction: 1 True Class: 1\n",
"Test 15 Prediction: 5 True Class: 5\n",
"Test 16 Prediction: 4 True Class: 9\n",
"Test 17 Prediction: 7 True Class: 7\n",
"Test 18 Prediction: 3 True Class: 3\n",
"Test 19 Prediction: 4 True Class: 4\n",
"Test 20 Prediction: 9 True Class: 9\n",
"Test 21 Prediction: 6 True Class: 6\n",
"Test 22 Prediction: 6 True Class: 6\n",
"Test 23 Prediction: 5 True Class: 5\n",
"Test 24 Prediction: 4 True Class: 4\n",
"Test 25 Prediction: 0 True Class: 0\n",
"Test 26 Prediction: 7 True Class: 7\n",
"Test 27 Prediction: 4 True Class: 4\n",
"Test 28 Prediction: 0 True Class: 0\n",
"Test 29 Prediction: 1 True Class: 1\n",
"Test 30 Prediction: 3 True Class: 3\n",
"Test 31 Prediction: 1 True Class: 1\n",
"Test 32 Prediction: 3 True Class: 3\n",
"Test 33 Prediction: 4 True Class: 4\n",
"Test 34 Prediction: 7 True Class: 7\n",
"Test 35 Prediction: 2 True Class: 2\n",
"Test 36 Prediction: 7 True Class: 7\n",
"Test 37 Prediction: 1 True Class: 1\n",
"Test 38 Prediction: 2 True Class: 2\n",
"Test 39 Prediction: 1 True Class: 1\n",
"Test 40 Prediction: 1 True Class: 1\n",
"Test 41 Prediction: 7 True Class: 7\n",
"Test 42 Prediction: 4 True Class: 4\n",
"Test 43 Prediction: 1 True Class: 2\n",
"Test 44 Prediction: 3 True Class: 3\n",
"Test 45 Prediction: 5 True Class: 5\n",
"Test 46 Prediction: 1 True Class: 1\n",
"Test 47 Prediction: 2 True Class: 2\n",
"Test 48 Prediction: 4 True Class: 4\n",
"Test 49 Prediction: 4 True Class: 4\n",
"Test 50 Prediction: 6 True Class: 6\n",
"Test 51 Prediction: 3 True Class: 3\n",
"Test 52 Prediction: 5 True Class: 5\n",
"Test 53 Prediction: 5 True Class: 5\n",
"Test 54 Prediction: 6 True Class: 6\n",
"Test 55 Prediction: 0 True Class: 0\n",
"Test 56 Prediction: 4 True Class: 4\n",
"Test 57 Prediction: 1 True Class: 1\n",
"Test 58 Prediction: 9 True Class: 9\n",
"Test 59 Prediction: 5 True Class: 5\n",
"Test 60 Prediction: 7 True Class: 7\n",
"Test 61 Prediction: 8 True Class: 8\n",
"Test 62 Prediction: 9 True Class: 9\n",
"Test 63 Prediction: 3 True Class: 3\n",
"Test 64 Prediction: 7 True Class: 7\n",
"Test 65 Prediction: 4 True Class: 4\n",
"Test 66 Prediction: 6 True Class: 6\n",
"Test 67 Prediction: 4 True Class: 4\n",
"Test 68 Prediction: 3 True Class: 3\n",
"Test 69 Prediction: 0 True Class: 0\n",
"Test 70 Prediction: 7 True Class: 7\n",
"Test 71 Prediction: 0 True Class: 0\n",
"Test 72 Prediction: 2 True Class: 2\n",
"Test 73 Prediction: 7 True Class: 9\n",
"Test 74 Prediction: 1 True Class: 1\n",
"Test 75 Prediction: 7 True Class: 7\n",
"Test 76 Prediction: 3 True Class: 3\n",
"Test 77 Prediction: 7 True Class: 2\n",
"Test 78 Prediction: 9 True Class: 9\n",
"Test 79 Prediction: 7 True Class: 7\n",
"Test 80 Prediction: 7 True Class: 7\n",
"Test 81 Prediction: 6 True Class: 6\n",
"Test 82 Prediction: 2 True Class: 2\n",
"Test 83 Prediction: 7 True Class: 7\n",
"Test 84 Prediction: 8 True Class: 8\n",
"Test 85 Prediction: 4 True Class: 4\n",
"Test 86 Prediction: 7 True Class: 7\n",
"Test 87 Prediction: 3 True Class: 3\n",
"Test 88 Prediction: 6 True Class: 6\n",
"Test 89 Prediction: 1 True Class: 1\n",
"Test 90 Prediction: 3 True Class: 3\n",
"Test 91 Prediction: 6 True Class: 6\n",
"Test 92 Prediction: 9 True Class: 9\n",
"Test 93 Prediction: 3 True Class: 3\n",
"Test 94 Prediction: 1 True Class: 1\n",
"Test 95 Prediction: 4 True Class: 4\n",
"Test 96 Prediction: 1 True Class: 1\n",
"Test 97 Prediction: 7 True Class: 7\n",
"Test 98 Prediction: 6 True Class: 6\n",
"Test 99 Prediction: 9 True Class: 9\n",
"Test 100 Prediction: 6 True Class: 6\n",
"Test 101 Prediction: 0 True Class: 0\n",
"Test 102 Prediction: 5 True Class: 5\n",
"Test 103 Prediction: 4 True Class: 4\n",
"Test 104 Prediction: 9 True Class: 9\n",
"Test 105 Prediction: 9 True Class: 9\n",
"Test 106 Prediction: 2 True Class: 2\n",
"Test 107 Prediction: 1 True Class: 1\n",
"Test 108 Prediction: 9 True Class: 9\n",
"Test 109 Prediction: 4 True Class: 4\n",
"Test 110 Prediction: 8 True Class: 8\n",
"Test 111 Prediction: 7 True Class: 7\n",
"Test 112 Prediction: 3 True Class: 3\n",
"Test 113 Prediction: 9 True Class: 9\n",
"Test 114 Prediction: 7 True Class: 7\n",
"Test 115 Prediction: 9 True Class: 4\n",
"Test 116 Prediction: 9 True Class: 4\n",
"Test 117 Prediction: 4 True Class: 4\n",
"Test 118 Prediction: 9 True Class: 9\n",
"Test 119 Prediction: 7 True Class: 2\n",
"Test 120 Prediction: 5 True Class: 5\n",
"Test 121 Prediction: 4 True Class: 4\n",
"Test 122 Prediction: 7 True Class: 7\n",
"Test 123 Prediction: 6 True Class: 6\n",
"Test 124 Prediction: 7 True Class: 7\n",
"Test 125 Prediction: 9 True Class: 9\n",
"Test 126 Prediction: 0 True Class: 0\n",
"Test 127 Prediction: 5 True Class: 5\n",
"Test 128 Prediction: 8 True Class: 8\n",
"Test 129 Prediction: 5 True Class: 5\n",
"Test 130 Prediction: 6 True Class: 6\n",
"Test 131 Prediction: 6 True Class: 6\n",
"Test 132 Prediction: 5 True Class: 5\n",
"Test 133 Prediction: 7 True Class: 7\n",
"Test 134 Prediction: 8 True Class: 8\n",
"Test 135 Prediction: 1 True Class: 1\n",
"Test 136 Prediction: 0 True Class: 0\n",
"Test 137 Prediction: 1 True Class: 1\n",
"Test 138 Prediction: 6 True Class: 6\n",
"Test 139 Prediction: 4 True Class: 4\n",
"Test 140 Prediction: 6 True Class: 6\n",
"Test 141 Prediction: 7 True Class: 7\n",
"Test 142 Prediction: 2 True Class: 3\n",
"Test 143 Prediction: 1 True Class: 1\n",
"Test 144 Prediction: 7 True Class: 7\n",
"Test 145 Prediction: 1 True Class: 1\n",
"Test 146 Prediction: 8 True Class: 8\n",
"Test 147 Prediction: 2 True Class: 2\n",
"Test 148 Prediction: 0 True Class: 0\n",
"Test 149 Prediction: 1 True Class: 2\n",
"Test 150 Prediction: 9 True Class: 9\n",
"Test 151 Prediction: 9 True Class: 9\n",
"Test 152 Prediction: 5 True Class: 5\n",
"Test 153 Prediction: 5 True Class: 5\n",
"Test 154 Prediction: 1 True Class: 1\n",
"Test 155 Prediction: 5 True Class: 5\n",
"Test 156 Prediction: 6 True Class: 6\n",
"Test 157 Prediction: 0 True Class: 0\n",
"Test 158 Prediction: 3 True Class: 3\n",
"Test 159 Prediction: 4 True Class: 4\n",
"Test 160 Prediction: 4 True Class: 4\n",
"Test 161 Prediction: 6 True Class: 6\n",
"Test 162 Prediction: 5 True Class: 5\n",
"Test 163 Prediction: 4 True Class: 4\n",
"Test 164 Prediction: 6 True Class: 6\n",
"Test 165 Prediction: 5 True Class: 5\n",
"Test 166 Prediction: 4 True Class: 4\n",
"Test 167 Prediction: 5 True Class: 5\n",
"Test 168 Prediction: 1 True Class: 1\n",
"Test 169 Prediction: 4 True Class: 4\n",
"Test 170 Prediction: 9 True Class: 4\n",
"Test 171 Prediction: 7 True Class: 7\n",
"Test 172 Prediction: 2 True Class: 2\n",
"Test 173 Prediction: 3 True Class: 3\n",
"Test 174 Prediction: 2 True Class: 2\n",
"Test 175 Prediction: 1 True Class: 7\n",
"Test 176 Prediction: 1 True Class: 1\n",
"Test 177 Prediction: 8 True Class: 8\n",
"Test 178 Prediction: 1 True Class: 1\n",
"Test 179 Prediction: 8 True Class: 8\n",
"Test 180 Prediction: 1 True Class: 1\n",
"Test 181 Prediction: 8 True Class: 8\n",
"Test 182 Prediction: 5 True Class: 5\n",
"Test 183 Prediction: 0 True Class: 0\n",
"Test 184 Prediction: 2 True Class: 8\n",
"Test 185 Prediction: 9 True Class: 9\n",
"Test 186 Prediction: 2 True Class: 2\n",
"Test 187 Prediction: 5 True Class: 5\n",
"Test 188 Prediction: 0 True Class: 0\n",
"Test 189 Prediction: 1 True Class: 1\n",
"Test 190 Prediction: 1 True Class: 1\n",
"Test 191 Prediction: 1 True Class: 1\n",
"Test 192 Prediction: 0 True Class: 0\n",
"Test 193 Prediction: 4 True Class: 9\n",
"Test 194 Prediction: 0 True Class: 0\n",
"Test 195 Prediction: 1 True Class: 3\n",
"Test 196 Prediction: 1 True Class: 1\n",
"Test 197 Prediction: 6 True Class: 6\n",
"Test 198 Prediction: 4 True Class: 4\n",
"Test 199 Prediction: 2 True Class: 2\n",
"Done!\n",
"Accuracy: 0.92\n"
]
}
],
"source": [
"# Launch the graph\n",
"with tf.Session() as sess:\n",
" sess.run(init)\n",
"\n",
" # loop over test data\n",
" for i in range(len(Xte)):\n",
" # Get nearest neighbor\n",
" nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i,:]})\n",
" # Get nearest neighbor class label and compare it to its true label\n",
" print \"Test\", i, \"Prediction:\", np.argmax(Ytr[nn_index]), \\\n",
" \"True Class:\", np.argmax(Yte[i])\n",
" # Calculate accuracy\n",
" if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):\n",
" accuracy += 1./len(Xte)\n",
" print \"Done!\"\n",
" print \"Accuracy:\", accuracy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.4.3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}