Skip to content

Commit 380653e

Browse files
committed
Text shorter than maximum kernel size is now padded
Also changed the kernel_sizes parameter from a string to an iterable.
1 parent 685fada commit 380653e

File tree

2 files changed

+21
-26
lines changed

2 files changed

+21
-26
lines changed

README.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,8 @@ Fork of Shawn Ng's [CNNs for Sentence Classification in PyTorch](https://github.
1111
## Known Issues
1212
* The predict method is probably not as efficient as it could be.
1313
* Doesn't play well with GridSearchCV if num_jobs isn't 1 (unless not using CUDA).
14-
* Only supports pre-trained word vectors from TorchText.
14+
* Only supports pre-trained word vectors from TorchText (or no pre-trained vectors).
1515
* The random_state parameter probably only works with integers or None.
16-
* Training samples shorter than the maximum kernel size are ignored.
17-
* Test samples shorter than the maximum kernel size are classified as the most common class found during training.
1816
* Features my idiosyncratic coding style.
1917

2018
## To Do
@@ -52,8 +50,8 @@ Fork of Shawn Ng's [CNNs for Sentence Classification in PyTorch](https://github.
5250
**kernel_num : integer, optional (default=100)**
5351
The number of each size of kernel.
5452

55-
**kernel_sizes : string, optional (default='3,4,5')**
56-
Comma-separated kernel sizes to use for convolution.
53+
**kernel_sizes : iterable of integers, optional (default=(3, 4, 5))**
54+
Kernel sizes to use for convolution.
5755

5856
**static : boolean, optional (default=False)**
5957
If true, fix the embedding.

cnn_text_classification.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
class CNNClassifier(BaseEstimator, ClassifierMixin):
1515
def __init__(self, lr=0.001, epochs=256, batch_size=64, test_interval=100,
1616
early_stop=1000, save_best=True, dropout=0.5, max_norm=0.0,
17-
embed_dim=128, kernel_num=100, kernel_sizes="3,4,5",
17+
embed_dim=128, kernel_num=100, kernel_sizes=(3, 4, 5),
1818
static=False, device=-1, cuda=True, activation_func="relu",
1919
scoring=make_scorer(accuracy_score), vectors=None,
2020
split_ratio=0.9, preprocessor=None, class_weight=None,
@@ -29,7 +29,7 @@ def __init__(self, lr=0.001, epochs=256, batch_size=64, test_interval=100,
2929
self.max_norm = max_norm
3030
self.embed_dim = embed_dim
3131
self.kernel_num = kernel_num
32-
self.kernel_sizes = kernel_sizes
32+
self.kernel_sizes = sorted(kernel_sizes)
3333
self.static = static
3434
self.device = device
3535
self.cuda = cuda
@@ -97,10 +97,10 @@ def fit(self, X, y, sample_weight=None):
9797
train_iter, dev_iter = self.__preprocess(X, y, sample_weight)
9898
embed_num = len(self.__text_field.vocab)
9999
class_num = len(self.__label_field.vocab) - 1
100-
kernel_sizes = [int(k) for k in self.kernel_sizes.split(",")]
101100
self.__model = _CNNText(embed_num, self.embed_dim, class_num,
102-
self.kernel_num, kernel_sizes, self.dropout,
103-
self.static, self.activation_func,
101+
self.kernel_num, self.kernel_sizes,
102+
self.dropout, self.static,
103+
self.activation_func,
104104
vectors=self.__text_field.vocab.vectors)
105105

106106
if self.cuda and torch.cuda.is_available():
@@ -154,18 +154,11 @@ def fit(self, X, y, sample_weight=None):
154154

155155
def predict(self, X):
156156
y_pred = []
157-
max_krnl_sz = int(self.kernel_sizes[self.kernel_sizes.rfind(",") + 1:])
158157

159158
for text in X:
160159
assert isinstance(text, str)
161160

162-
text = self.__text_field.preprocess(text)
163-
164-
if len(text) < max_krnl_sz:
165-
most_common = self.__label_field.vocab.freqs.most_common(1)[0]
166-
167-
y_pred.append(most_common[0])
168-
continue
161+
text = self.__pad(self.__text_field.preprocess(text), True)
169162

170163
self.__model.eval()
171164

@@ -179,21 +172,25 @@ def predict(self, X):
179172
torch.cuda.empty_cache()
180173
return y_pred
181174

175+
def __pad(self, x, preprocessed=False):
176+
tokens = x if preprocessed else self.__text_field.preprocess(x)
177+
difference = self.kernel_sizes[-1] - len(tokens)
178+
179+
if difference > 0:
180+
padding = [self.__text_field.pad_token] * difference
181+
return x + padding if preprocessed else " ".join([x] + padding)
182+
183+
return x
184+
182185
def __preprocess(self, X, y, sample_weight):
183186
self.__text_field = Field(lower=True)
184187
self.__label_field = Field(sequential=False)
185188
self.__text_field.preprocessing = Pipeline(self.__preprocess_text)
186-
max_krnl_sz = int(self.kernel_sizes[self.kernel_sizes.rfind(",") + 1:])
187189
X, y = list(X), list(y)
188190
sample_weight = None if sample_weight is None else list(sample_weight)
189191

190-
for i in range(len(X) - 1, -1, -1):
191-
if len(self.__text_field.preprocess(X[i])) < max_krnl_sz:
192-
del X[i]
193-
del y[i]
194-
195-
if sample_weight is not None:
196-
del sample_weight[i]
192+
for i in range(len(X)):
193+
X[i] = self.__pad(X[i])
197194

198195
fields = [("text", self.__text_field), ("label", self.__label_field)]
199196
exmpl = [Example.fromlist([X[i], y[i]], fields) for i in range(len(X))]

0 commit comments

Comments
 (0)