Skip to content

Commit

Permalink
Merge pull request #1171 from yimuchen/ml_tools_tf
Browse files Browse the repository at this point in the history
feat: tensorflow wrapper
  • Loading branch information
lgray authored Aug 28, 2024
2 parents eccdd2f + d33193b commit 2aaf4f0
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 19 deletions.
72 changes: 61 additions & 11 deletions binder/mltools.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,7 @@
" awkward arrays, the return should also be non-dask awkward arrays that can be\n",
" trivially converted to `numpy` arrays via a `ak.to_numpy` call; if the inputs\n",
" are dask awkward arrays, the return should be still be dask awkward arrays\n",
" that can be trivially converted via a `to_awkward().to_numpy()` call. To\n",
" minimize changes to the code, a simple `dask_awkward/awkward` switcher\n",
" `get_awkward_lib` is provided, as there should be (near)-perfect feature\n",
" parity between the dask and non-dask arrays.\n",
" that can be trivially converted via a `to_awkward().to_numpy()` call.\n",
"\n",
" In this ParticleNet-like example, the model expects the following inputs:\n",
"\n",
Expand Down Expand Up @@ -565,20 +562,19 @@
"## Comments about generalizing to other ML tools\n",
"\n",
"All ML wrappers provided in the `coffea.mltools` module (`triton_wrapper` for\n",
"[triton][triton] server inference, `torch_wrapper` for pytorch, and\n",
"`xgboost_wrapper` for [xgboost][xgboost] inference) follow the same design:\n",
"analyzers is responsible for providing the model of interest, along with\n",
"providing an inherited class that overloads of the following methods to data\n",
"type conversion:\n",
"[triton][triton] server inference, `torch_wrapper` for pytorch,\n",
"`xgboost_wrapper` for [xgboost][xgboost] inference, `tf_wrapper` for tensorflow) \n",
"follow the same design: analyzers is responsible for providing the model of \n",
"interest, along with providing an inherited class that overloads of the following\n",
"methods to data type conversion:\n",
"\n",
"- `prepare_awkward`: converting awkward arrays to `numpy`-compatible awkward\n",
" arrays, the output arrays should be in the format of a tuple `a` and a\n",
" dictionary `b`, which can be expanded out to the input of the ML tool like\n",
" `model(*a, **b)`. Notice some additional trivial conversion, such as the\n",
" conversion to available kernels for `pytorch`, converting to a matrix format\n",
" for `xgboost`, and slice of array for `triton` is handled automatically by the\n",
" respective wrappers. To handle both dask/non-dask arrays, the user should use\n",
" the provided `get_awkward_lib` library switcher.\n",
" respective wrappers.\n",
"- `postprocess_awkward` (optional): converting the trivial converted numpy array\n",
" results back to the analysis specific format. If this is not provided, then a\n",
" simple `ak.from_numpy` conversion results is returned.\n",
Expand All @@ -594,6 +590,60 @@
"[triton]: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver\n",
"[xgboost]: https://xgboost.readthedocs.io/en/stable/\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Additional comments on common `prepare_awkward` patterns\n",
"\n",
"The key requirement of all wrapper classes in `ml_tools` pacakge, is that to convert\n",
"awkward arrays into `numpy`-compatible formats using just `awkward` related tools, \n",
"which ensures that no eager data conversion is performed on dask arrays. Below are\n",
"some common patterns that are useful when defining a user-level class.\n",
"\n",
"### Casting multiple fields a collection to be separate axis\n",
"\n",
"Given our collection of particles of length $N$, our tool is interested in just a \n",
"sub-set of fields is to be represented as an $N\\time M$ array. You can do acheive this \n",
"using just `ak.concatenate` and dimension expansion with `np.newaxis`:\n",
"\n",
"```python\n",
"fields_of_interest = [\"field1\", \"field2\", \"field3\"]\n",
"part_np_array = ak.concatenate(\n",
" [\n",
" part[field][:,np.newaxis] # Expanding length N array to Nx1\n",
" for field in fields_of_interest\n",
" ],\n",
" axis=1,\n",
") # This should now be a Nx3 array\n",
"```\n",
"\n",
"### Fixing collection dimensions\n",
"\n",
"Many ML inteference tools work with fixed dimension inputs, with missing entries \n",
"being set to a placeholder values. A common method for achieving this in awkward\n",
"is with `pad_none`/`fill_none` calls, for example to pad the number of particles\n",
"passed to the inference tool in each event to be a fixed length of 128:\n",
"\n",
"```python\n",
"part_padded = ak.fill_none(\n",
" ak.pad_none(part, 128, axis=1, clip=True),\n",
" -1000, # Placeholder value\n",
" axis=1,\n",
")\n",
"```\n",
"\n",
"The dimensions of this resulting `part_padded` array is still `N x var`, indicating\n",
"that the number of entries `axis=1` can potentially be variable. Depending on the \n",
"ML tools being used, this axis dimension may to be fixed. To strictly convert this \n",
"to a `Nx128` array, one can call `flatten`/`unflatten` pairs:\n",
"\n",
"```python\n",
"part_padded = ak.flatten(part_padded)\n",
"part_padded = ak.unflatten(part_padded, 128) # Now this is a Nx128 array\n",
"```"
]
}
],
"metadata": {
Expand Down
2 changes: 2 additions & 0 deletions src/coffea/ml_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

from coffea.ml_tools.helper import numpy_call_wrapper
from coffea.ml_tools.tf_wrapper import tf_wrapper
from coffea.ml_tools.torch_wrapper import torch_wrapper
from coffea.ml_tools.triton_wrapper import triton_wrapper
from coffea.ml_tools.xgboost_wrapper import xgboost_wrapper
Expand All @@ -15,4 +16,5 @@
"torch_wrapper",
"triton_wrapper",
"xgboost_wrapper",
"tf_wrapper",
]
3 changes: 2 additions & 1 deletion src/coffea/ml_tools/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ class numpy_call_wrapper(abc.ABC):
{awkward.Array: awkward.to_numpy}, default_conv=container_converter.no_action
)
_np_to_ak_ = container_converter(
{numpy.ndarray: awkward.from_numpy}, default_conv=container_converter.no_action
{numpy.ndarray: awkward.from_numpy},
default_conv=container_converter.no_action,
)

def __init__(self):
Expand Down
113 changes: 113 additions & 0 deletions src/coffea/ml_tools/tf_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import warnings

import numpy

_tf_import_error = None
try:
import tensorflow
except (ImportError, ModuleNotFoundError) as err:
_tf_import_error = err

from .helper import nonserializable_attribute, numpy_call_wrapper


class tf_wrapper(nonserializable_attribute, numpy_call_wrapper):
"""
Wrapper for running tensorflow inference with awkward/dask-awkward inputs.
"""

def __init__(self, tf_model: str):
"""
As models are not guaranteed to be directly serializable, the use will
need to pass the model as files saved using the `tf.keras.save` method
[1]. If the user is attempting to run on the clusters, the model file
will need to be passed to the worker nodes in a way which preserves the
file path.
[1]
https://www.tensorflow.org/guide/keras/serialization_and_saving#saving
Parameters ----------
- tf_model: Path to the tensorflow model file to load
"""
if _tf_import_error is not None:
warnings.warn(
"Users should make sure the torch package is installed before proceeding!\n"
"> pip install tensorflow\n"
"or\n"
"> conda install tensorflow",
UserWarning,
)
raise _tf_import_error

nonserializable_attribute.__init__(self, ["model"])
self.tf_model = tf_model

def _create_model(self):
"""
Loading in the model from the model file. We simply rely on Tensorflow
to automatically load the accelerator resources.
# TODO: More control over accelerator resources?
"""
return tensorflow.keras.models.load_model(self.tf_model)

def validate_numpy_input(self, *args: numpy.array, **kwargs: numpy.array) -> None:
"""
Here we are assuming that the model contains the required information
for parsing the input numpy array(s), and that the input numpy array(s)
is the first argument of the user method call.
"""
model_input = self.model.input_shape
input_arr = args[0] # Getting the input array

def _equal_shape(mod_in: tuple, arr_shape: tuple) -> None:
"""Tuple of input shape and array shape"""
assert len(mod_in) == len(
arr_shape
), f"Mismatch number of axis (model: {mod_in}; received: {arr_shape})"
match = [
(m == a if m is not None else True) for m, a in zip(mod_in, arr_shape)
]
assert numpy.all(
match
), f"Mismatch shape (model: {mod_in}; received: {arr_shape})"

if isinstance(model_input, tuple):
# Single input model
_equal_shape(model_input, input_arr.shape)
else:
assert len(input_arr) == len(
model_input
), f"Mismatch number of inputs (model: {len(model_input)}; received: {len(input_arr)})"
for model_shape, arr in zip(model_input, input_arr):
_equal_shape(model_shape, arr.shape)

def numpy_call(self, *args: numpy.array, **kwargs: numpy.array) -> numpy.array:
"""
Evaluating the numpy inputs via the `model.__call__` method. With a
trivial conversion for tensors for the numpy inputs.
TODO: Do we need to evaluate using `predict` [1]? Since array batching
is already handled by dask.
[1]
https://keras.io/getting_started/faq/#whats-the-difference-between-model-methods-predict-and-call
"""
args = [
(
tensorflow.convert_to_tensor(arr)
if arr.flags["WRITEABLE"]
else tensorflow.convert_to_tensor(numpy.copy(arr))
)
for arr in args
]
kwargs = {
key: (
tensorflow.convert_to_tensor(arr)
if arr.flags["WRITABLE"]
else tensorflow.convert_to_tensor(numpy.copy(arr))
)
for key, arr in kwargs.items()
}
return self.model(*args, **kwargs).numpy()
Binary file added tests/samples/tf_model.keras
Binary file not shown.
67 changes: 60 additions & 7 deletions tests/test_ml_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ def prepare_jets_array(njets):
"phi": ak.from_regular(np.random.random(size=(njets, NFEAT))),
"feat1": ak.from_regular(np.random.random(size=(njets, NFEAT))),
"feat2": ak.from_regular(np.random.random(size=(njets, NFEAT))),
# Extra features for testing Tensorflow model for PFCandidate classification
**{
f"feat{i}": ak.from_regular(np.random.random(size=(njets, NFEAT)))
for i in range(3, 19)
},
}
)

Expand All @@ -37,8 +42,8 @@ def prepare_jets_array(njets):
return ak_jets, dak_jets


def common_prepare_awkward(array_lib, jets):
ak = array_lib
def common_prepare_awkward(jets):
"""Common jet parsing routing for pytorch and triton inference"""

def my_pad(arr):
return ak.fill_none(ak.pad_none(arr, 100, axis=1, clip=True), 0.0)
Expand Down Expand Up @@ -81,10 +86,9 @@ def test_triton():
# Defining custom wrapper function with awkward padding requirements.
class triton_wrapper_test(triton_wrapper):
def prepare_awkward(self, output_list, jets):
ak = self.get_awkward_lib(jets)
return [], {
"output_list": output_list,
"input_dict": common_prepare_awkward(ak, jets),
"input_dict": common_prepare_awkward(jets),
}

# Running the evaluation in lazy and non-lazy forms
Expand Down Expand Up @@ -127,8 +131,7 @@ def test_torch():

class torch_wrapper_test(torch_wrapper):
def prepare_awkward(self, jets):
ak = self.get_awkward_lib(jets)
default = common_prepare_awkward(ak, jets)
default = common_prepare_awkward(jets)
return [], {
"points": ak.values_astype(default["points"], np.float32),
"features": ak.values_astype(default["features"], np.float32),
Expand Down Expand Up @@ -156,6 +159,57 @@ def prepare_awkward(self, jets):
client.close()


def test_tensorflow():
_ = pytest.importorskip("tensorflow")

from coffea.ml_tools.tf_wrapper import tf_wrapper

client = Client() # Spawn local cluster

class tf_wrapper_test(tf_wrapper):
def prepare_awkward(self, jets):
# List of PF candidate features used for computation
features = [f"feat{i}" for i in range(1, 19)]

cands = ak.concatenate(
[
# Filling pad with dummy value
ak.fill_none(
ak.pad_none(jets.pfcands[f], 64),
0,
axis=1,
)[..., np.newaxis]
for f in features
],
axis=2,
)
cands = ak.flatten(cands, axis=None) # Flatten everything
cands = ak.unflatten(cands, 18) # Number of features
cands = ak.unflatten(cands, 64) # Number of target entries

return [cands], {}

def postprocess_awkward(self, ret, jets):
# First arguments is the return object of the models method
ret = ret[:, :, 0] # Flattening to get the per candidate entry
ret = ak.from_regular(ret) # Making this into a jagged array
ret = ret[ak.local_index(ret) < jets.ncands]
return ret

# The tensorflow model here is used to classify jet constitutes
tfw = tf_wrapper_test("tests/samples/tf_model.keras")
ak_jets, dak_jets = prepare_jets_array(njets=256)

ak_res = tfw(ak_jets)
dak_res = tfw(dak_jets)

assert np.all(np.isclose(ak_res, dak_res.compute()))
expected_columns = {"ncands"} | {f"pfcands.feat{i}" for i in range(1, 19)}
columns = set(list(dak.necessary_columns(dak_res).values())[0])
assert columns == expected_columns
client.close()


def test_xgboost():
_ = pytest.importorskip("xgboost")

Expand All @@ -167,7 +221,6 @@ def test_xgboost():

class xgboost_test(xgboost_wrapper):
def prepare_awkward(self, events):
ak = self.get_awkward_lib(events)
ret = ak.concatenate(
[events[name][:, np.newaxis] for name in feature_list], axis=1
)
Expand Down

0 comments on commit 2aaf4f0

Please sign in to comment.