Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] pickle 2 #2

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
baf2754
tiny improvments in convolve
sdpython Oct 16, 2020
4729888
Revert "tiny improvments in convolve"
sdpython Oct 16, 2020
b9defba
tiny improvments in convolve
sdpython Oct 16, 2020
e047a8b
Revert "tiny improvments in convolve"
sdpython Oct 16, 2020
c8a3f0e
tiny improvments in convolve
sdpython Oct 16, 2020
70036af
Revert "tiny improvments in convolve"
sdpython Oct 16, 2020
c744d9b
tiny improvments in convolve
sdpython Oct 16, 2020
056c935
Revert "tiny improvments in convolve"
sdpython Oct 16, 2020
aef0bba
tiny improvments in convolve
sdpython Oct 16, 2020
a4570e7
Revert "tiny improvments in convolve"
sdpython Oct 16, 2020
b4ce8f1
tiny improvments in convolve
sdpython Oct 16, 2020
55a4c1c
Revert "tiny improvments in convolve"
sdpython Oct 16, 2020
1b77c90
Enable FALLBACK_FORCE_TORCH_BACKWARD as fallback policy
sdpython Oct 14, 2021
507be15
fix policy
sdpython Oct 14, 2021
df154d9
Add unit test to validate the issue
sdpython Oct 14, 2021
1f19b0f
try catch around an exception
sdpython Oct 14, 2021
72d6548
fix missing fallback policy
sdpython Oct 14, 2021
cc4d163
creating an exception
sdpython Oct 14, 2021
3f046da
replace iterator by a list
sdpython Oct 14, 2021
d211e22
removes the added exceptions
sdpython Oct 14, 2021
d9f7af9
Merge branch 'master' of https://github.com/microsoft/onnxruntime int…
sdpython Oct 18, 2021
5442707
raise exception when not contiguous
sdpython Oct 18, 2021
8e9b883
Merge branch 'master' of https://github.com/microsoft/onnxruntime int…
sdpython Oct 19, 2021
0c81b33
lint
sdpython Oct 19, 2021
bc32689
Update orttraining/orttraining/test/python/orttraining_test_ortmodule…
Oct 19, 2021
77b9630
Update orttraining/orttraining/test/python/orttraining_test_ortmodule…
Oct 19, 2021
06e5e5d
Merge branch 'master' of https://github.com/microsoft/onnxruntime int…
sdpython Oct 19, 2021
488eb5d
pickle
sdpython Oct 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,9 @@ def _initialize_graph_builder(self, training):
grad_builder_config.enable_caching = self._enable_grad_acc_optimization
grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel(self._debug_options.logging.log_level)
grad_builder_config.use_memory_efficient_gradient = self._use_memory_efficient_gradient

# C.OrtModuleGraphBuilder() cannot be pickled.
self._graph_builder_training = training
self._graph_builder = C.OrtModuleGraphBuilder()

# It is assumed here that the order and names of the inputs and outputs are not modified by the backend in any way
Expand All @@ -450,6 +453,23 @@ def _initialize_graph_builder(self, training):
self._graph_initializers = [param for name, param in self._flattened_module.named_parameters()
if name in self._graph_initializer_names]

def __getstate__(self):
state = self.__dict__.copy()
if '_graph_builder' in state:
del state['_graph_builder']
state['_graph_builder'] = "TORESTORE"
if '_graph_info' in state:
del state['_graph_info']
state['_graph_info'] = "TORESTORE"
return state

def __setstate__(self, state):
self.__dict__.update(state)
if self._graph_builder == 'TORESTORE':
self._initialize_graph_builder(self._graph_builder_training)
if self._graph_info == 'TORESTORE':
self._build_graph()

def signal_model_changed(self):
"""Signals the execution manager to re-export the model on the next forward call"""
self._original_model_has_changed = True
5 changes: 4 additions & 1 deletion orttraining/orttraining/python/training/ortmodule/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def __init__(self,
self.schema = schema if schema else []
self.num_positionals = num_positionals
self.num_expanded_positionals_non_none = num_expanded_positionals_non_none
if not isinstance(keyword_names, list):
# It must be a list to avoid any pickling issue.
raise TypeError("keyword_names must be a list not %r." % type(keyword_names))
self.keyword_names = keyword_names

def __repr__(self) -> str:
Expand Down Expand Up @@ -528,7 +531,7 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names):
schema=schema,
num_positionals=len(inputs),
num_expanded_positionals_non_none=num_expanded_non_none_positional_inputs,
keyword_names=kwargs.keys())
keyword_names=list(kwargs.keys()))


def parse_outputs_for_onnx_export_and_extract_schema(module, inputs, kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def execution_session_run_forward(execution_session, onnx_model, device, gradien
forward_inputs = C.OrtValueVector()
forward_inputs.reserve(len(inputs))
for input in inputs:
forward_inputs.push_back(_utils._torch_tensor_to_dlpack(input), input.dtype == torch.bool)
dlp = _utils._torch_tensor_to_dlpack(input)
forward_inputs.push_back(dlp, input.dtype == torch.bool)

forward_outputs = C.OrtValueVector()
# Run and return module outputs.
Expand Down
6 changes: 5 additions & 1 deletion orttraining/orttraining/python/training/ortmodule/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from onnxruntime.capi.onnxruntime_inference_collection import OrtValue
from onnxruntime.capi import _pybind_state as C
from ._fallback_exceptions import ORTModuleDeviceException, wrap_exception
from ._fallback_exceptions import (
ORTModuleDeviceException, wrap_exception, ORTModuleIOError)
from ._torch_module_pytorch import TorchModulePytorch

import os
Expand Down Expand Up @@ -52,6 +53,9 @@ def _torch_tensor_to_dlpack(tensor):
# We need to convert bool tensor to unit8 tensor to workaround this.
# DLPack is discussing how to support bool type, we can remove this workaround once both DLPack
# and PyTorch support bool type.
if not tensor.is_contiguous():
raise ORTModuleIOError(
"Only contiguous tensors are supported.")
if tensor.dtype == torch.bool and LooseVersion(torch.__version__) >= LooseVersion('1.10.0'):
tensor = tensor.to(torch.uint8)
return to_dlpack(tensor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -564,3 +564,116 @@ def test_ortmodule_fallback_warn_message(is_training, persist_fallback):
assert "Fallback to PyTorch due to exception" in str(warning_record[0].message.args[0])

del os.environ['ORTMODULE_SKIPCHECK_POLICY']



@pytest.mark.parametrize("is_training,torch_forward",
list(itertools.product([True, False], repeat=2)))
def test_ortmodule_fallback_non_contiguous_tensors(is_training, torch_forward):
# is_training: True for torch.nn.Module training model, eval mode otherwise
# Validate fix for issue: https://github.com/pytorch/ort/issues/92

policy = ('FALLBACK_UNSUPPORTED_DEVICE|FALLBACK_UNSUPPORTED_DATA|FALLBACK_UNSUPPORTED_TORCH_MODEL|'
'FALLBACK_UNSUPPORTED_ONNX_MODEL')
if torch_forward:
policy += '|FALLBACK_FORCE_TORCH_FORWARD'
os.environ['ORTMODULE_FALLBACK_POLICY'] = policy

class PositionalEncoding(nn.Module):

def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = (torch.exp(torch.arange(0, d_model, 2) *
(-math.log(10000.0) / d_model)))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)

def forward(self, x: Tensor) -> Tensor:
x = x + self.pe[:x.size(0)]
return self.dropout(x)


class TransformerModel(nn.Module):

def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
nlayers: int, dropout: float = 0.5):
super().__init__()
self.model_type = 'Transformer'
encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.pos_encoder = PositionalEncoding(d_model, dropout)
self.encoder = nn.Embedding(ntoken, d_model)
self.d_model = d_model
self.decoder = nn.Linear(d_model, ntoken)
self.init_weights()

def init_weights(self) -> None:
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)

def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
src = self.encoder(src) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, src_mask)
output = self.decoder(output)
return output


def generate_square_subsequent_mask(sz):
return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)


def get_batch(source, i):
seq_len = min(bptt, len(source) - 1 - i)
data = source[i:i+seq_len]
target = source[i+1:i+1+seq_len].reshape(-1)
return data, target


criterion = nn.CrossEntropyLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_data = [9, 1352, 0, 9, 1352, 0, 26, 31, 818, 82, 2, 759, 5, 1127, 1376, 3, 23, 30, 8, 2313, 11, 4493, 440, 13, 1, 759, 149, 1, 2433, 6, 533, 3, 37, 10, 672, 22, 8, 4493, 440, 6, 1, 387, 7522, 651, 22, 4472, 8308, 2, 32, 10, 906, 6, 1605, 24, 1, 396, 721, 1127, 3, 23, 30, 8, 2313, 440, 6, 1, 759, 149, 3737, 400, 0, 6, 785, 3, 6, 786, 0, 1556, 8, 440, 17, 3619, 6, 1, 426, 8382, 12, 19, 680, 4, 1, 759, 149, 1, 211, 4881, 23, 1587, 1613, 1377, 676, 397, 5, 5709, 7633, 3, 23, 10, 1143, 6, 1, 697, 1127, 4400, 4, 1, 3345, 8138, 387, 5040, 5861, 2, 32, 10, 906, 24, 1, 3237, 1127, 6, 7952, 5, 1, 0, 0, 4870, 6, 531, 3, 23, 10, 794, 22, 400, 0, 5, 1587, 1613, 2087, 0, 2, 2794, 0, 2, 4927, 3740, 2, 5855, 0, 2, 8261, 8299, 5, 4820, 504, 3, 6, 433, 2, 0, 1587, 1613, 0, 6, 1, 387, 0, 651, 22, 676, 0, 3, 23, 614, 13, 8, 433, 426, 4, 1, 759, 149, 2, 3653, 2, 672, 22, 8, 440, 6, 1, 483, 1127, 317, 4, 529, 7, 4126, 794, 22, 0, 0, 3, 529, 7, 4126, 10, 906, 24, 3571, 1127, 6, 1, 531, 6866, 4, 0, 5, 7437, 3, 0, 1587, 6, 43, 1203, 6, 363, 2, 0, 0, 22, 7383, 2541, 0, 2, 5, 0, 3828, 794, 22, 0, 6852, 3, 6, 86, 363, 2, 0, 88, 8, 2313, 684, 13, 8, 43, 11, 162, 426, 3539, 4, 1, 759, 149, 8509, 1, 585, 2, 672, 22, 31, 684, 13, 1, 759, 149, 0, 6, 200, 363, 3, 23, 30, 8, 4422, 440, 6, 565, 2926, 4, 1, 759, 149, 0, 6, 350, 2, 17, 0, 5842, 3, 0, 1587, 6, 1, 394, 82, 0, 794, 22, 2541, 0, 3, 9, 9, 276, 9, 9, 9, 9, 9, 533, 48, 697, 9, 9, 9, 6, 533, 0, 30, 8, 2313, 11, 4493, 440, 13, 1, 759, 149, 1, 2433, 23, 3814, 2189, 7903, 6, 1, 426, 2, 6, 4454, 1463, 3, 0, 1587, 17, 2189, 6, 1, 387, 7522, 651, 22, 4472, 8308, 2, 32, 10, 906, 6, 1605, 24, 1, 396, 721, 1127, 3, 8, 1289, 4, 0, 12, 19, 827, 6, 1, 1552, 13, 2384, 376, 70, 17, 10338, 7761, 6, 1, 440, 2, 5, 23, 333, 1014, 2031, 6, 1, 7520, 2, 5, 1201, 1586, 3, 23, 614, 6, 1, 759, 149, 3737, 400, 0, 6, 785, 17, 0, 0, 6, 1, 426, 1224, 0, 2, 5, 30, 8, 440, 17, 8, 535, 718, 12129, 11916, 13, 1, 2433, 3, 23, 30, 8, 4422, 440, 6, 521, 13, 43, 2926, 4, 1, 2433, 2, 17, 718, 9435, 3364, 3, 6, 786, 0, 1556, 8, 440, 17, 3619, 6, 1, 426, 8382, 12, 19, 680, 4, 1, 759, 149, 1, 211, 4881, 23, 1587, 1613, 1377, 676, 397, 5, 5709, 7633, 3, 0, 1587, 17, 9574, 2, 6, 1, 697, 1127, 4400, 4, 1, 3345, 8138, 387, 5040, 5861, 3, 29, 10, 906, 24, 1, 3237, 1127, 6, 7952, 2, 5, 1, 0, 0, 4870, 6, 531, 3, 23, 10, 794, 22, 400, 0, 5, 1587, 1613, 2087, 0, 2, 2794, 0, 2, 4927, 3740, 2, 5855, 0, 2, 8261, 8299, 5, 4820, 504, 3, 0, 333, 8, 1846, 1289, 6, 1, 2111, 3907, 1, 2245, 26, 0, 2721, 2, 21, 0, 1888, 28, 2087, 0, 15, 250, 0, 28, 25, 827, 17, 12177, 0, 12, 19, 1639, 16, 2, 1352, 0, 2, 2794, 0, 5, 5855, 0, 3, 1, 4917, 1118, 2, 2087, 0, 5, 1352, 0, 2754, 0, 8897, 1, 0, 3, 9, 9, 9, 433, 48, 677, 9, 9, 9, 6, 433, 0, 1587, 6, 1, 387, 0, 651, 22, 676, 0, 3, 1, 387, 10, 162, 4, 8, 149, 32, 1151, 535, 0, 2, 2050, 6898, 103, 0, 103, 0, 3, 6, 8, 433, 1994, 2, 2131, 1376, 2087, 0, 1550, 0, 17, 44, 4, 25, 2934, 1956, 11, 2203, 120, 7716, 887, 21, 8, 4919, 190, 1352, 0, 2, 58, 10, 6, 1, 760, 2433, 4, 6898, 2, 0, 5, 0, 24, 1, 137, 3, 23, 369, 1028, 1390, 6, 5040, 5861, 3, 23, 3814, 3735, 8458, 13, 1, 433, 426, 4, 1, 759, 149, 2, 3653, 2, 2050, 1359, 120, 0, 3, 0, 1587, 17, 665, 6, 1, 483, 317, 4, 529, 7, 4126, 794, 22, 0, 0, 3, 529, 7, 4126, 10, 906, 24, 3571, 1127, 6, 1, 531, 6866, 4, 0, 5, 7437, 3, 6, 8, 1289, 4, 1, 317, 20, 1, 2111, 3907, 2, 1127, 1724, 928, 11867, 1118, 2, 1352, 0, 4702, 8, 12148, 3113, 7, 1, 1361, 17, 665, 3, 0, 1587, 6, 43, 1203, 6, 363, 2, 0, 0, 22, 7383, 2541, 0, 2, 5, 0, 3828, 794, 22, 0, 6852, 3, 0, 3814, 8, 718, 257, 8181, 6, 0, 3828, 2, 58, 0, 155, 21, 718, 7653, 17, 1, 8026, 1390, 3, 3, 3, 58, 1406, 29, 179, 21, 0, 3, 0, 2313, 1587, 13, 8, 43, 11, 162, 426, 3539, 0, 6, 86, 363, 4, 1, 759, 149, 8509, 1, 585, 17, 718, 7647, 0, 3, 23, 614, 13, 1, 759, 149, 0, 17, 7826, 6, 200, 363, 3, 23, 30, 8, 4422, 440, 6, 565, 2926, 4, 1, 759, 149, 0, 6, 350, 2, 17, 0, 5842, 3, 23, 3814, 31, 7283, 7939, 6770, 20, 8, 1064, 0, 3, 23, 2666, 13, 1, 7588, 4148, 6, 7963, 8, 7939, 13, 759, 1169, 8, 346, 26, 8, 4505, 1398, 3, 0, 306, 1863, 193, 306, 12, 416, 3441, 96, 53, 306, 1015, 12, 235, 26, 362, 6851, 38, 74, 33, 5442, 13, 360, 58, 33, 7356, 24, 1177, 306, 90, 11329, 5, 1851, 306, 1, 5647, 7, 3889, 74, 5, 1272, 148, 306, 1863, 193, 306, 12, 416, 1837, 3, 0]
train_data = tensor(numpy.array(train_data, dtype=numpy.int64))
train_data = train_data.to(torch.int64).to(device)
bptt = 35
src_mask = generate_square_subsequent_mask(bptt).to(device)
ntokens, emsize, nhead, d_hid, nlayers, dropout = 12455, 200, 2, 200, 2, 0.2
model = ORTModule(
TransformerModel(ntokens, emsize, nhead, d_hid, nlayers, dropout)).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=5.0)

for epoch in range(1, 2):
model.train() # turn on train mode

num_batches = len(train_data) // bptt
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
data, targets = get_batch(train_data, i)
batch_size = data.size(0)
if batch_size != bptt: # only on last batch
src_mask = src_mask[:batch_size, :batch_size]
try:
output = model(data, src_mask)
except RuntimeError as e:
if torch_forward:
raise AssertionError("Fallback failed: %r." % e)
if not torch_forward:
raise AssertionError("Fallback was not used but policy is %r." % policy)
nrows = min(ntokens, targets.shape[0])
loss = criterion(output.view(nrows, -1), targets)

optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
break

model_copied = copy.deepcopy(model)
assert model_copied is not model_copied
pkl = pickle.dump(model)
assert pkl is not None