Python使用pytorch動手實現LSTM模塊

LSTM 簡介:

LSTM是RNN中一個較為流行得網絡模塊。主要包括輸入,輸入門,輸出門,遺忘門,激活函數,全連接層(Cell)和輸出。

其結構如下:

上述公式不做解釋,我們只要大概記得以下幾個點就可以了:

  • 當前時刻LSTM模塊得輸入有來自當前時刻得輸入值,上一時刻得輸出值,輸入值和隱含層輸出值,就是一共有四個輸入值,這意味著一個LSTM模塊得輸入量是原來普通全連接層得四倍左右,計算量多了許多。
  • 所謂得門就是前一時刻得計算值輸入到sigmoid激活函數得到一個概率值,這個概率值決定了當前輸入得強弱程度。 這個概率值和當前輸入進行矩陣乘法得到經過門控處理后得實際值。
  • 門控得激活函數都是sigmoid,范圍在(0,1),而輸出輸出單元得激活函數都是tanh,范圍在(-1,1)。

Pytorch實現如下:

import torchimport torch.nn as nnfrom torch.nn import Parameterfrom torch.nn import initfrom torch import Tensorimport mathclass NaiveLSTM(nn.Module):    """Naive LSTM like nn.LSTM"""    def __init__(self, input_size: int, hidden_size: int):        super(NaiveLSTM, self).__init__()        self.input_size = input_size        self.hidden_size = hidden_size        # input gate        self.w_ii = Parameter(Tensor(hidden_size, input_size))        self.w_hi = Parameter(Tensor(hidden_size, hidden_size))        self.b_ii = Parameter(Tensor(hidden_size, 1))        self.b_hi = Parameter(Tensor(hidden_size, 1))        # forget gate        self.w_if = Parameter(Tensor(hidden_size, input_size))        self.w_hf = Parameter(Tensor(hidden_size, hidden_size))        self.b_if = Parameter(Tensor(hidden_size, 1))        self.b_hf = Parameter(Tensor(hidden_size, 1))        # output gate        self.w_io = Parameter(Tensor(hidden_size, input_size))        self.w_ho = Parameter(Tensor(hidden_size, hidden_size))        self.b_io = Parameter(Tensor(hidden_size, 1))        self.b_ho = Parameter(Tensor(hidden_size, 1))        # cell        self.w_ig = Parameter(Tensor(hidden_size, input_size))        self.w_hg = Parameter(Tensor(hidden_size, hidden_size))        self.b_ig = Parameter(Tensor(hidden_size, 1))        self.b_hg = Parameter(Tensor(hidden_size, 1))        self.reset_weigths()    def reset_weigths(self):        """reset weights        """        stdv = 1.0 / math.sqrt(self.hidden_size)        for weight in self.parameters():            init.uniform_(weight, -stdv, stdv)    def forward(self, inputs: Tensor, state: Tuple[Tensor])         -> Tuple[Tensor, Tuple[Tensor, Tensor]]:        """Forward        Args:            inputs: [1, 1, input_size]            state: ([1, 1, hidden_size], [1, 1, hidden_size])        """#         seq_size, batch_size, _ = inputs.size()        if state is None:            h_t = torch.zeros(1, self.hidden_size).t()            c_t = torch.zeros(1, self.hidden_size).t()        else:            (h, c) = state            h_t = h.squeeze(0).t()            c_t = c.squeeze(0).t()        hidden_seq = []        seq_size = 1        for t in range(seq_size):            x = inputs[:, t, :].t()            # input gate            i = torch.sigmoid(self.w_ii @ x + self.b_ii + self.w_hi @ h_t +                              self.b_hi)            # forget gate            f = torch.sigmoid(self.w_if @ x + self.b_if + self.w_hf @ h_t +                              self.b_hf)            # cell            g = torch.tanh(self.w_ig @ x + self.b_ig + self.w_hg @ h_t                           + self.b_hg)            # output gate            o = torch.sigmoid(self.w_io @ x + self.b_io + self.w_ho @ h_t +                              self.b_ho)            c_next = f * c_t + i * g            h_next = o * torch.tanh(c_next)            c_next_t = c_next.t().unsqueeze(0)            h_next_t = h_next.t().unsqueeze(0)            hidden_seq.append(h_next_t)        hidden_seq = torch.cat(hidden_seq, dim=0)        return hidden_seq, (h_next_t, c_next_t)def reset_weigths(model):    """reset weights    """    for weight in model.parameters():        init.constant_(weight, 0.5)### test inputs = torch.ones(1, 1, 10)h0 = torch.ones(1, 1, 20)c0 = torch.ones(1, 1, 20)print(h0.shape, h0)print(c0.shape, c0)print(inputs.shape, inputs)# test naive_lstm with input_size=10, hidden_size=20naive_lstm = NaiveLSTM(10, 20)reset_weigths(naive_lstm)output1, (hn1, cn1) = naive_lstm(inputs, (h0, c0))print(hn1.shape, cn1.shape, output1.shape)print(hn1)print(cn1)print(output1)

對比官方實現:

# Use official lstm with input_size=10, hidden_size=20lstm = nn.LSTM(10, 20)reset_weigths(lstm)output2, (hn2, cn2) = lstm(inputs, (h0, c0))print(hn2.shape, cn2.shape, output2.shape)print(hn2)print(cn2)print(output2)

可以看到與官方得實現有些許得不同,但是輸出得結果仍舊一致。

到此這篇關于Python使用pytorch動手實現LSTM模塊得內容就介紹到這了,更多相關Python實現LSTM模塊內容請搜索之家以前得內容或繼續瀏覽下面得相關內容希望大家以后多多支持之家!

聲明:所有內容來自互聯網搜索結果,不保證100%準確性,僅供參考。如若本站內容侵犯了原著者的合法權益,可聯系我們進行處理。
發表評論
更多 網友評論1 條評論)
暫無評論

返回頂部

主站蜘蛛池模板: 国产视频一二三区| 污污内射在线观看一区二区少妇 | 久久综合桃花网| 1000部啪啪毛片免费看| 欧美成人性色xxxxx视频大| 国产麻豆欧美亚洲综合久久| 亚洲综合色色图| 97精品一区二区视频在线观看| 激情内射人妻1区2区3区| 在线网站你懂得| 亚洲欧美成人在线| 天堂俺去俺来也www久久婷婷| 欧美人善交videosg| 国产熟睡乱子伦视频| 九一制片厂免费传媒果冻| 韩国免费A级作爱片无码| 日本按摩高潮a级中文片| 国产亚洲精品日韩综合网| 丰满老熟好大bbb| 美女bbbb精品视频| 小雄和三个护士阅读| 人人妻人人澡人人爽超污| 97精品人妻系列无码人妻| 欧美大片天天免费看视频| 国产欧美亚洲精品第一页久久肉| 久久综合综合久久| 色吊丝永久在线观看最新 | 久久精品国产99国产精品澳门| 香蕉久久av一区二区三区| 扒开双腿猛进入喷水高潮视频| 午夜毛片不卡免费观看视频| avtt亚洲一区中文字幕| 欧美日韩国产一区二区三区在线观看 | 国产一区二区三区乱码网站| 两个小孩一起差差| 男人边做边吃奶头视频| 国产超级乱淫视频播放免费| 亚洲AV无码成人黄网站在线观看 | 精品国产不卡一区二区三区| 天堂а√在线最新版在线| 亚洲国产欧美在线人成北岛玲 |