Skip to content

Commit

Permalink
[Pre Neuron Inf Cache system]Support neff/weights decoupling (#402)
Browse files Browse the repository at this point in the history
* add decoupling args

* add to modeling api

* workaround

* support replace weights of compiled model during the loading

* better sep the method

* add test

* fix style

* fix test

* unblock inf2 tests

* fix tests

* fix test

* fix test

* fix test

* Update optimum/neuron/utils/misc.py

* Update optimum/neuron/modeling_base.py

* improve help

* Update tests/inference/test_modeling.py
  • Loading branch information
JingyaHuang authored Jan 30, 2024
1 parent c114fc8 commit de5752d
Show file tree
Hide file tree
Showing 15 changed files with 306 additions and 100 deletions.
5 changes: 5 additions & 0 deletions optimum/commands/export/neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def parse_args_neuronx(parser: "ArgumentParser"):
type=Path,
help="Path indicating the directory where to store intermediary files generated by Neuronx compiler.",
)
optional_group.add_argument(
"--disable-weights-neff-inline",
action="store_true",
help="Whether to disable the weights / neff graph inline. You can only replace weights of neuron-compiled models when the weights-neff inlining has been disabled during the compilation.",
)
optional_group.add_argument(
"--disable-validation",
action="store_true",
Expand Down
3 changes: 3 additions & 0 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def main_export(
atol: Optional[float] = None,
cache_dir: Optional[str] = None,
compiler_workdir: Optional[Union[str, Path]] = None,
inline_weights_to_neff: bool = True,
optlevel: str = "2",
trust_remote_code: bool = False,
subfolder: str = "",
Expand Down Expand Up @@ -415,6 +416,7 @@ def main_export(
models_and_neuron_configs=models_and_neuron_configs,
output_dir=output,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
output_file_names=output_model_names,
compiler_kwargs=compiler_kwargs,
Expand Down Expand Up @@ -523,6 +525,7 @@ def main():
atol=args.atol,
cache_dir=args.cache_dir,
compiler_workdir=args.compiler_workdir,
inline_weights_to_neff=not args.disable_weights_neff_inline,
optlevel=optlevel,
trust_remote_code=args.trust_remote_code,
subfolder=args.subfolder,
Expand Down
15 changes: 15 additions & 0 deletions optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def export_models(
],
output_dir: Path,
compiler_workdir: Optional[Path] = None,
inline_weights_to_neff: bool = True,
optlevel: str = "2",
output_file_names: Optional[Dict[str, str]] = None,
compiler_kwargs: Optional[Dict[str, Any]] = {},
Expand All @@ -288,6 +289,8 @@ def export_models(
Output directory to store the exported Neuron models.
compiler_workdir (`Optional[Path]`, defaults to `None`):
The directory to store intermediary outputs of the neuron compiler.
inline_weights_to_neff (`bool`, defaults to `True`):
Whether to inline the weights to the neff graph. If set to False, weights will be seperated from the neff.
optlevel (`str`, defaults to `"2"`):
The level of optimization the compiler should perform. Can be `"1"`, `"2"` or `"3"`, defaults to "2".
1: enables the core performance optimizations in the compiler, while also minimizing compile time.
Expand Down Expand Up @@ -334,6 +337,7 @@ def export_models(
config=sub_neuron_config,
output=output_path,
compiler_workdir=compiler_workdir_path,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
**compiler_kwargs,
)
Expand Down Expand Up @@ -362,6 +366,7 @@ def export_models(
dynamic_batch_size=sub_neuron_config.dynamic_batch_size,
compiler_type=NEURON_COMPILER_TYPE,
compiler_version=NEURON_COMPILER_VERSION,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
model_type=getattr(sub_neuron_config, "MODEL_TYPE", None),
task=getattr(sub_neuron_config, "task", None),
Expand Down Expand Up @@ -392,6 +397,7 @@ def export(
config: "NeuronDefaultConfig",
output: Path,
compiler_workdir: Optional[Path] = None,
inline_weights_to_neff: bool = True,
optlevel: str = "2",
auto_cast: Optional[str] = None,
auto_cast_type: str = "bf16",
Expand All @@ -406,6 +412,7 @@ def export(
config=config,
output=output,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
auto_cast=auto_cast,
auto_cast_type=auto_cast_type,
Expand All @@ -421,6 +428,7 @@ def export_neuronx(
config: "NeuronDefaultConfig",
output: Path,
compiler_workdir: Optional[Path] = None,
inline_weights_to_neff: bool = True,
optlevel: str = "2",
auto_cast: Optional[str] = None,
auto_cast_type: str = "bf16",
Expand All @@ -437,6 +445,8 @@ def export_neuronx(
Directory to store the exported Neuron model.
compiler_workdir (`Optional[Path]`, defaults to `None`):
The directory used by neuronx-cc, where you can find intermediary outputs (neff, weight, hlo...).
inline_weights_to_neff (`bool`, defaults to `True`):
Whether to inline the weights to the neff graph. If set to False, weights will be seperated from the neff.
optlevel (`str`, defaults to `"2"`):
The level of optimization the compiler should perform. Can be `"1"`, `"2"` or `"3"`, defaults to "2".
1: enables the core performance optimizations in the compiler, while also minimizing compile time.
Expand Down Expand Up @@ -504,10 +514,15 @@ def export_neuronx(
dummy_inputs_tuple,
compiler_args=compiler_args,
input_output_aliases=aliases,
inline_weights_to_neff=inline_weights_to_neff,
compiler_workdir=compiler_workdir,
)

if config.dynamic_batch_size is True:
if not inline_weights_to_neff:
raise ValueError(
"Dynamic batching is not yet compatible with the weights/neff non-inlined model. Please set `dynamic_batch_size=False` or `inline_weights_to_neff=True`."
)
neuron_model = neuronx.dynamic_batch(neuron_model)

# diffusers specific
Expand Down
30 changes: 18 additions & 12 deletions optimum/exporters/neuron/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
@register_in_tasks_manager("bert", *COMMON_TEXT_TASKS)
class BertNeuronConfig(TextEncoderNeuronConfig):
NORMALIZED_CONFIG_CLASS = NormalizedConfigManager.get_normalized_config_class("bert")
ATOL_FOR_VALIDATION = 1e-4
ATOL_FOR_VALIDATION = 1e-3

@property
def inputs(self) -> List[str]:
Expand All @@ -83,6 +83,8 @@ class AlbertNeuronConfig(BertNeuronConfig):

@register_in_tasks_manager("convbert", *COMMON_TEXT_TASKS)
class ConvBertNeuronConfig(BertNeuronConfig):
ATOL_FOR_VALIDATION = 1e-1 # TODO: why accuracy more off than other arch

@property
def outputs(self) -> List[str]:
if self.task == "feature-extraction":
Expand All @@ -91,12 +93,16 @@ def outputs(self) -> List[str]:


@register_in_tasks_manager("electra", *COMMON_TEXT_TASKS)
class ElectraNeuronConfig(ConvBertNeuronConfig):
pass
class ElectraNeuronConfig(BertNeuronConfig):
@property
def outputs(self) -> List[str]:
if self.task == "feature-extraction":
return ["last_hidden_state"]
return self._TASK_TO_COMMON_OUTPUTS[self.task]


@register_in_tasks_manager("flaubert", *COMMON_TEXT_TASKS)
class FlaubertNeuronConfig(ConvBertNeuronConfig):
class FlaubertNeuronConfig(ElectraNeuronConfig):
pass


Expand All @@ -106,18 +112,18 @@ class MobileBertNeuronConfig(BertNeuronConfig):


@register_in_tasks_manager("roformer", *COMMON_TEXT_TASKS)
class RoFormerNeuronConfig(ConvBertNeuronConfig):
class RoFormerNeuronConfig(ElectraNeuronConfig):
pass


@register_in_tasks_manager("xlm", *COMMON_TEXT_TASKS)
class XLMNeuronConfig(ConvBertNeuronConfig):
class XLMNeuronConfig(ElectraNeuronConfig):
pass


@register_in_tasks_manager("distilbert", *COMMON_TEXT_TASKS)
class DistilBertNeuronConfig(BertNeuronConfig):
ATOL_FOR_VALIDATION = 1e-4
ATOL_FOR_VALIDATION = 1e-3

@property
def inputs(self) -> List[str]:
Expand All @@ -132,7 +138,7 @@ def outputs(self) -> List[str]:

@register_in_tasks_manager("camembert", *COMMON_TEXT_TASKS)
class CamembertNeuronConfig(BertNeuronConfig):
ATOL_FOR_VALIDATION = 1e-4
ATOL_FOR_VALIDATION = 1e-3

@property
def inputs(self) -> List[str]:
Expand All @@ -156,8 +162,8 @@ class XLMRobertaNeuronConfig(CamembertNeuronConfig):

# https://github.com/aws-neuron/aws-neuron-sdk/issues/642
# Failed only for INF1: 'XSoftmax'
@register_in_tasks_manager("deberta", *COMMON_TEXT_TASKS)
class DebertaNeuronConfig(BertNeuronConfig):
@register_in_tasks_manager("deberta", *([task for task in COMMON_TEXT_TASKS if task != "multiple-choice"]))
class DebertaNeuronConfig(ElectraNeuronConfig):
@property
def inputs(self) -> List[str]:
common_inputs = super().inputs
Expand All @@ -169,8 +175,8 @@ def inputs(self) -> List[str]:

# https://github.com/aws-neuron/aws-neuron-sdk/issues/642
# Failed only for INF1: 'XSoftmax'
@register_in_tasks_manager("deberta-v2", *COMMON_TEXT_TASKS)
class DebertaV2NeuronConfig(DebertaNeuronConfig):
@register_in_tasks_manager("deberta-v2", *([task for task in COMMON_TEXT_TASKS if task != "multiple-choice"]))
class DebertaV2NeuronConfig(ElectraNeuronConfig):
pass


Expand Down
26 changes: 24 additions & 2 deletions optimum/neuron/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@
from ..exporters.tasks import TasksManager
from ..modeling_base import OptimizedModel
from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from .utils import NEURON_FILE_NAME, is_neuron_available, store_compilation_config
from .utils import (
NEURON_FILE_NAME,
check_if_weights_replacable,
is_neuron_available,
replace_weights,
store_compilation_config,
)
from .utils.import_utils import is_neuronx_available
from .utils.version_utils import check_compiler_compatibility, get_neuroncc_version, get_neuronxcc_version

Expand Down Expand Up @@ -103,7 +109,13 @@ def load_model(path: Union[str, Path]) -> torch.jit._script.ScriptModule:
path = Path(path)

if path.is_file():
return torch.jit.load(path)
model = torch.jit.load(path)
return model

def replace_weights(self, weights: Optional[Union[Dict[str, torch.Tensor], torch.nn.Module]] = None):
check_if_weights_replacable(self.config, weights)
if weights is not None:
replace_weights(self.model, weights)

def _save_pretrained(self, save_directory: Union[str, Path]):
"""
Expand Down Expand Up @@ -216,6 +228,7 @@ def _export(
force_download: bool = False,
cache_dir: Optional[str] = None,
compiler_workdir: Optional[Union[str, Path]] = None,
inline_weights_to_neff: bool = True,
optlevel: str = "2",
subfolder: str = "",
local_files_only: bool = False,
Expand Down Expand Up @@ -303,6 +316,7 @@ def _export(
config=neuron_config,
output=save_dir_path / NEURON_FILE_NAME,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
**compiler_kwargs,
)
Expand All @@ -316,6 +330,7 @@ def _export(
dynamic_batch_size=dynamic_batch_size,
compiler_type=compiler_type,
compiler_version=compiler_version,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
task=task,
)
Expand Down Expand Up @@ -570,3 +585,10 @@ def remove_padding(
]

return outputs

@property
def is_weights_neff_separated(self) -> bool:
"""
Whether the Neuron model has separated weights and neff graph (by setting `inline_weights_to_neff=False` during the compilation).
"""
return not self.config.neuron.get("inline_weights_to_neff", True)
4 changes: 4 additions & 0 deletions optimum/neuron/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def _export(
force_download: bool = True,
cache_dir: Optional[str] = None,
compiler_workdir: Optional[str] = None,
inline_weights_to_neff: bool = True,
optlevel: str = "2",
subfolder: str = "",
local_files_only: bool = False,
Expand Down Expand Up @@ -580,6 +581,8 @@ def _export(
standard cache should not be used.
compiler_workdir (`Optional[str]`, defaults to `None`):
Path to a directory in which the neuron compiler will store all intermediary files during the compilation(neff, weight, hlo graph...).
inline_weights_to_neff (`bool`, defaults to `True`):
Whether to inline the weights to the neff graph. If set to False, weights will be seperated from the neff.
optlevel (`str`, defaults to `"2"`):
The level of optimization the compiler should perform. Can be `"1"`, `"2"` or `"3"`, defaults to "2".
1: enables the core performance optimizations in the compiler, while also minimizing compile time.
Expand Down Expand Up @@ -640,6 +643,7 @@ def _export(
dynamic_batch_size=dynamic_batch_size,
cache_dir=cache_dir,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
trust_remote_code=trust_remote_code,
subfolder=subfolder,
Expand Down
2 changes: 2 additions & 0 deletions optimum/neuron/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def _export(
force_download: bool = True,
cache_dir: Optional[str] = None,
compiler_workdir: Optional[str] = None,
inline_weights_to_neff: bool = True,
optlevel: str = "2",
subfolder: str = "",
local_files_only: bool = False,
Expand Down Expand Up @@ -302,6 +303,7 @@ def _export(
dynamic_batch_size=dynamic_batch_size,
cache_dir=cache_dir,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
trust_remote_code=trust_remote_code,
subfolder=subfolder,
Expand Down
1 change: 1 addition & 0 deletions optimum/neuron/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
is_transformers_neuronx_available,
)
from .input_generators import DummyBeamValuesGenerator
from .misc import check_if_weights_replacable, replace_weights
from .optimization_utils import get_attention_scores_sd, get_attention_scores_sdxl
from .patching import DynamicPatch, ModelPatcher, Patcher, patch_everywhere, patch_within_function
from .training_utils import (
Expand Down
2 changes: 2 additions & 0 deletions optimum/neuron/utils/argument_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def store_compilation_config(
dynamic_batch_size: bool,
compiler_type: str,
compiler_version: str,
inline_weights_to_neff: bool,
optlevel: str,
model_type: Optional[str] = None,
task: str = None,
Expand All @@ -161,6 +162,7 @@ def store_compilation_config(
# Add neuron version to the config, so it can be checked at load time
config_args["compiler_type"] = compiler_type
config_args["compiler_version"] = compiler_version
config_args["inline_weights_to_neff"] = inline_weights_to_neff

# Add input shapes during compilation to the config
for axis, shape in input_shapes.items():
Expand Down
42 changes: 41 additions & 1 deletion optimum/neuron/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
import re
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union

import torch
from transformers.modeling_utils import _add_variant
Expand All @@ -42,6 +42,9 @@
from .require_utils import requires_safetensors


if TYPE_CHECKING:
from transformers import PretrainedConfig

logger = logging.get_logger()


Expand Down Expand Up @@ -508,3 +511,40 @@ def download_checkpoints_in_cache(
resolved_archive_file = filenames_to_safetensors_filenames[Path(resolved_archive_file).name]

return resolved_archive_file, sharded_metadata


def replace_weights(
model: torch.jit._script.RecursiveScriptModule,
weights: Union[Dict[str, torch.Tensor], torch.nn.Module],
prefix: str = "model",
):
"""
Replaces the weights in a Neuron Model with weights from another model, the original neuron model should have separated weights(by setting `inline_weights_to_neff=Talse` during the tracing).
"""
if isinstance(weights, torch.nn.Module):
weights = weights.state_dict()

# extract module paths from the weights c module
code = model.weights._c.code
start_str = "__parameters__ = ["
end_str = "]\n"
module_paths = code.split(start_str)[1].split(end_str)[0].strip()[:-1:].replace('"', "").split(", ")
module_paths = [module_path for module_path in module_paths if module_path != ""]

for module_path in module_paths:
if len(re.findall("\w\d+", module_path)) > 0:
continue
else:
model.weights._c.setattr(module_path, weights[module_path.replace(prefix + "->", "").replace("->", ".")])


def check_if_weights_replacable(
config: "PretrainedConfig", weights: Optional[Union[Dict[str, torch.Tensor], torch.nn.Module]]
):
is_weights_neff_separated = (
not config.neuron.get("inline_weights_to_neff", True) if hasattr(config, "neuron") else False
)
if weights is not None and not is_weights_neff_separated:
raise RuntimeError(
"Unable to replace weights of the neuron model since its weights and neff are not separated, please set `inline_weights_to_neff=Talse` when converting the model to Neuron format."
)
Loading

0 comments on commit de5752d

Please sign in to comment.