Skip to content

Commit

Permalink
change argmax to DeviceSegmentedReduce::ArgMax && replace cudamalloc …
Browse files Browse the repository at this point in the history
…with legion instance (#896)

* change argmax to DeviceSegmentedReduce::ArgMax

* replace argmax, beam_topk, rms_norm cudamalloc

* replace layernorm, linear, sampling.

* destructor

* format
  • Loading branch information
xinhaoc authored Jul 28, 2023
1 parent 67977f4 commit 664667e
Show file tree
Hide file tree
Showing 25 changed files with 357 additions and 175 deletions.
25 changes: 14 additions & 11 deletions include/flexflow/ops/argmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,28 @@
#include "flexflow/model.h"
#include "flexflow/node.h"
#include "flexflow/ops/argmax_params.h"
#include "flexflow/utils/memory_allocator.h"

namespace FlexFlow {

class ArgMaxMeta : public OpMeta {
public:
bool beam_search;
float *probs;
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
cudnnTensorDescriptor_t inputTensor, outputTensor;
cudnnReduceTensorDescriptor_t reduceMaxDesc;
#else
miopenTensorDescriptor_t inputTensor, outputTensor;
miopenReduceTensorDescriptor_t reduceMaxDesc;
#endif
void *d_temp_storage;
size_t temp_storage_bytes = 0;
int *d_offsets;
void *d_out;
Realm::RegionInstance reserveInst;
ArgMaxMeta(FFHandler handler,
Op const *op,
Legion::Domain const &input_domain,
Legion::Domain const &output_domain,
GenericTensorAccessorW input);
GenericTensorAccessorW input,
int batch_size,
int total_ele,
MemoryAllocator &gpu_mem_allocator);
~ArgMaxMeta(void);
};

class ArgMax : public Op {
Expand Down Expand Up @@ -88,16 +91,16 @@ class ArgMax : public Op {
static void forward_kernel(ArgMaxMeta const *m,
DT *input_ptr,
int *indices_ptr,
DT *prob_ptr,
float *prob_ptr,
int *parent_ptr,
int length,
int batch_size,
ffStream_t stream);
static void forward_kernel_wrapper(ArgMaxMeta const *m,
GenericTensorAccessorW const &input,
GenericTensorAccessorW const &indices,
GenericTensorAccessorW const &value,
GenericTensorAccessorW const &parent);
GenericTensorAccessorW const &parent,
int batch_size);
Params get_params() const;

public:
Expand Down
7 changes: 6 additions & 1 deletion include/flexflow/ops/beam_topk.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,24 @@
#include "flexflow/model.h"
#include "flexflow/node.h"
#include "flexflow/ops/beam_topk_params.h"
#include "flexflow/utils/memory_allocator.h"

namespace FlexFlow {

class BeamTopKMeta : public OpMeta {
public:
BeamTopKMeta(FFHandler handle, Op const *op);
BeamTopKMeta(FFHandler handle,
Op const *op,
MemoryAllocator &gpu_mem_allocator);
~BeamTopKMeta(void);
bool sorted;
int max_beam_width;
int *parent_ids;
void *acc_probs;
int *block_start_index;
int *request_id;
int *tokens_per_request;
Realm::RegionInstance reserveInst;
};

class BeamTopK : public Op {
Expand Down
2 changes: 2 additions & 0 deletions include/flexflow/ops/kernels/linear_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class LinearMeta : public OpMeta {
Linear const *li,
MemoryAllocator gpu_mem_allocator,
int weightSize);
~LinearMeta(void);
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
cudnnTensorDescriptor_t outputTensor;
cudnnActivationDescriptor_t actiDesc;
Expand All @@ -34,6 +35,7 @@ class LinearMeta : public OpMeta {
float kernel_reg_lambda;
bool use_bias, add_bias_only_once;
char op_name[MAX_OPNAME];
Realm::RegionInstance reserveInst;
};

namespace Kernels {
Expand Down
7 changes: 6 additions & 1 deletion include/flexflow/ops/kernels/rms_norm_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "flexflow/device.h"
#include "flexflow/fftype.h"
#include "flexflow/op_meta.h"
#include "flexflow/utils/memory_allocator.h"

namespace FlexFlow {
using Legion::coord_t;
Expand All @@ -13,7 +14,10 @@ class RMSNorm;

class RMSNormMeta : public OpMeta {
public:
RMSNormMeta(FFHandler handler, RMSNorm const *rms);
RMSNormMeta(FFHandler handler,
RMSNorm const *rms,
MemoryAllocator &gpu_mem_allocator);
~RMSNormMeta(void);
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
cudnnTensorDescriptor_t inputTensor, outputTensor;
cudnnReduceTensorDescriptor_t reduceDesc;
Expand All @@ -34,6 +38,7 @@ class RMSNormMeta : public OpMeta {
int batch_size;
int num_elements;
char op_name[MAX_OPNAME];
Realm::RegionInstance reserveInst;
};

namespace Kernels {
Expand Down
7 changes: 6 additions & 1 deletion include/flexflow/ops/layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "flexflow/inference.h"
#include "flexflow/model.h"
#include "flexflow/utils/memory_allocator.h"
namespace FlexFlow {

class LayerNormMeta;
Expand Down Expand Up @@ -107,14 +108,18 @@ class LayerNorm : public Op {

class LayerNormMeta : public OpMeta {
public:
LayerNormMeta(FFHandler handle, LayerNorm const *ln);
LayerNormMeta(FFHandler handle,
LayerNorm const *ln,
MemoryAllocator &gpu_mem_allocator);
~LayerNormMeta(void);

public:
bool elementwise_affine;
int64_t effective_batch_size, effective_num_elements;
float eps;
void *mean_ptr, *rstd_ptr, *ds_ptr, *db_ptr, *scale_ptr, *bias_ptr;
char op_name[MAX_OPNAME];
Realm::RegionInstance reserveInst;
};

}; // namespace FlexFlow
1 change: 1 addition & 0 deletions include/flexflow/ops/rms_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "flexflow/inference.h"
#include "flexflow/model.h"
#include "flexflow/ops/rms_norm_params.h"
#include "flexflow/utils/memory_allocator.h"

namespace FlexFlow {

Expand Down
6 changes: 5 additions & 1 deletion include/flexflow/ops/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <curand.h>
#include <curand_kernel.h>
#endif
#include "flexflow/utils/memory_allocator.h"

namespace FlexFlow {

Expand All @@ -22,14 +23,17 @@ class SamplingMeta : public OpMeta {
int *idx;
void *d_temp_storage;
size_t temp_storage_bytes;
Realm::RegionInstance reserveInst;
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
curandState *state;
#endif
SamplingMeta(FFHandler handle,
Op const *op,
int batch_size,
int total_ele,
GenericTensorAccessorW input);
GenericTensorAccessorW input,
MemoryAllocator &gpu_mem_allocator);
~SamplingMeta(void);
};

class Sampling : public Op {
Expand Down
2 changes: 2 additions & 0 deletions include/flexflow/simulator.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class LinearMeta;
class Pool2DMeta;
class ElementUnaryMeta;
class ElementBinaryMeta;
class LayerNormMeta;
// class EmbeddingMeta;
// class SoftmaxMeta;
class BatchMatmulMeta;
Expand Down Expand Up @@ -754,6 +755,7 @@ class Simulator {
LinearMeta *linear_meta;
Pool2DMeta *pool2d_meta;
ElementUnaryMeta *ele_unary_meta;
LayerNormMeta *layernorm_meta;
// ElementBinaryMeta *ele_binary_meta;
// EmbeddingMeta *embedding_meta;
// SoftmaxMeta *softmax_meta;
Expand Down
78 changes: 28 additions & 50 deletions src/ops/argmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Tensor FFModel::argmax(const Tensor input, bool beam_search, char const *name) {
name,
1 /*inputs*/,
0 /*weights*/,
beam_search ? 3 : 2 /*outputs*/,
beam_search ? 2 : 1 /*outputs*/,
input);
{
int numdims = input->num_dims;
Expand All @@ -65,13 +65,9 @@ Tensor FFModel::argmax(const Tensor input, bool beam_search, char const *name) {
// numdims, dims, input->data_type, li, 0, true /*create_grad*/);
li->outputs[0] = create_tensor_legion_ordering(
numdims, dims, DT_INT32, li, 0, false /*create_grad*/);
// logits
li->outputs[1] = create_tensor_legion_ordering(
numdims, dims, input->data_type, li, 1, false /*create_grad*/);

if (beam_search) {
// parent id
li->outputs[2] = create_tensor_legion_ordering(
li->outputs[1] = create_tensor_legion_ordering(
numdims, dims, DT_INT32, li, 1, false /*create_grad*/);
}
}
Expand Down Expand Up @@ -116,7 +112,7 @@ ArgMax::ArgMax(FFModel &model,
name,
1 /*inputs*/,
0 /*weights*/,
_beam_search ? 3 : 2 /*outputs*/,
_beam_search ? 2 : 1 /*outputs*/,
_input),
beam_search(_beam_search) {
int numdim = inputs[0]->num_dims;
Expand All @@ -131,11 +127,9 @@ ArgMax::ArgMax(FFModel &model,
// numdim, dims, _input->data_type, this, 0 /*owner_idx*/);
outputs[0] = model.create_parallel_tensor_legion_ordering(
numdim, dims, DT_INT32, this, 0 /*owner_idx*/);
outputs[1] = model.create_parallel_tensor_legion_ordering(
numdim, dims, _input->data_type, this, 1 /*owner_idx*/);
if (_beam_search) {
outputs[2] = model.create_parallel_tensor_legion_ordering(
numdim, dims, DT_INT32, this, 2 /*owner_idx*/);
outputs[1] = model.create_parallel_tensor_legion_ordering(
numdim, dims, DT_INT32, this, 1 /*owner_idx*/);
}
}

Expand Down Expand Up @@ -180,12 +174,6 @@ void ArgMax::init_inference(FFModel const &ff,
EXCLUSIVE,
batch_outputs[0]->region));
launcher.add_field(1, FID_DATA);
launcher.add_region_requirement(RegionRequirement(batch_outputs[1]->part,
0 /*projection id*/,
WRITE_ONLY,
EXCLUSIVE,
batch_outputs[1]->region));
launcher.add_field(2, FID_DATA);
FutureMap fm = runtime->execute_index_space(ctx, launcher);
fm.wait_all_results();
set_opmeta_from_futuremap_inference(ff, fm, batch_outputs[0]);
Expand Down Expand Up @@ -240,9 +228,22 @@ OpMeta *ArgMax::init_task(Task const *task,
ctx, task->regions[0].region.get_index_space());
Domain output_domain = runtime->get_index_space_domain(
ctx, task->regions[2].region.get_index_space());
int length = acc_input.domain.hi()[0] - acc_input.domain.lo()[0] + 1;
int batch_size = acc_input.domain.get_volume() / length;
Memory gpu_mem = Machine::MemoryQuery(Machine::get_machine())
.only_kind(Memory::GPU_FB_MEM)
.best_affinity_to(task->target_proc)
.first();
MemoryAllocator gpu_mem_allocator(gpu_mem);

ArgMaxMeta *m =
new ArgMaxMeta(handle, s, input_domain, output_domain, acc_input);
ArgMaxMeta *m = new ArgMaxMeta(handle,
s,
input_domain,
output_domain,
acc_input,
batch_size,
length * batch_size,
gpu_mem_allocator);
m->profiling = s->profiling;
m->beam_search = s->beam_search;
return m;
Expand Down Expand Up @@ -297,13 +298,6 @@ FutureMap ArgMax::inference(FFModel const &ff,
EXCLUSIVE,
batch_outputs[1]->region));
launcher.add_field(2, FID_DATA);
launcher.add_region_requirement(
RegionRequirement(batch_outputs[2]->part,
0 /*projection id*/,
WRITE_ONLY,
EXCLUSIVE,
batch_outputs[2]->region));
launcher.add_field(3, FID_DATA);
return runtime->execute_index_space(ctx, launcher);
} else {
IndexLauncher launcher(ARGMAX_NORM_INF_TASK_ID,
Expand All @@ -328,13 +322,6 @@ FutureMap ArgMax::inference(FFModel const &ff,
EXCLUSIVE,
batch_outputs[0]->region));
launcher.add_field(1, FID_DATA);
launcher.add_region_requirement(
RegionRequirement(batch_outputs[1]->part,
0 /*projection id*/,
WRITE_ONLY,
EXCLUSIVE,
batch_outputs[1]->region));
launcher.add_field(2, FID_DATA);
return runtime->execute_index_space(ctx, launcher);
}
}
Expand All @@ -344,8 +331,8 @@ BeamInferenceResult
std::vector<PhysicalRegion> const &regions,
Context ctx,
Runtime *runtime) {
assert(regions.size() == 4);
assert(task->regions.size() == 4);
assert(regions.size() == 3);
assert(task->regions.size() == 3);
BatchConfig const *bc = BatchConfig::from_future(task->futures[0]);
if (bc->num_tokens == 0) {
// Directly return for empty batch config
Expand All @@ -359,21 +346,14 @@ BeamInferenceResult
GenericTensorAccessorW indices = helperGetGenericTensorAccessorWO(
DT_INT32, regions[1], task->regions[1], FID_DATA, ctx, runtime);
int batch_size = bc->num_active_tokens();
GenericTensorAccessorW value = helperGetGenericTensorAccessorWO(
m->input_type[0], regions[2], task->regions[1], FID_DATA, ctx, runtime);
GenericTensorAccessorW parent = helperGetGenericTensorAccessorWO(
DT_INT32, regions[3], task->regions[1], FID_DATA, ctx, runtime);
ArgMax::forward_kernel_wrapper(m, input, indices, value, parent);
DT_INT32, regions[2], task->regions[2], FID_DATA, ctx, runtime);
ArgMax::forward_kernel_wrapper(m, input, indices, parent, batch_size);

BeamInferenceResult ir;
download_tensor<BatchConfig::TokenId>(
indices.get_int32_ptr(), ir.token_ids, batch_size);
if (m->input_type[0] == DT_FLOAT) {
download_tensor<float>(value.get_float_ptr(), ir.probs, batch_size);
} else if (m->input_type[0] == DT_HALF) {
download_tensor(m->probs, ir.probs, batch_size);
}

download_tensor(m->probs, ir.probs, batch_size);
download_tensor<int>(parent.get_int32_ptr(), ir.parent_id, batch_size);
return ir;
}
Expand All @@ -383,8 +363,8 @@ InferenceResult
std::vector<PhysicalRegion> const &regions,
Context ctx,
Runtime *runtime) {
assert(regions.size() == 3);
assert(task->regions.size() == 3);
assert(regions.size() == 2);
assert(task->regions.size() == 2);
ArgMaxMeta const *m = *((ArgMaxMeta **)task->local_args);
BatchConfig const *bc = BatchConfig::from_future(task->futures[0]);
if (bc->num_tokens == 0) {
Expand All @@ -397,11 +377,9 @@ InferenceResult
m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime);
GenericTensorAccessorW indices = helperGetGenericTensorAccessorWO(
DT_INT32, regions[1], task->regions[1], FID_DATA, ctx, runtime);
GenericTensorAccessorW value = helperGetGenericTensorAccessorWO(
m->input_type[0], regions[2], task->regions[1], FID_DATA, ctx, runtime);
GenericTensorAccessorW parent;
int batch_size = bc->num_active_tokens();
ArgMax::forward_kernel_wrapper(m, input, indices, value, parent);
ArgMax::forward_kernel_wrapper(m, input, indices, parent, batch_size);
InferenceResult ir;
download_tensor<BatchConfig::TokenId>(
indices.get_int32_ptr(), ir.token_ids, batch_size);
Expand Down
Loading

0 comments on commit 664667e

Please sign in to comment.