|
1 | 1 | #! /usr/bin/env python |
2 | | - |
3 | | - |
4 | 2 | import os |
5 | 3 | import argparse |
6 | 4 | import datetime |
|
13 | 11 |
|
14 | 12 |
|
15 | 13 | parser = argparse.ArgumentParser(description='CNN text classificer') |
16 | | -parser.add_argument('-batch-size', type=int, default=64, metavar='N', help='batch size for training [default: 50]') |
17 | | -parser.add_argument('-lr', type=float, default=0.001, metavar='LR', help='initial learning rate [default: 0.01]') |
18 | | -parser.add_argument('-epochs', type=int, default=200, metavar='N', help='number of epochs for train [default: 10]') |
19 | | -parser.add_argument('-dropout', type=float, default=0.5, metavar='', help='the probability for dropout [default: 0.5]') |
20 | | -parser.add_argument('-max_norm', type=float, default=3.0, help='l2 constraint of parameters') |
21 | | -parser.add_argument('-cpu', action='store_true', default=False, help='disable the gpu' ) |
22 | | -parser.add_argument('-device', type=int, default=-1, help='device to use for iterate data') |
23 | | -# model |
24 | | -parser.add_argument('-embed-dim', type=int, default=128) |
25 | | -parser.add_argument('-static', action='store_true', default=False, help='fix the embedding') |
26 | | -parser.add_argument('-kernel-sizes', type=str, default='3,4,5', help='Comma-separated kernel size to use for convolution') |
27 | | -parser.add_argument('-kernel-num', type=int, default=100, help='number of each kind of kernel') |
28 | | -parser.add_argument('-class-num', type=int, default=2, help='number of class') |
| 14 | +# learning |
| 15 | +parser.add_argument('-lr', type=float, default=0.001, help='initial learning rate [default: 0.001]') |
| 16 | +parser.add_argument('-epochs', type=int, default=256, help='number of epochs for train [default: 256]') |
| 17 | +parser.add_argument('-batch-size', type=int, default=64, help='batch size for training [default: 64]') |
| 18 | +parser.add_argument('-log-interval', type=int, default=1, help='how many steps to wait before logging training status [default: 1]') |
| 19 | +parser.add_argument('-test-interval', type=int, default=100, help='how many steps to wait before testing [default: 100]') |
| 20 | +parser.add_argument('-save-interval', type=int, default=500, help='how many steps to wait before saving [default:500]') |
| 21 | +parser.add_argument('-save-dir', type=str, default='snapshot', help='where to save the checkpoint') |
29 | 22 | # data |
30 | 23 | parser.add_argument('-shuffle', action='store_true', default=False, help='shuffle the data every epoch' ) |
31 | | -parser.add_argument('-num-workers', type=int, default=0, help='how many subprocesses to use for data loading [default: 0]') |
32 | | -# log |
33 | | -parser.add_argument('-log-interval', type=int, default=1, help='how many batches to wait before logging training status') |
34 | | -parser.add_argument('-test-interval', type=int, default=100, help='how many epochs to wait before testing') |
35 | | -parser.add_argument('-save-interval', type=int, default=100, help='how many epochs to wait before saving') |
36 | | -parser.add_argument('-predict', type=str, default=None, help='predict the sentence given') |
| 24 | +# model |
| 25 | +parser.add_argument('-dropout', type=float, default=0.5, help='the probability for dropout [default: 0.5]') |
| 26 | +parser.add_argument('-max-norm', type=float, default=3.0, help='l2 constraint of parameters [default: 3.0]') |
| 27 | +parser.add_argument('-embed-dim', type=int, default=128, help='number of embedding dimension [default: 128]') |
| 28 | +parser.add_argument('-kernel-num', type=int, default=100, help='number of each kind of kernel') |
| 29 | +parser.add_argument('-kernel-sizes', type=str, default='3,4,5', help='comma-separated kernel size to use for convolution') |
| 30 | +parser.add_argument('-static', action='store_true', default=False, help='fix the embedding') |
| 31 | +# device |
| 32 | +parser.add_argument('-device', type=int, default=-1, help='device to use for iterate data, -1 mean cpu [default: -1]') |
| 33 | +parser.add_argument('-no-cuda', action='store_true', default=False, help='disable the gpu' ) |
| 34 | +# option |
37 | 35 | parser.add_argument('-snapshot', type=str, default=None, help='filename of model snapshot [default: None]') |
38 | | -parser.add_argument('-save-dir', type=str, default='snapshot', help='where to save the checkpoint') |
| 36 | +parser.add_argument('-predict', type=str, default=None, help='predict the sentence given') |
| 37 | +parser.add_argument('-test', action='store_true', default=False, help='train or test') |
39 | 38 | args = parser.parse_args() |
40 | | -args.cuda = not args.cpu and torch.cuda.is_available() |
41 | | -args.kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')] |
42 | | -args.save_dir = os.path.join(args.save_dir, datetime.datetime.now().strftime('%Y-%m-%d-%H-%M')) |
| 39 | + |
43 | 40 |
|
44 | 41 | # load SST dataset |
45 | | -''' |
46 | | -print("Loading data...") |
47 | | -text_field = data.Field(lower=True) |
48 | | -label_field = data.Field(sequential=False) |
49 | | -train_data, dev_data, test_data = datasets.SST.splits(text_field, label_field, fine_grained=True) |
50 | | -text_field.build_vocab(train_data, dev_data, test_data) |
51 | | -label_field.build_vocab(train_data) |
52 | | -train_iter, dev_iter, test_iter = data.BucketIterator.splits( |
53 | | - (train_data, dev_data, test_data), |
54 | | - batch_sizes=(args.batch_size, |
55 | | - len(dev_data), |
56 | | - len(test_data)), |
57 | | - device=-1, repeat=False) |
58 | | -''' |
| 42 | +def sst(text_field, label_field, **kargs): |
| 43 | + train_data, dev_data, test_data = datasets.SST.splits(text_field, label_field, fine_grained=True) |
| 44 | + text_field.build_vocab(train_data, dev_data, test_data) |
| 45 | + label_field.build_vocab(train_data, dev_data, test_data) |
| 46 | + train_iter, dev_iter, test_iter = data.BucketIterator.splits( |
| 47 | + (train_data, dev_data, test_data), |
| 48 | + batch_sizes=(args.batch_size, |
| 49 | + len(dev_data), |
| 50 | + len(test_data)), |
| 51 | + **kargs) |
| 52 | + return train_iter, dev_iter, test_iter |
| 53 | + |
59 | 54 |
|
60 | 55 | # load MR dataset |
61 | | -print("Loading data...") |
| 56 | +def mr(text_field, label_field, **kargs): |
| 57 | + train_data, dev_data = mydatasets.MR.splits(text_field, label_field) |
| 58 | + text_field.build_vocab(train_data, dev_data) |
| 59 | + label_field.build_vocab(train_data, dev_data) |
| 60 | + train_iter, dev_iter = data.Iterator.splits( |
| 61 | + (train_data, dev_data), |
| 62 | + batch_sizes=(args.batch_size, len(dev_data)), |
| 63 | + **kargs) |
| 64 | + return train_iter, dev_iter |
| 65 | + |
| 66 | + |
| 67 | +# load data |
| 68 | +print("\nLoading data...") |
62 | 69 | text_field = data.Field(lower=True) |
63 | 70 | label_field = data.Field(sequential=False) |
64 | | -train_data, dev_data = mydatasets.MR.splits(text_field, label_field) |
65 | | -text_field.build_vocab(train_data, dev_data) |
66 | | -label_field.build_vocab(train_data) |
67 | | -train_iter, dev_iter = data.Iterator.splits((train_data, dev_data), |
68 | | - batch_sizes=(args.batch_size, len(dev_data)), |
69 | | - device=-1, repeat=False) |
| 71 | +#train_iter, dev_iter = mr(text_field, label_field, device=-1, repeat=False) |
| 72 | +train_iter, dev_iter, test_iter = sst(text_field, label_field, device=-1, repeat=False) |
70 | 73 |
|
71 | | -# args from vacab |
| 74 | + |
| 75 | +# update args and print |
72 | 76 | args.embed_num = len(text_field.vocab) |
| 77 | +args.class_num = len(label_field.vocab) - 1 |
| 78 | +args.cuda = args.no_cuda and torch.cuda.is_available(); del args.no_cuda |
| 79 | +args.kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')] |
| 80 | +args.save_dir = os.path.join(args.save_dir, datetime.datetime.now().strftime('%Y-%m-%d-%H-%M')) |
| 81 | + |
| 82 | +print("\nParameters:") |
| 83 | +for attr, value in sorted(args.__dict__.items()): |
| 84 | + print("\t{}={}".format(attr.upper(), value)) |
| 85 | + |
73 | 86 |
|
74 | 87 | # model |
75 | 88 | if args.snapshot is None: |
76 | 89 | cnn = model.CNN_Text(args) |
77 | 90 | else : |
78 | | - print('Loading model from [%s]...' % args.snapshot) |
79 | | - cnn = torch.load(args.snapshot) |
| 91 | + print('\nLoading model from [%s]...' % args.snapshot) |
| 92 | + try: |
| 93 | + cnn = torch.load(args.snapshot) |
| 94 | + except : |
| 95 | + print("Sorry, This snapshot doesn't exist."); exit() |
| 96 | + |
80 | 97 |
|
81 | 98 | # train or predict |
82 | 99 | if args.predict is not None: |
83 | 100 | label = train.predict(args.predict, cnn, text_field, label_field) |
84 | | - print('\n[Text] %s'% args.predict) |
85 | | - print('[Label] %s\n'% label) |
86 | | -else: |
| 101 | + print('\n[Text] {}[Label] {}\n'.format(args.predict, label)) |
| 102 | +elif args.test : |
| 103 | + try: |
| 104 | + train.eval(test_iter, cnn, args) |
| 105 | + except Exception as e: |
| 106 | + print("\nSorry. The test dataset doesn't exist.\n") |
| 107 | +else : |
| 108 | + print() |
87 | 109 | train.train(train_iter, dev_iter, cnn, args) |
| 110 | + |
88 | 111 |
|
0 commit comments