Skip to content

Commit

Permalink
add one example
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Apr 21, 2024
1 parent 43c27b4 commit c8a64db
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 0 deletions.
7 changes: 7 additions & 0 deletions _doc/benchmarks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ See :ref:`l-example-op-scatternd_cuda`.
The benchmark compares two operators Mul profiles
with their fusion into a single operator.

plot_op_transpose2dcast_cuda
++++++++++++++++++++++++++++

See :ref:`l-example-op-transpose2dcast_cuda`.

The benchmark looks into the fusion to Transpose + Cast.

No specific provider
====================

Expand Down
270 changes: 270 additions & 0 deletions _doc/examples/plot_op_transpose_2d_cast_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
"""
.. _l-example-op-plot_op_transpose2dcast_cuda:
==============================
Fuse Tranpose and Cast on CUDA
==============================
This configuration happens in a :epkg:`Llama` model.
::
output = Cast(Transpose(X), to=FLOAT16)
Where the shapes are:
* X: 4096,4096
Transpose + Cast
================
"""

from onnx_extended.args import get_parsed_args

script_args = get_parsed_args(
"plot_op_transpose_2d_cast",
description=__doc__,
config=(
"small",
"small, short optimization (default), "
"medium for medium sizes, "
"large for big sizes",
"llama for a specific case on llama",
),
warmup=3,
repeat=5,
itype=(10, "1 or 10 for float or float16"),
expose="config,itype,warmup,repeat",
)

import time
import numpy as np
from numpy.testing import assert_almost_equal
from pandas import DataFrame
from tqdm import tqdm
import onnx.helper as oh
from onnx import TensorProto
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot

itype = script_args.itype
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
config = script_args.config
print(f"config={config}")
print(f"itype={itype}, dtype={dtype}")

if config == "small":
sizes = (256, 512, 1024)
elif config == "medium":
sizes = (512, 1024, 2048)
elif config == "large":
sizes = (1024, 2048, 4096, 8192)
elif config == "llama":
sizes = (2048, 4096, 8192)
else:
try:
sizes = list(map(int, config.split(",")))
except (ValueError, TypeError) as e:
raise AssertionError(f"Unexpected config value {config!r}.") from e


def get_model(fused=False, itype=TensorProto.FLOAT):
iitype = TensorProto.FLOAT if itype == TensorProto.FLOAT16 else TensorProto.FLOAT16
suffix = "32" if itype == TensorProto.FLOAT else "16"
if fused:
nodes = [
oh.make_node(
f"Transpose2DCastFP{suffix}",
["X"],
["Y"],
domain="onnx_extended.ortops.optim.cuda",
)
]
else:
nodes = [
oh.make_node("Transpose", ["X"], ["xt"], perm=[1, 0]),
oh.make_node("Cast", ["xt"], ["Y"], to=itype),
]
model = oh.make_model(
oh.make_graph(
nodes,
"g",
[oh.make_tensor_value_info("X", iitype, ["a", "b"])],
[oh.make_tensor_value_info("Y", itype, ["b", "a"])],
),
opset_imports=[
oh.make_opsetid("", 18),
oh.make_opsetid("onnx_extended.ortops.optim.cuda", 1),
],
ir_version=9,
)
return model


model = get_model(itype=itype)
print(onnx_simple_text_plot(model))

###################################
# Models
# ======


def get_session(model):
import onnxruntime
from onnx_extended.ortops.optim.cuda import get_ort_ext_libs

if "CUDAExecutionProvider" not in onnxruntime.get_available_providers():
return None

opts = onnxruntime.SessionOptions()
opts.register_custom_ops_library(get_ort_ext_libs()[0])
sess = onnxruntime.InferenceSession(
model.SerializeToString(),
opts,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
return sess


X = np.random.randn(64, 32).astype(
np.float16 if itype == TensorProto.FLOAT else np.float32
)
feeds = dict(X=X)

sess1 = get_session(model)
if sess1 is not None:
for k, v in feeds.items():
print(k, v.dtype, v.shape)
expected = sess1.run(None, feeds)[0]
print(expected[:4, :4])

##################################################
# Same model but using the fused op.

model = get_model(fused=True, itype=itype)
print(onnx_simple_text_plot(model))

sess2 = get_session(model)
if sess2 is not None:
got = sess2.run(None, feeds)[0]
print(got[:4, :4])
assert_almost_equal(expected, got)

#################################################
# Benchmark
# =========


def move_inputs(sess, feeds):
from onnxruntime.capi._pybind_state import (
SessionIOBinding,
OrtDevice as C_OrtDevice,
OrtValue as C_OrtValue,
)

input_names = [i.name for i in sess.get_inputs()]

ort_device = C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0)

feed_ort_value = [
(name, C_OrtValue.ortvalue_from_numpy(feeds[name], ort_device))
for name in input_names
]

bind = SessionIOBinding(sess._sess)
for name, value in feed_ort_value:
bind.bind_input(
name, ort_device, feeds[name].dtype, value.shape(), value.data_ptr()
)
for o in sess.get_outputs():
bind.bind_output(o.name, ort_device)
return bind, feed_ort_value


def benchmark(
sess, sizes, config, label, itype, times_col: int = 1, times_indices: int = 1
):

data = []
for size in tqdm(sizes):

X = np.random.randn(size, size).astype(
np.float16 if itype == TensorProto.FLOAT else np.float32
)
feeds = dict(X=X)
bind, cuda_feeds = move_inputs(sess, feeds)

begin = time.perf_counter()
for i in range(script_args.warmup):
# sess.run(None, feeds)
sess._sess.run_with_iobinding(bind, None)
warmup = time.perf_counter() - begin

times = []
for i in range(script_args.repeat):
begin = time.perf_counter()
# sess.run(None, feeds)
sess._sess.run_with_iobinding(bind, None)
times.append(time.perf_counter() - begin)

npt = np.array(times)
obs = dict(
warmup=warmup,
time=npt.mean(),
std=npt.std(),
min=npt.min(),
max=npt.max(),
repeat=script_args.repeat,
size=size,
label=label,
)
data.append(obs)
return data


#######################################
# Not Fused.


if sess1 is not None:

print(f"sizes={sizes}")

data_nd1 = benchmark(sess1, sizes, script_args.config, "Not Fused", itype=itype)

#######################################
# Fused.

if sess2 is not None:

data_nd2 = benchmark(sess2, sizes, script_args.config, "Fused", itype=itype)


##########################################
# Data
# ++++

if sess2 is not None:

df = DataFrame(data_nd1 + data_nd2)
df.to_csv("plot_op_transpose_2d_cast_cuda.csv", index=False)
df.to_csv("plot_op_transpose_2d_cast_cuda.xlsx", index=False)
print(df.head())

#####################
# Pivot.

if sess2 is not None:

pivot = df.pivot(index="size", columns="label", values="time")
pivot["ratio"] = pivot["Not Fused"] / pivot["Fused"]
print(pivot)

ax = pivot[["Not Fused", "Fused"]].plot(
logx=True,
logy=True,
title=f"Not Fused/Fused implementation for Transpose + Cast on CUDA\nitype={itype}",
)
ax.get_figure().savefig("plot_op_transpose_2d_cast_cuda.png")

##############################
# It seems worth it to combine both operators.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ max-complexity = 10
[tool.ruff.lint.per-file-ignores]
"_doc/examples/plot_op_mul_cuda.py" = ["E402"]
"_doc/examples/plot_op_scatternd_cuda.py" = ["E402"]
"_doc/examples/plot_op_transpose_2d_cast_cuda.py" = ["E402"]
"onnx_extended/helper/__init__.py" = ["F401"]
"onnx_extended/reference/__init__.py" = ["F401"]
"onnx_extended/tools/__init__.py" = ["F401"]
Expand Down

0 comments on commit c8a64db

Please sign in to comment.