diff --git a/include/flexflow/flexflow_c.h b/include/flexflow/flexflow_c.h index 93b0444cb9..1f7ed6989c 100644 --- a/include/flexflow/flexflow_c.h +++ b/include/flexflow/flexflow_c.h @@ -47,6 +47,10 @@ FF_NEW_OPAQUE_TYPE(flexflow_dlrm_config_t); FF_NEW_OPAQUE_TYPE(flexflow_dataloader_4d_t); FF_NEW_OPAQUE_TYPE(flexflow_dataloader_2d_t); FF_NEW_OPAQUE_TYPE(flexflow_single_dataloader_t); +// Inference +FF_NEW_OPAQUE_TYPE(flexflow_batch_config_t); +FF_NEW_OPAQUE_TYPE(flexflow_tree_verify_batch_config_t); +FF_NEW_OPAQUE_TYPE(flexflow_beam_search_batch_config_t); // ----------------------------------------------------------------------- // FFConfig @@ -73,6 +77,7 @@ int flexflow_config_get_epochs(flexflow_config_t handle); bool flexflow_config_get_enable_control_replication(flexflow_config_t handle); int flexflow_config_get_python_data_loader_type(flexflow_config_t handle); + // ----------------------------------------------------------------------- // FFModel // ----------------------------------------------------------------------- @@ -713,6 +718,34 @@ void flexflow_op_forward(flexflow_op_t handle, flexflow_model_t model); void flexflow_perform_registration(void); +// ----------------------------------------------------------------------- +// BatchConfig +// ----------------------------------------------------------------------- + +flexflow_batch_config_t flexflow_batch_config_create(void); + +void flexflow_batch_config_destroy(flexflow_batch_config_t handle); + +// ----------------------------------------------------------------------- +// TreeVerifyBatchConfig +// ----------------------------------------------------------------------- + +flexflow_tree_verify_batch_config_t + flexflow_tree_verify_batch_config_create(void); + +void flexflow_tree_verify_batch_config_destroy( + flexflow_tree_verify_batch_config_t handle); + +// ----------------------------------------------------------------------- +// BeamSearchBatchConfig +// ----------------------------------------------------------------------- + +flexflow_beam_search_batch_config_t + flexflow_beam_search_batch_config_create(void); + +void flexflow_beam_search_batch_config_destroy( + flexflow_beam_search_batch_config_t handle); + #ifdef __cplusplus } #endif diff --git a/python/flexflow/core/flexflow_cffi.py b/python/flexflow/core/flexflow_cffi.py index bb63dc153e..5862a78223 100644 --- a/python/flexflow/core/flexflow_cffi.py +++ b/python/flexflow/core/flexflow_cffi.py @@ -2566,3 +2566,33 @@ def __init__(self, shape, data_type, base_ptr, strides, read_only): 'data': (base_ptr, read_only), 'strides': strides, } + +# ----------------------------------------------------------------------- +# BatchConfig +# ----------------------------------------------------------------------- + +class BatchConfig(object): + __slots__ = ['handle', '_handle'] + def __init__(self): + self.handle = ffc.flexflow_batch_config_create() + self._handle = ffi.gc(self.handle, ffc.flexflow_batch_config_destroy) + +# ----------------------------------------------------------------------- +# TreeVerifyBatchConfig +# ----------------------------------------------------------------------- + +class TreeVerifyBatchConfig(object): + __slots__ = ['handle', '_handle'] + def __init__(self): + self.handle = ffc.flexflow_tree_verify_batch_config_create() + self._handle = ffi.gc(self.handle, ffc.flexflow_tree_verify_batch_config_destroy) + +# ----------------------------------------------------------------------- +# BeamSearchBatchConfig +# ----------------------------------------------------------------------- + +class BatchConfig(object): + __slots__ = ['handle', '_handle'] + def __init__(self): + self.handle = ffc.flexflow_beam_search_batch_config_create() + self._handle = ffi.gc(self.handle, ffc.flexflow_beam_search_batch_config_destroy) diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc index 09258d8206..671016a0fa 100644 --- a/src/c/flexflow_c.cc +++ b/src/c/flexflow_c.cc @@ -55,6 +55,12 @@ class FFCObjectWrapper { FF_NEW_OPAQUE_WRAPPER(flexflow_net_config_t, NetConfig *); FF_NEW_OPAQUE_WRAPPER(flexflow_dlrm_config_t, DLRMConfig *); FF_NEW_OPAQUE_WRAPPER(flexflow_single_dataloader_t, SingleDataLoader *); + // inference + FF_NEW_OPAQUE_WRAPPER(flexflow_batch_config_t, BatchConfig *); + FF_NEW_OPAQUE_WRAPPER(flexflow_tree_verify_batch_config_t, + TreeVerifyBatchConfig *); + FF_NEW_OPAQUE_WRAPPER(flexflow_beam_search_batch_config_t, + BeamSearchBatchConfig *); }; Logger ffc_log("flexflow_c"); @@ -1961,3 +1967,55 @@ void flexflow_perform_registration(void) { Runtime::perform_registration_callback(FFMapper::update_mappers, true /*global*/); } + +// ----------------------------------------------------------------------- +// BatchConfig +// ----------------------------------------------------------------------- + +flexflow_batch_config_t flexflow_batch_config_create(void) { + BatchConfig *config = new BatchConfig(); + DEBUG_PRINT("[BatchConfig] new %p", config); + return FFCObjectWrapper::wrap(config); +} + +void flexflow_batch_config_destroy(flexflow_batch_config_t handle_) { + BatchConfig *handle = FFCObjectWrapper::unwrap(handle_); + DEBUG_PRINT("[BatchConfig] delete %p", handle); + delete handle; +} + +// ----------------------------------------------------------------------- +// TreeVerifyBatchConfig +// ----------------------------------------------------------------------- + +flexflow_tree_verify_batch_config_t + flexflow_tree_verify_batch_config_create(void) { + TreeVerifyBatchConfig *config = new TreeVerifyBatchConfig(); + DEBUG_PRINT("[TreeVerifyBatchConfig] new %p", config); + return FFCObjectWrapper::wrap(config); +} + +void flexflow_tree_verify_batch_config_destroy( + flexflow_tree_verify_batch_config_t handle_) { + TreeVerifyBatchConfig *handle = FFCObjectWrapper::unwrap(handle_); + DEBUG_PRINT("[TreeVerifyBatchConfig] delete %p", handle); + delete handle; +} + +// ----------------------------------------------------------------------- +// BeamSearchBatchConfig +// ----------------------------------------------------------------------- + +flexflow_beam_search_batch_config_t + flexflow_beam_search_batch_config_create(void) { + BeamSearchBatchConfig *config = new BeamSearchBatchConfig(); + DEBUG_PRINT("[BeamSearchBatchConfig] new %p", config); + return FFCObjectWrapper::wrap(config); +} + +void flexflow_beam_search_batch_config_destroy( + flexflow_beam_search_batch_config_t handle_) { + BeamSearchBatchConfig *handle = FFCObjectWrapper::unwrap(handle_); + DEBUG_PRINT("[BeamSearchBatchConfig] delete %p", handle); + delete handle; +}