Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Use ".incomplete" suffix during data generation, rename on complete
PiperOrigin-RevId: 187114891
  • Loading branch information
Ryan Sepassi committed Mar 2, 2018
commit f7aeeb7da4188f9813c21f10494186d229f44b67
6 changes: 5 additions & 1 deletion tensor2tensor/data_generators/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,9 @@ def generate_files(generator, output_filenames, max_cases=None):
if outputs_exist(output_filenames):
tf.logging.info("Skipping generator because outputs files exist")
return
tmp_filenames = [fname + ".incomplete" for fname in output_filenames]
num_shards = len(output_filenames)
writers = [tf.python_io.TFRecordWriter(fname) for fname in output_filenames]
writers = [tf.python_io.TFRecordWriter(fname) for fname in tmp_filenames]
counter, shard = 0, 0
for case in generator:
if case is None:
Expand All @@ -165,6 +166,9 @@ def generate_files(generator, output_filenames, max_cases=None):
for writer in writers:
writer.close()

for tmp_name, final_name in zip(tmp_filenames, output_filenames):
tf.gfile.Rename(tmp_name, final_name)

tf.logging.info("Generated %s Examples", counter)


Expand Down
5 changes: 3 additions & 2 deletions tensor2tensor/data_generators/text_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,11 @@ def generate_text_for_vocab(self, data_dir, tmp_dir):
@property
def vocab_filename(self):
if self.vocab_type == VocabType.SUBWORD:
return "vocab.%s.%d.%s" % (self.name, self.approx_vocab_size,
return "vocab.%s.%d.%s" % (self.dataset_filename(),
self.approx_vocab_size,
VocabType.SUBWORD)
else:
return "vocab.%s.%s" % (self.name, VocabType.TOKEN)
return "vocab.%s.%s" % (self.dataset_filename(), VocabType.TOKEN)

def get_or_create_vocab(self, data_dir, tmp_dir, force_get=False):
if self.vocab_type == VocabType.CHARACTER:
Expand Down