data-science-ipython-notebooks/deep-learning/keras-tutorial/3.3 (Extra) LSTM for Sentence Generation.ipynb
2017-08-26 08:47:27 -04:00

965 lines
25 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Credits: Forked from [deep-learning-keras-tensorflow](https://github.com/leriomaggio/deep-learning-keras-tensorflow) by Valerio Maggio"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"\n",
"# RNN using LSTM \n",
" \n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"imgs/RNN-rolled.png\"/ width=\"80px\" height=\"80px\">"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"imgs/RNN-unrolled.png\"/ width=\"400px\" height=\"400px\">"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"imgs/LSTM3-chain.png\"/ width=\"60%\">"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_source: http://colah.github.io/posts/2015-08-Understanding-LSTMs_"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from keras.optimizers import SGD\n",
"from keras.preprocessing.text import one_hot, text_to_word_sequence, base_filter\n",
"from keras.utils import np_utils\n",
"from keras.models import Sequential\n",
"from keras.layers.core import Dense, Dropout, Activation\n",
"from keras.layers.embeddings import Embedding\n",
"from keras.layers.recurrent import LSTM, GRU\n",
"from keras.preprocessing import sequence"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Reading blog post from data directory"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import os\n",
"import pickle\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/home/valerio/deep-learning-keras-euroscipy2016/data\n"
]
}
],
"source": [
"DATA_DIRECTORY = os.path.join(os.path.abspath(os.path.curdir), 'data')\n",
"print(DATA_DIRECTORY)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"male_posts = []\n",
"female_post = []"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"with open(os.path.join(DATA_DIRECTORY,\"male_blog_list.txt\"),\"rb\") as male_file:\n",
" male_posts= pickle.load(male_file)\n",
" \n",
"with open(os.path.join(DATA_DIRECTORY,\"female_blog_list.txt\"),\"rb\") as female_file:\n",
" female_posts = pickle.load(female_file)"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"filtered_male_posts = list(filter(lambda p: len(p) > 0, male_posts))\n",
"filtered_female_posts = list(filter(lambda p: len(p) > 0, female_posts))"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# text processing - one hot builds index of the words\n",
"male_one_hot = []\n",
"female_one_hot = []\n",
"n = 30000\n",
"for post in filtered_male_posts:\n",
" try:\n",
" male_one_hot.append(one_hot(post, n, split=\" \", filters=base_filter(), lower=True))\n",
" except:\n",
" continue\n",
"\n",
"for post in filtered_female_posts:\n",
" try:\n",
" female_one_hot.append(one_hot(post,n,split=\" \",filters=base_filter(),lower=True))\n",
" except:\n",
" continue"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# 0 for male, 1 for female\n",
"concatenate_array_rnn = np.concatenate((np.zeros(len(male_one_hot)),\n",
" np.ones(len(female_one_hot))))"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from sklearn.cross_validation import train_test_split\n",
"\n",
"X_train_rnn, X_test_rnn, y_train_rnn, y_test_rnn = train_test_split(np.concatenate((female_one_hot,male_one_hot)),\n",
" concatenate_array_rnn, \n",
" test_size=0.2)"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"X_train_rnn shape: (3873, 100) (3873,)\n",
"X_test_rnn shape: (969, 100) (969,)\n"
]
}
],
"source": [
"maxlen = 100\n",
"X_train_rnn = sequence.pad_sequences(X_train_rnn, maxlen=maxlen)\n",
"X_test_rnn = sequence.pad_sequences(X_test_rnn, maxlen=maxlen)\n",
"print('X_train_rnn shape:', X_train_rnn.shape, y_train_rnn.shape)\n",
"print('X_test_rnn shape:', X_test_rnn.shape, y_test_rnn.shape)"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"max_features = 30000\n",
"dimension = 128\n",
"output_dimension = 128\n",
"model = Sequential()\n",
"model.add(Embedding(max_features, dimension))\n",
"model.add(LSTM(output_dimension))\n",
"model.add(Dropout(0.5))\n",
"model.add(Dense(1))\n",
"model.add(Activation('sigmoid'))"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"model.compile(loss='mean_squared_error', optimizer='sgd', metrics=['accuracy'])"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 3873 samples, validate on 969 samples\n",
"Epoch 1/4\n",
"3873/3873 [==============================] - 3s - loss: 0.2487 - acc: 0.5378 - val_loss: 0.2506 - val_acc: 0.5191\n",
"Epoch 2/4\n",
"3873/3873 [==============================] - 3s - loss: 0.2486 - acc: 0.5401 - val_loss: 0.2508 - val_acc: 0.5191\n",
"Epoch 3/4\n",
"3873/3873 [==============================] - 3s - loss: 0.2484 - acc: 0.5417 - val_loss: 0.2496 - val_acc: 0.5191\n",
"Epoch 4/4\n",
"3873/3873 [==============================] - 3s - loss: 0.2484 - acc: 0.5399 - val_loss: 0.2502 - val_acc: 0.5191\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7fa1e96ac4e0>"
]
},
"execution_count": 92,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(X_train_rnn, y_train_rnn, batch_size=32,\n",
" nb_epoch=4, validation_data=(X_test_rnn, y_test_rnn))"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"969/969 [==============================] - 0s \n"
]
}
],
"source": [
"score, acc = model.evaluate(X_test_rnn, y_test_rnn, batch_size=32)"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.250189056399 0.519091847357\n"
]
}
],
"source": [
"print(score, acc)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Using TFIDF Vectorizer as an input instead of one hot encoder"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.feature_extraction.text import TfidfVectorizer"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"vectorizer = TfidfVectorizer(decode_error='ignore', norm='l2', min_df=5)\n",
"tfidf_male = vectorizer.fit_transform(filtered_male_posts)\n",
"tfidf_female = vectorizer.fit_transform(filtered_female_posts)"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"flattened_array_tfidf_male = tfidf_male.toarray()\n",
"flattened_array_tfidf_female = tfidf_male.toarray()"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"y_rnn = np.concatenate((np.zeros(len(flattened_array_tfidf_male)),\n",
" np.ones(len(flattened_array_tfidf_female))))"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"X_train_rnn, X_test_rnn, y_train_rnn, y_test_rnn = train_test_split(np.concatenate((flattened_array_tfidf_male, \n",
" flattened_array_tfidf_female)),\n",
" y_rnn,test_size=0.2)"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"X_train_rnn shape: (4152, 100) (4152,)\n",
"X_test_rnn shape: (1038, 100) (1038,)\n"
]
}
],
"source": [
"maxlen = 100\n",
"X_train_rnn = sequence.pad_sequences(X_train_rnn, maxlen=maxlen)\n",
"X_test_rnn = sequence.pad_sequences(X_test_rnn, maxlen=maxlen)\n",
"print('X_train_rnn shape:', X_train_rnn.shape, y_train_rnn.shape)\n",
"print('X_test_rnn shape:', X_test_rnn.shape, y_test_rnn.shape)"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"max_features = 30000\n",
"model = Sequential()\n",
"model.add(Embedding(max_features, dimension))\n",
"model.add(LSTM(output_dimension))\n",
"model.add(Dropout(0.5))\n",
"model.add(Dense(1))\n",
"model.add(Activation('sigmoid'))"
]
},
{
"cell_type": "code",
"execution_count": 102,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"model.compile(loss='mean_squared_error',optimizer='sgd', metrics=['accuracy'])"
]
},
{
"cell_type": "code",
"execution_count": 103,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 4152 samples, validate on 1038 samples\n",
"Epoch 1/4\n",
"4152/4152 [==============================] - 3s - loss: 0.2502 - acc: 0.4988 - val_loss: 0.2503 - val_acc: 0.4865\n",
"Epoch 2/4\n",
"4152/4152 [==============================] - 3s - loss: 0.2507 - acc: 0.4843 - val_loss: 0.2500 - val_acc: 0.4865\n",
"Epoch 3/4\n",
"4152/4152 [==============================] - 3s - loss: 0.2504 - acc: 0.4952 - val_loss: 0.2501 - val_acc: 0.4865\n",
"Epoch 4/4\n",
"4152/4152 [==============================] - 3s - loss: 0.2506 - acc: 0.4913 - val_loss: 0.2500 - val_acc: 0.5135\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7fa1f466f278>"
]
},
"execution_count": 103,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(X_train_rnn, y_train_rnn, \n",
" batch_size=32, nb_epoch=4,\n",
" validation_data=(X_test_rnn, y_test_rnn))"
]
},
{
"cell_type": "code",
"execution_count": 104,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1038/1038 [==============================] - 0s \n"
]
}
],
"source": [
"score,acc = model.evaluate(X_test_rnn, y_test_rnn, \n",
" batch_size=32)"
]
},
{
"cell_type": "code",
"execution_count": 105,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.249981284572 0.513487476145\n"
]
}
],
"source": [
"print(score, acc)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Sentence Generation using LSTM"
]
},
{
"cell_type": "code",
"execution_count": 106,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# reading all the male text data into one string\n",
"male_post = ' '.join(filtered_male_posts)\n",
"\n",
"#building character set for the male posts\n",
"character_set_male = set(male_post)\n",
"#building two indices - character index and index of character\n",
"char_indices = dict((c, i) for i, c in enumerate(character_set_male))\n",
"indices_char = dict((i, c) for i, c in enumerate(character_set_male))\n",
"\n",
"\n",
"# cut the text in semi-redundant sequences of maxlen characters\n",
"maxlen = 20\n",
"step = 1\n",
"sentences = []\n",
"next_chars = []\n",
"for i in range(0, len(male_post) - maxlen, step):\n",
" sentences.append(male_post[i : i + maxlen])\n",
" next_chars.append(male_post[i + maxlen])\n"
]
},
{
"cell_type": "code",
"execution_count": 107,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(2552476, 20, 152) (2552476, 152)\n",
"(2552476, 20, 152) (2552476, 152)\n"
]
}
],
"source": [
"#Vectorisation of input\n",
"x_male = np.zeros((len(male_post), maxlen, len(character_set_male)), dtype=np.bool)\n",
"y_male = np.zeros((len(male_post), len(character_set_male)), dtype=np.bool)\n",
"\n",
"print(x_male.shape, y_male.shape)\n",
"\n",
"for i, sentence in enumerate(sentences):\n",
" for t, char in enumerate(sentence):\n",
" x_male[i, t, char_indices[char]] = 1\n",
" y_male[i, char_indices[next_chars[i]]] = 1\n",
"\n",
"print(x_male.shape, y_male.shape)"
]
},
{
"cell_type": "code",
"execution_count": 109,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Build model...\n"
]
}
],
"source": [
"\n",
"# build the model: a single LSTM\n",
"print('Build model...')\n",
"model = Sequential()\n",
"model.add(LSTM(128, input_shape=(maxlen, len(character_set_male))))\n",
"model.add(Dense(len(character_set_male)))\n",
"model.add(Activation('softmax'))\n",
"\n",
"optimizer = RMSprop(lr=0.01)\n",
"model.compile(loss='categorical_crossentropy', optimizer=optimizer)"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"auto_text_generating_male_model.compile(loss='mean_squared_error',optimizer='sgd')"
]
},
{
"cell_type": "code",
"execution_count": 110,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import random, sys"
]
},
{
"cell_type": "code",
"execution_count": 111,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# helper function to sample an index from a probability array\n",
"def sample(a, diversity=0.75):\n",
" if random.random() > diversity:\n",
" return np.argmax(a)\n",
" while 1:\n",
" i = random.randint(0, len(a)-1)\n",
" if a[i] > random.random():\n",
" return i"
]
},
{
"cell_type": "code",
"execution_count": 113,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"--------------------------------------------------\n",
"Iteration 1\n",
"Epoch 1/1\n",
"2552476/2552476 [==============================] - 226s - loss: 1.8022 \n",
"\n",
"----- diversity: 0.2\n",
"----- Generating with seed: \"p from the lack of \"\n",
"sense of the search \n",
"\n",
"\n",
"----- diversity: 0.4\n",
"----- Generating with seed: \"p from the lack of \"\n",
"through that possibl\n",
"\n",
"\n",
"----- diversity: 0.6\n",
"----- Generating with seed: \"p from the lack of \"\n",
". This is a \" by p\n",
"\n",
"\n",
"----- diversity: 0.8\n",
"----- Generating with seed: \"p from the lack of \"\n",
"d he latermal ta we \n",
"\n",
"\n",
"--------------------------------------------------\n",
"Iteration 2\n",
"Epoch 1/1\n",
"2552476/2552476 [==============================] - 228s - loss: 1.7312 \n",
"\n",
"----- diversity: 0.2\n",
"----- Generating with seed: \"s Last Dance\" with t\"\n",
" screening on the st\n",
"\n",
"\n",
"----- diversity: 0.4\n",
"----- Generating with seed: \"s Last Dance\" with t\"\n",
"r song think of the \n",
"\n",
"\n",
"----- diversity: 0.6\n",
"----- Generating with seed: \"s Last Dance\" with t\"\n",
". I'm akin computer \n",
"\n",
"\n",
"----- diversity: 0.8\n",
"----- Generating with seed: \"s Last Dance\" with t\"\n",
"played that comment \n",
"\n",
"\n",
"--------------------------------------------------\n",
"Iteration 3\n",
"Epoch 1/1\n",
"2552476/2552476 [==============================] - 229s - loss: 1.8693 \n",
"\n",
"----- diversity: 0.2\n",
"----- Generating with seed: \", as maybe someone w\"\n",
"the ssone the so the\n",
"\n",
"\n",
"----- diversity: 0.4\n",
"----- Generating with seed: \", as maybe someone w\"\n",
"the sasd nouts and t\n",
"\n",
"\n",
"----- diversity: 0.6\n",
"----- Generating with seed: \", as maybe someone w\"\n",
"p hin I had at f¿ to\n",
"\n",
"\n",
"----- diversity: 0.8\n",
"----- Generating with seed: \", as maybe someone w\"\n",
"oge rely bluy leanda\n",
"\n",
"\n",
"--------------------------------------------------\n",
"Iteration 4\n",
"Epoch 1/1\n",
"2552476/2552476 [==============================] - 228s - loss: 1.9135 \n",
"\n",
"----- diversity: 0.2\n",
"----- Generating with seed: \"o the package :(. Ah\"\n",
" suadedbe teacher th\n",
"\n",
"\n",
"----- diversity: 0.4\n",
"----- Generating with seed: \"o the package :(. Ah\"\n",
"e a searingly the id\n",
"\n",
"\n",
"----- diversity: 0.6\n",
"----- Generating with seed: \"o the package :(. Ah\"\n",
"propost the bure so \n",
"\n",
"\n",
"----- diversity: 0.8\n",
"----- Generating with seed: \"o the package :(. Ah\"\n",
"ing.Lever fan. By in\n",
"\n",
"\n",
"--------------------------------------------------\n",
"Iteration 5\n",
"Epoch 1/1\n",
"2552476/2552476 [==============================] - 229s - loss: 4.5892 \n",
"\n",
"----- diversity: 0.2\n",
"----- Generating with seed: \"ot as long as my fri\"\n",
"atde getu th> QQ.“]\n",
"\n",
"\n",
"----- diversity: 0.4\n",
"----- Generating with seed: \"ot as long as my fri\"\n",
"tQ t[we QaaefYhere Q\n",
"\n",
"\n",
"----- diversity: 0.6\n",
"----- Generating with seed: \"ot as long as my fri\"\n",
"ew[”*ing”e[ t[w that\n",
"\n",
"\n",
"----- diversity: 0.8\n",
"----- Generating with seed: \"ot as long as my fri\"\n",
" me]sQoonQ“]e” ti nw\n",
"\n",
"\n",
"--------------------------------------------------\n",
"Iteration 6\n",
"Epoch 1/1\n",
"2552476/2552476 [==============================] - 229s - loss: 6.7174 \n",
"\n",
"----- diversity: 0.2\n",
"----- Generating with seed: \"use I'm pretty damn \"\n",
"me g 'o a a a a\n",
"\n",
"\n",
"----- diversity: 0.4\n",
"----- Generating with seed: \"use I'm pretty damn \"\n",
" a o theT a o a \n",
"\n",
"\n",
"----- diversity: 0.6\n",
"----- Generating with seed: \"use I'm pretty damn \"\n",
" n . thot auupe to \n",
"\n",
"\n",
"----- diversity: 0.8\n",
"----- Generating with seed: \"use I'm pretty damn \"\n",
" tomalek ho tt Ion i\n",
"\n",
"\n",
"--------------------------------------------------\n",
"Iteration 7\n",
"Epoch 1/1\n",
"2552476/2552476 [==============================] - 227s - loss: 6.9138 \n",
"\n",
"----- diversity: 0.2\n",
"----- Generating with seed: \"ats all got along be\"\n",
" thrtg t ia thv i c\n",
"\n",
"\n",
"----- diversity: 0.4\n",
"----- Generating with seed: \"ats all got along be\"\n",
"th wtot.. t to gt? \n",
"\n",
"\n",
"----- diversity: 0.6\n",
"----- Generating with seed: \"ats all got along be\"\n",
" ed dthwnn,is a ment\n",
"\n",
"\n",
"----- diversity: 0.8\n",
"----- Generating with seed: \"ats all got along be\"\n",
" t incow . wmiyit\n",
"\n",
"\n",
"--------------------------------------------------\n",
"Iteration 8\n",
"Epoch 1/1\n",
"2552476/2552476 [==============================] - 228s - loss: 11.0629 \n",
"\n",
"----- diversity: 0.2\n",
"----- Generating with seed: \"oot of my sleeping b\"\n",
"m g te>t e s t anab\n",
"\n",
"\n",
"----- diversity: 0.4\n",
"----- Generating with seed: \"oot of my sleeping b\"\n",
" dttoe s s“snge es s\n",
"\n",
"\n",
"----- diversity: 0.6\n",
"----- Generating with seed: \"oot of my sleeping b\"\n",
"tut hou wen a onap\n",
"\n",
"\n",
"----- diversity: 0.8\n",
"----- Generating with seed: \"oot of my sleeping b\"\n",
"evtyr tt e io on tok\n",
"\n",
"\n",
"--------------------------------------------------\n",
"Iteration 9\n",
"Epoch 1/1\n",
"2552476/2552476 [==============================] - 228s - loss: 8.7874 \n",
"\n",
"----- diversity: 0.2\n",
"----- Generating with seed: \" Ive always looked \"\n",
"ea e ton ann n ffee\n",
"\n",
"\n",
"----- diversity: 0.4\n",
"----- Generating with seed: \" Ive always looked \"\n",
"o tire n a anV sia a\n",
"\n",
"\n",
"----- diversity: 0.6\n",
"----- Generating with seed: \" Ive always looked \"\n",
"r i jooe Vag o en \n",
"\n",
"\n",
"----- diversity: 0.8\n",
"----- Generating with seed: \" Ive always looked \"\n",
" ao at ge ena oro o\n",
"\n"
]
}
],
"source": [
"# train the model, output generated text after each iteration\n",
"for iteration in range(1,10):\n",
" print()\n",
" print('-' * 50)\n",
" print('Iteration', iteration)\n",
" model.fit(x_male, y_male, batch_size=128, nb_epoch=1)\n",
"\n",
" start_index = random.randint(0, len(male_post) - maxlen - 1)\n",
"\n",
" for diversity in [0.2, 0.4, 0.6, 0.8]:\n",
" print()\n",
" print('----- diversity:', diversity)\n",
"\n",
" generated = ''\n",
" sentence = male_post[start_index : start_index + maxlen]\n",
" generated += sentence\n",
" print('----- Generating with seed: \"' + sentence + '\"')\n",
"\n",
" for iteration in range(400):\n",
" try:\n",
" x = np.zeros((1, maxlen, len(character_set_male)))\n",
" for t, char in enumerate(sentence):\n",
" x[0, t, char_indices[char]] = 1.\n",
"\n",
" preds = model.predict(x, verbose=0)[0]\n",
" next_index = sample(preds, diversity)\n",
" next_char = indices_char[next_index]\n",
"\n",
" generated += next_char\n",
" sentence = sentence[1:] + next_char\n",
" except:\n",
" continue\n",
" \n",
" print(sentence)\n",
" print()"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.4.3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}