data-science-ipython-notebooks/deep-learning/theano-tutorial/intro_theano/intro_theano.ipynb

883 lines
892 KiB
Python
Raw Normal View History

2015-12-27 22:25:19 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Introduction to Theano\n",
"\n",
"Credits: Forked from [summerschool2015](https://github.com/mila-udem/summerschool2015) by mila-udem\n",
"\n",
"# Overview\n",
"\n",
"## Basic usage\n",
"\n",
"### Defining an expression"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import theano\n",
"from theano import tensor as T\n",
"x = T.vector('x')\n",
"W = T.matrix('W')\n",
"b = T.vector('b')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"dot = T.dot(x, W)\n",
"out = T.nnet.sigmoid(dot + b)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Graph visualization"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dot [@A] '' \n",
" |x [@B]\n",
" |W [@C]\n"
]
}
],
"source": [
"from theano.printing import debugprint\n",
"debugprint(dot)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sigmoid [@A] '' \n",
" |Elemwise{add,no_inplace} [@B] '' \n",
" |dot [@C] '' \n",
" | |x [@D]\n",
" | |W [@E]\n",
" |b [@F]\n"
]
}
],
"source": [
"debugprint(out)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Compiling a Theano function"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"f = theano.function(inputs=[x, W], outputs=dot)\n",
"g = theano.function([x, W, b], out)\n",
"h = theano.function([x, W, b], [dot, out])\n",
"i = theano.function([x, W, b], [dot + b, out])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Graph visualization"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CGemv{inplace} [@A] '' 3\n",
" |AllocEmpty{dtype='float64'} [@B] '' 2\n",
" | |Shape_i{1} [@C] '' 1\n",
" | |W [@D]\n",
" |TensorConstant{1.0} [@E]\n",
" |InplaceDimShuffle{1,0} [@F] 'W.T' 0\n",
" | |W [@D]\n",
" |x [@G]\n",
" |TensorConstant{0.0} [@H]\n"
]
}
],
"source": [
"debugprint(f)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Elemwise{ScalarSigmoid}[(0, 0)] [@A] '' 2\n",
" |CGemv{no_inplace} [@B] '' 1\n",
" |b [@C]\n",
" |TensorConstant{1.0} [@D]\n",
" |InplaceDimShuffle{1,0} [@E] 'W.T' 0\n",
" | |W [@F]\n",
" |x [@G]\n",
" |TensorConstant{1.0} [@D]\n"
]
}
],
"source": [
"debugprint(g)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The output file is available at pydotprint_f.png\n"
]
}
],
"source": [
"from theano.printing import pydotprint\n",
"pydotprint(f, outfile='pydotprint_f.png')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABlgAAAH9CAIAAAD55ObJAAAABmJLR0QA/wD/AP+gvaeTAAAgAElE\nQVR4nOzdd3RUZf4G8OdOJpMy6ZUkhECIhNC7SiiCkp9UQaQtLSJKkVVBVlB0dVFQEQERCxZgEVR0\nFUWKgnSwAKFJCpAQ0itpk8kkU39/3J3ZkEYSMrkpz+fkcGbu3Hnv996Z5Bye833fK5hMJhARERER\nEREREbV0MqkLICIiIiIiIiIiagwMwoiIiIiIiIiIqFVgEEZERERERERERK2CXOoCiOrmk08+kboE\nIvqfhx56KDg4WOoqiIiIiIiIakXgYvnUvAiCIHUJRPQ/u3btmjx5stRVEBERERER1Qo7wqgZ2gXw\n/91ETQFzaSIiIiIiala4RhgREREREREREbUKDMKIiIiIiIiIiKhVYBBGREREREREREStAoMwIiIi\nIiIiIiJqFRiEERERERERERFRq8AgjIiIiIiIiIiIWgUGYURERERERERE1CowCCMiIiIiIiIiolaB\nQRgREREREREREbUKDMKIiIiIiIiIiKhVYBBGREREREREREStAoMwIiIiIiIiIiJqFRiEERERERER\nERFRqyCXugAiIuu4BZwAYoGXrDD4deB7wAYYD4RYYXwiIiIiIiKyAnaEETWWdYAjIABjgN+AdOBl\nQAAEYCZwwrzbKeBBQA68AOgqDXIUEAA3oA9wLyAA9sC9QC9ACQhARqOek/RVRQPrzY9NwBrgRWAw\nIAdmA48C2xv6iCrgSWA8MBhYWlUK9j4gNPRB715XYN6d9tEDrwCpjVEOERERERGRJNgRRtRYlgA6\nYDnQDRgIAHgDSAJ2AA8DQ8y7DQJmAh2BNVUNUgJEAHsAOwCAALQH/gQAFADhgMbap9GUqvoF+BLY\nYn66DlgLZAJFwHTgBWDfXR/iJtC+3NM84EFAD5wC3Kva/yyw7K4PWg83b6+zMl/A406DyIHlwBzg\nTSC4gQojIiIiIiJqStgRRtSI5gEOwA7AYN6yGEC5KEd0FHiqmhE0wFJz3lSBGzBfoiBMkqouA08D\n7wM25i0fAR6ADHAD9pXLFustBZhV7qkJmAn8BXxdTQqWD/wIBN71ceuqQp1VOgK8WYuhlMAqYBxQ\n2AB1ERERERERNTUMwogakRswAUgDfjFv6QW4A0eAePOWYuAa0LeaEUYBw6of/0ngngYrtg4avyoD\nMAt4HHApt/Fmgx4iGxgNZJfbchDYD0wAula1vwl4HfhHo8+LrFznXQoBOgNLG25AIiIiIiKiJoNB\nGLUsJmAvsAgIBJKBhwE7oAdw3rxDNDAOeBmYAwwAfgcAqIFvgEggHPgS8AA6AWeBU0A4YA90Ay6V\nO4oKWAnMBQYBg4BzAIBbQFw1P0nl3jsbAPCZ+elRQHn7lm+BSdWHKY41Tmi2BxRVlXfHy3IJGAb8\nC3gJsAFUAIBs4O/AYuAFYBCwAMgCDMBJ4AUgGEgE+gLeQNGdqvqPebGw9YAeAPAN4AjsAM4ALwEd\ngThgiPlqH6jxUgPYDVwCxpqf7gXmAwYgE5gPzAeKK5VR5emIqvxWfAT8ZR5QJDbueQO9AAXQE9hb\nbvz3gSmAa/XXobK6fvHuWGeVn04a8A0w29widwUYAwjAZCAP+CfQEfj69sLGAJ8D1+pyLkRERERE\nRM2CiahZAYBdgKmaHyOQbZ629gaQDhwCBKCveYd2QIh5zzbmxwYgDQDgBhwB0gA5EAisAzTAVUAO\nDDWPYADGAmnmp5MAd6AAeKf6X7PwchXqAX9ADmQAJmCaOQvzBbSACXgAyKz+BCv8AAi9fUuV5eXf\n6bIEA23Nj58EsoBsoD2w2ryxAAgD2gJJwFnAGQCwDjgKTAXy7lSVybxyVqz56Q1gPKAHfjGPtgSI\nAr4H3AAbIKr6S20CHgVsAN2djmvZUt3pZFT/rag8YAAAYAugAi4CHQAZ8BtgAn4D3jXvFir+ba3F\nT52+eLWps6yaT6fo9nNRA2FAD0ALTAOuVipMTN9erdU3cNeuXVL/VSAiIiIiIqotdoRRyyIA3oA3\nAGAF4Ac8BAQBF8w7PAM8CwAwAY5AAgBABvgBAHyBYYA/EAikAIsBe6AT0A44ax7hV+AnIMB8w8dv\ngXzgCLC0+rDgVLkKbYCZgB74N5AHXAWGAlOALGAPcB1wAnzv4gpUWd7RO12WPCAV+AAwms/6LeBm\nuaXKXIFXgVTgHaCf+XI9BTwAfFXNglkViMOuNT/dATwB2AAR5tHeBPoAE4DVgAHYWP2lBvAn4FuX\nu31UdzqrAFTzragsE2gLPA44AT2BtwEjsAm4BXwGPFfrYizq9MWrTZ2Kaj4dp9t3cwT+DUQDg4ER\nQKdK47QFYO44IyIiIiIiakEYhFFLVGFeoR1gND9+HpgBbAA2AWXmbprKb1Hc/tQWKDE//h3oUSnq\nmlCX8mYDAD4DdgBTAQGYCwD4FNgGTK/LUJXVUF4Nl2UDYAMsAgYA+YALcByAubdI9AAA4HS5oZR1\nKcwXmAtsN3d4HQUeNr8kjma55uKEx4s1nksm4FiXo9d8OtV9Kyqwv/2LIY5wBVgAzACumWfClgEA\n4qoP1Mqr/Rev9nVW/nQqz7TtDywDzgC9qhpBvFDpNRVORERERETUHDEIo1bmCNAJ6AU8U6lNppa0\nQDxQevtGQ63XCAMQBvQH4oHXzbHXfUAX4CDwJTCuXlXdsbyazQbOAg8CUcAgYKM5OilfuQeAOsZP\nFfwDMAHrgbPAfdX3c7UBANjXeC5C9TFQlWo+nVp+K8KAnHLHdTfXuQcYDoSZf26ad/6/ulRYG3f/\n7bUwAvFAIDDLnNwRERERERG1AgzCqJWJBJTmXp46JSkWXYESYFO5LWnAJmBruSikwk/lJi+xKaw/\n4A8AEIAnABMwsC5JU5X1V1dezd4CegO/At8BAF4GHgQA/Fxun1QAwJh6VSVqB8wANgObgDnV75YP\nAIio8VwCzOte1VLNpxNZ/bfCWO7xI4AKiDM/zQUAhAOlt/esWdYIi0cDq2WdtbEGGA9sAa4Ar1Z6\nVQ3AvCYaERERERFRC8IgjFoisWnIkhToAJiTgmIgHbgI7ATyAACxQEalt4g766sa8BGgHfAC8Bzw\nA7ABmAVE1nqNMNFUwNYch4lmArbAlLqcptgqVaGdp7ryar4s68xX41HAHwgBXgDuAdaaYykAHwP9\ngGequj53rMriVaAMSAZCKr1kaVs7DHQEFtd4LuFAzu3TBrW3D2IpT9xS8+lU963wArLM69nDfNtN\nyzJnewBPYEk1Z2rxAhAEbK3m1dp/8WpfZ+VPR3/7lj+B88BU4EFgIfBOpa+oONR9dzo1IiIiIiKi\n5oZBGLU4X5hnwL0PFAFbzVPVVgMaYC3gCEwGvIHFgAKYB9wC3gYApAEngeNACgBgFZAHbDEP+BGQ\nCyiBQ8AIYDMQCZwHvgRc61ikJzDr9lmQ3sDsukym+9W8OvtN4J/AH+bt1ZVX82XJAe4H3gT+AfQA\n/gN4AL8D44AxwDJgMSADjgImYB2QCAD4J3CldlVZtAdGA09UdUYfAkVABhAPnAbca7zUYoZ43vze\nOOB1AEAi8LF5Oqq4EH4SsAUQqjkdsf+uym+FDHgDMJW7H6gbcAIoAKYDy4DDwCnzuvI1SAeSq1lK\nP6cuX7za1Kmu6tNRA+vNl2IbsAMYD/iZp4t6A0ZgPLCzXGHnAQGYdqdTIyIiIiIiam4Ek6l+08OI\npCEIAnYBk6Wug+rHANwPHLt9Bmhn4GodZ6qagAigN7CmYeuzjlRgNHBJ6jJq71HABdhWiz0F7Nq1\na/Jk/kISEREREVHzwI4wImpEnwFD727FfZEAbAX2m2cINmUa4EXgU6nLqL3LQLS5iYyIiIiIiKhl\nqe62bUREDecXYDGgB/KA2EqviquV6ev4B6kt8AXwHPAZoGiYMq3iGrAaCJS6jFrKBVYAB8z3xCQi\nIiIiImpZ2BFGRNbnDxQ
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"execution_count": 9,
"metadata": {
"image/png": {
"width": 1000
}
},
"output_type": "execute_result"
}
],
"source": [
"from IPython.display import Image\n",
"Image('pydotprint_f.png', width=1000)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The output file is available at pydotprint_g.png\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABScAAAJ4CAIAAAAY0cErAAAABmJLR0QA/wD/AP+gvaeTAAAgAElE\nQVR4nOzdaXgUZd7+/bOzmwAJkbAEghC2RBhWFZGAOCgiICBMQJFNBUHl9hYHQR3cQR1kwL/ggoPo\njeijqDgyCKIMSAwwgkRQMEFlJwtJyN5Zu9PPi5ruabKRkHQ6hO/n6MOj6+rqq35V3UU866qqNtls\nNgEAAAAAABfwcHcBAAAAAAA0WqRuAAAAAABchdQNAAAAAICreLm7AAAAAKDB2bZt27Fjx9xdBYD/\nCA8Pv/nmm91dxUUidQMAAABlvf3225988om7qwDwH9HR0aRuAAAAoHGJlta7uwYAkia4u4Da4bpu\nAAAAAABchdQNAAAAAICrkLoBAAAAAHAVUjcAAAAAAK5C6gYAAAAAwFVI3QAAAAAAuAqpGwAAAAAA\nVyF1AwAAAADgKqRuAAAAAABchdQNAAAAAICrkLoBAAAAAHAVUjcAAAAAAK5C6gYAAAAAwFW83F0A\nAAAAAFzKzkkxUrz0pAs6/03aIHlKY6XOLugfrsdYNwAAANC4LJP8JZM0StotJUkLJZNkkqZIMfbZ\nYqWhkpc0Xyop18kOySQFSX2l/pJJ8pP6S72lAMkkJdfrOrm/qsPScvtzm7REekIaJHlJ06Rx0tq6\nXmKuNFMaKw2S5lUUuVdIprpeaO11l2ZdaB6L9JR0pj7KaQgY6wYAAAAal0elEulxqYd0gyRpkXRS\nWicNlwbbZ4uSpkidpCUVdZIvDZM2Sr6SJJPUQfpekpQlDZQKXL0aDamqrdKH0hr75DJpqZQi5Uh3\nS/OlL2u9iBNSB6fJDGmoZJFipeYVzb9PWlDrhV6EE+fXWV4rKfhCnXhJj0v3Si9J4XVUWAPGWDcA\nAADQ6MySrpDWSVZ7y1xJTrnRsEO6v5IeCqR59nBbRpA0202p2y1V/SQ9JK2QPO0tb0rBkocUJH3p\ndCDjop2WpjpN2qQp0s/SR5VE7kzpCyms1sutqTJ1Vmi79FI1ugqQFkujpew6qKuBI3UDAAAAjU6Q\ndIeUKG21t/SWmkvbpd/tLXnSr1K/SnoYId1Uef8zpS51VmwN1H9VVmmqdI/UzKnxRJ0uIlUaKaU6\ntXwtbZbukLpXNL9NekF6rN5PLy9fZy11liKkeXXXYUNF6gYAAABqziZtkuZIYdIpabjkK/WU4uwz\nHJZGSwule6XrpD2SJLO0XpouDZQ+lIKlrtI+KVYaKPlJPaSDTkvJlZ6XZkhRUpT0gyTpnJRQyeOk\n03unSZJW2yd3SAHnt3wiRVee3PyrvBrVT/KpqLwLbpaD0k3Sc9KTkqeUK0lKlf5HmivNl6KkB6Sz\nklX6TpovhUvHpX5SiJRzoao+tV/gvVyySJLWS/7SOmmv9KTUSUqQBtu39pYqN7Wkz6WD0u32yU3S\nbMkqpUizpdlSXrkyKlwdQ4Xfijeln+0dGoxTEkKk3pKP1Eva5NT/CmmiFFj5diivpl+8C9ZZ4aeT\nKK2XptkH/w9JoySTNEHKkJ6WOkkfnV/YKOkd6dearMulyAYAAADgfNHR0YqWbJU/SqVU+9m/i6Qk\n6RvJJPWzz9Be6myfs7X9uVVKlCQFSdulRMlLCpOWSQXSEclLutHeg1W6XUq0T0ZLzaUs6ZXK/+d+\noFOFFilU8pKSJZt0lz14t5KKJZs0REqpch2dH5K6nd9SYXmZF9os4VI7+/OZ0lkpVeogvWhvzJIi\npXbSSWmf1FSStEzaId0pZVyoKpv9aud4++Qxaaxkkbbae3tU2i9tkIIkT2l/5ZvaJo2TPKWSCy3X\n0VLZ6iRX/q0o32FbSdIaKVc6IHWUPKTdkk3aLf3NPls3I9FV41GjL1516iyq5NPJOX9dzFKk1FMq\nlu6SjpQrzIj6z1yo/mhFR0e7+1+Fi8dYNwAAAFBzJilECpEk/UVqI90sXSX9aJ/hYel/JUk2yV86\nKknykNpIklpJN0mhUph0Wpor+UldpfbSPnsP26R/Sm3ttx//RMqUtkvzKg8nsU4VekpTJIv0f1KG\ndES6UZoonZU2Sr9JTaRWtdgCFZa340KbJUM6I70uldrX+mXphNPl5YHSM9IZ6RXpGvvmul8aIv1/\nlVzkXIbR7VL75DrpPslTGmbv7SWpr3SH9KJklV6rfFNL+l5qVZObUFe2OoslVfKtKC9FaifdIzWR\nekl/lUqlldI5abX0SLWLcajRF686dfpU8uk0OX82f+n/pMPSIOkWqWu5ftpJso+lN16kbgAAAOBi\nlTk921cqtT//szRZelVaKRXZxwnLv8Xn/ElvKd/+fI/Us1yuvqMm5U2TJK2W1kl3SiZphiTp79J7\n0t016aq8KsqrYrO8KnlKc6TrpEypmbRTkn3U1DBEkrTLqauAmhTWSpohrbWPXe+QhttfMnpzbHPj\nvPEDVa5LiuRfk6VXvTqVfSvK8Dv/i2H0cEh6QJos/Wq/oKBIkpRQeXp3Vv0vXvXrLP/plL9g4Vpp\ngbRX6l1RD8aGSqqq8EaA1A0AAAC4wHapq9RberjcAGA1FUu/S4XnN1qrfV23pEjpWul36QV7xr5e\nulr6WvpQGn1RVV2wvKpNk/ZJQ6X9UpT0mj2nOVdu/O5UjbJuGY9JNmm5tE+6vvKR6taSJL8q18VU\neeasUNWrU81vRaSU5rTc5vY6N0p/lCLtjxP2mW+tSYXVUftvr0Op9LsUJk21Hya4/JC6AQAAABeY\nLgXYRylrFNscukv50kqnlkRppfSuU+4q8yg/fG0Md18rhUqSTNJ9kk26oSaxtsL6Kyuvai9LfaRt\n0meSpIXSUEnSV07znJEkjbqoqgztpcnSKmmldG/ls2VKkoZVuS5t7dcqV1PVqzO98m9FqdPzMVKu\nlGCfTJckDZQKzx+Nd1zX/bvqWDXrrI4l0lhpjXRIeqbcq2ZJ9uvYGy9SNwAAAHCxjOFQRywpkWSP\nJXlSknRA+kDKkCTFS8nl3mLMbKmowzFSe2m+9Ij0D+lVaao0vdrXdRvulLzt2dswRfKWJtZkNY1B\n4DIDlZWVV/VmWWbfGuOkUKmzNF/qIi21Z2BJb0nXSA9XtH0uWJXDM1KRdErqXO4lx4D8v6RO0twq\n12WglHb+2dfF53fiKM9oqXp1KvtWtJDO2m94JvtN4B2Xpm+UrpQerWRNHeZLV0nvVvJq9b941a+z\n/KdjOb/leylOulMaKj0ovVLuK2p0df2FVu0SR+oGAAAALsr79hOJV0g50rv2M35flAqkpZK/NEEK\nkeZKPtIs6Zz0V0lSovSdtFM6LUlaLGVIa+wdvimlSwHSN9It0ippuhQnfVjD34uSdKU09fyTyUOk\naTU5J3mb/fZdJ6SnpX/b2ysrr+rNkiYNkF6SHpN6Sp9KwdIeabQ0SlogzZU8pB2STVomHZckPS0d\nql5VDh2kkdJ9Fa3RG1KOlCz9Lu2Smle5qY0DFo5fPkuQXpAkHZfesp/Vb9wp7aS0RjJVsjrGmQUV\nfis8pEWSzenu9EFSjJQl3S0tkP4lxdpvPFaFJOlUJfdaS6vJF686dZor+nTM0nL7pnhPWieNldrY\nz7oPkUqlsdIHToXFSSbprgut2iXOZLNd3PkuAAAAQKM1YcKET/SJ1ru7Dlw0qzRA+vb8E+kj7D9e\nVX02aZjUR1pSt/W5xhlp5Pm/vN3AjZOaSe9daLYJilb0+vWX6g7JWDcAAACARme1dGPtbslmMEnv\nSpvtJ1o3ZAXSE9Lf3V1G9f0kHbYPjzdq1f/hOQAAAABo2LZKcyWLlCHFl3vVuMLcUsMY1E56X3pE\nWl3uB7calF+lF6Uwd5dRTenSX6Qt1fsN9kscY90AAAAAGotQKUsqkj6TQpzazdIi6ZgkaYG0v4bd\n9pGekl6rszJdotelE7lLpNXS+1K4uyupF4x1AwAAAGgs/iAlVdQeIC2UFtai5y7SvFq8Hc68pcfd\nXUM9YqwbAAAAAABXIXU
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"execution_count": 10,
"metadata": {
"image/png": {
"width": 1000
}
},
"output_type": "execute_result"
}
],
"source": [
"pydotprint(g, outfile='pydotprint_g.png')\n",
"Image('pydotprint_g.png', width=1000)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The output file is available at pydotprint_h.png\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABlgAAALdCAIAAADPoPliAAAABmJLR0QA/wD/AP+gvaeTAAAgAElE\nQVR4nOzdd3RUZf4G8OdOJpMy6ZUkhECIhNC7SiiCkp9UQaQtLSJKkVVBVlB0dVFQEQERCxZgEVR0\nFUWKgnSwAKFJCpAQ0itpk8kkU39/3J3ZkEYSMrkpz+fkcGbu3Hnv996Z5Bye833fK5hMJhARERER\nEREREbV0MqkLICIiIiIiIiIiagwMwoiIiIiIiIiIqFVgEEZERERERERERK2CXOoCiOrmk08+kboE\nIvqfhx56KDg4WOoqiIiIiIiIakXgYvnUvAiCIHUJRPQ/u3btmjx5stRVEBERERER1Qo7wqgZ2gXw\n/91ETQFzaSIiIiIiala4RhgREREREREREbUKDMKIiIiIiIiIiKhVYBBGREREREREREStAoMwIiIi\nIiIiIiJqFRiEERERERERERFRq8AgjIiIiIiIiIiIWgUGYURERERERERE1CowCCMiIiIiIiIiolaB\nQRgREREREREREbUKDMKIiIiIiIiIiKhVYBBGREREREREREStAoMwIiIiIiIiIiJqFRiEERERERER\nERFRqyCXugAiIuu4BZwAYoGXrDD4deB7wAYYD4RYYXwiIiIiIiKyAnaEETWWdYAjIABjgN+AdOBl\nQAAEYCZwwrzbKeBBQA68AOgqDXIUEAA3oA9wLyAA9sC9QC9ACQhARqOek/RVRQPrzY9NwBrgRWAw\nIAdmA48C2xv6iCrgSWA8MBhYWlUK9j4gNPRB715XYN6d9tEDrwCpjVEOERERERGRJNgRRtRYlgA6\nYDnQDRgIAHgDSAJ2AA8DQ8y7DQJmAh2BNVUNUgJEAHsAOwCAALQH/gQAFADhgMbap9GUqvoF+BLY\nYn66DlgLZAJFwHTgBWDfXR/iJtC+3NM84EFAD5wC3Kva/yyw7K4PWg83b6+zMl/A406DyIHlwBzg\nTSC4gQojIiIiIiJqStgRRtSI5gEOwA7AYN6yGEC5KEd0FHiqmhE0wFJz3lSBGzBfoiBMkqouA08D\n7wM25i0fAR6ADHAD9pXLFustBZhV7qkJmAn8BXxdTQqWD/wIBN71ceuqQp1VOgK8WYuhlMAqYBxQ\n2AB1ERERERERNTUMwogakRswAUgDfjFv6QW4A0eAePOWYuAa0LeaEUYBw6of/0ngngYrtg4avyoD\nMAt4HHApt/Fmgx4iGxgNZJfbchDYD0wAula1vwl4HfhHo8+LrFznXQoBOgNLG25AIiIiIiKiJoNB\nGLUsJmAvsAgIBJKBhwE7oAdw3rxDNDAOeBmYAwwAfgcAqIFvgEggHPgS8AA6AWeBU0A4YA90Ay6V\nO4oKWAnMBQYBg4BzAIBbQFw1P0nl3jsbAPCZ+elRQHn7lm+BSdWHKY41Tmi2BxRVlXfHy3IJGAb8\nC3gJsAFUAIBs4O/AYuAFYBCwAMgCDMBJ4AUgGEgE+gLeQNGdqvqPebGw9YAeAPAN4AjsAM4ALwEd\ngThgiPlqH6jxUgPYDVwCxpqf7gXmAwYgE5gPzAeKK5VR5emIqvxWfAT8ZR5QJDbueQO9AAXQE9hb\nbvz3gSmAa/XXobK6fvHuWGeVn04a8A0w29widwUYAwjAZCAP+CfQEfj69sLGAJ8D1+pyLkRERERE\nRM2CiahZAYBdgKmaHyOQbZ629gaQDhwCBKCveYd2QIh5zzbmxwYgDQDgBhwB0gA5EAisAzTAVUAO\nDDWPYADGAmnmp5MAd6AAeKf6X7PwchXqAX9ADmQAJmCaOQvzBbSACXgAyKz+BCv8AAi9fUuV5eXf\n6bIEA23Nj58EsoBsoD2w2ryxAAgD2gJJwFnAGQCwDjgKTAXy7lSVybxyVqz56Q1gPKAHfjGPtgSI\nAr4H3AAbIKr6S20CHgVsAN2djmvZUt3pZFT/rag8YAAAYAugAi4CHQAZ8BtgAn4D3jXvFir+ba3F\nT52+eLWps6yaT6fo9nNRA2FAD0ALTAOuVipMTN9erdU3cNeuXVL/VSAiIiIiIqotdoRRyyIA3oA3\nAGAF4Ac8BAQBF8w7PAM8CwAwAY5AAgBABvgBAHyBYYA/EAikAIsBe6AT0A44ax7hV+AnIMB8w8dv\ngXzgCLC0+rDgVLkKbYCZgB74N5AHXAWGAlOALGAPcB1wAnzv4gpUWd7RO12WPCAV+AAwms/6LeBm\nuaXKXIFXgVTgHaCf+XI9BTwAfFXNglkViMOuNT/dATwB2AAR5tHeBPoAE4DVgAHYWP2lBvAn4FuX\nu31UdzqrAFTzragsE2gLPA44AT2BtwEjsAm4BXwGPFfrYizq9MWrTZ2Kaj4dp9t3cwT+DUQDg4ER\nQKdK47QFYO44IyIiIiIiakEYhFFLVGFeoR1gND9+HpgBbAA2AWXmbprKb1Hc/tQWKDE//h3oUSnq\nmlCX8mYDAD4DdgBTAQGYCwD4FNgGTK/LUJXVUF4Nl2UDYAMsAgYA+YALcByAubdI9AAA4HS5oZR1\nKcwXmAtsN3d4HQUeNr8kjma55uKEx4s1nksm4FiXo9d8OtV9Kyqwv/2LIY5wBVgAzACumWfClgEA\n4qoP1Mqr/Rev9nVW/nQqz7TtDywDzgC9qhpBvFDpNRVORERERETUHDEIo1bmCNAJ6AU8U6lNppa0\nQDxQevtGQ63XCAMQBvQH4oHXzbHXfUAX4CDwJTCuXlXdsbyazQbOAg8CUcAgYKM5OilfuQeAOsZP\nFfwDMAHrgbPAfdX3c7UBANjXeC5C9TFQlWo+nVp+K8KAnHLHdTfXuQcYDoSZf26ad/6/ulRYG3f/\n7bUwAvFAIDDLnNwRERERERG1AgzCqJWJBJTmXp46JSkWXYESYFO5LWnAJmBruSikwk/lJi+xKaw/\n4A8AEIAnABMwsC5JU5X1V1dezd4CegO/At8BAF4GHgQA/Fxun1QAwJh6VSVqB8wANgObgDnV75YP\nAIio8VwCzOte1VLNpxNZ/bfCWO7xI4AKiDM/zQUAhAOlt/esWdYIi0cDq2WdtbEGGA9sAa4Ar1Z6\nVQ3AvCYaERERERFRC8IgjFoisWnIkhToAJiTgmIgHbgI7ATyAACxQEalt4g766sa8BGgHfAC8Bzw\nA7ABmAVE1nqNMNFUwNYch4lmArbAlLqcptgqVaGdp7ryar4s68xX41HAHwgBXgDuAdaaYykAHwP9\ngGequj53rMriVaAMSAZCKr1kaVs7DHQEFtd4LuFAzu3TBrW3D2IpT9xS8+lU963wArLM69nDfNtN\nyzJnewBPYEk1Z2rxAhAEbK3m1dp/8WpfZ+VPR3/7lj+B88BU4EFgIfBOpa+oONR9dzo1IiIiIiKi\n5oZBGLU4X5hnwL0PFAFbzVPVVgMaYC3gCEwGvIHFgAKYB9wC3gYApAEngeNACgBgFZAHbDEP+BGQ\nCyiBQ8AIYDMQCZwHvgRc61ikJzDr9lmQ3sDsukym+9W8OvtN4J/AH+bt1ZVX82XJAe4H3gT+AfQA\n/gN4AL8D44AxwDJgMSADjgImYB2QCAD4J3CldlVZtAdGA09UdUYfAkVABhAPnAbca7zUYoZ43vze\nOOB1AEAi8LF5Oqq4EH4SsAUQqjkdsf+uym+FDHgDMJW7H6gbcAIoAKYDy4DDwCnzuvI1SAeSq1lK\nP6cuX7za1Kmu6tNRA+vNl2IbsAMYD/iZp4t6A0ZgPLCzXGHnAQGYdqdTIyIiIiIiam4Ek6l+08OI\npCEIAnYBk6Wug+rHANwPHLt9Bmhn4GodZ6qagAigN7CmYeuzjlRgNHBJ6jJq71HABdhWiz0F7Nq1\na/Jk/kISEREREVHzwI4wImpEnwFD727FfZEAbAX2m2cINmUa4EXgU6nLqL3LQLS5iYyIiIiIiKhl\nqe62bUREDecXYDGgB/KA2EqviquV6ev4B6kt8AXwHPAZoGiYMq3iGrAaCJS6jFrKBVYAB8z3xCQi\nIiIiImpZ2BFGRNbnDxQ
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"execution_count": 11,
"metadata": {
"image/png": {
"width": 1000
}
},
"output_type": "execute_result"
}
],
"source": [
"pydotprint(h, outfile='pydotprint_h.png')\n",
"Image('pydotprint_h.png', width=1000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Executing a Theano function"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([ 1.79048354, 0.03158954, -0.26423186])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"np.random.seed(42)\n",
"W_val = np.random.randn(4, 3)\n",
"x_val = np.random.rand(4)\n",
"b_val = np.ones(3)\n",
"\n",
"f(x_val, W_val)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.9421594 , 0.73722395, 0.67606977])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"g(x_val, W_val, b_val)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"[array([ 1.79048354, 0.03158954, -0.26423186]),\n",
" array([ 0.9421594 , 0.73722395, 0.67606977])]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"h(x_val, W_val, b_val)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"[array([ 2.79048354, 1.03158954, 0.73576814]),\n",
" array([ 0.9421594 , 0.73722395, 0.67606977])]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"i(x_val, W_val, b_val)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Graph definition and Syntax\n",
"## Graph structure"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The output file is available at pydotprint_f_notcompact.png\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABXwAAALwCAIAAACrxiohAAAABmJLR0QA/wD/AP+gvaeTAAAgAElE\nQVR4nOzdd3hTZd8H8G9GkzbdpS0dQG1B2soqCCh7KMhQWQIiMkQEVEBBHkT0caCPAxy8ggguQEQF\nBZRVlF2mlFEeWihQoKWD0t2kSZs0yXn/OG/yhi4KNE3H93NxeSUnJ/f5ndPTXp7vue/7SARBABER\nERERERFRTZM6ugAiIiIiIiIiapgYOhARERERERGRXTB0ICIiIiIiIiK7kDu6AKp5V69e3bNnj6Or\nIKL/N23aNEeXQERERETkABJOJNnwbNy4cezYsY6ugoj+H//SEhEREVHjxJ4ODRevcYjqgo0AM0Ai\nIiIiaqw4pwMRERERERER2QVDByIiIiIiIiKyC4YORERERERERGQXDB2IiIiIiIiIyC4YOhARERER\nERGRXTB0ICIiIiIiIiK7YOhARERERERERHbB0IGIiIiIiIiI7IKhAxERERERERHZBUMHIiIiIiIi\nIrILhg5EREREREREZBcMHYiIiIiIiIjILhg6EBEREREREZFdyB1dABE5VC4QA1wAFtqh8cvAZkAG\nDAda2aF9IiIiIiKq29jTgeqYzwEVIAEeB44CGcBbgASQABOAGMtqh4FHADkwHygt18h+QAJ4AZ2A\nhwAJ4Aw8BEQBroAEuFGr++T4qhKALyyvBWAx8AbQC5ADk4CRwI81vUUN8AIwHOgFzKsocVgGSGp6\no/euDTD9dusYgX8DabVRDhERERFRfceeDlTHzAVKgQVAW6A7AOADIAX4CRgE9Las1hOYALQEFlfU\niA4YCGwFlAAACXAf8A8AoADoARTbezfqUlV/AT8DP1jefg58CmQCamA8MB/Ycc+bSAbus3mbBzwC\nGIHDgHdF68cCr9/zRu9C8q11ltcU8LldI3JgATAF+AgIq6HCiIiIiIgaKPZ0oLpnOuAC/ASYLEvm\nALC5bBbtB6ZV0kIxMM9ybV+GFzDDQaGDQ6r6L/AysAyQWZZ8DfgAUsAL2GGT49y1VGCizVsBmACc\nA36tJHHIB/4Emt/zdu9UmTortA/4qBpNuQL/AZ4ECmugLiIiIiKiBoyhA9U9XsAIIB34y7IkCvAG\n9gFJliVFwCXgwUpaGAL0q7z9F4D7a6zYO1D7VZmAicBzgIfNwuQa3UQWMBTIslnyN7ATGAG0qWh9\nAXgf+Fetj60oX+c9agVEAPNqrkEiIiIiooaIoUOjJADbgZlAc+A6MAhQAu2B05YVEoAngbeAKUBX\n4BgAQAtsBCYDPYCfAR+gNRALHAZ6AM5AW+CszVY0wCJgKtAT6AmcBADkAomV/Eux+e4kAMB3lrf7\nAddbl/wGjK78wlVV5cghZ0BRUXm3PSxngX7Ae8BCQAZoAABZwCxgDjAf6Am8CNwETMAhYD4QBlwD\nHgT8APXtqvrdMrnDF4ARALARUAE/ASeAhUBLIBHobTna0VUeagBbgLPAE5a324EZgAnIBGYAM4Ci\ncmVUuDuiCs+Kr4FzlgZFYocUPyAKUAAdgO027S8DxgKelR+H8u70xLttnRX+dNKBjcAkS9ePeOBx\nQAKMAfKAt4GWwK+3FvY48D1w6U72hYiIiIiosRGowdmwYQMACJX/MwNZlq7vHwAZwG5AAjxoWaEF\n0MqyZoDltQlIBwB4AfuAdEAONAc+B4qBi4Ac6GNpwQQ8AaRb3o4GvIECYEnl52IPmwqNQBAgB24A\nAjDOkjs0BQyAAPQFMqvcR9t/AMJvXVJhefm3OyxhQDPL6xeAm0AWcB/woWVhARAJNANSgFjAHQDw\nObAfeBrIu11VgmWmgwuWt1eB4YAR+MvS2lzgFLAZ8AJkwKnKD7UAjARkQOnttmtdUtnu3Kj8rCjf\nYDAA4AdAA8QBoYAUOAoIwFHgM8tq4bjNWWr7w6r+iVedOvWV/HTUt+6LFogE2gMGYBxwsVxhYtLx\nzu3q38C/tERERETUeLGnQ6MkAfwAPwDAm0Ag8CgQApyxrDAbeAUAIAAq4AoAQAoEAgCaAv2AIKA5\nkArMAZyB1kALINbSwh5gGxBsefDEb0A+sA+YV/m12WGbCmXABMAIrAXygItAH2AscBPYClwG3ICm\n93AEKixv/+0OSx6QBnwFmC17/TGQbDO1hCfwDpAGLAE6Ww7XNKAv8EslExyUITb7qeXtT8DzgAwY\naGntI6ATMAL4EDABX1Z+qAH8AzS9k+liK9ud/wCo5KwoLxNoBjwHuAEdgE8AM7AcyAW+A16tdjFW\nd3TiVadORSU/HbdbV1MBa4EEoBcwAGhdrp1mACw9KYiIiIiIqCIMHRqxMmMTlIDZ8vo14FlgKbAc\n0FvuEpf/iuLWt06AzvL6GNC+XKww4k7KmwQA+A74CXgakABTAQDfAmuA8XfSVHlVlFfFYVkKyICZ\nQFcgH/AADgKw3DMX9QUAHLFpyvVOCmsKTAV+tPRc2A8MsnwktmY95uKgibgq9yUTUN3J1qvencrO\nijKcbz0xxBbigReBZ4FLltE0egBAYuXhha3qn3jVr7P8T6f8aJ0uwOvACSCqohbEA5VRVeFERERE\nRI0cQweqyD6gNRAFzC53+7eaDEASUHLrQlO153QAEAl0AZKA9y0Rw8PAA8DfwM/Ak3dV1W3Lq9ok\nIBZ4BDgF9AS+tFym2lYuPnDxji71y/gXIABfALHAw5X3UwgAADhXuS+Syi+5K1T17lTzrIgEsm22\n622pcyvQH4i0/Eu2rPzYnVRYHfd+9lqZgSSgOTDRkpIQEREREdGdYOhAFZkMuFruUd/RVatVG0AH\nLLdZkg4sB1bbXHaW+Ve+84LY2aELEAQAkADPAwLQ/U6u6iusv7LyqvYx0BHYA2wCALwFPAIA2GWz\nThoA4PG7qkrUAngWWAUsB6ZUvlo+AGBglfsSbJmnoJqq3p3JlZ8VZpvXwwANkGh5mwMA6AGU3NoX\nwzqnQxJqWDXrrI7FwHDgByAeeKfcp1oAljksiIiIiIioIgwdGjHxZrj1qqwUgOWqrAjIAOKA9UAe\nAOACcKPcV8SVjRU1OAxoAcwHXgX+AJYCE4HJ1Z7TQfQ04GSJHkQTACdg7J3sptgFoMxt6srKq/qw\nfG45GiOBIKAVMB+4H/jUEgEAWAl0BmZXdHxuW5XVO4AeuA60KveRtTvGXqAlMKfKfekBZN869MBw\nayPW8sQlVe9OZWeFL3DTMtcjLI//sE5LsRVoAsytZE+t5gMhwOpKPq3+iVf9Osv/dIy3LvkHOA08\nDTwCvAQsKXeKik09fLtdIyIiIiJqxBg6NFbrLL3olwFqYLWlu/uHQDHwKaACxgB+wBxAAUwHcoFP\nAADpwCHgIJAKAPgPkAf8YGnwayAHcAV2AwOAVcBk4DTw8x0+KBFAE2DirSMp/IBJd9Ihf49l5sJk\n4G3guGV5ZeVVfViygW7AR8C/gPbA74APcAx4EngceB2YA0iB/YAAfA5cAwC8DcRXryqr+4ChwPMV\n7dEKQA3cAJKAI4B3lYdazGusj/xMBN4HAFwDVlqGtIiTRKYAPwCSSnZH7FdS4VkhBT4ABJvnkngB\nMUABMB54HdgLHLbMuViFDOB6JdNMZt/JiVedOrUV/XS0wBeWQ7EG+AkYDgRahpz4AWZgOLDeprDT\ngAQYd7tdIyIiIiJqxCSCcHe956nu2rhx49ixY+9yWATVBSagG3Dg1lEkEZanNlafAAwEOgKLa7Y+\n+0gDhlqeQ1kvjAQ8gDW3W20jMBb8S0tEREREjRN7OhDVPd8Bfe5tNkqRBFgN7LSMMqjLioE3gG8d\nXUb1/RdIsHSOICIiIiKiSlQ2Mz4R1bq/gDmAEcgDLpT7VJxdwniHv7XNgHXAq8B35Z40WadcAj4E\nmju6jGrKAd4Eoi3P5iAiIiIiokqwpwNRnREEFAB6YBPgZ7NcC3wAXAUAvA6cusNmOwL/Br6ssTLt\nokP9SRxKge+AdUCYoys
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"execution_count": 16,
"metadata": {
"image/png": {
"width": 1000
}
},
"output_type": "execute_result"
}
],
"source": [
"pydotprint(f, compact=False, outfile='pydotprint_f_notcompact.png')\n",
"Image('pydotprint_f_notcompact.png', width=1000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Strong typing\n",
"### Broadcasting tensors"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(True, False)\n"
]
}
],
"source": [
"r = T.row('r')\n",
"print(r.broadcastable)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(False, True)\n"
]
}
],
"source": [
"c = T.col('c')\n",
"print(c.broadcastable)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 1.1 2.1 3.1]\n",
" [ 1.2 2.2 3.2]]\n"
]
}
],
"source": [
"f = theano.function([r, c], r + c)\n",
"print(f([[1, 2, 3]], [[.1], [.2]]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Graph Transformations\n",
"## Substitution and Cloning\n",
"### The `givens` keyword"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([ 1.90651511, 0.60431744, -0.64253361])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_ = T.vector('x_')\n",
"x_n = (x_ - x_.mean()) / x_.std()\n",
"f_n = theano.function([x_, W], dot, givens={x: x_n})\n",
"f_n(x_val, W_val)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Cloning with replacement"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([ 1.90651511, 0.60431744, -0.64253361])"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dot_n, out_n = theano.clone([dot, out], replace={x: (x - x.mean()) / x.std()}) \n",
"f_n = theano.function([x, W], dot_n) \n",
"f_n(x_val, W_val)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Gradient\n",
"### Using `theano.grad`"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"y = T.vector('y')\n",
"C = ((out - y) ** 2).sum()\n",
"dC_dW = theano.grad(C, W)\n",
"dC_db = theano.grad(C, b)\n",
"# dC_dW, dC_db = theano.grad(C, [W, b])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Using the gradients"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[array(0.6137821438190066), array([[ 0.01095277, 0.07045955, 0.051161 ],\n",
" [ 0.01889131, 0.12152849, 0.0882424 ],\n",
" [ 0.01555008, 0.10003427, 0.07263534],\n",
" [ 0.01048429, 0.06744584, 0.04897273]]), array([ 0.03600015, 0.23159028, 0.16815877])]\n"
]
}
],
"source": [
"cost_and_grads = theano.function([x, W, b, y], [C, dC_dW, dC_db])\n",
"y_val = np.random.uniform(size=3)\n",
"print(cost_and_grads(x_val, W_val, b_val, y_val))"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[array(0.6137821438190066), array([[ 0.49561888, -0.14531026, 0.64257244],\n",
" [ 1.52114073, -0.24630622, -0.2429612 ],\n",
" [ 1.57765781, 0.7574313 , -0.47673792],\n",
" [ 0.54151161, -0.47016228, -0.47062703]]), array([ 0.99639999, 0.97684097, 0.98318412])]\n"
]
}
],
"source": [
"upd_W = W - 0.1 * dC_dW\n",
"upd_b = b - 0.1 * dC_db\n",
"cost_and_upd = theano.function([x, W, b, y], [C, upd_W, upd_b])\n",
"print(cost_and_upd(x_val, W_val, b_val, y_val))"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The output file is available at pydotprint_cost_and_upd.png\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAACVMAAAPwCAIAAAC5/esPAAAABmJLR0QA/wD/AP+gvaeTAAAgAElE\nQVR4nOzdeVxVdeL/8fdlVVDBBcUNFUwwzH1JuWqW46SZS6a2QTqZWt++NVajLbaM7U3LTIvtZpBN\nLlmj/rTSXCFLQ7OUpRQQlUUMZd8u3N8fZ+CLIgoKHJbX8+FjHveee/ic9zkczAfv+XyOxW63CwAA\nAAAAAAAAAEAD52B2AAAAAAAAAAAAAAA1gOYPAAAAAAAAAAAAaAxo/gAAAAAAAAAAAIDGwMnsAAAA\nAAAAAABQp7Zs2RIXF2d2CgD/5evrO3bsWLNTAI0EzR8AAAAAAACApuX9999fvXq12SkA/Nf06dNp\n/oCaQvMHAAAAAAAAoOmZLq0yOwMASTPMDgA0LjznDwAAAAAAAAAAAGgMaP4AAAAAAAAAAACAxoDm\nDwAAAAAAAAAAAGgMaP4AAAAAAAAAAACAxoDmDwAAAAAAAAAAAGgMaP4AAAAAAAAAAACAxoDmDwAA\nAAAAAAAAAGgMaP4AAAAAAAAAAACAxoDmDwAAAAAAAAAAAGgMaP4AAAAAAAAAAACAxoDmDwAAAAAA\nAAAAAGgMaP4AAAAAAAAAAACAxoDmDwAAAAAAAAAAAGgMnMwOAAAAAAAAAABAffWHtFOKlh6rhcF/\nl9ZKjtIUqWctjA+g6WHOHwAAAAAAAAA0HK9JbpJFmih9LyVJiyWLZJGCpZ2lu4VL10lO0kKpqMIg\n2ySL5CkNlIZJFqmZNEzqL7lLFim5Ts/J/FSHpNdLX9ull6VHpZGSk3SndJMUWtNHzJLulqZII6WH\nz1f7vSlZavqgly9QmnexfWzSE9LxuogDoCLm/AEAAAAAAABAw/GgVCQ9IvWRRkiSnpWOSp9K10uj\nSnezSsGSn/Ty+QbJlcZJ6yRXSZJF6i79KEk6IwVJebV9GvUp1TfSZ9Ky0revSa9IKVKmdLu0UPp/\nl32IBKl7ubfp0nWSTQqXWp9v/73Soss+6CVIODtnRR2kNhcbxEl6RPqL9ILkW0PBAFQZc/4AAAAA\nAAAAoEGZJzWXPpWKS7cskFSuuzJsk+ZWMkKe9HBpwXYOT2m+Sc2fKal+kf5HelNyLN3yjtRGcpA8\npf9Xrky9ZMekkHJv7VKw9Kv0eSW132npP1LXyz5udZ2T87y2Si9UYSh36TlpkpRRA7kAVAvNHwAA\nAAAAAAA0KJ7SVOmE9E3plv5Sa2mrdLh0S7b0mzSokhEmSGMqH/9u6YoaC1sNdZ+qWAqRZkutym1M\nqNFDnJRukE6W2/KttFGaKgWeb3+79Iz0tzpf6rNizsvUUwqQHq65AQFUDc0fAAAAAAAAAJzNLm2Q\n7pO6SonS9ZKr1FfaV7rDIWmStFj6izRU2i1JypFWSbOkIOkzqY3US9orhUtBUjOpj3Sg3FGypCXS\nHMkqWaWfJEl/SDGV/Dla7mvvlCR9WPp2m+R+9pbV0vTK2yO3Cz4Jqpnkcr54F70sB6Qx0t+lxyRH\nKUuSdFL6X2mBtFCySvdIqVKxtEtaKPlK8dIgyUvKvFiqNaUP/HtdskmSVklu0qfSHukxyU+KkUaV\nXu1NF7zUkr6UDkg3lr7dIM2XiqUUab40X8quEOO8p2M4713xjvRr6YAGY2qml9RfcpH6SRvKjf+m\nNFPyqPw6VFTdG++iOc/73TkhrZLuLJ0EeVCaKFmkGVK69KTkJ31+drCJ0kfSb9U5FwCXzWK3283O\nAAAAAAAAAAB1Z8aMGau1Wqsq38MunZL8pdPSs9JfpEPSOGlgaWnUTXKRfpfsUiephfS7VCKlSJ0l\nT2mt5C91kzpKC6R7pEQpUAqStkuSSqQp0rtSJyOTtEWKlz6Q/lZJqiApvPR1seQjnZSOSd7SbdJc\naaLUQjomOUtjpM+lDlW7IhbJX4opt+W88eKkogteFj+pUDomSZorPStZpKHSXOlRSVKGNFzKkiKk\nk9K1Upb0mjRAek9aevbqlxVTSXpEekmKlgIkSfHSg9Ia6TvpZilLelC6XToq/UXKkvZI/Su51B7S\nNOk/Uv7ZjWPF45ZtSavkdPZK3pXcFRUH7CKdkJZJ06Uj0lTpqBQuDZd2S7ulByVJAVKsVJXf31fr\nxlMVchZKv5zvu+MktSp3LrnSYMlZ+km6U3pa6nV2sF+kftJT0tMXzD9D0zV91aoL/EACqAbm/AEA\nAAAAAADA2SySl+QlSXpc6iiNlbpJ+0t3uF96QJJkl9ykI5IkB6mjJKmDNEbqJHWVjkkLpGZSL8lH\n2ls6whZpvdRZskgWabV0WtoqPSzZK/kT/n8B5SgFSzbpEyldipVGSzOlVGmd9LvUosq133mdN962\ni12WdOm49LZUUnrWL0oJ5R436CE9JR2X/iENLr1cc6VrpH9X8tC7cxjDvlL69lPpLslRGlc62gvS\nQGmq9LxULL1R+aWW9KPU4YITDc9R2ek8J6mSu6KiFKmLNFtqIfWTXpJKpLekP6QPpb9WOUyZat14\nVcnpUsl3p8XZu7lJn0iHpJHSnyrUfpK6SCqdUwigrtD8AQAAAAAAAMD5nLNUpqtUUvr6IekO6Z/S\nW1JBuYlZ53yJy9lvnaXc0te7pb4Vur2p1YlXtuDnp9ItkkWaI0n6QFou3V6doSq6QLwLXJZ/So7S\nfdJQ6bTUStohSWpZbv9rJEkR5YZyr06wDtIcKVQ6IdmlbdL1pR8Zo5Vdc2MNz58veC4pklt1jn7h\n06nsrjhHs7NvDGOEg9I90h3Sb6WLuxZIkmIqbxDLq/qNV/WcFb87FRePHSItKp1YWZFxoZIuFBxA\njaP5AwAAAAAAAIBq2ir1kvpL91eYCFVFhdJhKf/sjcVVfs6fpN7SEOmw9Expz3e1dKX0rfSZNOmS\nUl003oXdKe2VrpMiJav0RmlXVD55G0nV7NvO8TfJLr0u7ZWurnzGnrckqdkFz8VSteU0y1z4dKp4\nV/SW0sodt3VpznXStVLv0j8JpTv/uToJq+Ly794yJdJhqasUUlpVAjAbzR8AAAAAAAAAVNMsyb10\ntla1qqMygVKu9Fa5LSekt6SPy3U/5/ypOI3PmPY3pPQJdhbpLskujahOtXbe/JXFu7AXpQHSFukL\nSdJi6TpJ0tfl9jkuSZp4SakMPtId0nvSW9JfKt/ttCRp3AXPpbOUebEk5V34dGZVfleUlHs9Wcoq\n99i/U5KkICn/7FmJ/qXjHK5OwqqoYs6qeFmaIi2TDkpPVfg0R5LUufoJAVwGmj8AAAAAAAAAOB9j\nWlhZNVIkqbQayZaSpJ+lFVK6JClaSq7wJcbOtvMNOFnykRZKf5W+kv4phUizqvycP8MtknNp/2cI\nlpylmdU5TWMy3DkTtiqLd+HL8lrp1bhJ6iT1lBZKV0ivlPZwkt6VBkv3n+/6XDRVmaekAilR6lnh\no7KJid9JftKCC55LkJR29kqYhWcPUhbP2HLh06nsrmgnpUonSr/kPqlruUcVrpPaSg9WcqZlFkrd\npI8r+bTqN17Vc1b87tjO3vKjtE+6RbpOulf6R4Vb1Bjq6oudGoAaRfMHAAAAAAAAABWElS7q+KaU\nKX1cuvri81Ke9IrkJs2QvKQFkos0T/pDekmSdELaJe2QjkmSnpPSpWWlA74jnZLcpc3Sn6T3pFnS\nPukzyaOaIdtKIWcv7Okl3Vmd9SG3SH+VJCVIT0o/lG6vLN6FL0uaNFx6Qfqb1FdaI7WRdkuTpInS\nImmB5CBtk+zSa1K8JOlJ6WDVUpXpLt0g3XW+M1oqZUrJ0mEpQmp9wUttlKb7Sr82RnpGkhQvvVu6\nwupzkqSj0jLJUsnpGDMsz3tXOEjPSnbpH6VH8ZR2Smek26VF0ndSuNTlIt8oJUmJpZflHGnVufGq\nkjPnfN+dHOn10kuxXPpUmiJ1LF0B1UsqkaZIK8oF2ydZpFsvdmoAapTFbr+0uegAAAAAAAAA0CDN\nmDFjtVZrldk5cMmKpeH
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"execution_count": 25,
"metadata": {
"image/png": {
"width": 1000
}
},
"output_type": "execute_result"
}
],
"source": [
"pydotprint(cost_and_upd, outfile='pydotprint_cost_and_upd.png')\n",
"Image('pydotprint_cost_and_upd.png', width=1000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Shared variables\n",
"### Update values"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"C_val, dC_dW_val, dC_db_val = cost_and_grads(x_val, W_val, b_val, y_val)\n",
"W_val -= 0.1 * dC_dW_val\n",
"b_val -= 0.1 * dC_db_val\n",
"\n",
"C_val, W_val, b_val = cost_and_upd(x_val, W_val, b_val, y_val)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Using shared variables"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 1.78587062 0.00189954 -0.28566499]\n"
]
}
],
"source": [
"x = T.vector('x')\n",
"y = T.vector('y')\n",
"W = theano.shared(W_val)\n",
"b = theano.shared(b_val)\n",
"dot = T.dot(x, W)\n",
"out = T.nnet.sigmoid(dot + b)\n",
"f = theano.function([x], dot) # W is an implicit input\n",
"g = theano.function([x], out) # W and b are implicit inputs\n",
"print(f(x_val))"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0.94151144 0.72221187 0.66391952]\n"
]
}
],
"source": [
"print(g(x_val))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Updating shared variables"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"C = ((out - y) ** 2).sum()\n",
"dC_dW, dC_db = theano.grad(C, [W, b])\n",
"upd_W = W - 0.1 * dC_dW\n",
"upd_b = b - 0.1 * dC_db\n",
"\n",
"cost_and_perform_updates = theano.function(\n",
" inputs=[x, y],\n",
" outputs=C,\n",
" updates=[(W, upd_W),\n",
" (b, upd_b)])"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The output file is available at pydotprint_cost_and_perform_updates.png\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAB90AAAOlCAIAAABPBGD/AAAABmJLR0QA/wD/AP+gvaeTAAAgAElE\nQVR4nOzdeVxU9f4/8NdsgAMKKCj7MuxLoKKmglpmpqbZtTBbXK5paXnz5i29ll1zvd1yqexmptm9\nXm1BszTSa3VLcUsRwwUBUXYQAWXfZ+b8/jhf5jcO26DAjPh6PubhY86ZM5/zPnMAP/M+n/P+SARB\nABERERERERERERERdQmpqQMgIiIiIiIiIiIiIrqHMC9PRERERERERERERNR1mJcnIiIiIiIiIiIi\nIuo6clMHcG8pLS0VBEGr1ZaVlQFoaGiorKwEUFtbW1NTA6C6urqurs7gXbrNjGFjY6NQKAxWWlpa\nKpVKAEql0tLSEkDPnj3lcjkAW1tbqVQqkUjs7Ozu6NiIiIiIiIiIiIiIyAjMy7dPTU1NaWlpSaOq\nqqqysrLKysqqqqqqqipxTXV1dUVFRVlZmbiyvLy8oqJCrVbf9k4VCoWNjY2RG1dWVjY0NNz2vuRy\nec+ePW1tbZVKpbW1ta2tbc+ePa2tra2tre3s7MQnNjY2dnZ2SqXSXo+VldVt75SIiIiIiIiIiIjo\n3iERBMHUMZieWq0uKioqKiq6fv16YWFhUVFRSUmJfv5dp7a21uC99vb2Btlqa2vrXr169erVS1xv\na2urG8NuMDhdTIKjufHsHavpeHzdpYJmh/CXlpZWV1eLFxXKy8vFCwwGVyAMdtGjRw97e3s7Ozv7\nW9nZ2Tk6Ovbt27dfv36Ojo6Ojo7iOH3qWFOnTt29e7epoyDqhvi/JBERUeeRSCSmDoHonsA+LRGR\nGbonMqQajeb69es5OTn5+fn5+fnFxcVFRUUFBQViLr6wsPDGjRu6jeVyuaOjo35mWaVSNU03i8Rk\nuvmzsrISx7Pb29t3VJtidr6lCxglJSXp6enik6KiIv3bBRwcHMQEvX6y3sXFxcXFxc3NzcnJSSrl\ntAe3ZSiwyNQxEHUnJ4GNpo6BiIio23sVGGbqGIi6MfZpiYjMVffJy9fX14uZ9+zs7GvXruXm5ubm\n5oqL169f1+WFe/fu3bdvXzEpHBoaKiaF+/XrJ650dHR0cHAw7YHcLcS7Adzc3IzZWLwWIl4FuX79\nuvi8oKDg/Pnz4sqbN2+KW8rl8n79+nl4eIhpejc3NxcXF3d3d/FfCwuLzjymu5w7EG3qGIi6Ew4q\nIiIi6gJD2Ykl6kzs0xIRmau7Mi9fXFyc3kRubq5Go8Gtid3Bgwf/4Q9/cHFxERddXV1ZBt0kHBwc\nHBwcgoKCWtqgtrY2Ly9PvI6Sn58vXlY5derUnj17CgoKxDMrk8nc3NxUenx8fFQqVZ8+fbrwUIiI\niIiIiIiIiIjuiLnn5aurq1NSUi5dunTp0qXU1FQxBV9eXo5bs7RjxoxRqVTe3t7u7u79+vWTyWSm\nDpzax8rKysfHx8fHp+lLGo2moKAgJycnMzNTdxnmp59+ys3N1Wq1AGxtbcWzHxAQEBISEhQUFBQU\n1KNHjy4/CCIiIiIiIiIiIqK2mVdevrKyMjk5OSkpKTk5WczFZ2ZmarVahULh6+sbGBgo5t9Fnp6e\nrGpyL5DJZK6urq6urkOHDtVfX19fr5+pT09P/+6779atW9fQ0CCVSr28vIIbhYSEBAYG2tjYmOoQ\niIiIiIiIiIiIiHRMnJcvLCw8e/bs77//fvbs2bNnz6anpwOwtbUNCAgICgqaO3duYGBgcHCwSqWS\ny83rEgKZnIWFhb+/v7+/v/7KhoaG9PR08e6KlJSUw4cPb9mypaysDIBKpRqox9HR0USBExERERER\nERER0T2tq5Pd+fn5Z86c0SXic3NzZTKZv7//wIEDX3rppf79+wcFBbm4uHRxVNRtKBSKgICAgIAA\n/ZX5+fnJycmJiYlnz57dsWPHsmXLtFqtu7v7gAEDxBx9REQEf+qIiIiIiIiIiIioa3R6Xl4QhJSU\nlGON0tPTFQpFcHDwgAEDFi9ePHDgwPDwcBYYoU7l4uLi4uLy0EMPiYuVlZWJiYnixaG9e/euXr1a\nrVb7+PhERkaOGDEiMjIyMDBQIpGYNmYiIiIiIiIiIiLqrjolL6/VauPj448ePXr06NETJ04UFxfb\n2NgMGzbsj3/846hRowYPHmxlZdUZ+yUyho2NTVRUVFRUlLhYU1MTHx9/5MiRuLi4P//5z1VVVY6O\njsOHDx8xYsSIESMGDRoklUpNGzARERERERERERF1Jx2Zly8vL//xxx9jY2MPHDhQVFRkZ2cXFRW1\nePHikSNHRkREsEA8macePXqMHDly5MiRABoaGhISEuLi4uLi4latWlVWVta3b98JEyZMnDhx7Nix\nPXv2NHWwREREREREREREdNfrgFz51atXY2NjY2Nj4+Li1Gr1wIED58+fP2HChMGDB3OgMd1dFArF\n0KFDhw4dunjxYo1GEx8ff+DAgQMHDvz73/9WKBSjRo2aNGnSo48+qlKpTB0pERERERERERER3a0k\ngiDc3juLi4u//PLL//znP/Hx8XZ2dmPHjp0wYcK4ceP69evXsSESmdz169cPHjx44MCBn376qbS0\ndMiQIdOnT582bZqDg4OpQ/s/U6dO3Y3diDF1HETdSQzwFG77f0kiIiJqk0QiwdfAVFPHQdSNsU9L\nRGSu2j2eva6u7ptvvpk8ebKLi8tf//pXf3////73v0VFRV9//fXMmTOZlKduqV+/frNmzYqJiSkq\nKjp48KCvr++SJUtcXV0ff/zxvXv31tXVmTpAIiIiIiIiIiIiumu0Iy9fXFy8Zs0aDw+PqVOnVlZW\nfvrppwUFBTt37nzkkUdYO57uEXK5fNy4cbt27SooKNiyZUt5eXl0dLSHh8fatWtv3Lhh6uiIiIiI\niIiIiIjoLmBUXj4/P/+ll17y8PB4//33586dm5mZ+b///W/WrFmcBpPuWT179pw1a9Yvv/ySkZEx\nZ86cDRs2eHh4LFiw4Nq1a6YOjYiIiIiIiIiIiMxaG3n5ysrKpUuX+vn57du3b926ddnZ2atXr3Z3\nd++a4DrPjRs3vv3227Vr15o6kI5XXl7eGc1240/sDnl4eKxZsyY7O/vdd9/du3evn5/fm2++WVVV\nZeq4Os4N4FuAZ/6u0KknKw34B7AOuNI57bcLfyyJiIiIup97pzdLRETUel7+4MGDwcHBW7Zseeut\nt9LS0l566aUePXp0WWRNHTt27G9/+5tEIpFIJHPmzPn+++9vr52UlJR33nlnypQpO3bs6MDw4uLi\npk+fLoYnzoI7ZMiQ8ePHf/zxxzU1NQYbh4SEvPjii7exF0EQPv3009DQ0P79+/v4+Ii7++WXXwBs\n3Lhx9OjRtzcTqSAI77777tKlS0eMGBEaGpqcnKy/Ri6Xz5w50/hPLCMjY/z48WPGjDl9+rT++ry8\nvO3bt0+dOnXYsGGtB/Phhx9GR0cvX7582rRpW7Zs0Z+j5vTp0w899NC4ceOysrJu40g7iVKpfPnl\nl9PS0t54442PP/44JCTkxx9/NHVQHSEFeAeYArTrdyUP2A5MBVo7z4AAfAhEA8uBacAWoNm5iH4F\nJIAdMBC4H5AAVsD9QH/AGpAAJrlFwYRRJQEbG58LwLvAUmAEIAdmtv9kGaMCmAs8DowAXgN8m2yw\nCZB09E5bcXs/lgDUwFtAbqcERURERNQ89mYNsDd759itJSLqHoTmaLXalStXSqXSp556qqCgoNlt\nTMXLywtAbW3tnTSiVqsBBAQEdFRUIjH/7uvrKy5qtdrDhw/7+Ph4enqeO3dOf8sHH3zwr3/9623s\nYtOmTQC++eYbcfHgwYO2trY7duwQBKG+vt7Jyamlc9q6devWOTo6ajSakpKSCRMmHDt2zGDNkSNH\njP/EpkyZAiA1NbXpS+Jw/tbbWbFihZ+fX1VVlSAIVVVVfn5+q1at0t8gJSUFwNSpU9tziF3n2rVr\n0dHRUql09erVWq22a3Y
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"execution_count": 30,
"metadata": {
"image/png": {
"width": 1000
}
},
"output_type": "execute_result"
}
],
"source": [
"pydotprint(cost_and_perform_updates, outfile='pydotprint_cost_and_perform_updates.png')\n",
"Image('pydotprint_cost_and_perform_updates.png', width=1000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Advanced Topics\n",
"## Extending Theano\n",
"### The easy way: Python"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import theano\n",
"import numpy\n",
"from theano.compile.ops import as_op\n",
"\n",
"def infer_shape_numpy_dot(node, input_shapes):\n",
" ashp, bshp = input_shapes\n",
" return [ashp[:-1] + bshp[-1:]]\n",
"\n",
"@as_op(itypes=[theano.tensor.fmatrix, theano.tensor.fmatrix],\n",
" otypes=[theano.tensor.fmatrix], infer_shape=infer_shape_numpy_dot)\n",
"def numpy_dot(a, b):\n",
" return numpy.dot(a, b)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.4.3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}