diff --git a/shark/iree_utils/_common.py b/shark/iree_utils/_common.py index 3fc689fbe7..481f45890c 100644 --- a/shark/iree_utils/_common.py +++ b/shark/iree_utils/_common.py @@ -13,7 +13,7 @@ # limitations under the License. ## Common utilities to be shared by iree utilities. - +import functools import os import sys import subprocess @@ -93,6 +93,7 @@ def iree_target_map(device): # Finds whether the required drivers are installed for the given device. +@functools.cache def check_device_drivers(device): """Checks necessary drivers present for gpu and vulkan devices""" if "://" in device: diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 4790f631f7..2f24f7b8c3 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -11,12 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import iree.runtime as ireert -import iree.compiler as ireec -from shark.iree_utils._common import iree_device_map, iree_target_map -from shark.iree_utils.cpu_utils import get_iree_cpu_rt_args -from shark.iree_utils.benchmark_utils import * -from shark.parser import shark_args +import functools import numpy as np import os import re @@ -24,6 +19,15 @@ import time from pathlib import Path +import iree.runtime as ireert +import iree.compiler as ireec +from shark.parser import shark_args + +from .trace import DetailLogger +from ._common import iree_device_map, iree_target_map +from .cpu_utils import get_iree_cpu_rt_args +from .benchmark_utils import * + # Get the iree-compile arguments given device. def get_iree_device_args(device, extra_args=[]): @@ -318,7 +322,6 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None): device = iree_device_map(device) print("registering device id: ", device_idx) haldriver = ireert.get_driver(device) - haldevice = haldriver.create_device( haldriver.query_available_devices()[device_idx]["device_id"], allocators=shark_args.device_allocator, @@ -338,63 +341,64 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None): def load_vmfb_using_mmap( flatbuffer_blob_or_path, device: str, device_idx: int = None ): - instance = ireert.VmInstance() - device = iree_device_map(device) - haldriver = ireert.get_driver(device) - haldevice = haldriver.create_device_by_uri( - device, - allocators=[], - ) - # First get configs. - if device_idx is not None: - device = iree_device_map(device) - print("registering device id: ", device_idx) - haldriver = ireert.get_driver(device) - - haldevice = haldriver.create_device( - haldriver.query_available_devices()[device_idx]["device_id"], - allocators=shark_args.device_allocator, - ) - config = ireert.Config(device=haldevice) - else: - config = get_iree_runtime_config(device) - if "task" in device: - print( - f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}" - ) - for flag in get_iree_cpu_rt_args(): - ireert.flags.parse_flags(flag) - # Now load vmfb. - # Two scenarios we have here :- - # 1. We either have the vmfb already saved and therefore pass the path of it. - # (This would arise if we're invoking `load_module` from a SharkInference obj) - # OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with. - # (This would arise if we're invoking `compile` from a SharkInference obj) - temp_file_to_unlink = None - if isinstance(flatbuffer_blob_or_path, Path): - flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__() - if ( - isinstance(flatbuffer_blob_or_path, str) - and ".vmfb" in flatbuffer_blob_or_path - ): - vmfb_file_path = flatbuffer_blob_or_path - print( - f"Loading module {flatbuffer_blob_or_path}... ", end="", flush=True - ) - mmaped_vmfb = ireert.VmModule.mmap(instance, flatbuffer_blob_or_path) - print(f"mmap complete... ", end="", flush=True) - ctx = ireert.SystemContext(config=config) - ctx.add_vm_module(mmaped_vmfb) - print(f"module initialized. Ready to run!") - mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name) - else: - with tempfile.NamedTemporaryFile(delete=False) as tf: - tf.write(flatbuffer_blob_or_path) - tf.flush() - vmfb_file_path = tf.name - temp_file_to_unlink = vmfb_file_path - mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path) - return mmaped_vmfb, config, temp_file_to_unlink + print(f"Loading module {flatbuffer_blob_or_path}...") + + with DetailLogger(timeout=2.5) as dl: + # First get configs. + if device_idx is not None: + dl.log(f"Mapping device id: {device_idx}") + device = iree_device_map(device) + haldriver = ireert.get_driver(device) + dl.log(f"ireert.get_driver()") + + haldevice = haldriver.create_device( + haldriver.query_available_devices()[device_idx]["device_id"], + allocators=shark_args.device_allocator, + ) + dl.log(f"ireert.create_device()") + config = ireert.Config(device=haldevice) + dl.log(f"ireert.Config()") + else: + config = get_iree_runtime_config(device) + dl.log("get_iree_runtime_config") + if "task" in device: + print( + f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}" + ) + for flag in get_iree_cpu_rt_args(): + ireert.flags.parse_flags(flag) + # Now load vmfb. + # Two scenarios we have here :- + # 1. We either have the vmfb already saved and therefore pass the path of it. + # (This would arise if we're invoking `load_module` from a SharkInference obj) + # OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with. + # (This would arise if we're invoking `compile` from a SharkInference obj) + temp_file_to_unlink = None + if isinstance(flatbuffer_blob_or_path, Path): + flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__() + if ( + isinstance(flatbuffer_blob_or_path, str) + and ".vmfb" in flatbuffer_blob_or_path + ): + vmfb_file_path = flatbuffer_blob_or_path + mmaped_vmfb = ireert.VmModule.mmap( + config.vm_instance, flatbuffer_blob_or_path + ) + dl.log(f"mmap {flatbuffer_blob_or_path}") + ctx = ireert.SystemContext(config=config) + dl.log(f"ireert.SystemContext created") + ctx.add_vm_module(mmaped_vmfb) + dl.log(f"module initialized") + mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name) + else: + with tempfile.NamedTemporaryFile(delete=False) as tf: + tf.write(flatbuffer_blob_or_path) + tf.flush() + vmfb_file_path = tf.name + temp_file_to_unlink = vmfb_file_path + mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path) + dl.log(f"mmap temp {vmfb_file_path}") + return mmaped_vmfb, config, temp_file_to_unlink def get_iree_compiled_module( @@ -502,31 +506,44 @@ def get_results( config, frontend="torch", send_to_host=True, + debug_timeout: float = 5.0, ): """Runs a .vmfb file given inputs and config and returns output.""" - device_inputs = [ireert.asdevicearray(config.device, a) for a in input] - result = compiled_vm[function_name](*device_inputs) - result_tensors = [] - if isinstance(result, tuple): - if send_to_host: - for val in result: - result_tensors.append(np.asarray(val, val.dtype)) + with DetailLogger(debug_timeout) as dl: + device_inputs = [] + for input_array in input: + dl.log(f"Load to device: {input_array.shape}") + device_inputs.append( + ireert.asdevicearray(config.device, input_array) + ) + dl.log(f"Invoke function: {function_name}") + result = compiled_vm[function_name](*device_inputs) + dl.log(f"Invoke complete") + result_tensors = [] + if isinstance(result, tuple): + if send_to_host: + for val in result: + dl.log(f"Result to host: {val.shape}") + result_tensors.append(np.asarray(val, val.dtype)) + else: + for val in result: + result_tensors.append(val) + return result_tensors + elif isinstance(result, dict): + data = list(result.items()) + if send_to_host: + res = np.array(data, dtype=object) + return np.copy(res) + return data else: - for val in result: - result_tensors.append(val) - return result_tensors - elif isinstance(result, dict): - data = list(result.items()) - if send_to_host: - res = np.array(data, dtype=object) - return np.copy(res) - return data - else: - if send_to_host and result is not None: - return result.to_host() - return result + if send_to_host and result is not None: + dl.log("Result to host") + return result.to_host() + return result + dl.log("Execution complete") +@functools.cache def get_iree_runtime_config(device): device = iree_device_map(device) haldriver = ireert.get_driver(device) diff --git a/shark/iree_utils/cpu_utils.py b/shark/iree_utils/cpu_utils.py index 182b8c8c48..8aca117093 100644 --- a/shark/iree_utils/cpu_utils.py +++ b/shark/iree_utils/cpu_utils.py @@ -14,6 +14,7 @@ # All the iree_cpu related functionalities go here. +import functools import subprocess import platform from shark.parser import shark_args @@ -30,6 +31,7 @@ def get_cpu_count(): # Get the default cpu args. +@functools.cache def get_iree_cpu_args(): uname = platform.uname() os_name, proc_name = uname.system, uname.machine @@ -51,6 +53,7 @@ def get_iree_cpu_args(): # Get iree runtime flags for cpu +@functools.cache def get_iree_cpu_rt_args(): default = get_cpu_count() default = default if default <= 8 else default - 2 diff --git a/shark/iree_utils/gpu_utils.py b/shark/iree_utils/gpu_utils.py index eeca7d17a8..c4a822f850 100644 --- a/shark/iree_utils/gpu_utils.py +++ b/shark/iree_utils/gpu_utils.py @@ -14,12 +14,14 @@ # All the iree_gpu related functionalities go here. +import functools import iree.runtime as ireert import ctypes from shark.parser import shark_args # Get the default gpu args given the architecture. +@functools.cache def get_iree_gpu_args(): ireert.flags.FUNCTION_INPUT_VALIDATION = False ireert.flags.parse_flags("--cuda_allow_inline_execution") @@ -37,6 +39,7 @@ def get_iree_gpu_args(): # Get the default gpu args given the architecture. +@functools.cache def get_iree_rocm_args(): ireert.flags.FUNCTION_INPUT_VALIDATION = False # get arch from rocminfo. @@ -65,6 +68,7 @@ def get_iree_rocm_args(): CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE = 36 +@functools.cache def get_cuda_sm_cc(): libnames = ("libcuda.so", "libcuda.dylib", "nvcuda.dll") for libname in libnames: diff --git a/shark/iree_utils/metal_utils.py b/shark/iree_utils/metal_utils.py index 053c03a9b2..8704709fb2 100644 --- a/shark/iree_utils/metal_utils.py +++ b/shark/iree_utils/metal_utils.py @@ -14,12 +14,15 @@ # All the iree_vulkan related functionalities go here. +import functools + from shark.iree_utils._common import run_cmd import iree.runtime as ireert from sys import platform from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag +@functools.cache def get_metal_device_name(device_num=0): iree_device_dump = run_cmd("iree-run-module --dump_devices") iree_device_dump = iree_device_dump[0].split("\n\n") diff --git a/shark/iree_utils/trace.py b/shark/iree_utils/trace.py new file mode 100644 index 0000000000..ea51243b87 --- /dev/null +++ b/shark/iree_utils/trace.py @@ -0,0 +1,76 @@ +# Copyright 2023 The Nod Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple + +import os +import threading +import time + + +def _enable_detail_trace() -> bool: + return os.getenv("SHARK_DETAIL_TRACE", "0") == "1" + + +class DetailLogger: + """Context manager which can accumulate detailed log messages. + + Detailed log is only emitted if the operation takes a long time + or errors. + """ + + def __init__(self, timeout: float): + self._timeout = timeout + self._messages: List[Tuple[float, str]] = [] + self._start_time = time.time() + self._active = not _enable_detail_trace() + self._lock = threading.RLock() + self._cond = threading.Condition(self._lock) + self._thread = None + + def __enter__(self): + self._thread = threading.Thread(target=self._run) + self._thread.start() + return self + + def __exit__(self, type, value, traceback): + with self._lock: + self._active = False + self._cond.notify() + if traceback: + self.dump_on_error(f"exception") + + def _run(self): + with self._lock: + timed_out = not self._cond.wait(self._timeout) + if timed_out: + self.dump_on_error(f"took longer than {self._timeout}s") + + def log(self, msg): + with self._lock: + timestamp = time.time() + if self._active: + self._messages.append((timestamp, msg)) + else: + print(f" +{(timestamp - self._start_time) * 1000}ms: {msg}") + + def dump_on_error(self, summary: str): + with self._lock: + if self._active: + print(f"::: Detailed report ({summary}):") + for timestamp, msg in self._messages: + print( + f" +{(timestamp - self._start_time) * 1000}ms: {msg}" + ) + self._active = False diff --git a/shark/iree_utils/vulkan_target_env_utils.py b/shark/iree_utils/vulkan_target_env_utils.py index 1559fc261d..36cc7b397d 100644 --- a/shark/iree_utils/vulkan_target_env_utils.py +++ b/shark/iree_utils/vulkan_target_env_utils.py @@ -13,8 +13,10 @@ # limitations under the License. from collections import OrderedDict +import functools +@functools.cache def get_vulkan_target_env(vulkan_target_triple): arch, product, os = vulkan_target_triple.split("=")[1].split("-") triple = (arch, product, os) @@ -52,6 +54,7 @@ def get_version(triple): return "v1.3" +@functools.cache def get_extensions(triple): def make_ext_list(ext_list): res = "" @@ -122,6 +125,7 @@ def make_ext_list(ext_list): return make_ext_list(ext_list=ext) +@functools.cache def get_vendor(triple): arch, product, os = triple if arch == "unknown": @@ -146,6 +150,7 @@ def get_vendor(triple): return "Unknown" +@functools.cache def get_device_type(triple): arch, product, _ = triple if arch == "unknown": @@ -166,6 +171,7 @@ def get_device_type(triple): # get all the capabilities for the device # TODO: make a dataclass for capabilites and init using vulkaninfo +@functools.cache def get_vulkan_target_capabilities(triple): def get_subgroup_val(l): return int(sum([subgroup_feature[sgf] for sgf in l])) diff --git a/shark/iree_utils/vulkan_utils.py b/shark/iree_utils/vulkan_utils.py index 3675a86929..956929b073 100644 --- a/shark/iree_utils/vulkan_utils.py +++ b/shark/iree_utils/vulkan_utils.py @@ -14,6 +14,7 @@ # All the iree_vulkan related functionalities go here. +import functools from os import linesep from shark.iree_utils._common import run_cmd import iree.runtime as ireert @@ -22,6 +23,7 @@ from shark.parser import shark_args +@functools.cache def get_vulkan_device_name(device_num=0): vulkaninfo_dump, _ = run_cmd("vulkaninfo") vulkaninfo_dump = vulkaninfo_dump.split(linesep) @@ -48,6 +50,7 @@ def get_os_name(): return "linux" +@functools.cache def get_vulkan_target_triple(device_name): """This method provides a target triple str for specified vulkan device. @@ -172,6 +175,7 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]): return res_vulkan_flag +@functools.cache def get_iree_vulkan_runtime_flags(): vulkan_runtime_flags = [ f"--vulkan_large_heap_block_size={shark_args.vulkan_large_heap_block_size}",