394 lines
130 KiB
Python
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Recurrent Neural Networks in Theano\n",
"\n",
"Credits: Forked from [summerschool2015](https://github.com/mila-udem/summerschool2015) by mila-udem\n",
"\n",
"First, we import some dependencies:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"from synthetic import mackey_glass\n",
"import matplotlib.pyplot as plt\n",
"import theano\n",
"import theano.tensor as T\n",
"import numpy\n",
"\n",
"floatX = theano.config.floatX"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now define a class that uses `scan` to initialize an RNN and apply it to a sequence of data vectors. The constructor initializes the shared variables after which the instance can be called on a symbolic variable to construct an RNN graph. Note that this class only handles the computation of the hidden layer activations. We'll define a set of output weights later."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class SimpleRNN(object):\n",
" def __init__(self, input_dim, recurrent_dim):\n",
" w_xh = numpy.random.normal(0, .01, (input_dim, recurrent_dim))\n",
" w_hh = numpy.random.normal(0, .02, (recurrent_dim, recurrent_dim))\n",
" self.w_xh = theano.shared(numpy.asarray(w_xh, dtype=floatX), name='w_xh')\n",
" self.w_hh = theano.shared(numpy.asarray(w_hh, dtype=floatX), name='w_hh')\n",
" self.b_h = theano.shared(numpy.zeros((recurrent_dim,), dtype=floatX), name='b_h')\n",
" self.parameters = [self.w_xh, self.w_hh, self.b_h]\n",
" \n",
" def _step(self, input_t, previous):\n",
" return T.tanh(T.dot(previous, self.w_hh) + input_t)\n",
" \n",
" def __call__(self, x):\n",
" x_w_xh = T.dot(x, self.w_xh) + self.b_h \n",
" result, updates = theano.scan(self._step,\n",
" sequences=[x_w_xh],\n",
" outputs_info=[T.zeros_like(self.b_h)])\n",
" return result"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For visualization purposes and to keep the optimization time managable, we will train the RNN on a short synthetic chaotic time series. Let's first have a look at the data:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAEACAYAAABRQBpkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztvXu0b0lV3/uZfd59Xk3T0HQ3LbQKofFKRBSfKN6IaTAB\nw7hX5Wo0xtd1iHpjVDCJsaPDDNE4fOZ6FdGLMYoOMF4ytCNEacRoRCICCi3dYhug6W6g6Rf9OOd0\n1/1jrTq7du2qOeuxfvv32/vUd4wz9j57/db61apVNef8fuesWuKcY2BgYGDgwsZF627AwMDAwMD6\nMZzBwMDAwMBwBgMDAwMDwxkMDAwMDDCcwcDAwMAAwxkMDAwMDLCAMxCR60TkJhG5WURepnzuM0Xk\nnIi8uPc7BwYGBgaWRZczEJEDwM8A1wFPB14iItdmPvcK4L8A0vOdAwMDAwPLo5cZPBu4xTl3q3Pu\nLPAa4EWJz30b8Frgw53fNzAwMDCwAvQ6g6uA9wf//8D8t/MQkauYHMTPzn8aS54HBgYGNgy9zqDE\nsP8E8HI37XshDJloYGBgYONwsPP8DwJXB/+/mokdhHgW8BoRAbgMeL6InHXOvT78kIgMxjAwMDDQ\nAOdcd5AtPRvVichB4K+AvwfcBrwVeIlz7j2Zz/8S8J+dc7+ZOOaWuKGBCSJyvXPu+nW3Yz9g9OWy\nGP25LJaynV3MwDl3TkReCvwucAB4lXPuPSLyzfPxn+tt4MDAwMDA6tErE+GcuwG4Ifpb0gk4576u\n9/sGBgYGBpbHWIG8f3Hjuhuwj3Djuhuwz3DjuhswsBNdOYMlMXIGAwMDA/VYynYOZjAwMDAwMJzB\nwMDAwMBwBgMDAwMDDGcwMDAwMMBwBgMDAwMDDGcwMDAwMMBwBgMDAwMDDGcwMDAwMMBwBgMDAwMD\nDGcwMDAwMMBwBgMDAwMDDGcwMDAwMMBwBgMRRDgkwvF1t2NgYGB3MZzBQIx/D9y57kbsZYhwRAQn\nwmPW3ZaB9UKEi0W4QYRD626LheEMFoIIp0T43HW3YwE8A7h43Y3Y43js/POJa23FHoYIIsJfi3Dl\nutvSiU8ErmNrTGwshjNYDv8W+G/rbsQC2PgIZg/AM4JTa23F3sZJJkP65DW3oxeXzD83niUOZ7Ac\n9svEH86gH37in15rK/Y2Hjf/vGytreiHdwYbbx/2vDMQ4XtF+NF1t2MfYTiDfngDcGStrdjb8A71\n2Fpb0Q8/FjZ+Xu15Z8Akz3zXuhvB3h+0HofX3YB9AG/IRl+248T8c6/3oXcGG38f+8EZbAqOrrsB\nC2G/OLV1wjuDjY8GNxjeGex1djWcwS7jkZ6T58qF3gTPfnEGo5KoH3vGAGwwBjPYZewXZ3Cm8/wv\nBO7qvMbFACJ7vk+PwOQg192QPQyfOB7MoB1+4eNgBruEvW64PHqdwRIG8OT8c60PXYTfFmmrXBDh\nAFNfPAocXLRhewgivFSEH+u4xAngHHvAAGwwNoIZiHBcpKsNp5jm08aPheEMJvjoo+eB+WhwbQ99\nNuYvAD618RIXAx9n6s8LOar9HuA7O84/DnyMC7sPEenKP21KzuA3gXd3nO/HwnAGu4SHO8/3UX2P\nXn4acKz3ofv2t+4tdILJGZzlAjdknTjBJDtuvAFYFUR4BvBAxyVOAA+x/j58JvBJHecPZ7AbCPb7\ncJ2X8rJKUyQzy0sngQ+z3kjmRPSz5fz7Gc6gF4MZwCdAl/R6Avgo6zeivbZlOINdwlK6omcGPRH1\ng0yR0Dofeq8zOM5gBtDv0E8Ad7MHDMAK4fvwpPqpPDy7WrdM1IvhDHYJ3gj3drRnBq3XOQ3cy6S1\nbwIzaL2PwQwm9K61GMxgK7BqlV43hRkc6Dz/OHtEMtzrzuAkywwYH720Tt5TwD1MuYtNYAat93Gc\nfeAM5ncyfEbHJXrXjJxgj0SDK4R3Bj1jcROYQe/3D2awS/DRQ+8D88ygtZzyNJMz6GYGIvyACF/U\nePoSzGA/yEQvBf604/xDTOWArfCGbC/3ISK8RYRPbDzdO4Oesdgd6IlwQoTP77jEEpLhcAa7AN/R\nBzoXe3ln0MMMvEzU+9C/D/g3jef2MoP9IhO15kxCPNj5/XvCABj4fOCzGs9dwhkskXf5h8BbOs5v\nngdzqfchJtuw8WNhPziD++ivi++VibwRPUe/xgh0VWBAHzXfCGYwvyms9eUwZxdowkMtJ81ByVEm\nprjuPrxChKsbz+1dgb6EM7iH/vm0xDNoHU/HmYpKztG5iHPeMueKnmtY2OvO4CSTEe6NyE8xRYK9\nRvQRllm52+sM1p5AFuGqDmPucW1nG6r7MTinde3KxUyOpJslinBMhG/ruMRvAm9qPNfnTVrvYSlm\n0OsMBECka162jgVvF5YIEr8KuK3zGir2ujPwxqtXqz9Jn8Y7mMFOfAB4Xe6gCNfklvkHBrm1Dd6Q\ntRgAb8Ran4FPwi8xFp4D/FTuoAjPFOFVyvlPo33BlK+oap1XSzGD3uDK30e1dBg4kNaNMJcMEq/q\nPN9EtzMQketE5CYRuVlEXpY4/lUi8g4ReaeI/DcReUbvdwbwMtEj9E28xwJ3sDnM4FzjeT4BvCk5\ngycpx94HXJ855g1Qti9FVCPlyxlbDNEp+vYV8s+gd0zi26DsjfNNwD9VzleNmAj/i0h2HUBPH0JH\nNdEcDBxnGZmo5z6OMxUSLGEXeu+ja2fmEnQ5AxE5APwM0wufnw68RERiav8+4Aucc88AfhD4+Z7v\njNAdkc8a72OBD9H/0JdiBj3OoKemeWlnYDnGx2f+7idw0uCLcBp4SJGB/Pkt93AK+Ah9RnCpseCj\n2lytvrU61qqIehfwauO718EMjjKx/YfpD656xoIvBFjCLvTex8pVnN4veDZwi3PuVufcWeA1wIvC\nDzjn/tg5d8/83z+Bbh05hM8Z9ETkl8zXeIC+iHqphw7tCaslB+8STs26Rm78WdsX+0Ra7h0Uvczg\nI/SzqyVYojfIufuwxknJ/L6y8bst+LHUcv6SsmvPWDg/nxoT6qtgiUvMyyR6ncFVwPuD/38AXdv6\neuB3Or8zhJeJegbN45n2FOqJhr1O3GUAFthracncx27IXbnxpzID4NL55yqcwWmWYQZLGACf+8j1\ngzV/S55hzsiZzkCET1au27PYakmH2usM7u1ox5L5o16mZqK3o4uNloh8EZO++XnKZ64P/nujc+5G\n47JLDJrHsYwzWCKi9g+8Z73DUhLHEs6gVee0nIFlJHtloo8BIsIB56rvYcmo1jIASziD3DXU754j\n5ZtFuMI5bk98pNcZLMVQ1bEw52OudY53KO3wtqGWsS+ZSzz/PETk2cBzO6+3A70N/CBsq2O+mokd\nbMOcNH4lcJ1z7mO5iznnrq/8fi8T9RivJZ1B70M3k24i/BXwfc7xG4nDp4BbaV90dYypxHa3nEHu\n+FLOIFet9C3ADzl3nmGE8FuL+LUrtc5gJQYgc/wATIbZuWRg1sMMrIjav8HrCsg6g79RztdwMZNs\nuxvM4F8xLfRM9UNYrdhiG1aRPzoyB8k3+gMi8v2d1wb6ZaK3AU8RkSeLyGHgK4DXhx8QkU9gqnf+\naufcLZ3fF2OJaqJLmaSVHgO4VDRYQmmfSn5V6CmmJfw9JZkPUjAJRTgpwouN6yUNaaB75pil1Q9d\nzgB4PnmJyW86eFY5X8MxtgyZORZE+CkRnpA5bN2n/3vue0rGc87ZWY7I918u8OjZrM87g93IGaQC\nAo+woKK1GmnpwGBlK5m7nIFz7hzTPjC/y/Q2oF93zr1HRL5ZRL55/ti/Zho4PysibxeRt3a1eDuW\n0LgvYRq0m8AMSuWN3IDolYmOMS2YKpmE34ayjmBGLmdgyWG9zMC6vvZmPB8cnKXtWR5lqw8thypM\n/fi8zEcsA2AtDCsxpNYzyl37SPTzPOb7Ok37+z08Q12SGeTGgrbSPAwMWp2av4+SwEBbdb/xOQOc\nczcAN0R/+7ng928AvqH
"text/plain": [
"<matplotlib.figure.Figure at 0x7f0f3cac8750>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data = numpy.asarray(mackey_glass(2000)[0], dtype=floatX)\n",
"plt.plot(data)\n",
"plt.show()\n",
"data_train = data[:1500]\n",
"data_val = data[1500:]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To train an RNN model on this sequences, we need to generate a theano graph that computes the cost and its gradient. In this case, the task will be to predict the next time step and the error objective will be the mean squared error (MSE). We also need to define shared variables for the output weights. Finally, we also add a regularization term to the cost."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"w_ho_np = numpy.random.normal(0, .01, (15, 1))\n",
"w_ho = theano.shared(numpy.asarray(w_ho_np, dtype=floatX), name='w_ho')\n",
"b_o = theano.shared(numpy.zeros((1,), dtype=floatX), name='b_o')\n",
"\n",
"x = T.matrix('x')\n",
"my_rnn = SimpleRNN(1, 15)\n",
"hidden = my_rnn(x)\n",
"prediction = T.dot(hidden, w_ho) + b_o\n",
"parameters = my_rnn.parameters + [w_ho, b_o]\n",
"l2 = sum((p**2).sum() for p in parameters)\n",
"mse = T.mean((prediction[:-1] - x[1:])**2)\n",
"cost = mse + .0001 * l2\n",
"gradient = T.grad(cost, wrt=parameters)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now compile the function that will update the parameters of the model using gradient descent. "
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"lr = .3\n",
"updates = [(par, par - lr * gra) for par, gra in zip(parameters, gradient)] \n",
"update_model = theano.function([x], cost, updates=updates)\n",
"get_cost = theano.function([x], mse)\n",
"predict = theano.function([x], prediction)\n",
"get_hidden = theano.function([x], hidden)\n",
"get_gradient = theano.function([x], gradient)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now train the network by supplying this function with our data and calling it repeatedly."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0: train mse: 0.0515556260943 validation mse: 0.0469637289643\n",
"Epoch 100: train mse: 0.0407442860305 validation mse: 0.0401079840958\n",
"Epoch 200: train mse: 0.00225670542568 validation mse: 0.00203324528411\n",
"Epoch 300: train mse: 0.00185390282422 validation mse: 0.00163305236492\n",
"Epoch 400: train mse: 0.00161687470973 validation mse: 0.00139373319689\n",
"Epoch 500: train mse: 0.00145859015174 validation mse: 0.00123134546448\n",
"Epoch 600: train mse: 0.00134439510293 validation mse: 0.00111229927279\n",
"Epoch 700: train mse: 0.00125755299814 validation mse: 0.00102029775735\n",
"Epoch 800: train mse: 0.0011889107991 validation mse: 0.000946390733588\n",
"Epoch 900: train mse: 0.00113300536759 validation mse: 0.000885214598384\n",
"Epoch 1000: train mse: 0.00108635553624 validation mse: 0.000833337020595\n"
]
}
],
"source": [
"for i in range(1001):\n",
" mse_train = update_model(data_train)\n",
" \n",
" if i % 100 == 0:\n",
" mse_val = get_cost(data_val)\n",
" print 'Epoch {}: train mse: {} validation mse: {}'.format(i, mse_train, mse_val)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Since we're only looking at a very small toy problem here, the model probably already memorized the train data quite well. Let's find out by plotting the predictions of the network:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAEACAYAAABRQBpkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsvXm4ZVdVL/obc87V7e401aUqSSUhCZEmQFQCSCOIYi4K\nXEE6I6IocL02T58NPEWqYvt8fl5Esb+oyAW5NtihIFwRFYk0Kn0iCZCkUqlU1Tl1mt2tfr4/5lx7\ndXPtU7V3hVNF7fF9+VJVa++151prrtH8xm+MQVJKLGQhC1nIQi5tYbu9gIUsZCELWcjuy8IYLGQh\nC1nIQhbGYCELWchCFrIwBgtZyEIWshAsjMFCFrKQhSwEC2OwkIUsZCELwXkwBkR0CxHdSUR3EdFr\npnzu8UQUE9Hz5/3NhSxkIQtZyPmVuYwBEXEAbwJwC4BHAngpET2i4XO/COA9AGie31zIQhaykIWc\nf5k3MrgZwN1SynuklBGAdwB4nuFzPwDgTwGcnvP3FrKQhSxkIQ+BzGsMLgdwrPD3+/W/TYSILocy\nEL+p/2lR8ryQhSxkIReYzGsMzkax/wqA10rV94KwgIkWspCFLOSCEzHn948DuLLw9yuhooOifBWA\ndxARAOwF8F+IKJJS/lXxQ0S0iBgWspCFLGQGkVLO7WTTPI3qiEgA+E8AzwTwAICPAHiplPKOhs//\nPoC/llK+03BMno8LWogSIjoqpTy62+v4cpDFvTy/srif51fOl+6cKzKQUsZE9P0A/g4AB/BmKeUd\nRPRqffy3513gQhaykIUs5KGXeWEiSCnfDeDdlX8zGgEp5XfN+3sLWchCFrKQ8y+LCuQvX/nAbi/g\ny0g+sNsL+DKTD+z2AhZSl7lyBudTFjmDhSxkIQs5d7kgcgYLWchCvvxkwey7cOWhdJgXxmAhC1lI\nTRZR+oUnD7WRXuQMFrKQhSxkIQtjsJCFLGQhC1kYg4UsZCELWQgWxmAhC1nIRS5E9AdE9DO7vY6L\nXRbGYCELWcjFLhJn0TSTiD5ARN/9JVjPRSkLY7CQhSzky0HOhv20oMxOkYUxWMhCFnJRCRHdRET/\nTkTbRPQOAK7+92UiehcRnSKiM0T013qeCojo5wA8FcCbiKhPRL+q//2NRHQfEW0R0ceI6Cm7dmG7\nLAtjsJCFLOSiESKyAfwFgLcAWAHwJwBeAOX1MwBvBnBY/zeGGssLKeVPAvhnAN8npexKKX9Qn/Ij\nAB6rz/V2AH+if+OSk4UxWMhCFnLOQgR5Pv6b4aefCEBIKd8opUyklH8G4KMAIKU8I6X8cymlL6Uc\nAPh5AF9bXXrxL1LKt0kpN6SUqZTyfwBwANwww7ouellUIC9kIQs5Z5Fy1yYWHoIaqlWUewGAiDyo\nyYrfCOXpA0CHdPMe/feSASKiHwXwCn1eCaAHNYTrkpNFZLCQkhChRQRnt9exkIU0yAlU5qwDuArK\n4/9RAA8HcLOUcgkqKiiO2q0agqcC+DEAL5RSLkspVwBs4RIdzbswBgupyqcB/PFuL+JiFiJcT4R0\nt9fxZSofAhAT0Q8SkUVEzwfweH2sA5Un2CKiVQBHKt89CeDawt+7AGIAa0RkE9HroSKD8yZEuIEI\nv3c+z/lQycIYnCchwlVEeMRur2NuWb7nGhz+p6/Z7WVc5HIN9n+KiMB3eyEXqxChQ4TTRDWMPwLw\nfADfCWAdwIsA/BmU1/8rADwAa1BG490oRwNvBPCtmmn0KwDeo//7HIB7oAzJfef5Up4JSi+KoV6L\neQbnSYhwHMChXcRSz4vQK58occWHIY9cvM9it4X2f+bl+L5H/wE+84JHyz/+08/s9nrOVS6Ed5EI\nT4I1/BCi9pKU2N7NtcwjdPA/Xov/9pW/gJM3tuVvfHI017kansv5el6LyOB8ydX/sBff+uLdXsX8\nwoPdXsHFL737FaadisO7vJKLV+z+AfxkB7jl/7p5t5cyl+y94yAAYPvyap7jgpOL3hgQ4Y+I8LLd\nXge++rcEHv1lALXzaLdXcPGLd+YyAIAIVnb45EKaZP+nlCFtn7xil1cyn7jb+wEA1nDfLq9kR7no\njQHswUvQu+/bdnsZSMXFfy8BgIe7vYKLX9xNpQB4uLzLK7l4Zfm+AwAAZ7Bnl1cyn9j9Vf2npV1d\nx1nIxV9n8KqvAqzxU85/3ucc5cvAoyaCgx9aGIO5xdYKjKS3yyu5eMUeqKhKUmuXVzKfZNeR2Bf8\ndVz8xmDv5wBJnXlOQQQBwJMS/ZlPIvx5lnChSG+RMzgP4vS1ImPuLq/k4hVrqKKqi92g2kNFVZX8\ngr+OixramNDO5NyzQX8amJOxIMZzLuGCkDasuQgPCwEAe6C56nJhDGYVMe7qP13wSnSqWAN1HZJf\n8JHBRW0MkG2U1JqvwOfQR67Dq79y5q8TgWDtvjEgwhIR3lXlZp+DdCAWkQERfocIL5/5BHY/e/EX\nldyzivC1Mdhdg6rfqdn1pDNQOkrSBW/ULnZjoOChxErmOssVH17Fwf+Y5wxtWEMAAN1Gu1lodBhf\n8RffBPdMd+ePGkSMu1kCmW6jS7fO4OC/vRIHPvHDM39f+JkRuGQjAyIQ0RxevfDb6kS7DhOdAfAb\nM3/bGmUdUHf7OnaUi90YKKWXWvPBRLF6d+dQgEtwt7I/7177Wxa18JJvAZ73iifM9P3esVXErkQi\nAMA6r2u7mOTbbwG+/ZbZq8mF7yDopCB5KUcGPwBgdszR0sYAX7p7SET3ENHX6T//BBH9Lv7LDzC8\n8EXPPcfzfJqIngYAEGMLYVsCF35kcLEnkLPIYD5jYA+yDWcDmAUn6cHdlFANrmyosvYvvSzdp2l4\n7MBM3+8+uAeRl8CCAI8tAJcmtai9BkTe7O+GGFsIetGXUpFdcLJ615PxDT8O4M9n+z73W0gZQMmX\n8h5O9IiU8ucBgF67+j3wNg42fYGI/gDAMSnlTxW+++jJB8RY6L1wwUeJF3dkYPdVoi6dE5lxtrMN\nN9vGY9ESnG2C35v9HOdDesdV611Js3Gz3Y09iN0IiZC4lCMDAKA50lDW2ELUDkDykhySAgC47j2r\neMRfzP59EbgIlhLQbEqUiHbd0SUCgz1iCNshLgLI8OI2BitfUMqPJfNdhz1UIVzYmk2RL927H6lI\nEXaA3YSJ3A1lBHg4W87A6a8gsUOkFnCpGwM527xcTSbgCNvjSzoyoIQAgG6bcWqY8B0E3QCUlr6v\noZzXEtFndMO53yMih4ieTkT3E9GPE9EJAG8mJa8loruJaI2I/jcRrRTO9TIiulcf+4nK7xwlorcW\n/v4UIvoQEW3oMZkvJ6JXAvg2AD+uR2n+ZWGNzwTQAt+W+PtNgT957/cQ0XEiekM2Sa2w5v+biE4S\n0QNE9J0z3a/zIBe3MWif3IdEACye7zqsscInhwdmq1dYOnYZolaIxAFGq7vnAWTVjixuz/R9a7iM\n1PKR2ISL2BgQYXnuDrKzM9Rc2H2JqD2+pCMDa5QZwtlaQlsjG2HHB6Umg/ptAJ4F1Y764QBeBwXx\nHIAaanMYwKsB/CCA5wJ4GoCDADYA/DoAENEjoRLDt0INttkDoNj6Qk7+twkA+Fuorqd7ATwOwMel\nlL8L4G0AflGP0nxe4bsSQBv/8QBw0mf4pif/FdR4zZv1ejM5AHWPDgH4bgC/TkS7Uq2866HUXOJu\n7UGwFIOH8xkDMVJUwLAzGxfYW9+PyAuQWA4GB2fzyrUQ4a0AXielmt50TpJVO84KVVnjZSRidPFH\nBuk7IcJnAO45EwKIQDgKIBWzMtTasAcS/cvHYNFFaww0PfkeAA+TErV7QbftUNvz9ZM/nabbmh9D\nY3dcMbYQtdaqkQGUkn2TlPK4Wif9HIBfA/B/AKQAjug21xERvRp40hHgQ3ulxANEdBuAe4noZQC+\nFcBfSyk/qM/zUwC+v3iJAFRngU8BAN4npfzf+tgZ/V/5s3Vp4z+HhBtvWkfHhZRyTa/htwG8Xn8m\nAvDTUsoUwLuJaAA1dvM
"text/plain": [
"<matplotlib.figure.Figure at 0x7f0f35c39350>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"predict = theano.function([x], prediction)\n",
"prediction_np = predict(data)\n",
"plt.plot(data[1:], label='data')\n",
"plt.plot(prediction_np, label='prediction')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Small scale optimizations of this type often benefit from more advanced second order methods. The following block defines some functions that allow you to experiment with off-the-shelf optimization routines. In this case we used BFGS."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Warning: Desired error not necessarily achieved due to precision loss.\n",
" Current function value: 0.000218\n",
" Iterations: 5\n",
" Function evaluations: 31\n",
" Gradient evaluations: 19\n",
"train mse: 0.000217512235395 validation mse: 0.00018158860621\n"
]
}
],
"source": [
"def vector_to_params(v):\n",
" return_list = []\n",
" offset = 0\n",
" # note the global variable here\n",
" for par in parameters:\n",
" par_size = numpy.product(par.get_value().shape)\n",
" return_list.append(v[offset:offset+par_size].reshape(par.get_value().shape))\n",
" offset += par_size\n",
" return return_list\n",
" \n",
" \n",
"def set_params(values):\n",
" for parameter, value in zip(parameters, values):\n",
" parameter.set_value(numpy.asarray(value, dtype=floatX))\n",
" \n",
" \n",
"def f_obj(x):\n",
" values = vector_to_params(x)\n",
" set_params(values)\n",
" return get_cost(data_train)\n",
" \n",
" \n",
"def f_prime(x):\n",
" values = vector_to_params(x)\n",
" set_params(values)\n",
" grad = get_gradient(data_train)\n",
" return numpy.asarray(numpy.concatenate([var.flatten() for var in grad]), dtype='float64')\n",
" \n",
" \n",
"from scipy.optimize import fmin_bfgs\n",
"x0 = numpy.asarray(numpy.concatenate([p.get_value().flatten() for p in parameters]), dtype='float64')\n",
"result = fmin_bfgs(f_obj, x0, f_prime)\n",
"\n",
"print 'train mse: {} validation mse: {}'.format(get_cost(data_train), get_cost(data_val))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Generating sequences\n",
"Predicting a single step ahead is a relatively easy task. It would be more intresting to see if the network actually learned how to generate multiple time steps such that it can continue the sequence.\n",
"Write code that generates the next 1000 examples after processing the train sequence."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAEACAYAAAC6d6FnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xm8HFWVwPHfIQn7vgUSQhZIWANhSdghIGEnKCMIjsCo\nKC5sM4oi41ipcRkUHUFRREAGhQEZBAQhQAQeILKFrJCEJIRAFhL2fTGBM3/c29B56X6vX9d2q/t8\nP5/+vNfd1VU3nXp16m7niqpijDGm/axSdAGMMcYUwwKAMca0KQsAxhjTpiwAGGNMm7IAYIwxbcoC\ngDHGtKnEAUBEDhORWSIyR0S+3cV2I0VkuYgcm/SYxhhjkksUAESkF3AxcBiwPXCiiGxXZ7sfA3cA\nkuSYxhhj0pG0BjAKmKuq81V1GXAdcEyN7c4AbgBeTHg8Y4wxKUkaAPoDC6qeL/SvfURE+uOCwiX+\nJZt6bIwxAUgaABq5mF8InKsu54RgTUDGGBOE3gk/vwgYUPV8AK4WUG034DoRAdgYOFxElqnqLdUb\niYjVDIwxpgmq2tSNtSRJBicivYGngE8Ai4FHgRNVdWad7a8EblXVG2u8p83+I8zKRGScqo4ruhyt\nwL7LdNn3ma4k185ENQBVXS4ipwN3Ar2AK1R1poic5t+/NMn+jTHGZCdpExCqOh4Y3+m1mhd+Vf18\n0uMZY4xJh80Ebl0dRReghXQUXYAW01F0AYyTqA8gTdYHYIwxPZfk2mk1AGOMaVMWAIwxpk1ZADDG\nmDZlAcAYY9qUBQBjjGlTFgCMMaZNWQAwxpg2ZQHAGGPalAUAY4xpUxYAjDGmTVkAMMaYNmUBwBhj\n2pQFAGOMaVMWAIwxpk0lXhDGtAeJZR3gbGA0cJtG+t/FlsgYk5TVAEy3JJZewP8CuwC/As6WWI4t\ntlTGmKQsAJhG/AhYG/iMRnojcCxwqcTSv9hiGWOSsABguiSxjAROBj6tkS4D0EgnAn8BPl1k2Ywx\nyVgAMHVJLAJcBJynkb7c6e0/Af+Uf6mMMWmxAGC6ciLQB7iqxnsTgOESy2b5FskYkxYLAKYmiaU3\n8J/ANzXSDzu/r5G+D4wHPpl32Ywx6bAAYOr5LLBQI72vi22sGciYErMAYFbi7/6/C8TdbHoPsJcf\nJmqMKRkLAKaW44ClQEdXG2mkrwLPA9vmUCZjTMosAJgV+JE/3wR+opFqAx95DBiZbamMMVmwAGA6\nOxBYE7itwe0nArtnVxxjTFYsAJjOzgF+WmvkTx0WAIwpKQsA5iMSy7a4fD/X9OBjk3HzAfpkUypj\nTFYsAJhqpwOXaaTvNfoBjfRNYD6wY1aFMsZkw9JBGwAklvVwY/+HN/HxibiO4MmpFsqYlEksewPb\nAPO6mePSFqwGYCr+BbhTI13UxGen0FzgMCY3Pnvtrbg1LW6UWLYptkTFswBgkFhWAc4AftnkLmZi\ncwFM+H4EXKqRngL8HPhOweUpnAUAA3A48BrwUJOfn4UFABMwiWV3YAzwX/6li4GjJJYhxZWqeBYA\nDPi7/wYnftXyHLChXzbSmBB9DfiZH7SARvoa8FvgzEJLVbDEAUBEDhORWSIyR0S+XeP9fxaRqSIy\nTUQeFJGdkh7TpKdq6Ocfm92HnzMwG9e5ZkxQJJbVcFlrr+v01g3AofmXKByJAoCI9MJVpQ4DtgdO\nFJHtOm02D9hfVXcCvo+LuiYcpwO/7cnQzzpmAZ3/740JwSHAEzUGOEwF+kos/QooUxCS1gBGAXNV\ndb6qLsNF2GOqN1DVh1T1df/0EWCLhMc0Kaka+vmbFHZn/QAmVCew8t0/GukHuISHB+ZdoFAkDQD9\ngQVVzxf61+r5InB7wmOa9PwLcFeTQz87s5FAJjgSy5rAkbjmnlruAQ7Kr0RhSToRrOFOQxE5EPgC\nsE8X24yretqhqh1Nl8x0yQ/9PB0XBNJgTUAmREcAj2mkL9R5/x7gGzmWJzERGY2by5BY0gCwCBhQ\n9XwArhawAt/xexlwmKq+Wm9nqjouYXlM4w4G3gb+ntL+ZgNDJJY+GumylPZpTFI1m3+qzARWl1iG\naKTzcipTIv7GuKPyXESiZveVtAloIjBURAaJyKrAZ4BbqjcQkS2BG4HPqerchMcz6fkqcEmCoZ8r\n8J3Ii4DBaezPmKT8sOQxwE31tvHn/9+AvfIqV0gSBQBVXY5rRrgTmAH8UVVnishpInKa3+x7wAbA\nJSIyWUQeTVRik5jEsgVwAD3L+tmIucDWKe/TmGaNBR7QSF/pZrvJwIgcyhOcxMngVHU8ML7Ta5dW\n/X4qcGrS45hUfQm4ViN9K+X9Pg1slfI+jWnWCTQ2v2UyJesHSIvNBG4zPm//qcAlGez+aawGYALg\na7n70EXzT5UpwAi/HGpbsQDQfsYCT2ukT2Swb6sB5Ehi2UBiuVhieUNi+b2f1W2cL9JgLVcjfR5Y\nThvOUWr5ACCxrC6xnCOxPCSxnFN0eQLwVbK5+wfXB2ABIAcSy+bA40Av3JKc84EbJJZeRZYrBBJL\nb1wt99Lutq0yGZcSpa20fADAXfA+BfwYOFtiOaTg8hRGYhmGy9t/Y0aHmAcM8nMMTEYklnVx/W5X\naKRf1UhnAxHwBq7du90dBSzUSKf14DNTaMOO4Jb+Q/Xt3f8KnKWR3gz8M3CVxLJRsSUrzCnAHzTS\n97PYuUb6DvAqXc8GNwn44HoN8DAuvz3w0XDG84C4nddn9onfzufjtM+NshpACzoet/TbYwAaaQfw\nV9yFsK34C8dJwO8zPpT1A2Tru8D6wJmd53D483sR7g64XX0DmK2R3tLtlitqywDQ6msC/yuualzt\nUuByieXnaU2CKokDgFd6WC1uRiUAdGR8nLYjsRwLfBkYqZH+o85mt+CyXzYy+qX+sQTB3SD28o/e\nVb/3Svj6KvDR/ldJ7fnIX23NIWucytW3Xyjj+EYX26/82VXe78V56/STtZdexNt9l9f4bK0HXbyX\nx2euVOX/6v0fNqJlA4Bf/3MQcEentx4EPsBdEDvyLVWhTib7u39o86Gg/sLZG1jNP1av+r3e8z7d\nPFZlp6sHceTa/8QN1/2ZOUd+T8bV2XbAAxtw3AkjRBheZ1+NXrhXweX6+gA3QuaDGo9mX/+w00MT\nPe/9rrD/94eyxy/34o6f38Gzo9cC1mjgs8s/ev7hqh/y7kYvstVdMO2kxZ227+pBA9tk9ZlZJNSy\nAQCX6+Yen/L1IxqpSiyXAqfRJgFAYlkd1xF+Xg6Hexq3+EYQROiFuxisWfWz3qO79ysX7+qLeK0L\n+ofAe8D7/lH791WWvc/+P+zHOov78N767zL1pHm8MPw1YNkKj/1/sDX7/tfRdIy7mjlHPrnS+9WP\nhXsuY80Xr2XXy3/BpFOfq7FNwxdu1caTPebNn9NjgU/gzrfZwJ468bSma7gSPz+cY0+epFNPuiql\nYgavlQPAGFx7fy3XAN+XWNb0HZet7mBgmh/vnLXEfQAi9MG1c28ArAusU/VYu9Pzeo/KdmsA7wLv\n1HjUe/0dYEmnbd71j3oX9o+eq7LCTUfNf2MsGwPX4hLydQAD2OeCscATuORls4G1gBOBvYHR+uA5\nE7v/9noj8T9uY+yXNtDHT72+++3LxU/WOhM4F5iOSy//U410Tgq7n4Fb2KpttGQA8CfJwcB/1Hpf\nI31ZYpmIayu9Oc+yFeRTJGwT7oF5wBARVgM2AjYGNsRd0CsX9e5+roZbpP413NDGN+s8FgFvdfH+\nm8A7qnyY8b+5GZcCc3CducvhoxEshwKfBo7DVfNvBL6mUf0sujVMwN0d92QcfPD8HIeLgZHAmAwm\nM84APp/yPoPWkgEA2BF4SyN9pottbsJdGFs6APhJMWNxy3E2tw/Xrr0usBnugl55bLLy7x9szHfW\nW5/VX32L9zZ4CXgJeAV3MX+16ufTNV6rXPTfCrn5ISmJZWfcXf1WlYs/gB+eewudMuo24a/AhRKL\ntNhAh/Nxa04cpJG+kcH
"text/plain": [
"<matplotlib.figure.Figure at 0x7f0f3ca042d0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"x_t = T.vector()\n",
"h_p = T.vector()\n",
"preactivation = T.dot(x_t, my_rnn.w_xh) + my_rnn.b_h\n",
"h_t = my_rnn._step(preactivation, h_p)\n",
"o_t = T.dot(h_t, w_ho) + b_o\n",
"\n",
"single_step = theano.function([x_t, h_p], [o_t, h_t])\n",
"\n",
"def generate(single_step, x_t, h_p, n_steps):\n",
" output = numpy.zeros((n_steps, 1))\n",
" for output_t in output:\n",
" x_t, h_p = single_step(x_t, h_p)\n",
" output_t[:] = x_t\n",
" return output\n",
"\n",
"\n",
"output = predict(data_train)\n",
"hidden = get_hidden(data_train)\n",
"\n",
"output = generate(single_step, output[-1], hidden[-1], n_steps=200)\n",
"plt.plot(output)\n",
"plt.plot(data_val[:200])\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"#Things to Try\n",
"The quality of the generated sequence is probably not very good. Let's try to improve on it. Things to consider are:\n",
"* The initial weight values\n",
"* Using L2/L1 regularization\n",
"* Using weight noise\n",
"* The number of hidden units\n",
"* The non-linearity\n",
"* Adding direct connections between the input and the output"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.4.3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}