-
Notifications
You must be signed in to change notification settings - Fork 17
/
ReverseUnreverse.lua
51 lines (44 loc) · 1.58 KB
/
ReverseUnreverse.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
----------------------------------------------------------
--[[ ReverseUnreverse ]]--
-- This module is used internally by BiSequencer modules
-- to handle the backward sequence.
-- It reverses the input and output sequences and
-- reverses the zeroMask.
----------------------------------------------------------
local ReverseUnreverse, parent = torch.class("nn.ReverseUnreverse", "nn.Decorator")
function ReverseUnreverse:__init(sequencer)
assert(nn.BiSequencer.isSeq(sequencer), "Expecting AbstractSequencer or SeqLSTM or SeqGRU at arg 1")
parent.__init(self, nn.Sequential()
:add(nn.ReverseSequence()) -- reverse
:add(sequencer)
:add(nn.ReverseSequence()) -- unreverse
)
end
function ReverseUnreverse:setZeroMask(zeroMask)
-- reverse the zeroMask
assert(torch.isTensor(zeroMask))
assert(zeroMask:dim() >= 2)
self._zeroMask = self._zeroMask or zeroMask.new()
self._zeroMask:resizeAs(zeroMask)
self._range = self._range or torch.isCudaTensor(zeroMask) and torch.CudaLongTensor() or torch.LongTensor()
local seqlen = zeroMask:size(1)
if self._range:nElement() ~= seqlen then
self._range:range(seqlen, 1, -1)
end
self._zeroMask:index(zeroMask, 1, self._range)
self.modules[1]:setZeroMask(self._zeroMask)
end
function ReverseUnreverse:reinforce(zeroMask)
error"Not implemented"
end
function ReverseUnreverse:clearState()
self._zeroMask = nil
return parent.clearState(self)
end
function ReverseUnreverse:type(...)
self:clearState()
return parent.type(self, ...)
end
function ReverseUnreverse:getModule()
return self:get(1):get(2)
end