Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactoring flag validation into a separate function and adding check…
…s for invalid FLAGS.cloud_mlengine_master_type when running on GPU

PiperOrigin-RevId: 187098901
  • Loading branch information
brian-barnes authored and Ryan Sepassi committed Mar 2, 2018
commit d860c088b5d92a5291d193f9703306ed65ca1cf0
81 changes: 43 additions & 38 deletions tensor2tensor/utils/cloud_mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,41 +89,23 @@ def flags_as_args():
return args


def machine_config(num_gpus=1, use_tpu=False, master_type=None):
"""Return dict specifying machine config for trainingInput."""
def get_default_master_type(num_gpus=1, use_tpu=False):
"""Returns master_type for trainingInput."""
if use_tpu:
master_type = 'standard_tpu'
return 'standard_tpu'
elif num_gpus <= 0:
master_type = master_type or 'standard'
cpu_types = ['standard', 'large_model', 'complex_model_s',
'complex_model_m', 'complex_model_l']
if master_type not in cpu_types:
raise ValueError('Expected `cloudml_engine_master_type` to be one of %s '
'when `worker_gpu` <= 0, found %s.', str(cpu_types),
master_type)
elif num_gpus >= 1:
if num_gpus == 1:
if master_type != 'standard_gpu':
master_type = 'standard_p100'
elif num_gpus == 4:
if master_type != 'complex_model_m_gpu':
master_type = 'complex_model_m_p100'
elif num_gpus == 8:
master_type = 'complex_model_l_gpu'
else:
raise ValueError('Must use exactly 1, 4, or 8 GPUs.')
assert master_type
return {
'scaleTier': 'CUSTOM',
'masterType': master_type
}
return 'standard'
elif num_gpus == 1:
return 'standard_p100'
elif num_gpus == 4:
return 'complex_model_m_p100'
elif num_gpus == 8:
return 'complex_model_l_gpu'
assert False


def configure_job():
"""Construct jobSpec for ML Engine job."""
train_dir = FLAGS.output_dir
assert train_dir.startswith('gs://')

# See documentation:
# https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#traininginput
training_input = {
Expand All @@ -132,15 +114,13 @@ def configure_job():
'region': cloud.default_region(),
'runtimeVersion': '1.4',
'pythonVersion': '3.5' if sys.version_info.major == 3 else '2.7',
'jobDir': train_dir,
}
training_input.update(
machine_config(
'jobDir': FLAGS.output_dir,
'scaleTier': 'CUSTOM',
'masterType': FLAGS.cloud_mlengine_master_type or get_default_master_type(
num_gpus=FLAGS.worker_gpu,
use_tpu=FLAGS.use_tpu,
master_type=FLAGS.cloud_mlengine_master_type))
use_tpu=FLAGS.use_tpu)
}
if FLAGS.hparams_range:
assert FLAGS.autotune_objective
tf.logging.info('Configuring hyperparameter tuning.')
training_input['hyperparameters'] = configure_autotune(
FLAGS.hparams_range,
Expand Down Expand Up @@ -278,15 +258,40 @@ def configure_usr_dir(job_spec, usr_tar):
job_spec['trainingInput']['args'].extend(usr_args)


def launch():
"""Launch t2t_trainer on Cloud ML Engine."""
def validate_flags():
"""Validates flags are set to acceptable values for CloudML Engine runs."""
assert not FLAGS.cloud_tpu
assert not job_dir()
assert FLAGS.output_dir.startswith('gs://')
assert FLAGS.data_dir.startswith('gs://')
assert FLAGS.worker_replicas <= 1
assert FLAGS.ps_replicas <= 0
if FLAGS.hparams_range:
assert FLAGS.autotune_objective
if FLAGS.worker_gpu:
assert FLAGS.worker_gpu in [1, 4, 8]
if FLAGS.cloud_mlengine_master_type:
if FLAGS.use_tpu:
assert FLAGS.cloud_mlengine_master_type == 'standard_tpu'
elif FLAGS.worker_gpu:
if FLAGS.worker_gpu == 1:
assert FLAGS.cloud_ml_engine_master_type in ['standard_gpu',
'standard_p100']
elif FLAGS.worker_gpu == 4:
assert FLAGS.cloud_ml_engine_master_type in ['complex_model_m_gpu',
'complex_model_m_p100']
else:
assert FLAGS.cloud_ml_engine_master_type == 'complex_model_l_gpu'
else:
assert FLAGS.cloud_mlengine_master_type in ['standard', 'large_model',
'complex_model_s',
'complex_model_m',
'complex_model_l']


def launch():
"""Launch t2t_trainer on Cloud ML Engine."""
validate_flags()
job_spec = configure_job()
job_name = job_spec['jobId']
tf.logging.info('Launching job %s with ML Engine spec:\n%s', job_name,
Expand Down