From 3203c91ebfe8be9358dbccc5ec27fd9ae6d0325c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 19 Feb 2024 00:21:57 +0100 Subject: [PATCH] Add plot to show a timeline (#155) * add plot to show a timeline * better * changelogs --- CHANGELOGS.rst | 1 + _doc/api/tools_other.rst | 5 ++ _unittests/ut_tools/test_js_profile.py | 51 +++++++++++ onnx_extended/tools/js_profile.py | 112 ++++++++++++++++++++++++- 4 files changed, 167 insertions(+), 2 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 57766156..3a3ccce0 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.3.0 +++++ +* :pr:`155`: add a function to draw a timeline from a profile * :pr:`154`: improves ploting legend for profiling * :pr:`151`: refactoring of TreeEnsemble code to make them faster * :pr:`129`, :pr:`132`: support sparse features for TreeEnsemble diff --git a/_doc/api/tools_other.rst b/_doc/api/tools_other.rst index bc9c4162..5f3ba1d1 100644 --- a/_doc/api/tools_other.rst +++ b/_doc/api/tools_other.rst @@ -24,6 +24,11 @@ plot_ort_profile .. autofunction:: onnx_extended.tools.js_profile.plot_ort_profile +plot_ort_profile_timeline ++++++++++++++++++++++++++ + +.. autofunction:: onnx_extended.tools.js_profile.plot_ort_profile_timeline + onnx_extended.tools.run_onnx ============================ diff --git a/_unittests/ut_tools/test_js_profile.py b/_unittests/ut_tools/test_js_profile.py index c7e88eec..e9caf199 100644 --- a/_unittests/ut_tools/test_js_profile.py +++ b/_unittests/ut_tools/test_js_profile.py @@ -17,6 +17,7 @@ from onnx_extended.tools.js_profile import ( js_profile_to_dataframe, plot_ort_profile, + plot_ort_profile_timeline, _process_shape, ) @@ -54,6 +55,7 @@ def _get_model(self): ], ), opset_imports=[make_opsetid("", 18)], + ir_version=9, ) check_model(model_def0) return model_def0 @@ -218,6 +220,7 @@ def _get_model_domain(self): make_opsetid("", 18), make_opsetid("onnx_extented.ortops.tutorial.cpu", 1), ], + ir_version=9, ) check_model(model_def0) return model_def0 @@ -248,6 +251,54 @@ def test_plot_domain_agg(self): os.remove(prof) + def _get_model2(self): + model_def0 = make_model( + make_graph( + [ + make_node("Add", ["X", "init1"], ["X1"]), + make_node("Abs", ["X"], ["X2"]), + make_node("Add", ["X", "init3"], ["inter"]), + make_node("Mul", ["X1", "inter"], ["Xm"]), + make_node("MatMul", ["X1", "Xm"], ["Xm2"]), + make_node("Sub", ["X2", "Xm2"], ["final"]), + ], + "test", + [make_tensor_value_info("X", TensorProto.FLOAT, [None, None])], + [make_tensor_value_info("final", TensorProto.FLOAT, [None, None])], + [ + from_array(np.array([1], dtype=np.float32), name="init1"), + from_array(np.array([3], dtype=np.float32), name="init3"), + ], + ), + opset_imports=[make_opsetid("", 18)], + ir_version=9, + ) + check_model(model_def0) + return model_def0 + + @ignore_warnings(UserWarning) + def test_plot_profile_timeline(self): + sess_options = SessionOptions() + sess_options.enable_profiling = True + sess = InferenceSession( + self._get_model2().SerializeToString(), + sess_options, + providers=["CPUExecutionProvider"], + ) + for _ in range(11): + sess.run(None, dict(X=np.random.rand(2**10, 2**10).astype(np.float32))) + prof = sess.end_profiling() + + df = js_profile_to_dataframe(prof, first_it_out=True) + + fig, ax = plt.subplots(1, 1, figsize=(5, 10)) + plot_ort_profile_timeline(df, ax, title="test_timeline", quantile=0.5) + fig.tight_layout() + fig.savefig("test_plot_profile_timeline.png") + self.assertNotEmpty(fig) + + os.remove(prof) + if __name__ == "__main__": import logging diff --git a/onnx_extended/tools/js_profile.py b/onnx_extended/tools/js_profile.py index cf02e258..e6076c25 100644 --- a/onnx_extended/tools/js_profile.py +++ b/onnx_extended/tools/js_profile.py @@ -205,7 +205,7 @@ def plot_ort_profile( title: Optional[str] = None, ) -> "matplotlib.axes.Axes": """ - Plots time spend in computation based on dataframe + Plots time spend in computation based on a dataframe produced by function :func:`js_profile_to_dataframe`. :param df: dataframe @@ -216,7 +216,7 @@ def plot_ort_profile( """ fontsize = 10 if ax0 is None: - import matplotlib as plt + import matplotlib.pyplot as plt ax0 = plt.gca() @@ -255,3 +255,111 @@ def plot_ort_profile( ax0.get_yaxis().set_label_text("") ax0.set_yticklabels(ax0.get_yticklabels(), fontsize=fontsize) return ax0 + + +def plot_ort_profile_timeline( + df: DataFrame, + ax: Optional["matplotlib.axes.Axes"] = None, + iteration: int = -2, + title: Optional[str] = None, + quantile: float = 0.5, + fontsize: int = 12, +) -> "matplotlib.axes.Axes": + """ + Creates a timeline based on a dataframe + produced by function :func:`js_profile_to_dataframe`. + + :param df: dataframe + :param ax: first axis to draw time + :param iteration: iteration to plot, negative value to start from the end + :param title: graph title + :param quantile: draw the 10% less consuming operators in a different color + :param fontsize: font size + :return: the graph + """ + if ax is None: + import matplotlib.pyplot as plt + + ax = plt.gca() + + df = df.copy() + df["iteration"] = df["iteration"].astype(int) + iterations = set(df["iteration"]) + n_iter = iteration if iteration >= 0 else max(iterations) + 1 + iteration + dfi = df[df["iteration"] == n_iter] + assert dfi.shape[0] > 0, f"Iteration {iteration} cannot be found in {iterations}." + + started = {} + data = [] + for irow in dfi.iterrows(): + assert isinstance( + irow, tuple + ), f"pandas has changed its api, type is {type(row)}" + assert len(irow) == 2, f"pandas has changed its api, row is {row}" + row = irow[1] + it = row["iteration"] + op_type = row["args_op_name"] + op_name = row["op_name"] + event_name = row["event_name"] + provider = row["args_provider"] + ts = float(row["ts"]) + dur = float(row["dur"]) + if event_name == "fence_before": + started[op_type, op_name, it] = dict( + op_name=op_name, op_type=op_type, begin=ts + ) + elif event_name == "kernel_time": + obs = started[op_type, op_name, it] + obs["duration"] = dur + obs["begin_kernel"] = ts + obs["provider"] = provider + elif event_name == "fence_after": + obs = started[op_type, op_name, it] + obs["end"] = ts + data.append(obs) + del started[op_type, op_name, it] + else: + assert event_name in { + "SequentialExecutor::Execute", + "model_run", + }, f"Unexpected event_name={event_name!r}, row={row}" + + # durations + data_dur = list(sorted(d["duration"] for d in data)) + threshold = data_dur[int(quantile * len(data_dur))] + origin = dfi["ts"].min() + + colors = ["blue", "green", "red", "orange"] + + import matplotlib.patches as mpatches + + cs = [0, 0] + for i, obs in enumerate(data): + dur = obs["duration"] + cat = int(dur >= threshold) + + # color + color = colors[cat * 2 + cs[cat] % 2] + cs[cat] += 1 + + # rectangle + t1 = obs["begin"] - origin + t2 = obs["end"] - origin + shape = mpatches.Rectangle((0, t1), 1, t2 - t1, ec="none", color=color) + ax.add_artist(shape) + tk1 = obs["begin_kernel"] - origin + tk2 = (obs["begin_kernel"] + obs["duration"]) - origin + ax.plot([0, 1], [tk1, tk1], "b--") + ax.plot([0, 1], [tk2, tk2], "b--") + + # text + y = (tk1 + tk2) / 2 + text = obs["op_type"] + prov = obs["provider"].replace("ExecutionProvider", "") + name = obs["op_name"] + if len(name) >= 10: + name = name[:5] + "..." + name[5:] + ax.text(1, y, f"{i}:{prov}:{text}-{name}", fontsize=fontsize, va="center") + + ax.invert_yaxis() + return ax