@@ -16,7 +16,8 @@ def __init__(self, args):
1616 Ks = args .kernel_sizes
1717
1818 self .embed = nn .Embedding (V , D )
19- self .convs1 = [nn .Conv2d (Ci , Co , (K , D )) for K in Ks ]
19+ #self.convs1 = [nn.Conv2d(Ci, Co, (K, D)) for K in Ks]
20+ self .convs1 = nn .ModuleList ([nn .Conv2d (Ci , Co , (K , D )) for K in Ks ])
2021 '''
2122 self.conv13 = nn.Conv2d(Ci, Co, (3, D))
2223 self.conv14 = nn.Conv2d(Ci, Co, (4, D))
@@ -38,9 +39,14 @@ def forward(self, x):
3839 x = Variable (x )
3940
4041 x = x .unsqueeze (1 ) # (N,Ci,W,D)
42+
4143 x = [F .relu (conv (x )).squeeze (3 ) for conv in self .convs1 ] #[(N,Co,W), ...]*len(Ks)
44+
45+
4246 x = [F .max_pool1d (i , i .size (2 )).squeeze (2 ) for i in x ] #[(N,Co), ...]*len(Ks)
47+
4348 x = torch .cat (x , 1 )
49+
4450 '''
4551 x1 = self.conv_and_pool(x,self.conv13) #(N,Co)
4652 x2 = self.conv_and_pool(x,self.conv14) #(N,Co)
@@ -49,4 +55,4 @@ def forward(self, x):
4955 '''
5056 x = self .dropout (x ) # (N,len(Ks)*Co)
5157 logit = self .fc1 (x ) # (N,C)
52- return logit
58+ return logit
0 commit comments