1414class CNNClassifier (BaseEstimator , ClassifierMixin ):
1515 def __init__ (self , lr = 0.001 , epochs = 256 , batch_size = 64 , test_interval = 100 ,
1616 early_stop = 1000 , save_best = True , dropout = 0.5 , max_norm = 0.0 ,
17- embed_dim = 128 , kernel_num = 100 , kernel_sizes = "3,4,5" ,
17+ embed_dim = 128 , kernel_num = 100 , kernel_sizes = ( 3 , 4 , 5 ) ,
1818 static = False , device = - 1 , cuda = True , activation_func = "relu" ,
1919 scoring = make_scorer (accuracy_score ), vectors = None ,
2020 split_ratio = 0.9 , preprocessor = None , class_weight = None ,
@@ -29,7 +29,7 @@ def __init__(self, lr=0.001, epochs=256, batch_size=64, test_interval=100,
2929 self .max_norm = max_norm
3030 self .embed_dim = embed_dim
3131 self .kernel_num = kernel_num
32- self .kernel_sizes = kernel_sizes
32+ self .kernel_sizes = sorted ( kernel_sizes )
3333 self .static = static
3434 self .device = device
3535 self .cuda = cuda
@@ -97,10 +97,10 @@ def fit(self, X, y, sample_weight=None):
9797 train_iter , dev_iter = self .__preprocess (X , y , sample_weight )
9898 embed_num = len (self .__text_field .vocab )
9999 class_num = len (self .__label_field .vocab ) - 1
100- kernel_sizes = [int (k ) for k in self .kernel_sizes .split ("," )]
101100 self .__model = _CNNText (embed_num , self .embed_dim , class_num ,
102- self .kernel_num , kernel_sizes , self .dropout ,
103- self .static , self .activation_func ,
101+ self .kernel_num , self .kernel_sizes ,
102+ self .dropout , self .static ,
103+ self .activation_func ,
104104 vectors = self .__text_field .vocab .vectors )
105105
106106 if self .cuda and torch .cuda .is_available ():
@@ -154,18 +154,11 @@ def fit(self, X, y, sample_weight=None):
154154
155155 def predict (self , X ):
156156 y_pred = []
157- max_krnl_sz = int (self .kernel_sizes [self .kernel_sizes .rfind ("," ) + 1 :])
158157
159158 for text in X :
160159 assert isinstance (text , str )
161160
162- text = self .__text_field .preprocess (text )
163-
164- if len (text ) < max_krnl_sz :
165- most_common = self .__label_field .vocab .freqs .most_common (1 )[0 ]
166-
167- y_pred .append (most_common [0 ])
168- continue
161+ text = self .__pad (self .__text_field .preprocess (text ), True )
169162
170163 self .__model .eval ()
171164
@@ -179,21 +172,25 @@ def predict(self, X):
179172 torch .cuda .empty_cache ()
180173 return y_pred
181174
175+ def __pad (self , x , preprocessed = False ):
176+ tokens = x if preprocessed else self .__text_field .preprocess (x )
177+ difference = self .kernel_sizes [- 1 ] - len (tokens )
178+
179+ if difference > 0 :
180+ padding = [self .__text_field .pad_token ] * difference
181+ return x + padding if preprocessed else " " .join ([x ] + padding )
182+
183+ return x
184+
182185 def __preprocess (self , X , y , sample_weight ):
183186 self .__text_field = Field (lower = True )
184187 self .__label_field = Field (sequential = False )
185188 self .__text_field .preprocessing = Pipeline (self .__preprocess_text )
186- max_krnl_sz = int (self .kernel_sizes [self .kernel_sizes .rfind ("," ) + 1 :])
187189 X , y = list (X ), list (y )
188190 sample_weight = None if sample_weight is None else list (sample_weight )
189191
190- for i in range (len (X ) - 1 , - 1 , - 1 ):
191- if len (self .__text_field .preprocess (X [i ])) < max_krnl_sz :
192- del X [i ]
193- del y [i ]
194-
195- if sample_weight is not None :
196- del sample_weight [i ]
192+ for i in range (len (X )):
193+ X [i ] = self .__pad (X [i ])
197194
198195 fields = [("text" , self .__text_field ), ("label" , self .__label_field )]
199196 exmpl = [Example .fromlist ([X [i ], y [i ]], fields ) for i in range (len (X ))]
0 commit comments