Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions tensor2tensor/data_generators/text_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

* Text2TextProblem: input=text, target=text.
* Text2ClassProblem: input=text, target=class.
* Text2RealProblem: input=text, target=float.
* Text2SelfProblem (for language modeling): target=text
* QuestionAndContext2TextProblem: input=text, context=text, target=text.

Expand Down Expand Up @@ -605,6 +606,94 @@ def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
yield {"inputs": inputs, "targets": [label]}


class Text2RealProblem(Text2TextProblem):
"""Base class for text regression problems with one or more tasks.
Suitable for text-based problems where targets are continuous, real values.
When ntasks = 1, each text example is mapped to a single scalar value. When
ntasks > 1, each text example is mapped to a 1-d vector of length ntasks.
"""

@property
def ntasks(self):
"""Set to n > 1 for multitask regression."""
return 1

def generate_samples(self, data_dir, tmp_dir, dataset_split):
"""Generate samples of text and real-valued target pairs.
Each yielded dict will be a single example. The inputs should be raw text.
The target should be a list containing ntasks floats.
Args:
data_dir: final data directory. Typically only used in this method to copy
over user-supplied vocab files (for example, if vocab_type ==
VocabType.TOKEN).
tmp_dir: temporary directory that you can use for downloading and scratch.
dataset_split: problem.DatasetSplit, which data split to generate samples
for (for example, training and evaluation).
Yields:
{"inputs": text, "targets": [x1, x2, ..., xN]} where N is ntasks
"""
raise NotImplementedError()

def generate_text_for_vocab(self, data_dir, tmp_dir):
for i, sample in enumerate(
self.generate_samples(data_dir, tmp_dir, problem.DatasetSplit.TRAIN)):
yield sample["inputs"]
if self.max_samples_for_vocab and (i + 1) >= self.max_samples_for_vocab:
break

def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
generator = self.generate_samples(data_dir, tmp_dir, dataset_split)
encoder = self.get_or_create_vocab(data_dir, tmp_dir)
for sample in generator:
inputs = encoder.encode(sample["inputs"])
inputs.append(text_encoder.EOS_ID)
yield {"inputs": inputs, "targets": sample["targets"]}

def feature_encoders(self, data_dir):
encoder = self.get_or_create_vocab(data_dir, None, force_get=True)

return {
"inputs": encoder,
"targets": text_encoder.RealEncoder(),
}

def hparams(self, defaults, unused_model_hparams):
p = defaults
p.modality = {
"inputs": modalities.ModalityType.SYMBOL,
"targets": modalities.ModalityType.REAL_L2_LOSS,
}
p.vocab_size = {
"inputs": self._encoders["inputs"].vocab_size,
"targets": self.ntasks
}
p.target_space_id = problem.SpaceID.REAL
p.add_hparam("regression_targets", True)

def max_length(self, model_hparams):
return model_hparams.batch_size * self.ntasks

def preprocess_example(self, example, unused_mode, unused_hparams):
example = problem.preprocess_example_common(example, unused_mode,
unused_hparams)
example["targets"] = tf.reshape(example["targets"], [1, 1, self.ntasks])
return example

def example_reading_spec(self):
data_fields = {
"inputs": tf.VarLenFeature(tf.int64),
"targets": tf.FixedLenFeature([self.ntasks], tf.float32),
}
data_items_to_decoders = None
return (data_fields, data_items_to_decoders)

def eval_metrics(self):
metrics_list = [metrics.Metrics.RMSE]
if self.ntasks == 1:
metrics_list.append(metrics.Metrics.PEARSON)
return metrics_list


def txt_line_iterator(txt_path):
"""Iterate through lines of file."""
with tf.gfile.Open(txt_path) as f:
Expand Down Expand Up @@ -692,6 +781,21 @@ def text2class_txt_iterator(source_txt_path, label_txt_path, class_strs=None):
yield {"inputs": inputs, "label": label}


def text2real_txt_iterator(source_txt_path, target_txt_path):
"""Yield dicts for Text2RealProblem.generate_samples from lines of files.
Args:
source_txt_path: txt file with record per line.
target_txt_path: txt file with float (or space-separated float list for
multitask) per line.
Yields:
{"inputs": inputs, "targets": targets}
"""
for inputs, targets in zip(
txt_line_iterator(source_txt_path), txt_line_iterator(target_txt_path)):
targets = [float(x) for x in targets.split(" ")]
yield {"inputs": inputs, "targets": targets}


def text2text_txt_tab_iterator(txt_path):
"""Yield dicts for Text2TextProblem.generate_samples from lines of txt_path.

Expand Down
17 changes: 17 additions & 0 deletions tensor2tensor/data_generators/text_problems_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ def setUpClass(cls):
tf.gfile.Copy(cls.targets_file, os.path.join(cls.tmp_dir,
"targets.eval.txt"))

cls.targets_regr = [[1.23, 2.34], [4.56, 5.67]]
cls.targets_regr_file = os.path.join(cls.tmp_dir, "targets_regr.train.txt")
with tf.gfile.Open(cls.targets_regr_file, "w") as f:
for targets in cls.targets_regr:
f.write(" ".join([str(x) for x in targets]) + "\n")


def testTxtLineIterator(self):
lines = [line for line in text_problems.txt_line_iterator(self.inputs_file)]
self.assertEqual(lines, self.inputs)
Expand Down Expand Up @@ -136,6 +143,16 @@ def testText2ClassTxtIteratorWithStrs(self):
self.assertEqual(inputs, self.inputs)
self.assertEqual(labels, self.labels)

def testText2RealTxtIterator(self):
inputs = []
targets = []
for entry in text_problems.text2real_txt_iterator(self.inputs_file,
self.targets_regr_file):
inputs.append(entry["inputs"])
targets.append(entry["targets"])
self.assertEqual(inputs, self.inputs)
self.assertEqual(targets, self.targets_regr)

def testText2TextTxtTabIterator(self):
inputs = []
targets = []
Expand Down
6 changes: 4 additions & 2 deletions tensor2tensor/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,8 @@ def _fast_decode_tpu(self,

if self.has_input:
inputs_shape = common_layers.shape_list(features["inputs"])
if target_modality == modalities.ModalityType.CLASS_LABEL:
if (target_modality == modalities.ModalityType.CLASS_LABEL or
self._problem_hparams.get("regression_targets")):
decode_length = 1
else:
decode_length = (
Expand Down Expand Up @@ -704,7 +705,8 @@ def _fast_decode(self,
" of the dataset when decoding.")
if self.has_input:
inputs_shape = common_layers.shape_list(features["inputs"])
if target_modality == modalities.ModalityType.CLASS_LABEL:
if (target_modality == modalities.ModalityType.CLASS_LABEL or
self._problem_hparams.get("regression_targets")):
decode_length = 1
else:
decode_length = (
Expand Down
12 changes: 8 additions & 4 deletions tensor2tensor/utils/t2t_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,8 +806,10 @@ def infer(self,

if self._problem_hparams:
target_modality = self._problem_hparams.modality["targets"]
if target_modality == modalities.ModalityType.CLASS_LABEL:
beam_size = 1 # No use to run beam-search for a single class.
if (target_modality == modalities.ModalityType.CLASS_LABEL or
self._problem_hparams.get("regression_targets")):
# No use to run beam-search for classification or regression.
beam_size = 1
if beam_size == 1:
log_info("Greedy Decoding")
results = self._greedy_infer(features, decode_length, use_tpu)
Expand Down Expand Up @@ -1064,7 +1066,8 @@ def infer_step(i, recent_output, recent_logits, unused_loss):
initial_output = tf.slice(initial_output, [0, 0, 0, 0],
common_layers.shape_list(initial_output))
target_modality = self._problem_hparams.modality["targets"]
if target_modality == modalities.ModalityType.CLASS_LABEL:
if (target_modality == modalities.ModalityType.CLASS_LABEL or
self._problem_hparams.get("regression_targets")):
decode_length = 1
else:
if "partial_targets" in features:
Expand Down Expand Up @@ -1243,7 +1246,8 @@ def infer_step(recent_output, recent_logits, unused_loss):
initial_output = tf.slice(initial_output, [0, 0, 0, 0],
common_layers.shape_list(initial_output))
target_modality = self._problem_hparams.modality["targets"]
if target_modality == modalities.ModalityType.CLASS_LABEL:
if (target_modality == modalities.ModalityType.CLASS_LABEL or
self._problem_hparams.get("regression_targets")):
decode_length = 1
else:
if "partial_targets" in features:
Expand Down