-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsparse_weights.py
182 lines (153 loc) · 6.54 KB
/
sparse_weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# ----------------------------------------------------------------------
# Numenta Platform for Intelligent Computing (NuPIC)
# Copyright (C) 2019, Numenta, Inc. Unless you have an agreement
# with Numenta, Inc., for a separate license for this software code, the
# following terms and conditions apply:
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero Public License version 3 as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Affero Public License for more details.
#
# You should have received a copy of the GNU Affero Public License
# along with this program. If not, see http://www.gnu.org/licenses.
#
# http://numenta.org/licenses/
# ----------------------------------------------------------------------
import abc
import math
import warnings
import numpy as np
import torch
import torch.nn as nn
def rezero_weights(m):
"""Function used to update the weights after each epoch.
Call using :meth:`torch.nn.Module.apply` after each epoch if required
For example: ``m.apply(rezero_weights)``
:param m: HasRezeroWeights module
"""
if isinstance(m, HasRezeroWeights):
m.rezero_weights()
class HasRezeroWeights(metaclass=abc.ABCMeta):
@abc.abstractmethod
def rezero_weights(self):
"""Set the previously selected weights to zero."""
raise NotImplementedError
class SparseWeightsBase(nn.Module, HasRezeroWeights):
"""
Base class for the all Sparse Weights modules.
:param module:
The module to sparsify the weights
:param sparsity:
Pct of weights that are zero in the layer.
"""
def __init__(self, module, weight_sparsity=None, sparsity=None):
super(SparseWeightsBase, self).__init__()
assert weight_sparsity is not None or sparsity is not None
if weight_sparsity is not None and sparsity is None:
sparsity = 1.0 - weight_sparsity
warnings.warn(
"Parameter `weight_sparsity` is deprecated. Use `sparsity` instead.",
DeprecationWarning,
)
self.module = module
self.sparsity = sparsity
def extra_repr(self):
return "sparsity={}".format(self.sparsity)
def forward(self, x):
return self.module(x)
@property
def weight_sparsity(self):
warnings.warn(
"Parameter `weight_sparsity` is deprecated. Use `sparsity` instead.",
DeprecationWarning,
)
return 1.0 - self.sparsity
@property
def weight(self):
return self.module.weight
@property
def bias(self):
return self.module.bias
class SparseWeights(SparseWeightsBase):
"""Enforce weight sparsity on linear module during training.
Sample usage:
model = nn.Linear(784, 10)
model = SparseWeights(model, 0.4)
:param module:
The module to sparsify the weights
:param sparsity:
Pct of weights that are zero in the layer.
:param allow_extremes:
Allow values sparsity=0 and sparsity=1. These values are often a sign that
there is a bug in the configuration, because they lead to Identity and
Zero layers, respectively, but they can make sense in scenarios where the
mask is dynamic.
"""
def __init__(self, module, weight_sparsity=None, sparsity=None,
allow_extremes=False):
assert len(module.weight.shape) == 2, "Should resemble a nn.Linear"
super(SparseWeights, self).__init__(
module, weight_sparsity=weight_sparsity, sparsity=sparsity
)
if allow_extremes:
assert 0 <= self.sparsity <= 1
else:
assert 0 < self.sparsity < 1
# For each unit, decide which weights are going to be zero
in_features = self.module.in_features
out_features = self.module.out_features
num_nz = int(round((1 - self.sparsity) * in_features))
zero_mask = torch.ones(out_features, in_features, dtype=torch.bool)
for out_feature in range(out_features):
in_indices = np.random.choice(in_features, num_nz, replace=False)
zero_mask[out_feature, in_indices] = False
# Use float16 because pytorch distributed nccl doesn't support bools
self.register_buffer("zero_mask", zero_mask.half())
self.rezero_weights()
def rezero_weights(self):
self.module.weight.data.masked_fill_(self.zero_mask.bool(), 0)
class SparseWeights2d(SparseWeightsBase):
"""Enforce weight sparsity on CNN modules Sample usage:
model = nn.Conv2d(in_channels, out_channels, kernel_size, ...)
model = SparseWeights2d(model, 0.4)
:param module:
The module to sparsify the weights
:param sparsity:
Pct of weights that are zero in the layer.
:param allow_extremes:
Allow values sparsity=0 and sparsity=1. These values are often a sign that
there is a bug in the configuration, because they lead to Identity and
Zero layers, respectively, but they can make sense in scenarios where the
mask is dynamic.
"""
def __init__(self, module, weight_sparsity=None, sparsity=None,
allow_extremes=True):
assert len(module.weight.shape) == 4, "Should resemble a nn.Conv2d"
super(SparseWeights2d, self).__init__(
module, weight_sparsity=weight_sparsity, sparsity=sparsity
)
if allow_extremes:
assert 0 <= self.sparsity <= 1
else:
assert 0 < self.sparsity < 1
# For each unit, decide which weights are going to be zero
in_channels = self.module.in_channels
out_channels = self.module.out_channels
kernel_size = self.module.kernel_size
input_size = in_channels * kernel_size[0] * kernel_size[1]
num_nz = int(round((1 - self.sparsity) * input_size))
zero_mask = torch.ones(out_channels, input_size, dtype=torch.bool)
for out_channel in range(out_channels):
in_indices = np.random.choice(input_size, num_nz, replace=False)
zero_mask[out_channel, in_indices] = False
zero_mask = zero_mask.view(out_channels, in_channels, *kernel_size)
# Use float16 because pytorch distributed nccl doesn't support bools
self.register_buffer("zero_mask", zero_mask.half())
self.rezero_weights()
def rezero_weights(self):
self.module.weight.data.masked_fill_(self.zero_mask.bool(), 0)