見出し画像

「Deep Learning」って何?--誤差逆伝播法3(加算、活性化関数レイヤの実装)。

参考にしたいサイトです。

まず、乗算の復習しておきます。ざっくりとしたことをいうと、forword()、順伝播はx,yがあるとx*yが順伝播の解です。

backward()、逆伝播はというと順伝播のxとyをひっくり返した値乗算することで解が求められます。

では今回の話です。

加算レイヤの実装です。足し算ですね。この場合は順伝播はx,yという値があるとして、結果としては、x + yと加算します。

class AddLayer:
 def __init__(self):
    pass 
 
 def forward(self, x, y):
    out = x + y
    return out
 
 def backward(self, dout):
    dx = dout * 1 
    dy = dout * 1
   
    return dx, dy

逆伝播はというと、

dx = dout * 1

ということで1をかけても数字は変わらないので、同じデータを出力しますね。

乗算、加算とみてきましたが、次は活性化関数レイヤの実装といきましょう!

最初はReLUレイヤです。ここで大事なところは、"mask"変数です。

"x"が"0"以下のものを"True"と判定して"mask"変数に入れます。そして逆伝播の時にmask変数がTrueの場所を"0"にします。

スイッチの振る舞いです。オンの場合はデータが流れていきます、オフの場合はデータはストップしてしまいます。

class Relu:
 def __init__(self):
   self.mask = None
 
def forward(self, x):
   self.mask = (x <= 0)
   out = x.copy() 
   out[self.mask] = 0
   return out

 def backward(self, dout):
   dout[self.mask] = 0
   dx = dout
   return dx

次にシグモイドです。順伝播の際に"out"変数に入れ保持します。そして逆伝播じに"out"変数を使って計算します。

class Sigmoid:
 def __init__(self):
   self.out = None
 
 def forward(self, x):
   out = 1 / (1 + np.exp(-x))
   self.out = out
   return out
 
 def backward(self, dout):
   dx = dout * (1.0 - self.out) * self.out
   return dx

Affineレイヤです。

class Affine:
 def __init__(self, W, b):
   self.W =W
   self.b = b
   self.x = None
   self.dW = None
   self.db = None

 def forward(self, x):
   self.x = x
   out = np.dot(self.x, self.W) + self.b
   return out

 def backward(self, dout):
   dx = np.dot(dout, self.W.T)
   self.dW = np.dot(self.x.T, dout)
   self.db = np.sum(dout, axis=0)
   return dx

Softmax-with-Loss レイヤです。

class SoftmaxWithLoss:
 def __init__(self):
   self.loss = None
   self.y = None 
   self.t = None # 教師データ(one-hot vector)

 def forward(self, x, t):
   self.t = t
   self.y = softmax(x)
   self.loss = cross_entropy_error(self.y, self.t)
   return self.loss

 def backward(self, dout=1):
   batch_size = self.t.shape[0]
   dx = (self.y - self.t) / batch_size
   return dx

これで誤差逆伝播法の実装の準備ができました。

わかりにくいので実装例で詳しく説明されているサイトがありましたのでリンクしておきます


この記事が気に入ったらサポートをしてみませんか?