Этот фрагмент кода на C ++, использующий Armadillo, является как Backward, так и Forward частью нейронной сети LSTM seq2seq char toy, которую я кодирую. Я успешно писал Dense and Vanilla RNN.
Я видел различные сайты о градиентах, и даже я сделал «схему» лист, чтобы помочь мне Вот.
когда стандартное восточное время включен, LSTM ничего не изучает (возвращает только «пробелы»
). Но когда стандартное восточное время удаляется, он учится в небольших наборах данных и что-то в средних наборах данных.
// backward
for (n = out.n_rows - 2; n >= 0 ; n--){e = out(n,6) - pair.row(n + 1); // Out - Target (next char in sequence)
esum += pair.row(n + 1) % -arma::log(out(n,6) + 1e-12);
if (n > 0) st1 = out(n - 1,5); else st1 = st1.ones();
if (n > 0) ct1 = out(n - 1,4); else ct1 = ct1.ones();
// Out Softmax
beta = e % dsoftmax(out(n,6));
Dout += out(n,5).t() * beta;
DBout += beta;
e = beta * Wout.t() + est; // When est (h-1) is removed LSTM learns few usefull things but when is included nothing happen ...
t = arma::join_rows(st1,pair.row(n));
t = t.t();
// Wo
beta = (e % this->tanh(out(n,4)/* + ect*/)) % dsigmoid(out(n,2));
Do += t * beta;
DBo += beta;
temp = beta * Wo.t();
est = temp.subvec(0,middle - 1); // Backpropagate state (h-1)
e = ( e % out(n,2) % dtanh(out(n,4)) ) + ect;
// Wg
beta = e % out(n,0) % dtanh(out(n,3));
Dg += t * beta;
DBg += beta;
temp = beta * Wg.t();
est += temp.subvec(0,middle - 1);
// Wi
beta = e % out(n,3) % dsigmoid(out(n,0));
Di += t * beta;
DBi += beta;
temp = beta * Wi.t();
est += temp.subvec(0,middle - 1);// Wf
beta = (e % ct1) % dsigmoid(out(n,1));
Df += t * beta;
DBf += beta;
temp = beta * Wf.t();
est += temp.subvec(0,middle - 1);
ect = e % out(n,1);
}// backw
// Forward
for (a = 0; a < steps; a++) {
rowvec v = arma::join_rows(st,X.row(a)); // previous st concat with x of X ( X.row(a) )
i = this->sigmoid(v * Wi + Bi);
f = this->sigmoid(v * Wf + Bf);
o = this->sigmoid(v * Wo + Bo);
g = this->tanh(v * Wg + Bg);
ct = ct % f + g % i;
st = this->tanh(ct) % o;
outs = this->softmax(st * Wout + Bout);
out(a,0) = i;
out(a,1) = f;
out(a,2) = o;
out(a,3) = g;
out(a,4) = ct;
out(a,5) = st;
out(a,6) = outs;
}
return out;
Задача ещё не решена.
Других решений пока нет …