Skip to content

Parallelization failed on tf.keras model #499

@XiYuan68

Description

@XiYuan68

Describe the bug
Attack parallelization failed on tf.keras model

To Reproduce
Steps to reproduce the behavior:

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Embedding, GlobalMaxPooling1D
from textattack.attack_recipes import PWWSRen2019
from textattack import Attacker, AttackArgs


from task import TrainSequence, load_data
from audionlp import get_dataset, get_tokenizer, TwentynewsModelWrapper


data = 'twentynews'
batch_size = 32
epochs = 1
train = TrainSequence(data, batch_size)
validation = load_data(data, 'val')[:-1]
input_shape = train.x.shape[1:]
n_choice = train.y.shape[-1]

model = Sequential()
model.add(Embedding(10000, 128, input_length=500, name='embedding'))
model.add(GlobalMaxPooling1D())
model.add(Dense(n_choice, 'softmax'))
model.summary()
model.compile('adam', 'categorical_crossentropy', ['accuracy'])
model.fit(train, epochs=epochs, validation_data=validation)

dataset = get_dataset(data)
tokenizer = get_tokenizer()
wrapper = TwentynewsModelWrapper(model, tokenizer)
attack = PWWSRen2019.build(wrapper)
attack_args = AttackArgs(num_examples=-1, disable_stdout=True,)
                         # parallel=True, num_workers_per_device=2)
attacker = Attacker(attack, dataset, attack_args)
result = attacker.attack_dataset()

when parallel=False everything is fine

but when parallel=True, num_workers_per_device=2, error raised:

textattack: Running 2 worker(s) on 1 GPU(s).
Traceback (most recent call last):

  File "<ipython-input-6-625f4babcd86>", line 4, in <module>
    result = attacker.attack_dataset()

  File "/home/chenxiyuan/projects/python/TextAttack/textattack/attacker.py", line 427, in attack_dataset
    self._attack_parallel()

  File "/home/chenxiyuan/projects/python/TextAttack/textattack/attacker.py", line 280, in _attack_parallel
    worker_pool = torch.multiprocessing.Pool(

  File "/usr/lib/python3.9/multiprocessing/context.py", line 119, in Pool
    return Pool(processes, initializer, initargs, maxtasksperchild,

  File "/usr/lib/python3.9/multiprocessing/pool.py", line 212, in __init__
    self._repopulate_pool()

  File "/usr/lib/python3.9/multiprocessing/pool.py", line 303, in _repopulate_pool
    return self._repopulate_pool_static(self._ctx, self.Process,

  File "/usr/lib/python3.9/multiprocessing/pool.py", line 326, in _repopulate_pool_static
    w.start()

  File "/usr/lib/python3.9/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)

  File "/usr/lib/python3.9/multiprocessing/context.py", line 284, in _Popen
    return Popen(process_obj)

  File "/usr/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)

  File "/usr/lib/python3.9/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)

  File "/usr/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)

  File "/usr/lib/python3.9/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)

TypeError: cannot pickle 'weakref' object

by the way, when loading a similar pretrained model containing a custom keras layer like:

class Dropout(Layer):
  def __init__(self, level, **kwargs):
    super(Dropout, self).__init__(**kwargs)
    self.level = level

  def build(self, input_shape):
      pass

  def call(self, inputs):
    return K.dropout(inputs, self.level)

parallelizaion raised a slightly different error:

textattack: Running 2 worker(s) on 1 GPU(s).
Traceback (most recent call last):

  ...

  File "/home/chenxiyuan/projects/python/TextAttack/textattack/attacker.py", line 427, in attack_dataset
    self._attack_parallel()

  File "/home/chenxiyuan/projects/python/TextAttack/textattack/attacker.py", line 280, in _attack_parallel
    worker_pool = torch.multiprocessing.Pool(

  File "/usr/lib/python3.9/multiprocessing/context.py", line 119, in Pool
    return Pool(processes, initializer, initargs, maxtasksperchild,

  File "/usr/lib/python3.9/multiprocessing/pool.py", line 212, in __init__
    self._repopulate_pool()

  File "/usr/lib/python3.9/multiprocessing/pool.py", line 303, in _repopulate_pool
    return self._repopulate_pool_static(self._ctx, self.Process,

  File "/usr/lib/python3.9/multiprocessing/pool.py", line 326, in _repopulate_pool_static
    w.start()

  File "/usr/lib/python3.9/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)

  File "/usr/lib/python3.9/multiprocessing/context.py", line 284, in _Popen
    return Popen(process_obj)

  File "/usr/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)

  File "/usr/lib/python3.9/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)

  File "/usr/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)

  File "/usr/lib/python3.9/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)

PicklingError: Can't pickle <class 'tensorflow.python.keras.saving.saved_model.load.Dropout'>: attribute lookup Dropout on tensorflow.python.keras.saving.saved_model.load failed

Expected behavior
allow parallelization on tf.keras model

Screenshots or Traceback
If applicable, add screenshots to help explain your problem. Also, copy and paste tracebacks produced by the bug.

System Information (please complete the following information):

  • OS: Linux
  • Library versions torch==1.7.1, transformers==4.8.2, tensorflow-gpu==2.5.0
  • Textattack version: 0.3.0

Additional context
with my limited experience of parallelization, multiprocessing is not very ideal for the job, so maybe consider replacing it with tensorflow APIs?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions