Skip to content

Commit

Permalink
Add plot to show a timeline (#155)
Browse files Browse the repository at this point in the history
* add plot to show a timeline

* better

* changelogs
  • Loading branch information
xadupre authored Feb 18, 2024
1 parent 4b5a633 commit 3203c91
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.3.0
+++++

* :pr:`155`: add a function to draw a timeline from a profile
* :pr:`154`: improves ploting legend for profiling
* :pr:`151`: refactoring of TreeEnsemble code to make them faster
* :pr:`129`, :pr:`132`: support sparse features for TreeEnsemble
Expand Down
5 changes: 5 additions & 0 deletions _doc/api/tools_other.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ plot_ort_profile

.. autofunction:: onnx_extended.tools.js_profile.plot_ort_profile

plot_ort_profile_timeline
+++++++++++++++++++++++++

.. autofunction:: onnx_extended.tools.js_profile.plot_ort_profile_timeline

onnx_extended.tools.run_onnx
============================

Expand Down
51 changes: 51 additions & 0 deletions _unittests/ut_tools/test_js_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from onnx_extended.tools.js_profile import (
js_profile_to_dataframe,
plot_ort_profile,
plot_ort_profile_timeline,
_process_shape,
)

Expand Down Expand Up @@ -54,6 +55,7 @@ def _get_model(self):
],
),
opset_imports=[make_opsetid("", 18)],
ir_version=9,
)
check_model(model_def0)
return model_def0
Expand Down Expand Up @@ -218,6 +220,7 @@ def _get_model_domain(self):
make_opsetid("", 18),
make_opsetid("onnx_extented.ortops.tutorial.cpu", 1),
],
ir_version=9,
)
check_model(model_def0)
return model_def0
Expand Down Expand Up @@ -248,6 +251,54 @@ def test_plot_domain_agg(self):

os.remove(prof)

def _get_model2(self):
model_def0 = make_model(
make_graph(
[
make_node("Add", ["X", "init1"], ["X1"]),
make_node("Abs", ["X"], ["X2"]),
make_node("Add", ["X", "init3"], ["inter"]),
make_node("Mul", ["X1", "inter"], ["Xm"]),
make_node("MatMul", ["X1", "Xm"], ["Xm2"]),
make_node("Sub", ["X2", "Xm2"], ["final"]),
],
"test",
[make_tensor_value_info("X", TensorProto.FLOAT, [None, None])],
[make_tensor_value_info("final", TensorProto.FLOAT, [None, None])],
[
from_array(np.array([1], dtype=np.float32), name="init1"),
from_array(np.array([3], dtype=np.float32), name="init3"),
],
),
opset_imports=[make_opsetid("", 18)],
ir_version=9,
)
check_model(model_def0)
return model_def0

@ignore_warnings(UserWarning)
def test_plot_profile_timeline(self):
sess_options = SessionOptions()
sess_options.enable_profiling = True
sess = InferenceSession(
self._get_model2().SerializeToString(),
sess_options,
providers=["CPUExecutionProvider"],
)
for _ in range(11):
sess.run(None, dict(X=np.random.rand(2**10, 2**10).astype(np.float32)))
prof = sess.end_profiling()

df = js_profile_to_dataframe(prof, first_it_out=True)

fig, ax = plt.subplots(1, 1, figsize=(5, 10))
plot_ort_profile_timeline(df, ax, title="test_timeline", quantile=0.5)
fig.tight_layout()
fig.savefig("test_plot_profile_timeline.png")
self.assertNotEmpty(fig)

os.remove(prof)


if __name__ == "__main__":
import logging
Expand Down
112 changes: 110 additions & 2 deletions onnx_extended/tools/js_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def plot_ort_profile(
title: Optional[str] = None,
) -> "matplotlib.axes.Axes":
"""
Plots time spend in computation based on dataframe
Plots time spend in computation based on a dataframe
produced by function :func:`js_profile_to_dataframe`.
:param df: dataframe
Expand All @@ -216,7 +216,7 @@ def plot_ort_profile(
"""
fontsize = 10
if ax0 is None:
import matplotlib as plt
import matplotlib.pyplot as plt

ax0 = plt.gca()

Expand Down Expand Up @@ -255,3 +255,111 @@ def plot_ort_profile(
ax0.get_yaxis().set_label_text("")
ax0.set_yticklabels(ax0.get_yticklabels(), fontsize=fontsize)
return ax0


def plot_ort_profile_timeline(
df: DataFrame,
ax: Optional["matplotlib.axes.Axes"] = None,
iteration: int = -2,
title: Optional[str] = None,
quantile: float = 0.5,
fontsize: int = 12,
) -> "matplotlib.axes.Axes":
"""
Creates a timeline based on a dataframe
produced by function :func:`js_profile_to_dataframe`.
:param df: dataframe
:param ax: first axis to draw time
:param iteration: iteration to plot, negative value to start from the end
:param title: graph title
:param quantile: draw the 10% less consuming operators in a different color
:param fontsize: font size
:return: the graph
"""
if ax is None:
import matplotlib.pyplot as plt

ax = plt.gca()

df = df.copy()
df["iteration"] = df["iteration"].astype(int)
iterations = set(df["iteration"])
n_iter = iteration if iteration >= 0 else max(iterations) + 1 + iteration
dfi = df[df["iteration"] == n_iter]
assert dfi.shape[0] > 0, f"Iteration {iteration} cannot be found in {iterations}."

started = {}
data = []
for irow in dfi.iterrows():
assert isinstance(
irow, tuple
), f"pandas has changed its api, type is {type(row)}"
assert len(irow) == 2, f"pandas has changed its api, row is {row}"
row = irow[1]
it = row["iteration"]
op_type = row["args_op_name"]
op_name = row["op_name"]
event_name = row["event_name"]
provider = row["args_provider"]
ts = float(row["ts"])
dur = float(row["dur"])
if event_name == "fence_before":
started[op_type, op_name, it] = dict(
op_name=op_name, op_type=op_type, begin=ts
)
elif event_name == "kernel_time":
obs = started[op_type, op_name, it]
obs["duration"] = dur
obs["begin_kernel"] = ts
obs["provider"] = provider
elif event_name == "fence_after":
obs = started[op_type, op_name, it]
obs["end"] = ts
data.append(obs)
del started[op_type, op_name, it]
else:
assert event_name in {
"SequentialExecutor::Execute",
"model_run",
}, f"Unexpected event_name={event_name!r}, row={row}"

# durations
data_dur = list(sorted(d["duration"] for d in data))
threshold = data_dur[int(quantile * len(data_dur))]
origin = dfi["ts"].min()

colors = ["blue", "green", "red", "orange"]

import matplotlib.patches as mpatches

cs = [0, 0]
for i, obs in enumerate(data):
dur = obs["duration"]
cat = int(dur >= threshold)

# color
color = colors[cat * 2 + cs[cat] % 2]
cs[cat] += 1

# rectangle
t1 = obs["begin"] - origin
t2 = obs["end"] - origin
shape = mpatches.Rectangle((0, t1), 1, t2 - t1, ec="none", color=color)
ax.add_artist(shape)
tk1 = obs["begin_kernel"] - origin
tk2 = (obs["begin_kernel"] + obs["duration"]) - origin
ax.plot([0, 1], [tk1, tk1], "b--")
ax.plot([0, 1], [tk2, tk2], "b--")

# text
y = (tk1 + tk2) / 2
text = obs["op_type"]
prov = obs["provider"].replace("ExecutionProvider", "")
name = obs["op_name"]
if len(name) >= 10:
name = name[:5] + "..." + name[5:]
ax.text(1, y, f"{i}:{prov}:{text}-{name}", fontsize=fontsize, va="center")

ax.invert_yaxis()
return ax

0 comments on commit 3203c91

Please sign in to comment.