Skip to content

Commit

Permalink
Remove dependency on eval_saved_model from legacy_feature_extractor.p…
Browse files Browse the repository at this point in the history
…y and legacy_meta_feature_extractor.py

PiperOrigin-RevId: 682377308
  • Loading branch information
zhouhao138 authored and tfx-copybara committed Oct 4, 2024
1 parent 4bb7a93 commit cff152f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
import tensorflow as tf
from tensorflow_model_analysis import constants
from tensorflow_model_analysis.api import types
from tensorflow_model_analysis.eval_saved_model import encoding
from tensorflow_model_analysis.extractors import extractor
from tensorflow_model_analysis.utils import util

_FEATURE_EXTRACTOR_STAGE_NAME = 'ExtractFeatures'
_ENCODING_NODE_SUFFIX = 'node'


def FeatureExtractor(
Expand Down Expand Up @@ -71,8 +71,8 @@ def _AugmentExtracts(
continue
# If data originated from FeaturesPredictionsLabels, then the value will be
# stored under a 'node' key.
if isinstance(val, dict) and encoding.NODE_SUFFIX in val:
val = val.get(encoding.NODE_SUFFIX)
if isinstance(val, dict) and _ENCODING_NODE_SUFFIX in val:
val = val.get(_ENCODING_NODE_SUFFIX)

if name in (prefix, util.KEY_SEPARATOR + prefix):
col_name = prefix
Expand Down Expand Up @@ -128,7 +128,7 @@ def _ParseExample(
if materialize_columns:
extracts[key] = types.MaterializedColumn(name=key, value=values)
if name not in features:
features[name] = {encoding.NODE_SUFFIX: np.array([values])}
features[name] = {_ENCODING_NODE_SUFFIX: np.array([values])}


def _MaterializeFeatures(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@
import tensorflow as tf
from tensorflow_model_analysis import constants
from tensorflow_model_analysis.api import types
from tensorflow_model_analysis.eval_saved_model import encoding


_ENCODING_NODE_SUFFIX = 'node'


def get_feature_value(
fpl: types.FeaturesPredictionsLabels, feature_key: str
) -> Any:
"""Helper to get value from FPL dict."""
node_value = fpl.features[feature_key][encoding.NODE_SUFFIX]
node_value = fpl.features[feature_key][_ENCODING_NODE_SUFFIX]
if isinstance(node_value, tf.compat.v1.SparseTensorValue):
return node_value.values
return node_value
Expand All @@ -47,7 +49,7 @@ def _set_feature_value(
feature_value, tf.compat.v1.SparseTensorValue
):
feature_value = np.array([feature_value])
features[feature_key] = {encoding.NODE_SUFFIX: feature_value}
features[feature_key] = {_ENCODING_NODE_SUFFIX: feature_value}
return features # pytype: disable=bad-return-type


Expand Down

0 comments on commit cff152f

Please sign in to comment.