Skip to content

Commit

Permalink
Optimize device enumeration overhead and log details on long operatio…
Browse files Browse the repository at this point in the history
…ns. (#1734)

* Optimize device enumeration overhead and log details on long operations.

* Various fixes to add `@functools.cache` to what should be one time, expensive, device enumeration and setup activities. Cuts several seconds off of initialization on my machine.
* Add detailed tracing to actual invocations if they exceed a certain timeout or have an exception.
* Add detailed tracing to loading status.
* By default detail logging is only printed if an operation takes an excessive amount of time. All logging/timing can be printed by setting the variable `$env:SHARK_DETAIL_TRACE = "1"`

* Remove cache from unhashable functions
  • Loading branch information
stellaraccident authored Aug 8, 2023
1 parent 9e37e03 commit cec6eda
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 85 deletions.
3 changes: 2 additions & 1 deletion shark/iree_utils/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

## Common utilities to be shared by iree utilities.

import functools
import os
import sys
import subprocess
Expand Down Expand Up @@ -93,6 +93,7 @@ def iree_target_map(device):


# Finds whether the required drivers are installed for the given device.
@functools.cache
def check_device_drivers(device):
"""Checks necessary drivers present for gpu and vulkan devices"""
if "://" in device:
Expand Down
185 changes: 101 additions & 84 deletions shark/iree_utils/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import iree.runtime as ireert
import iree.compiler as ireec
from shark.iree_utils._common import iree_device_map, iree_target_map
from shark.iree_utils.cpu_utils import get_iree_cpu_rt_args
from shark.iree_utils.benchmark_utils import *
from shark.parser import shark_args
import functools
import numpy as np
import os
import re
import tempfile
import time
from pathlib import Path

import iree.runtime as ireert
import iree.compiler as ireec
from shark.parser import shark_args

from .trace import DetailLogger
from ._common import iree_device_map, iree_target_map
from .cpu_utils import get_iree_cpu_rt_args
from .benchmark_utils import *


# Get the iree-compile arguments given device.
def get_iree_device_args(device, extra_args=[]):
Expand Down Expand Up @@ -318,7 +322,6 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
device = iree_device_map(device)
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)

haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
allocators=shark_args.device_allocator,
Expand All @@ -338,63 +341,64 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
def load_vmfb_using_mmap(
flatbuffer_blob_or_path, device: str, device_idx: int = None
):
instance = ireert.VmInstance()
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device_by_uri(
device,
allocators=[],
)
# First get configs.
if device_idx is not None:
device = iree_device_map(device)
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)

haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
allocators=shark_args.device_allocator,
)
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
if "task" in device:
print(
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
)
for flag in get_iree_cpu_rt_args():
ireert.flags.parse_flags(flag)
# Now load vmfb.
# Two scenarios we have here :-
# 1. We either have the vmfb already saved and therefore pass the path of it.
# (This would arise if we're invoking `load_module` from a SharkInference obj)
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
# (This would arise if we're invoking `compile` from a SharkInference obj)
temp_file_to_unlink = None
if isinstance(flatbuffer_blob_or_path, Path):
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
if (
isinstance(flatbuffer_blob_or_path, str)
and ".vmfb" in flatbuffer_blob_or_path
):
vmfb_file_path = flatbuffer_blob_or_path
print(
f"Loading module {flatbuffer_blob_or_path}... ", end="", flush=True
)
mmaped_vmfb = ireert.VmModule.mmap(instance, flatbuffer_blob_or_path)
print(f"mmap complete... ", end="", flush=True)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(mmaped_vmfb)
print(f"module initialized. Ready to run!")
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
else:
with tempfile.NamedTemporaryFile(delete=False) as tf:
tf.write(flatbuffer_blob_or_path)
tf.flush()
vmfb_file_path = tf.name
temp_file_to_unlink = vmfb_file_path
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
return mmaped_vmfb, config, temp_file_to_unlink
print(f"Loading module {flatbuffer_blob_or_path}...")

with DetailLogger(timeout=2.5) as dl:
# First get configs.
if device_idx is not None:
dl.log(f"Mapping device id: {device_idx}")
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
dl.log(f"ireert.get_driver()")

haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
allocators=shark_args.device_allocator,
)
dl.log(f"ireert.create_device()")
config = ireert.Config(device=haldevice)
dl.log(f"ireert.Config()")
else:
config = get_iree_runtime_config(device)
dl.log("get_iree_runtime_config")
if "task" in device:
print(
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
)
for flag in get_iree_cpu_rt_args():
ireert.flags.parse_flags(flag)
# Now load vmfb.
# Two scenarios we have here :-
# 1. We either have the vmfb already saved and therefore pass the path of it.
# (This would arise if we're invoking `load_module` from a SharkInference obj)
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
# (This would arise if we're invoking `compile` from a SharkInference obj)
temp_file_to_unlink = None
if isinstance(flatbuffer_blob_or_path, Path):
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
if (
isinstance(flatbuffer_blob_or_path, str)
and ".vmfb" in flatbuffer_blob_or_path
):
vmfb_file_path = flatbuffer_blob_or_path
mmaped_vmfb = ireert.VmModule.mmap(
config.vm_instance, flatbuffer_blob_or_path
)
dl.log(f"mmap {flatbuffer_blob_or_path}")
ctx = ireert.SystemContext(config=config)
dl.log(f"ireert.SystemContext created")
ctx.add_vm_module(mmaped_vmfb)
dl.log(f"module initialized")
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
else:
with tempfile.NamedTemporaryFile(delete=False) as tf:
tf.write(flatbuffer_blob_or_path)
tf.flush()
vmfb_file_path = tf.name
temp_file_to_unlink = vmfb_file_path
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
dl.log(f"mmap temp {vmfb_file_path}")
return mmaped_vmfb, config, temp_file_to_unlink


def get_iree_compiled_module(
Expand Down Expand Up @@ -502,31 +506,44 @@ def get_results(
config,
frontend="torch",
send_to_host=True,
debug_timeout: float = 5.0,
):
"""Runs a .vmfb file given inputs and config and returns output."""
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
result = compiled_vm[function_name](*device_inputs)
result_tensors = []
if isinstance(result, tuple):
if send_to_host:
for val in result:
result_tensors.append(np.asarray(val, val.dtype))
with DetailLogger(debug_timeout) as dl:
device_inputs = []
for input_array in input:
dl.log(f"Load to device: {input_array.shape}")
device_inputs.append(
ireert.asdevicearray(config.device, input_array)
)
dl.log(f"Invoke function: {function_name}")
result = compiled_vm[function_name](*device_inputs)
dl.log(f"Invoke complete")
result_tensors = []
if isinstance(result, tuple):
if send_to_host:
for val in result:
dl.log(f"Result to host: {val.shape}")
result_tensors.append(np.asarray(val, val.dtype))
else:
for val in result:
result_tensors.append(val)
return result_tensors
elif isinstance(result, dict):
data = list(result.items())
if send_to_host:
res = np.array(data, dtype=object)
return np.copy(res)
return data
else:
for val in result:
result_tensors.append(val)
return result_tensors
elif isinstance(result, dict):
data = list(result.items())
if send_to_host:
res = np.array(data, dtype=object)
return np.copy(res)
return data
else:
if send_to_host and result is not None:
return result.to_host()
return result
if send_to_host and result is not None:
dl.log("Result to host")
return result.to_host()
return result
dl.log("Execution complete")


@functools.cache
def get_iree_runtime_config(device):
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
Expand Down
3 changes: 3 additions & 0 deletions shark/iree_utils/cpu_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# All the iree_cpu related functionalities go here.

import functools
import subprocess
import platform
from shark.parser import shark_args
Expand All @@ -30,6 +31,7 @@ def get_cpu_count():


# Get the default cpu args.
@functools.cache
def get_iree_cpu_args():
uname = platform.uname()
os_name, proc_name = uname.system, uname.machine
Expand All @@ -51,6 +53,7 @@ def get_iree_cpu_args():


# Get iree runtime flags for cpu
@functools.cache
def get_iree_cpu_rt_args():
default = get_cpu_count()
default = default if default <= 8 else default - 2
Expand Down
4 changes: 4 additions & 0 deletions shark/iree_utils/gpu_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

# All the iree_gpu related functionalities go here.

import functools
import iree.runtime as ireert
import ctypes
from shark.parser import shark_args


# Get the default gpu args given the architecture.
@functools.cache
def get_iree_gpu_args():
ireert.flags.FUNCTION_INPUT_VALIDATION = False
ireert.flags.parse_flags("--cuda_allow_inline_execution")
Expand All @@ -37,6 +39,7 @@ def get_iree_gpu_args():


# Get the default gpu args given the architecture.
@functools.cache
def get_iree_rocm_args():
ireert.flags.FUNCTION_INPUT_VALIDATION = False
# get arch from rocminfo.
Expand Down Expand Up @@ -65,6 +68,7 @@ def get_iree_rocm_args():
CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE = 36


@functools.cache
def get_cuda_sm_cc():
libnames = ("libcuda.so", "libcuda.dylib", "nvcuda.dll")
for libname in libnames:
Expand Down
3 changes: 3 additions & 0 deletions shark/iree_utils/metal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@

# All the iree_vulkan related functionalities go here.

import functools

from shark.iree_utils._common import run_cmd
import iree.runtime as ireert
from sys import platform
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag


@functools.cache
def get_metal_device_name(device_num=0):
iree_device_dump = run_cmd("iree-run-module --dump_devices")
iree_device_dump = iree_device_dump[0].split("\n\n")
Expand Down
76 changes: 76 additions & 0 deletions shark/iree_utils/trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2023 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple

import os
import threading
import time


def _enable_detail_trace() -> bool:
return os.getenv("SHARK_DETAIL_TRACE", "0") == "1"


class DetailLogger:
"""Context manager which can accumulate detailed log messages.
Detailed log is only emitted if the operation takes a long time
or errors.
"""

def __init__(self, timeout: float):
self._timeout = timeout
self._messages: List[Tuple[float, str]] = []
self._start_time = time.time()
self._active = not _enable_detail_trace()
self._lock = threading.RLock()
self._cond = threading.Condition(self._lock)
self._thread = None

def __enter__(self):
self._thread = threading.Thread(target=self._run)
self._thread.start()
return self

def __exit__(self, type, value, traceback):
with self._lock:
self._active = False
self._cond.notify()
if traceback:
self.dump_on_error(f"exception")

def _run(self):
with self._lock:
timed_out = not self._cond.wait(self._timeout)
if timed_out:
self.dump_on_error(f"took longer than {self._timeout}s")

def log(self, msg):
with self._lock:
timestamp = time.time()
if self._active:
self._messages.append((timestamp, msg))
else:
print(f" +{(timestamp - self._start_time) * 1000}ms: {msg}")

def dump_on_error(self, summary: str):
with self._lock:
if self._active:
print(f"::: Detailed report ({summary}):")
for timestamp, msg in self._messages:
print(
f" +{(timestamp - self._start_time) * 1000}ms: {msg}"
)
self._active = False
Loading

0 comments on commit cec6eda

Please sign in to comment.