-
Notifications
You must be signed in to change notification settings - Fork 434
Description
Describe the bug
Attack parallelization failed on tf.keras model
To Reproduce
Steps to reproduce the behavior:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Embedding, GlobalMaxPooling1D
from textattack.attack_recipes import PWWSRen2019
from textattack import Attacker, AttackArgs
from task import TrainSequence, load_data
from audionlp import get_dataset, get_tokenizer, TwentynewsModelWrapper
data = 'twentynews'
batch_size = 32
epochs = 1
train = TrainSequence(data, batch_size)
validation = load_data(data, 'val')[:-1]
input_shape = train.x.shape[1:]
n_choice = train.y.shape[-1]
model = Sequential()
model.add(Embedding(10000, 128, input_length=500, name='embedding'))
model.add(GlobalMaxPooling1D())
model.add(Dense(n_choice, 'softmax'))
model.summary()
model.compile('adam', 'categorical_crossentropy', ['accuracy'])
model.fit(train, epochs=epochs, validation_data=validation)
dataset = get_dataset(data)
tokenizer = get_tokenizer()
wrapper = TwentynewsModelWrapper(model, tokenizer)
attack = PWWSRen2019.build(wrapper)
attack_args = AttackArgs(num_examples=-1, disable_stdout=True,)
# parallel=True, num_workers_per_device=2)
attacker = Attacker(attack, dataset, attack_args)
result = attacker.attack_dataset()
when parallel=False everything is fine
but when parallel=True, num_workers_per_device=2, error raised:
textattack: Running 2 worker(s) on 1 GPU(s).
Traceback (most recent call last):
File "<ipython-input-6-625f4babcd86>", line 4, in <module>
result = attacker.attack_dataset()
File "/home/chenxiyuan/projects/python/TextAttack/textattack/attacker.py", line 427, in attack_dataset
self._attack_parallel()
File "/home/chenxiyuan/projects/python/TextAttack/textattack/attacker.py", line 280, in _attack_parallel
worker_pool = torch.multiprocessing.Pool(
File "/usr/lib/python3.9/multiprocessing/context.py", line 119, in Pool
return Pool(processes, initializer, initargs, maxtasksperchild,
File "/usr/lib/python3.9/multiprocessing/pool.py", line 212, in __init__
self._repopulate_pool()
File "/usr/lib/python3.9/multiprocessing/pool.py", line 303, in _repopulate_pool
return self._repopulate_pool_static(self._ctx, self.Process,
File "/usr/lib/python3.9/multiprocessing/pool.py", line 326, in _repopulate_pool_static
w.start()
File "/usr/lib/python3.9/multiprocessing/process.py", line 121, in start
self._popen = self._Popen(self)
File "/usr/lib/python3.9/multiprocessing/context.py", line 284, in _Popen
return Popen(process_obj)
File "/usr/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 32, in __init__
super().__init__(process_obj)
File "/usr/lib/python3.9/multiprocessing/popen_fork.py", line 19, in __init__
self._launch(process_obj)
File "/usr/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 47, in _launch
reduction.dump(process_obj, fp)
File "/usr/lib/python3.9/multiprocessing/reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
TypeError: cannot pickle 'weakref' object
by the way, when loading a similar pretrained model containing a custom keras layer like:
class Dropout(Layer):
def __init__(self, level, **kwargs):
super(Dropout, self).__init__(**kwargs)
self.level = level
def build(self, input_shape):
pass
def call(self, inputs):
return K.dropout(inputs, self.level)
parallelizaion raised a slightly different error:
textattack: Running 2 worker(s) on 1 GPU(s).
Traceback (most recent call last):
...
File "/home/chenxiyuan/projects/python/TextAttack/textattack/attacker.py", line 427, in attack_dataset
self._attack_parallel()
File "/home/chenxiyuan/projects/python/TextAttack/textattack/attacker.py", line 280, in _attack_parallel
worker_pool = torch.multiprocessing.Pool(
File "/usr/lib/python3.9/multiprocessing/context.py", line 119, in Pool
return Pool(processes, initializer, initargs, maxtasksperchild,
File "/usr/lib/python3.9/multiprocessing/pool.py", line 212, in __init__
self._repopulate_pool()
File "/usr/lib/python3.9/multiprocessing/pool.py", line 303, in _repopulate_pool
return self._repopulate_pool_static(self._ctx, self.Process,
File "/usr/lib/python3.9/multiprocessing/pool.py", line 326, in _repopulate_pool_static
w.start()
File "/usr/lib/python3.9/multiprocessing/process.py", line 121, in start
self._popen = self._Popen(self)
File "/usr/lib/python3.9/multiprocessing/context.py", line 284, in _Popen
return Popen(process_obj)
File "/usr/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 32, in __init__
super().__init__(process_obj)
File "/usr/lib/python3.9/multiprocessing/popen_fork.py", line 19, in __init__
self._launch(process_obj)
File "/usr/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 47, in _launch
reduction.dump(process_obj, fp)
File "/usr/lib/python3.9/multiprocessing/reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
PicklingError: Can't pickle <class 'tensorflow.python.keras.saving.saved_model.load.Dropout'>: attribute lookup Dropout on tensorflow.python.keras.saving.saved_model.load failed
Expected behavior
allow parallelization on tf.keras model
Screenshots or Traceback
If applicable, add screenshots to help explain your problem. Also, copy and paste tracebacks produced by the bug.
System Information (please complete the following information):
- OS: Linux
- Library versions
torch==1.7.1, transformers==4.8.2, tensorflow-gpu==2.5.0 - Textattack version: 0.3.0
Additional context
with my limited experience of parallelization, multiprocessing is not very ideal for the job, so maybe consider replacing it with tensorflow APIs?