diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 39836ca602b8..5356b8a9f6b6 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -432,6 +432,9 @@ def _initialize_graph_builder(self, training): grad_builder_config.enable_caching = self._enable_grad_acc_optimization grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel(self._debug_options.logging.log_level) grad_builder_config.use_memory_efficient_gradient = self._use_memory_efficient_gradient + + # C.OrtModuleGraphBuilder() cannot be pickled. + self._graph_builder_training = training self._graph_builder = C.OrtModuleGraphBuilder() # It is assumed here that the order and names of the inputs and outputs are not modified by the backend in any way @@ -450,6 +453,23 @@ def _initialize_graph_builder(self, training): self._graph_initializers = [param for name, param in self._flattened_module.named_parameters() if name in self._graph_initializer_names] + def __getstate__(self): + state = self.__dict__.copy() + if '_graph_builder' in state: + del state['_graph_builder'] + state['_graph_builder'] = "TORESTORE" + if '_graph_info' in state: + del state['_graph_info'] + state['_graph_info'] = "TORESTORE" + return state + + def __setstate__(self, state): + self.__dict__.update(state) + if self._graph_builder == 'TORESTORE': + self._initialize_graph_builder(self._graph_builder_training) + if self._graph_info == 'TORESTORE': + self._build_graph() + def signal_model_changed(self): """Signals the execution manager to re-export the model on the next forward call""" self._original_model_has_changed = True diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index dbf59bc654c8..51374314804d 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -94,6 +94,9 @@ def __init__(self, self.schema = schema if schema else [] self.num_positionals = num_positionals self.num_expanded_positionals_non_none = num_expanded_positionals_non_none + if not isinstance(keyword_names, list): + # It must be a list to avoid any pickling issue. + raise TypeError("keyword_names must be a list not %r." % type(keyword_names)) self.keyword_names = keyword_names def __repr__(self) -> str: @@ -528,7 +531,7 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names): schema=schema, num_positionals=len(inputs), num_expanded_positionals_non_none=num_expanded_non_none_positional_inputs, - keyword_names=kwargs.keys()) + keyword_names=list(kwargs.keys())) def parse_outputs_for_onnx_export_and_extract_schema(module, inputs, kwargs): diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 182c8dbf5723..d5e8eeca8475 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -47,7 +47,8 @@ def execution_session_run_forward(execution_session, onnx_model, device, gradien forward_inputs = C.OrtValueVector() forward_inputs.reserve(len(inputs)) for input in inputs: - forward_inputs.push_back(_utils._torch_tensor_to_dlpack(input), input.dtype == torch.bool) + dlp = _utils._torch_tensor_to_dlpack(input) + forward_inputs.push_back(dlp, input.dtype == torch.bool) forward_outputs = C.OrtValueVector() # Run and return module outputs. diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index 87aca7ce86b8..14f0b5109b48 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -5,7 +5,8 @@ from onnxruntime.capi.onnxruntime_inference_collection import OrtValue from onnxruntime.capi import _pybind_state as C -from ._fallback_exceptions import ORTModuleDeviceException, wrap_exception +from ._fallback_exceptions import ( + ORTModuleDeviceException, wrap_exception, ORTModuleIOError) from ._torch_module_pytorch import TorchModulePytorch import os @@ -52,6 +53,9 @@ def _torch_tensor_to_dlpack(tensor): # We need to convert bool tensor to unit8 tensor to workaround this. # DLPack is discussing how to support bool type, we can remove this workaround once both DLPack # and PyTorch support bool type. + if not tensor.is_contiguous(): + raise ORTModuleIOError( + "Only contiguous tensors are supported.") if tensor.dtype == torch.bool and LooseVersion(torch.__version__) >= LooseVersion('1.10.0'): tensor = tensor.to(torch.uint8) return to_dlpack(tensor) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py index 250014846d44..ccc80418acca 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py @@ -564,3 +564,116 @@ def test_ortmodule_fallback_warn_message(is_training, persist_fallback): assert "Fallback to PyTorch due to exception" in str(warning_record[0].message.args[0]) del os.environ['ORTMODULE_SKIPCHECK_POLICY'] + + + +@pytest.mark.parametrize("is_training,torch_forward", + list(itertools.product([True, False], repeat=2))) +def test_ortmodule_fallback_non_contiguous_tensors(is_training, torch_forward): + # is_training: True for torch.nn.Module training model, eval mode otherwise + # Validate fix for issue: https://github.com/pytorch/ort/issues/92 + + policy = ('FALLBACK_UNSUPPORTED_DEVICE|FALLBACK_UNSUPPORTED_DATA|FALLBACK_UNSUPPORTED_TORCH_MODEL|' + 'FALLBACK_UNSUPPORTED_ONNX_MODEL') + if torch_forward: + policy += '|FALLBACK_FORCE_TORCH_FORWARD' + os.environ['ORTMODULE_FALLBACK_POLICY'] = policy + + class PositionalEncoding(nn.Module): + + def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + position = torch.arange(max_len).unsqueeze(1) + div_term = (torch.exp(torch.arange(0, d_model, 2) * + (-math.log(10000.0) / d_model))) + pe = torch.zeros(max_len, 1, d_model) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward(self, x: Tensor) -> Tensor: + x = x + self.pe[:x.size(0)] + return self.dropout(x) + + + class TransformerModel(nn.Module): + + def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int, + nlayers: int, dropout: float = 0.5): + super().__init__() + self.model_type = 'Transformer' + encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout) + self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) + self.pos_encoder = PositionalEncoding(d_model, dropout) + self.encoder = nn.Embedding(ntoken, d_model) + self.d_model = d_model + self.decoder = nn.Linear(d_model, ntoken) + self.init_weights() + + def init_weights(self) -> None: + initrange = 0.1 + self.encoder.weight.data.uniform_(-initrange, initrange) + self.decoder.bias.data.zero_() + self.decoder.weight.data.uniform_(-initrange, initrange) + + def forward(self, src: Tensor, src_mask: Tensor) -> Tensor: + src = self.encoder(src) * math.sqrt(self.d_model) + src = self.pos_encoder(src) + output = self.transformer_encoder(src, src_mask) + output = self.decoder(output) + return output + + + def generate_square_subsequent_mask(sz): + return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1) + + + def get_batch(source, i): + seq_len = min(bptt, len(source) - 1 - i) + data = source[i:i+seq_len] + target = source[i+1:i+1+seq_len].reshape(-1) + return data, target + + + criterion = nn.CrossEntropyLoss() + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + train_data = [9, 1352, 0, 9, 1352, 0, 26, 31, 818, 82, 2, 759, 5, 1127, 1376, 3, 23, 30, 8, 2313, 11, 4493, 440, 13, 1, 759, 149, 1, 2433, 6, 533, 3, 37, 10, 672, 22, 8, 4493, 440, 6, 1, 387, 7522, 651, 22, 4472, 8308, 2, 32, 10, 906, 6, 1605, 24, 1, 396, 721, 1127, 3, 23, 30, 8, 2313, 440, 6, 1, 759, 149, 3737, 400, 0, 6, 785, 3, 6, 786, 0, 1556, 8, 440, 17, 3619, 6, 1, 426, 8382, 12, 19, 680, 4, 1, 759, 149, 1, 211, 4881, 23, 1587, 1613, 1377, 676, 397, 5, 5709, 7633, 3, 23, 10, 1143, 6, 1, 697, 1127, 4400, 4, 1, 3345, 8138, 387, 5040, 5861, 2, 32, 10, 906, 24, 1, 3237, 1127, 6, 7952, 5, 1, 0, 0, 4870, 6, 531, 3, 23, 10, 794, 22, 400, 0, 5, 1587, 1613, 2087, 0, 2, 2794, 0, 2, 4927, 3740, 2, 5855, 0, 2, 8261, 8299, 5, 4820, 504, 3, 6, 433, 2, 0, 1587, 1613, 0, 6, 1, 387, 0, 651, 22, 676, 0, 3, 23, 614, 13, 8, 433, 426, 4, 1, 759, 149, 2, 3653, 2, 672, 22, 8, 440, 6, 1, 483, 1127, 317, 4, 529, 7, 4126, 794, 22, 0, 0, 3, 529, 7, 4126, 10, 906, 24, 3571, 1127, 6, 1, 531, 6866, 4, 0, 5, 7437, 3, 0, 1587, 6, 43, 1203, 6, 363, 2, 0, 0, 22, 7383, 2541, 0, 2, 5, 0, 3828, 794, 22, 0, 6852, 3, 6, 86, 363, 2, 0, 88, 8, 2313, 684, 13, 8, 43, 11, 162, 426, 3539, 4, 1, 759, 149, 8509, 1, 585, 2, 672, 22, 31, 684, 13, 1, 759, 149, 0, 6, 200, 363, 3, 23, 30, 8, 4422, 440, 6, 565, 2926, 4, 1, 759, 149, 0, 6, 350, 2, 17, 0, 5842, 3, 0, 1587, 6, 1, 394, 82, 0, 794, 22, 2541, 0, 3, 9, 9, 276, 9, 9, 9, 9, 9, 533, 48, 697, 9, 9, 9, 6, 533, 0, 30, 8, 2313, 11, 4493, 440, 13, 1, 759, 149, 1, 2433, 23, 3814, 2189, 7903, 6, 1, 426, 2, 6, 4454, 1463, 3, 0, 1587, 17, 2189, 6, 1, 387, 7522, 651, 22, 4472, 8308, 2, 32, 10, 906, 6, 1605, 24, 1, 396, 721, 1127, 3, 8, 1289, 4, 0, 12, 19, 827, 6, 1, 1552, 13, 2384, 376, 70, 17, 10338, 7761, 6, 1, 440, 2, 5, 23, 333, 1014, 2031, 6, 1, 7520, 2, 5, 1201, 1586, 3, 23, 614, 6, 1, 759, 149, 3737, 400, 0, 6, 785, 17, 0, 0, 6, 1, 426, 1224, 0, 2, 5, 30, 8, 440, 17, 8, 535, 718, 12129, 11916, 13, 1, 2433, 3, 23, 30, 8, 4422, 440, 6, 521, 13, 43, 2926, 4, 1, 2433, 2, 17, 718, 9435, 3364, 3, 6, 786, 0, 1556, 8, 440, 17, 3619, 6, 1, 426, 8382, 12, 19, 680, 4, 1, 759, 149, 1, 211, 4881, 23, 1587, 1613, 1377, 676, 397, 5, 5709, 7633, 3, 0, 1587, 17, 9574, 2, 6, 1, 697, 1127, 4400, 4, 1, 3345, 8138, 387, 5040, 5861, 3, 29, 10, 906, 24, 1, 3237, 1127, 6, 7952, 2, 5, 1, 0, 0, 4870, 6, 531, 3, 23, 10, 794, 22, 400, 0, 5, 1587, 1613, 2087, 0, 2, 2794, 0, 2, 4927, 3740, 2, 5855, 0, 2, 8261, 8299, 5, 4820, 504, 3, 0, 333, 8, 1846, 1289, 6, 1, 2111, 3907, 1, 2245, 26, 0, 2721, 2, 21, 0, 1888, 28, 2087, 0, 15, 250, 0, 28, 25, 827, 17, 12177, 0, 12, 19, 1639, 16, 2, 1352, 0, 2, 2794, 0, 5, 5855, 0, 3, 1, 4917, 1118, 2, 2087, 0, 5, 1352, 0, 2754, 0, 8897, 1, 0, 3, 9, 9, 9, 433, 48, 677, 9, 9, 9, 6, 433, 0, 1587, 6, 1, 387, 0, 651, 22, 676, 0, 3, 1, 387, 10, 162, 4, 8, 149, 32, 1151, 535, 0, 2, 2050, 6898, 103, 0, 103, 0, 3, 6, 8, 433, 1994, 2, 2131, 1376, 2087, 0, 1550, 0, 17, 44, 4, 25, 2934, 1956, 11, 2203, 120, 7716, 887, 21, 8, 4919, 190, 1352, 0, 2, 58, 10, 6, 1, 760, 2433, 4, 6898, 2, 0, 5, 0, 24, 1, 137, 3, 23, 369, 1028, 1390, 6, 5040, 5861, 3, 23, 3814, 3735, 8458, 13, 1, 433, 426, 4, 1, 759, 149, 2, 3653, 2, 2050, 1359, 120, 0, 3, 0, 1587, 17, 665, 6, 1, 483, 317, 4, 529, 7, 4126, 794, 22, 0, 0, 3, 529, 7, 4126, 10, 906, 24, 3571, 1127, 6, 1, 531, 6866, 4, 0, 5, 7437, 3, 6, 8, 1289, 4, 1, 317, 20, 1, 2111, 3907, 2, 1127, 1724, 928, 11867, 1118, 2, 1352, 0, 4702, 8, 12148, 3113, 7, 1, 1361, 17, 665, 3, 0, 1587, 6, 43, 1203, 6, 363, 2, 0, 0, 22, 7383, 2541, 0, 2, 5, 0, 3828, 794, 22, 0, 6852, 3, 0, 3814, 8, 718, 257, 8181, 6, 0, 3828, 2, 58, 0, 155, 21, 718, 7653, 17, 1, 8026, 1390, 3, 3, 3, 58, 1406, 29, 179, 21, 0, 3, 0, 2313, 1587, 13, 8, 43, 11, 162, 426, 3539, 0, 6, 86, 363, 4, 1, 759, 149, 8509, 1, 585, 17, 718, 7647, 0, 3, 23, 614, 13, 1, 759, 149, 0, 17, 7826, 6, 200, 363, 3, 23, 30, 8, 4422, 440, 6, 565, 2926, 4, 1, 759, 149, 0, 6, 350, 2, 17, 0, 5842, 3, 23, 3814, 31, 7283, 7939, 6770, 20, 8, 1064, 0, 3, 23, 2666, 13, 1, 7588, 4148, 6, 7963, 8, 7939, 13, 759, 1169, 8, 346, 26, 8, 4505, 1398, 3, 0, 306, 1863, 193, 306, 12, 416, 3441, 96, 53, 306, 1015, 12, 235, 26, 362, 6851, 38, 74, 33, 5442, 13, 360, 58, 33, 7356, 24, 1177, 306, 90, 11329, 5, 1851, 306, 1, 5647, 7, 3889, 74, 5, 1272, 148, 306, 1863, 193, 306, 12, 416, 1837, 3, 0] + train_data = tensor(numpy.array(train_data, dtype=numpy.int64)) + train_data = train_data.to(torch.int64).to(device) + bptt = 35 + src_mask = generate_square_subsequent_mask(bptt).to(device) + ntokens, emsize, nhead, d_hid, nlayers, dropout = 12455, 200, 2, 200, 2, 0.2 + model = ORTModule( + TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout)).to(device) + optimizer = torch.optim.SGD(model.parameters(), lr=5.0) + + for epoch in range(1, 2): + model.train() # turn on train mode + + num_batches = len(train_data) // bptt + for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)): + data, targets = get_batch(train_data, i) + batch_size = data.size(0) + if batch_size != bptt: # only on last batch + src_mask = src_mask[:batch_size, :batch_size] + try: + output = model(data, src_mask) + except RuntimeError as e: + if torch_forward: + raise AssertionError("Fallback failed: %r." % e) + if not torch_forward: + raise AssertionError("Fallback was not used but policy is %r." % policy) + nrows = min(ntokens, targets.shape[0]) + loss = criterion(output.view(nrows, -1), targets) + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + optimizer.step() + break + + model_copied = copy.deepcopy(model) + assert model_copied is not model_copied + pkl = pickle.dump(model) + assert pkl is not None