Skip to content

Commit

Permalink
add batch config support in cffi
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Jul 21, 2023
1 parent c616060 commit 1fa3c77
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 0 deletions.
33 changes: 33 additions & 0 deletions include/flexflow/flexflow_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ FF_NEW_OPAQUE_TYPE(flexflow_dlrm_config_t);
FF_NEW_OPAQUE_TYPE(flexflow_dataloader_4d_t);
FF_NEW_OPAQUE_TYPE(flexflow_dataloader_2d_t);
FF_NEW_OPAQUE_TYPE(flexflow_single_dataloader_t);
// Inference
FF_NEW_OPAQUE_TYPE(flexflow_batch_config_t);
FF_NEW_OPAQUE_TYPE(flexflow_tree_verify_batch_config_t);
FF_NEW_OPAQUE_TYPE(flexflow_beam_search_batch_config_t);

// -----------------------------------------------------------------------
// FFConfig
Expand All @@ -73,6 +77,7 @@ int flexflow_config_get_epochs(flexflow_config_t handle);
bool flexflow_config_get_enable_control_replication(flexflow_config_t handle);

int flexflow_config_get_python_data_loader_type(flexflow_config_t handle);

// -----------------------------------------------------------------------
// FFModel
// -----------------------------------------------------------------------
Expand Down Expand Up @@ -713,6 +718,34 @@ void flexflow_op_forward(flexflow_op_t handle, flexflow_model_t model);

void flexflow_perform_registration(void);

// -----------------------------------------------------------------------
// BatchConfig
// -----------------------------------------------------------------------

flexflow_batch_config_t flexflow_batch_config_create(void);

void flexflow_batch_config_destroy(flexflow_batch_config_t handle);

// -----------------------------------------------------------------------
// TreeVerifyBatchConfig
// -----------------------------------------------------------------------

flexflow_tree_verify_batch_config_t
flexflow_tree_verify_batch_config_create(void);

void flexflow_tree_verify_batch_config_destroy(
flexflow_tree_verify_batch_config_t handle);

// -----------------------------------------------------------------------
// BeamSearchBatchConfig
// -----------------------------------------------------------------------

flexflow_beam_search_batch_config_t
flexflow_beam_search_batch_config_create(void);

void flexflow_beam_search_batch_config_destroy(
flexflow_beam_search_batch_config_t handle);

#ifdef __cplusplus
}
#endif
Expand Down
30 changes: 30 additions & 0 deletions python/flexflow/core/flexflow_cffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2566,3 +2566,33 @@ def __init__(self, shape, data_type, base_ptr, strides, read_only):
'data': (base_ptr, read_only),
'strides': strides,
}

# -----------------------------------------------------------------------
# BatchConfig
# -----------------------------------------------------------------------

class BatchConfig(object):
__slots__ = ['handle', '_handle']
def __init__(self):
self.handle = ffc.flexflow_batch_config_create()
self._handle = ffi.gc(self.handle, ffc.flexflow_batch_config_destroy)

# -----------------------------------------------------------------------
# TreeVerifyBatchConfig
# -----------------------------------------------------------------------

class TreeVerifyBatchConfig(object):
__slots__ = ['handle', '_handle']
def __init__(self):
self.handle = ffc.flexflow_tree_verify_batch_config_create()
self._handle = ffi.gc(self.handle, ffc.flexflow_tree_verify_batch_config_destroy)

# -----------------------------------------------------------------------
# BeamSearchBatchConfig
# -----------------------------------------------------------------------

class BatchConfig(object):
__slots__ = ['handle', '_handle']
def __init__(self):
self.handle = ffc.flexflow_beam_search_batch_config_create()
self._handle = ffi.gc(self.handle, ffc.flexflow_beam_search_batch_config_destroy)
58 changes: 58 additions & 0 deletions src/c/flexflow_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ class FFCObjectWrapper {
FF_NEW_OPAQUE_WRAPPER(flexflow_net_config_t, NetConfig *);
FF_NEW_OPAQUE_WRAPPER(flexflow_dlrm_config_t, DLRMConfig *);
FF_NEW_OPAQUE_WRAPPER(flexflow_single_dataloader_t, SingleDataLoader *);
// inference
FF_NEW_OPAQUE_WRAPPER(flexflow_batch_config_t, BatchConfig *);
FF_NEW_OPAQUE_WRAPPER(flexflow_tree_verify_batch_config_t,
TreeVerifyBatchConfig *);
FF_NEW_OPAQUE_WRAPPER(flexflow_beam_search_batch_config_t,
BeamSearchBatchConfig *);
};

Logger ffc_log("flexflow_c");
Expand Down Expand Up @@ -1961,3 +1967,55 @@ void flexflow_perform_registration(void) {
Runtime::perform_registration_callback(FFMapper::update_mappers,
true /*global*/);
}

// -----------------------------------------------------------------------
// BatchConfig
// -----------------------------------------------------------------------

flexflow_batch_config_t flexflow_batch_config_create(void) {
BatchConfig *config = new BatchConfig();
DEBUG_PRINT("[BatchConfig] new %p", config);
return FFCObjectWrapper::wrap(config);
}

void flexflow_batch_config_destroy(flexflow_batch_config_t handle_) {
BatchConfig *handle = FFCObjectWrapper::unwrap(handle_);
DEBUG_PRINT("[BatchConfig] delete %p", handle);
delete handle;
}

// -----------------------------------------------------------------------
// TreeVerifyBatchConfig
// -----------------------------------------------------------------------

flexflow_tree_verify_batch_config_t
flexflow_tree_verify_batch_config_create(void) {
TreeVerifyBatchConfig *config = new TreeVerifyBatchConfig();
DEBUG_PRINT("[TreeVerifyBatchConfig] new %p", config);
return FFCObjectWrapper::wrap(config);
}

void flexflow_tree_verify_batch_config_destroy(
flexflow_tree_verify_batch_config_t handle_) {
TreeVerifyBatchConfig *handle = FFCObjectWrapper::unwrap(handle_);
DEBUG_PRINT("[TreeVerifyBatchConfig] delete %p", handle);
delete handle;
}

// -----------------------------------------------------------------------
// BeamSearchBatchConfig
// -----------------------------------------------------------------------

flexflow_beam_search_batch_config_t
flexflow_beam_search_batch_config_create(void) {
BeamSearchBatchConfig *config = new BeamSearchBatchConfig();
DEBUG_PRINT("[BeamSearchBatchConfig] new %p", config);
return FFCObjectWrapper::wrap(config);
}

void flexflow_beam_search_batch_config_destroy(
flexflow_beam_search_batch_config_t handle_) {
BeamSearchBatchConfig *handle = FFCObjectWrapper::unwrap(handle_);
DEBUG_PRINT("[BeamSearchBatchConfig] delete %p", handle);
delete handle;
}

0 comments on commit 1fa3c77

Please sign in to comment.