Skip to content

Commit

Permalink
First iteration for tensorflow wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
yimuchen committed Aug 26, 2024
1 parent 4de0159 commit a692d7b
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 7 deletions.
2 changes: 2 additions & 0 deletions src/coffea/ml_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from coffea.ml_tools.helper import numpy_call_wrapper
from coffea.ml_tools.torch_wrapper import torch_wrapper
from coffea.ml_tools.tf_wrapper import tf_wrapper
from coffea.ml_tools.triton_wrapper import triton_wrapper
from coffea.ml_tools.xgboost_wrapper import xgboost_wrapper

Expand All @@ -15,4 +16,5 @@
"torch_wrapper",
"triton_wrapper",
"xgboost_wrapper",
"tf_wrapper",
]
2 changes: 1 addition & 1 deletion src/coffea/ml_tools/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class numpy_call_wrapper(abc.ABC):
{awkward.Array: awkward.to_numpy}, default_conv=container_converter.no_action
)
_np_to_ak_ = container_converter(
{numpy.ndarray: awkward.from_numpy}, default_conv=container_converter.no_action
{numpy.ndarray: awkward.from_regular}, default_conv=container_converter.no_action
)

def __init__(self):
Expand Down
77 changes: 77 additions & 0 deletions src/coffea/ml_tools/tf_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import warnings

import numpy

_tf_import_error = None
try:
import tensorflow
except (ImportError, ModuleNotFoundError) as err:
_tf_import_error = err

from .helper import nonserializable_attribute, numpy_call_wrapper


class tf_wrapper(nonserializable_attribute, numpy_call_wrapper):
"""
Wrapper for running tensorflow inference with awkward/dask-awkward inputs.
"""

def __init__(self, tf_model: str):
"""
As models are not guaranteed to be directly serializable, the use will
need to pass the model as files saved using the `tf.keras.save` method
[1]. If the user is attempting to run on the clusters, the model file
will need to be passed to the worker nodes in a way which preserves the
file path.
[1]
https://www.tensorflow.org/guide/keras/serialization_and_saving#saving
Parameters ----------
- tf_model: Path to the tensorflow model file to load
"""
if _tf_import_error is not None:
warnings.warn(
"Users should make sure the torch package is installed before proceeding!\n"
"> pip install tensorflow\n"
"or\n"
"> conda install tensorflow",
UserWarning,
)
raise _tf_import_error

nonserializable_attribute.__init__(self, ["model", "device"])
self.tf_model = tf_model

def _create_device(self):
"""
TODO: is this needed?
"""
return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def _create_model(self):
"""
Loading in the model from the model file. Tensorflow automatically
determines if GPU are available or not and load the resources
accordingly.
"""
return tensorflow.keras.models.load_model(self.tf_model)

def validate_numpy_input(self, *args: numpy.array, **kwargs: numpy.array) -> None:
# Pytorch's model.parameters is not a reliable way to extract input
# information for arbitrary models, so we will leave this to the user.
super().validate_numpy_input(*args, **kwargs)

def numpy_call(self, *args: numpy.array, **kwargs: numpy.array) -> numpy.array:
"""
Evaluating the numpy inputs via the model. Here we are assuming all
inputs can be trivially passed to the underlying model instance after a trivial
`tensorflow.convert_to_tensor method`. The return result will also be cased as
non-available Returning the results also as as numpy array.
TODO: Non-copy conversions?
"""
args = [tensorflow.convert_to_tensor(arr) for arr in args]
kwargs = {key: tensorflow.convert_to_tensor(arr) for key, arr in kwargs.items()}
return self.model(*args, **kwargs).numpy()
Binary file added tests/samples/tf_model.keras
Binary file not shown.
58 changes: 52 additions & 6 deletions tests/test_ml_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ def prepare_jets_array(njets):
"phi": ak.from_regular(np.random.random(size=(njets, NFEAT))),
"feat1": ak.from_regular(np.random.random(size=(njets, NFEAT))),
"feat2": ak.from_regular(np.random.random(size=(njets, NFEAT))),
# Extra features for testing Tensorflow model for PFCandidate classification
**{
f"feat{i}": ak.from_regular(np.random.random(size=(njets, NFEAT)))
for i in range(3, 19)
},
}
)

Expand All @@ -37,8 +42,8 @@ def prepare_jets_array(njets):
return ak_jets, dak_jets


def common_prepare_awkward(array_lib, jets):
ak = array_lib
def common_prepare_awkward(jets):
"""Common jet parsing routing for pytorch and triton inference"""

def my_pad(arr):
return ak.fill_none(ak.pad_none(arr, 100, axis=1, clip=True), 0.0)
Expand Down Expand Up @@ -81,10 +86,9 @@ def test_triton():
# Defining custom wrapper function with awkward padding requirements.
class triton_wrapper_test(triton_wrapper):
def prepare_awkward(self, output_list, jets):
ak = self.get_awkward_lib(jets)
return [], {
"output_list": output_list,
"input_dict": common_prepare_awkward(ak, jets),
"input_dict": common_prepare_awkward(jets),
}

# Running the evaluation in lazy and non-lazy forms
Expand Down Expand Up @@ -127,8 +131,7 @@ def test_torch():

class torch_wrapper_test(torch_wrapper):
def prepare_awkward(self, jets):
ak = self.get_awkward_lib(jets)
default = common_prepare_awkward(ak, jets)
default = common_prepare_awkward(jets)
return [], {
"points": ak.values_astype(default["points"], np.float32),
"features": ak.values_astype(default["features"], np.float32),
Expand Down Expand Up @@ -156,6 +159,49 @@ def prepare_awkward(self, jets):
client.close()


def test_tensorflow():
_ = pytest.importorskip("tensorflow")

from coffea.ml_tools.tf_wrapper import tf_wrapper

client = Client() # Spawn local cluster

class tf_wrapper_test(tf_wrapper):
def prepare_awkward(self, jets):
# List of PF candidate features used for computation
features = [f"feat{i}" for i in range(1, 19)]
# Padding PFCandidate list length to target length (64)
cands = ak.pad_none(jets.pfcands, 64, axis=1)
# Folding features into inner most axis
cands = ak.unflatten(
ak.concatenate([ak.fill_none(cands[f], 100) for f in features], axis=1),
len(features),
axis=-1,
)
# This should now be a trivially convert-able via ak.to_numpy

return [cands], {}

def postprocess_awkward(self, ret, jets):
# First arguments is the return object of the models method
ret = ret[:, :, 0] # Flattening to get the per candidate entry
ret = ret[ak.local_index(ret) < jets.ncands]
return ret

# The tensorflow model here is used to classify jet constitutes
tfw = tf_wrapper_test("tests/samples/tf_model.keras")
ak_jets, dak_jets = prepare_jets_array(njets=256)

ak_res = tfw(ak_jets)
dak_res = tfw(dak_jets)

assert np.all(np.isclose(ak_res, dak_res.compute()))
expected_columns = {f"pfcands.feat{i}" for i in range(1, 19)}
columns = set(list(dak.necessary_columns(dak_res).values())[0])
assert columns == expected_columns
client.close()


def test_xgboost():
_ = pytest.importorskip("xgboost")

Expand Down

0 comments on commit a692d7b

Please sign in to comment.