-
Notifications
You must be signed in to change notification settings - Fork 417
/
temporal_shift.py
executable file
·203 lines (170 loc) · 7.23 KB
/
temporal_shift.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
# arXiv:1811.08383
# Ji Lin*, Chuang Gan, Song Han
# {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
import torch
import torch.nn as nn
import torch.nn.functional as F
class TemporalShift(nn.Module):
def __init__(self, net, n_segment=3, n_div=8, inplace=False):
super(TemporalShift, self).__init__()
self.net = net
self.n_segment = n_segment
self.fold_div = n_div
self.inplace = inplace
if inplace:
print('=> Using in-place shift...')
print('=> Using fold div: {}'.format(self.fold_div))
def forward(self, x):
x = self.shift(x, self.n_segment, fold_div=self.fold_div, inplace=self.inplace)
return self.net(x)
@staticmethod
def shift(x, n_segment, fold_div=3, inplace=False):
nt, c, h, w = x.size()
n_batch = nt // n_segment
x = x.view(n_batch, n_segment, c, h, w)
fold = c // fold_div
if inplace:
# Due to some out of order error when performing parallel computing.
# May need to write a CUDA kernel.
raise NotImplementedError
# out = InplaceShift.apply(x, fold)
else:
out = torch.zeros_like(x)
out[:, :-1, :fold] = x[:, 1:, :fold] # shift left
out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right
out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift
return out.view(nt, c, h, w)
class InplaceShift(torch.autograd.Function):
# Special thanks to @raoyongming for the help to this function
@staticmethod
def forward(ctx, input, fold):
# not support higher order gradient
# input = input.detach_()
ctx.fold_ = fold
n, t, c, h, w = input.size()
buffer = input.data.new(n, t, fold, h, w).zero_()
buffer[:, :-1] = input.data[:, 1:, :fold]
input.data[:, :, :fold] = buffer
buffer.zero_()
buffer[:, 1:] = input.data[:, :-1, fold: 2 * fold]
input.data[:, :, fold: 2 * fold] = buffer
return input
@staticmethod
def backward(ctx, grad_output):
# grad_output = grad_output.detach_()
fold = ctx.fold_
n, t, c, h, w = grad_output.size()
buffer = grad_output.data.new(n, t, fold, h, w).zero_()
buffer[:, 1:] = grad_output.data[:, :-1, :fold]
grad_output.data[:, :, :fold] = buffer
buffer.zero_()
buffer[:, :-1] = grad_output.data[:, 1:, fold: 2 * fold]
grad_output.data[:, :, fold: 2 * fold] = buffer
return grad_output, None
class TemporalPool(nn.Module):
def __init__(self, net, n_segment):
super(TemporalPool, self).__init__()
self.net = net
self.n_segment = n_segment
def forward(self, x):
x = self.temporal_pool(x, n_segment=self.n_segment)
return self.net(x)
@staticmethod
def temporal_pool(x, n_segment):
nt, c, h, w = x.size()
n_batch = nt // n_segment
x = x.view(n_batch, n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w
x = F.max_pool3d(x, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))
x = x.transpose(1, 2).contiguous().view(nt // 2, c, h, w)
return x
def make_temporal_shift(net, n_segment, n_div=8, place='blockres', temporal_pool=False):
if temporal_pool:
n_segment_list = [n_segment, n_segment // 2, n_segment // 2, n_segment // 2]
else:
n_segment_list = [n_segment] * 4
assert n_segment_list[-1] > 0
print('=> n_segment per stage: {}'.format(n_segment_list))
import torchvision
if isinstance(net, torchvision.models.ResNet):
if place == 'block':
def make_block_temporal(stage, this_segment):
blocks = list(stage.children())
print('=> Processing stage with {} blocks'.format(len(blocks)))
for i, b in enumerate(blocks):
blocks[i] = TemporalShift(b, n_segment=this_segment, n_div=n_div)
return nn.Sequential(*(blocks))
net.layer1 = make_block_temporal(net.layer1, n_segment_list[0])
net.layer2 = make_block_temporal(net.layer2, n_segment_list[1])
net.layer3 = make_block_temporal(net.layer3, n_segment_list[2])
net.layer4 = make_block_temporal(net.layer4, n_segment_list[3])
elif 'blockres' in place:
n_round = 1
if len(list(net.layer3.children())) >= 23:
n_round = 2
print('=> Using n_round {} to insert temporal shift'.format(n_round))
def make_block_temporal(stage, this_segment):
blocks = list(stage.children())
print('=> Processing stage with {} blocks residual'.format(len(blocks)))
for i, b in enumerate(blocks):
if i % n_round == 0:
blocks[i].conv1 = TemporalShift(b.conv1, n_segment=this_segment, n_div=n_div)
return nn.Sequential(*blocks)
net.layer1 = make_block_temporal(net.layer1, n_segment_list[0])
net.layer2 = make_block_temporal(net.layer2, n_segment_list[1])
net.layer3 = make_block_temporal(net.layer3, n_segment_list[2])
net.layer4 = make_block_temporal(net.layer4, n_segment_list[3])
else:
raise NotImplementedError(place)
def make_temporal_pool(net, n_segment):
import torchvision
if isinstance(net, torchvision.models.ResNet):
print('=> Injecting nonlocal pooling')
net.layer2 = TemporalPool(net.layer2, n_segment)
else:
raise NotImplementedError
if __name__ == '__main__':
# test inplace shift v.s. vanilla shift
tsm1 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=False)
tsm2 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=True)
print('=> Testing CPU...')
# test forward
with torch.no_grad():
for i in range(10):
x = torch.rand(2 * 8, 3, 224, 224)
y1 = tsm1(x)
y2 = tsm2(x)
assert torch.norm(y1 - y2).item() < 1e-5
# test backward
with torch.enable_grad():
for i in range(10):
x1 = torch.rand(2 * 8, 3, 224, 224)
x1.requires_grad_()
x2 = x1.clone()
y1 = tsm1(x1)
y2 = tsm2(x2)
grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0]
grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0]
assert torch.norm(grad1 - grad2).item() < 1e-5
print('=> Testing GPU...')
tsm1.cuda()
tsm2.cuda()
# test forward
with torch.no_grad():
for i in range(10):
x = torch.rand(2 * 8, 3, 224, 224).cuda()
y1 = tsm1(x)
y2 = tsm2(x)
assert torch.norm(y1 - y2).item() < 1e-5
# test backward
with torch.enable_grad():
for i in range(10):
x1 = torch.rand(2 * 8, 3, 224, 224).cuda()
x1.requires_grad_()
x2 = x1.clone()
y1 = tsm1(x1)
y2 = tsm2(x2)
grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0]
grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0]
assert torch.norm(grad1 - grad2).item() < 1e-5
print('Test passed.')