Skip to content

Commit f2d82e1

Browse files
committed
Added support for alternate activation functions
1 parent a7c9629 commit f2d82e1

File tree

2 files changed

+32
-24
lines changed

2 files changed

+32
-24
lines changed

README.md

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Fork of Shawn Ng's [CNNs for Sentence Classification in PyTorch](https://github.
1010

1111
## Known Issues
1212
* The predict method is probably not as efficient as it could be.
13-
* Doesn't play well with GridSearchCV if num_jobs isn't 1.
13+
* Doesn't play well with GridSearchCV if num_jobs isn't 1 (unless not using CUDA).
1414
* Only supports pre-trained word vectors from TorchText.
1515
* The random_state parameter probably only works with integers or None.
1616
* Training samples shorter than the maximum kernel size are ignored.
@@ -64,23 +64,26 @@ Fork of Shawn Ng's [CNNs for Sentence Classification in PyTorch](https://github.
6464
**cuda : boolean, optional (default=True)**
6565
If true, use the GPU if available.
6666

67-
**class_weight : dict, "balanced" or None, optional (default=None)**
68-
Weights associated with each class (see class_weight parameter in existing scikit-learn classifiers).
69-
70-
**split_ratio : float, optional (default=0.9)**
71-
Ratio of training data used for training. The remainder will be used for validation.
67+
** activation_func : string, optional (default='relu')**
68+
Activation function. If 'relu' or 'tanh', uses rectified linear unit or hyperbolic tangent, respectively. Otherwise, uses no activation function (f(x) = x).
7269

73-
**random_state : integer, optional (default=None)**
74-
Seed for the random number generator.
70+
**scoring : callable or None, optional (default=sklearn.metrics.accuracy_score)**
71+
Scoring method for testing model performance during fitting.
7572

7673
**vectors : string, optional (default=None)**
7774
Which pretrained TorchText vectors to use (see [torchtext.vocab.pretrained_aliases](https://torchtext.readthedocs.io/en/latest/vocab.html#pretrained-aliases) for options).
7875

76+
**split_ratio : float, optional (default=0.9)**
77+
Ratio of training data used for training. The remainder will be used for validation.
78+
7979
**preprocessor : callable or None, optional (default=None)**
8080
Override default string preprocessing.
8181

82-
**scoring : callable or None, optional (default=sklearn.metrics.accuracy_score)**
83-
Scoring method for testing model performance during fitting.
82+
**class_weight : dict, "balanced" or None, optional (default=None)**
83+
Weights associated with each class (see class_weight parameter in existing scikit-learn classifiers).
84+
85+
**random_state : integer, optional (default=None)**
86+
Seed for the random number generator.
8487

8588
**verbose : integer, optional (default=0)**
8689
Controls the verbosity when fitting.

cnn_text_classification.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ class CNNClassifier(BaseEstimator, ClassifierMixin):
1515
def __init__(self, lr=0.001, epochs=256, batch_size=64, test_interval=100,
1616
early_stop=1000, save_best=True, dropout=0.5, max_norm=0.0,
1717
embed_dim=128, kernel_num=100, kernel_sizes="3,4,5",
18-
static=False, device=-1, cuda=True, class_weight=None,
19-
split_ratio=0.9, random_state=None, vectors=None,
20-
preprocessor=None, scoring=make_scorer(accuracy_score),
21-
verbose=0):
18+
static=False, device=-1, cuda=True, activation_func="relu",
19+
scoring=make_scorer(accuracy_score), vectors=None,
20+
split_ratio=0.9, preprocessor=None, class_weight=None,
21+
random_state=None, verbose=0):
2222
self.lr = lr
2323
self.epochs = epochs
2424
self.batch_size = batch_size
@@ -33,12 +33,13 @@ def __init__(self, lr=0.001, epochs=256, batch_size=64, test_interval=100,
3333
self.static = static
3434
self.device = device
3535
self.cuda = cuda
36-
self.class_weight = class_weight
37-
self.split_ratio = split_ratio
38-
self.random_state = random_state
36+
self.activation_func = activation_func
37+
self.scoring = scoring
3938
self.vectors = vectors
39+
self.split_ratio = split_ratio
4040
self.preprocessor = preprocessor
41-
self.scoring = scoring
41+
self.class_weight = class_weight
42+
self.random_state = random_state
4243
self.verbose = verbose
4344

4445
def __clean_str(self, string):
@@ -100,7 +101,8 @@ def fit(self, X, y, sample_weight=None):
100101
kernel_sizes = [int(k) for k in self.kernel_sizes.split(",")]
101102
self.__model = _CNNText(embed_num, self.embed_dim, class_num,
102103
self.kernel_num, kernel_sizes, self.dropout,
103-
self.static, self.__text_field.vocab.vectors)
104+
self.static, self.activation_func,
105+
vectors=self.__text_field.vocab.vectors)
104106

105107
if self.cuda and torch.cuda.is_available():
106108
torch.cuda.set_device(self.device)
@@ -257,7 +259,7 @@ def __print_elapsed_time(self, seconds):
257259

258260
class _CNNText(nn.Module):
259261
def __init__(self, embed_num, embed_dim, class_num, kernel_num,
260-
kernel_sizes, dropout, static, vectors=None):
262+
kernel_sizes, dropout, static, activation_func, vectors=None):
261263
super(_CNNText, self).__init__()
262264

263265
self.__embed = nn.Embedding(embed_num, embed_dim)
@@ -272,13 +274,16 @@ def __init__(self, embed_num, embed_dim, class_num, kernel_num,
272274
self.__fc1 = nn.Linear(len(Ks) * kernel_num, class_num)
273275
self.__static = static
274276

275-
def conv_and_pool(self, x, conv):
276-
x = F.relu(conv(x)).squeeze(3)
277-
return F.max_pool1d(x, x.size(2)).squeeze(2)
277+
if activation_func == "relu":
278+
self.__f = F.relu
279+
elif activation_func == "tanh":
280+
self.__f = torch.tanh
281+
else:
282+
self.__f = lambda x: x
278283

279284
def forward(self, x):
280285
x = Variable(self.__embed(x)) if self.__static else self.__embed(x)
281-
x = [F.relu(conv(x.unsqueeze(1))).squeeze(3) for conv in self.__convs1]
286+
x = [self.__f(cnv(x.unsqueeze(1))).squeeze(3) for cnv in self.__convs1]
282287
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x]
283288
return self.__fc1(self.__dropout(torch.cat(x, 1)))
284289

0 commit comments

Comments
 (0)