それほどよく分かってもいませんが、続けていきます。 言語モデルを作っていく上で通常のRNNではできなかった長期依存をできるようにした、Long short-term memory (LSTM) と言うモデルがあるらしいです。
今日はこれをみていこうと思います。 引続き、MXnetのチュートリアルから。
Long short-term memory (LSTM) RNNs — The Straight Dope 0.1 documentation
実装
from __future__ import print_function import mxnet as mx from mxnet import nd, autograd import numpy as np mx.random.seed(1) with open("all_kaiseki.txt") as f: text = f.read() text = text[:-38083] character_list = list(set(text)) vocab_size = len(character_list) character_dict = {} for e, char in enumerate(character_list): character_dict[char] = e time_numerical = [character_dict[char] for char in text] def one_hots(numerical_list, vocab_size=vocab_size): result = nd.zeros((len(numerical_list), vocab_size)) for i, idx in enumerate(numerical_list): result[i, idx] = 1.0 return result def textify(embedding): result = "" indices = nd.argmax(embedding, axis=1).asnumpy() for idx in indices: result += character_list[int(idx)] return result batch_size = 32 seq_length = 64 # -1 here so we have enough characters for labels later num_samples = (len(time_numerical) - 1) // seq_length dataset = one_hots(time_numerical[:seq_length*num_samples]).reshape((num_samples, seq_length, vocab_size)) num_batches = len(dataset) // batch_size train_data = dataset[:num_batches*batch_size].reshape((num_batches, batch_size, seq_length, vocab_size)) # swap batch_size and seq_length axis to make later access easier train_data = nd.swapaxes(train_data, 1, 2) labels = one_hots(time_numerical[1:seq_length*num_samples+1]) train_label = labels.reshape((num_batches, batch_size, seq_length, vocab_size)) train_label = nd.swapaxes(train_label, 1, 2)
ここまではRNNと同じですね。 次からがLSTMの実装部分になるぽいです。
In [8]: num_inputs = vocab_size num_hidden = 256 num_outputs = vocab_size Wxg = nd.random_normal(shape=(num_inputs,num_hidden), ctx=ctx) * .01 Wxi = nd.random_normal(shape=(num_inputs,num_hidden), ctx=ctx) * .01 Wxf = nd.random_normal(shape=(num_inputs,num_hidden), ctx=ctx) * .01 Wxo = nd.random_normal(shape=(num_inputs,num_hidden), ctx=ctx) * .01 Whg = nd.random_normal(shape=(num_hidden,num_hidden), ctx=ctx)* .01 Whi = nd.random_normal(shape=(num_hidden,num_hidden), ctx=ctx)* .01 Whf = nd.random_normal(shape=(num_hidden,num_hidden), ctx=ctx)* .01 Who = nd.random_normal(shape=(num_hidden,num_hidden), ctx=ctx)* .01 bg = nd.random_normal(shape=num_hidden, ctx=ctx) * .01 bi = nd.random_normal(shape=num_hidden, ctx=ctx) * .01 bf = nd.random_normal(shape=num_hidden, ctx=ctx) * .01 bo = nd.random_normal(shape=num_hidden, ctx=ctx) * .01 Why = nd.random_normal(shape=(num_hidden,num_outputs), ctx=ctx) * .01 by = nd.random_normal(shape=num_outputs, ctx=ctx) * .01 params = [Wxg, Wxi, Wxf, Wxo, Whg, Whi, Whf, Who, bg, bi, bf, bo, Why, by] for param in params: param.attach_grad()
パラメーター多すぎ! 次にsoftmaxの設定。
def softmax(y_linear, temperature=1.0): lin = (y_linear-nd.max(y_linear)) / temperature exp = nd.exp(lin) partition = nd.sum(exp, axis=0, exclude=True).reshape((-1,1)) return exp / partition
次にLSTMのモデル
def lstm_rnn(inputs, h, c, temperature=1.0): outputs = [] for X in inputs: g = nd.tanh(nd.dot(X, Wxg) + nd.dot(h, Whg) + bg) i = nd.sigmoid(nd.dot(X, Wxi) + nd.dot(h, Whi) + bi) f = nd.sigmoid(nd.dot(X, Wxf) + nd.dot(h, Whf) + bf) o = nd.sigmoid(nd.dot(X, Wxo) + nd.dot(h, Who) + bo) c = f * c + i * g h = o * nd.tanh(c) yhat_linear = nd.dot(h, Why) + by yhat = softmax(yhat_linear, temperature=temperature) outputs.append(yhat) return (outputs, h, c)
損失関数というものを決めます。損失関数は入力と出力がどれだけあっているか?という計算らしいです。
def cross_entropy(yhat, y): return - nd.mean(nd.sum(y * nd.log(yhat), axis=0, exclude=True)) def average_ce_loss(outputs, labels): assert(len(outputs) == len(labels)) total_loss = 0. for (output, label) in zip(outputs,labels): total_loss = total_loss + cross_entropy(output, label) return total_loss / len(outputs)
最適化とパラメータの処理をしていきます。パラメーターが多すぎ。。。
def SGD(params, lr): for param in params: param[:] = param - lr * param.grad num_inputs = vocab_size num_hidden = 256 num_outputs = vocab_size Wxg = nd.random_normal(shape=(num_inputs,num_hidden)) * .01 Wxi = nd.random_normal(shape=(num_inputs,num_hidden)) * .01 Wxf = nd.random_normal(shape=(num_inputs,num_hidden)) * .01 Wxo = nd.random_normal(shape=(num_inputs,num_hidden)) * .01 Whg = nd.random_normal(shape=(num_hidden,num_hidden))* .01 Whi = nd.random_normal(shape=(num_hidden,num_hidden))* .01 Whf = nd.random_normal(shape=(num_hidden,num_hidden))* .01 Who = nd.random_normal(shape=(num_hidden,num_hidden))* .01 bg = nd.random_normal(shape=num_hidden) * .01 bi = nd.random_normal(shape=num_hidden) * .01 bf = nd.random_normal(shape=num_hidden) * .01 bo = nd.random_normal(shape=num_hidden) * .01 Why = nd.random_normal(shape=(num_hidden,num_outputs)) * .01 by = nd.random_normal(shape=num_outputs) * .01
最適化の部分。SGD(Stochastic Gradient Descent : 確率的勾配降下法)というのがあるそうです。
def SGD(params, lr): for param in params: param[:] = param - lr * param.grad
あとは出力部分になります。
def sample(prefix, num_chars, temperature=1.0): string = prefix prefix_numerical = [character_dict[char] for char in prefix] input = one_hots(prefix_numerical) h = nd.zeros(shape=(1, num_hidden)) c = nd.zeros(shape=(1, num_hidden)) for i in range(num_chars): outputs, h, c = lstm_rnn(input, h, c, temperature=temperature) choice = np.random.choice(vocab_size, p=outputs[-1][0].asnumpy()) string += character_list[choice] input = one_hots([choice]) return string epochs = 2000 moving_loss = 0. learning_rate = 2.0 for e in range(epochs): if ((e+1) % 100 == 0): learning_rate = learning_rate / 2.0 h = nd.zeros(shape=(batch_size, num_hidden)) c = nd.zeros(shape=(batch_size, num_hidden)) for i in range(num_batches): data_one_hot = train_data[i] label_one_hot = train_label[i] with autograd.record(): outputs, h, c = lstm_rnn(data_one_hot, h, c) loss = average_ce_loss(outputs, label_one_hot) loss.backward() SGD(params, learning_rate) if (i == 0) and (e == 0): moving_loss = nd.mean(loss).asscalar() else: moving_loss = .99 * moving_loss + .01 * nd.mean(loss).asscalar() print("Epoch %s. Loss: %s" % (e, moving_loss)) print(sample("こんにち", 1024, temperature=.1))
ちなみに実行結果はこんな感じです。 ぱっと見では、RNNと何が変わったのかわからないですね。
Epoch 17. Loss: 2.71476472981 こんにち を し て くださっ た ? https :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / t . / ttps :// t . / t . / ttps :// t . / t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / t . / ttps :// t . / t . / ttps :// t . / t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / t . / ttps :// t . / t . / ttps :// t . / t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / tpps 9 M 3 3 た ? https :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / ttps :// t . / t . / t
ひとまず今日はここまで。
今日の結果
今日のAKBの呟きはです。
'ない': 3, '楽しい': 3, '可愛い': 3, '良い': 2, 'すごい': 2, '嬉しい': 2, '若い': 2, '騒がしい': 1, '懐かしい': 1, '寒い': 1, '高い': 1, 'くい': 1, '凄い': 1, '悲しい': 1, '近い': 1, 'よい': 1, 'かわいい': 1 'こと': 9, '今日': 9, '日': 9, 'お願い': 8, '公演': 8, '生誕': 7, 'よう': 7, 'ちゃん': 7, 'さん': 6, 'ん': 6, '祭': 5, '私': 5, '写真': 5, '笑': 5, '時': 5, '方': 5, 'の': 4, '夜': 4, 'くるみ': 4, '前': 4, 'アクシュカイ': 4, 'する': 21, 'くださる': 10, 'こと': 9, '今日': 9, '日': 9, 'お願い': 8, '公演': 8, '撮る': 8, '生誕': 7, 'よう': 7, 'ちゃん': 7, 'なる': 7, 'くる': 7, 'れる': 7, 'さん': 6, 'ん': 6, 'ある': 6, 'くれる': 6, '見る': 6, 'いる': 6, '祭': 5, '私': 5, '写真': 5, '笑': 5, '時': 5, '方': 5, '来る': 5, 'てる': 5,
要約するとこんな感じです。
ドキドキ を 買い た 〜 ( ・ ᴗ ・ ) こういう 機会 ない と なかなか 踏み出せ ない ので 楽しかっ た ! ! ! お うま さん みんな 頑張っ て 走っ た ね お つかれ さ ま !" " 小倉 12 レース ! " " こんな チロル チョコ ある ん だ ! ! ! きゅうしゅ ー ! " " テレビ東京 本日 19 : 54 〜 池 の 水 ぜんぶ 抜く ! ! 放送 ! ! 高知 城 ロケ 行か せ て いただき た ! ぜひ 見 て ください ! " " 今日 の 18 時 頃 から showroom しよ う か な と ?? という こと に なっ た ので 、 、 よろしく お願い しま ー す ♪ " " この間 の showroom で 話し て た 握手会 で 十 夢 と セーラー服 を 着よ う か な 〜 って やつ ! 11 / 23 に やる こと に し た ?❤ ️ だから 皆さん 来 て ください ね ‼ ︎ 申し込み は 明日 の 13 … " " 本日 、 2 0 歳 に なり た ! 支え て 下さっ て いる 全て の 方 に ただただ 感謝 。 ここ まで 生き て こ れ た の も 奇跡 。 。 これから も 頑張っ て いき ? "# A ビート イベント 開催 中 ! ! という こと で … 少し だけ やっ て み た \(^ o ^)/ 10 月 31 日 AiiA シアター で お待ち し て ✨✨ # きょう の せい ちゃん " " メンバー に も 、 あずき に も 聞か れ た けど 、 、 、 最近 インスタ の 写真 は どこ で 撮っ た の 誰 に 撮っ て もらっ た の ❓ って 実は 、 、 、 自分 で 家 で 撮り た ??