Skip to content

Commit

Permalink
Macos fixes (#1883)
Browse files Browse the repository at this point in the history
* fix venv setup for MacOS

* allow stream fuse binding on mac

* clean iree metal args
  • Loading branch information
PhaneeshB authored Oct 10, 2023
1 parent 2004d16 commit a731eb6
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 25 deletions.
3 changes: 0 additions & 3 deletions apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ def forward(self, noise_pred, sigma, latent, dt):
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")

def _import(self):
scaling_model = ScalingModel()
Expand Down
4 changes: 0 additions & 4 deletions apps/stable_diffusion/src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,10 +535,6 @@ def get_opt_flags(model, precision="fp16"):
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
)

# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")

if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
iree_flags += opt_flags[model][is_tuned][precision][
"default_compilation_flags"
Expand Down
9 changes: 8 additions & 1 deletion setup_venv.sh
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ $PYTHON -m pip install --upgrade -r "$TD/requirements.txt"
if [ "$torch_mlir_bin" = true ]; then
if [[ $(uname -s) = 'Darwin' ]]; then
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
$PYTHON -m pip uninstall -y timm #TEMP FIX FOR MAC
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
else
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
Expand Down Expand Up @@ -128,7 +129,13 @@ if [[ ! -z "${IMPORTER}" ]]; then
fi
fi

$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/cpu/
if [[ $(uname -s) = 'Darwin' ]]; then
PYTORCH_URL=https://download.pytorch.org/whl/nightly/torch/
else
PYTORCH_URL=https://download.pytorch.org/whl/nightly/cpu/
fi

$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f ${PYTORCH_URL}

if [[ $(uname -s) = 'Linux' && ! -z "${IMPORTER}" ]]; then
T_VER=$($PYTHON -m pip show torch | grep Version)
Expand Down
20 changes: 3 additions & 17 deletions shark/iree_utils/metal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,24 +89,10 @@ def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):


def get_iree_metal_args(device_num=0, extra_args=[]):
# res_metal_flag = ["--iree-flow-demote-i64-to-i32"]

# Add any metal spefic compilation flags here
res_metal_flag = []
metal_triple_flag = None
for arg in extra_args:
if "-iree-metal-target-platform=" in arg:
print(f"Using target triple {arg} from command line args")
metal_triple_flag = arg
break

if metal_triple_flag is None:
metal_triple_flag = get_metal_triple_flag(extra_args=extra_args)

if metal_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(
"-iree-vulkan-target-triple=m1-moltenvk-macos"
)
res_metal_flag.append(vulkan_target_env)
if len(extra_args) > 0:
res_metal_flag.extend(extra_args)
return res_metal_flag


Expand Down

0 comments on commit a731eb6

Please sign in to comment.