diff --git a/docker/Dockerfile.intel b/docker/Dockerfile.intel index 60fd51b42..a7f1dc978 100644 --- a/docker/Dockerfile.intel +++ b/docker/Dockerfile.intel @@ -27,6 +27,8 @@ RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \ libpng-dev \ python3 \ python3-pip \ + python3-dev \ + libnuma-dev \ && rm -rf /var/lib/apt/lists/*" RUN /usr/sbin/update-ccache-symlinks RUN mkdir /opt/ccache && ccache --set-config=cache_dir=/opt/ccache @@ -43,12 +45,11 @@ RUN python3 -m pip install --no-cache-dir \ torchaudio==${TORCHAUDIO_VERSION} \ -f https://download.pytorch.org/whl/torch_stable.html && \ python3 -m pip install intel-extension-for-pytorch==$IPEX_VERSION && \ - python3 -m pip install oneccl_bind_pt --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ + python3 -m pip install oneccl_bind_pt --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ && \ + python3 -m pip install --no-cache-dir numa -ARG OMP_NUM_THREADS=1 -ENV OMP_NUM_THREADS=${OMP_NUM_THREADS} ARG KMP_BLOCKTIME=1 ENV KMP_BLOCKTIME=${KMP_BLOCKTIME} ARG KMP_HW_SUBSET=1T ENV KMP_HW_SUBSET=${KMP_HW_SUBSET} -ENV LD_PRELOAD="/usr/local/lib/libiomp5.so /usr/lib/x86_64-linux-gnu/libtcmalloc.so" \ No newline at end of file +ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc.so" diff --git a/optimum/intel/utils/__init__.py b/optimum/intel/utils/__init__.py index d77588f89..50cdfa143 100644 --- a/optimum/intel/utils/__init__.py +++ b/optimum/intel/utils/__init__.py @@ -22,6 +22,7 @@ is_neural_compressor_available, is_neural_compressor_version, is_nncf_available, + is_numa_available, is_openvino_available, is_torch_version, is_transformers_available, diff --git a/optimum/intel/utils/import_utils.py b/optimum/intel/utils/import_utils.py index 6be0aac47..032280e94 100644 --- a/optimum/intel/utils/import_utils.py +++ b/optimum/intel/utils/import_utils.py @@ -150,6 +150,14 @@ except importlib_metadata.PackageNotFoundError: _accelerate_available = False +_numa_available = importlib.util.find_spec("numa") is not None + +if _numa_available: + try: + importlib_metadata.version("numa") + except importlib_metadata.PackageNotFoundError: + _numa_available = False + def is_transformers_available(): return _transformers_available @@ -272,6 +280,10 @@ def is_accelerate_available(): return _accelerate_available +def is_numa_available(): + return _numa_available + + # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319 def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str): """ diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index cd5b34f86..1d2f7b03c 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -12,16 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import math +import os +import platform import re from pathlib import Path from typing import List, Optional, Union +import psutil import torch from huggingface_hub import HfApi, HfFolder +from .import_utils import is_numa_available + MULTI_QUERY_ATTN_MODELS = {"gpt_bigcode"} +logger = logging.getLogger(__name__) + def get_model_device(model: torch.nn.Module) -> torch.device: """ @@ -135,3 +144,76 @@ def replace_customized_linear_with_linear(model): setattr(model, child_name, new_m) else: replace_customized_linear_with_linear(child) + + +def get_int_from_env(env_keys, default): + """Returns the first positive env value found in the `env_keys` list or the default.""" + for e in env_keys: + val = int(os.environ.get(e, -1)) + if val >= 0: + return val + return default + + +def bind_cores_for_best_perf(): + """ + Set number of threads per rank, numa cpu affinity and numa memory binding if not already set for better OOB performance. + Works for wold_size >= 1 and rank >= 1 + + Example: + .. code-block:: python + + from optimum.intel.ipex import IPEXModelForCausalLM + from optimum.intel.utils.modeling_utils import bind_cores_for_best_perf + + bind_cores_for_best_perf() + model = IPEXModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.bfloat16, export=True) + tokenizer = AutoTokenizer.from_pretrained("gpt2") + input_sentence = ["tell me a story about a trip to the moon"] + model_inputs = tokenizer(input_sentence, return_tensors="pt") + generation_kwargs = dict(max_new_tokens=500) + generated_ids = model.generate(**model_inputs, **generation_kwargs) + + Returns: + None + + """ + if platform.system() != "Linux": + logger.error("bind_cores_for_best_perf: OS not supported, this function can only be run on Linux systems.") + raise OSError("bind_cores_for_best_perf: OS not supported, this function can only be run on Linux systems.") + if not is_numa_available(): + logger.error("'numa' module not found") + raise ImportError("'numa' module not found, install with 'pip install numa'") + import numa + + local_size = get_int_from_env( + ["LOCAL_WORLD_SIZE", "MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1 + ) + rank_id = get_int_from_env( + ["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"], 0 + ) + nodes = numa.get_max_node() + 1 + rank_per_node = math.ceil(local_size / nodes) + num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes) + node_id = int(rank_id / rank_per_node) + rank_offset_per_node = rank_id % rank_per_node + if os.getenv("OMP_NUM_THREADS") is None: + num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1) + logger.info(f"Setting OMP_NUM_THREADS to {num_cpus_per_rank} for better performance") + else: + num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS")) + logger.info(f"OMP_NUM_THREADS already set to {num_cpus_per_rank}") + if len(numa.get_membind()) == nodes: + # if numa memory binding is not set, set it to the node where the rank is running + numa.set_membind([node_id]) + + torch.set_num_threads(num_cpus_per_rank) + + if len(numa.get_affinity(0)) == psutil.cpu_count(logical=True): + # if numa affinity is unset (default value is set to all logical cores) set it to the physical cores assigned to the rank + cpu_start = num_cpus_per_rank * rank_offset_per_node + numa.set_affinity( + 0, + list(numa.node_to_cpus(node_id))[cpu_start : cpu_start + num_cpus_per_rank], + ) + logger.info(f"affinity={numa.get_affinity(0)}, membind = {numa.get_membind()}")