From c66112d741b5e946ea94734c86a70003e1f6d632 Mon Sep 17 00:00:00 2001 From: Zhihao Jia Date: Sun, 6 Oct 2024 00:07:49 -0400 Subject: [PATCH] Reused memory usage for FusedOp --- src/ops/fused.cu | 82 +++++++++++++++++++++--------------------------- 1 file changed, 36 insertions(+), 46 deletions(-) diff --git a/src/ops/fused.cu b/src/ops/fused.cu index 8f1212beb4..971ef145e5 100644 --- a/src/ops/fused.cu +++ b/src/ops/fused.cu @@ -101,40 +101,40 @@ __host__ void assert((int)regions.size() == fused->numInputs + fused->numWeights + fused->numOutputs + softmax_grad_additional_region); - GenericTensorAccessorR input_accessor[MAX_NUM_INPUTS]; - GenericTensorAccessorR weight_accessor[MAX_NUM_WEIGHTS]; - GenericTensorAccessorW output_accessor[MAX_NUM_OUTPUTS]; + std::vector input_accessor; + std::vector weight_accessor; + std::vector output_accessor; assert(fused->numInputs <= MAX_NUM_INPUTS); for (int i = 0; i < fused->numInputs; i++) { - input_accessor[i] = + input_accessor.push_back( helperGetGenericTensorAccessorRO(fused->input_data_types[i], regions[i], task->regions[i], FID_DATA, ctx, - runtime); + runtime)); } int roff = fused->numInputs; assert(fused->numWeights <= MAX_NUM_WEIGHTS); for (int i = 0; i < fused->numWeights; i++) { - weight_accessor[i] = + weight_accessor.push_back( helperGetGenericTensorAccessorRO(fused->weight_data_types[i], regions[i + roff], task->regions[i + roff], FID_DATA, ctx, - runtime); + runtime)); } roff += fused->numWeights; assert(fused->numOutputs <= MAX_NUM_OUTPUTS); for (int i = 0; i < fused->numOutputs; i++) { - output_accessor[i] = + output_accessor.push_back( helperGetGenericTensorAccessorWO(fused->output_data_types[i], regions[i + roff], task->regions[i + roff], FID_DATA, ctx, - runtime); + runtime)); } roff += fused->numOutputs; // Assert that all meta share the same dnn/blas handler @@ -153,39 +153,28 @@ __host__ void int ioff = 0, woff = 0, ooff = 0; for (int op = 0; op < fused->numOperators; op++) { -#if 0 - std::cout << get_operator_type_name(fused->op_op_type[op]) << std::endl; -#endif - GenericTensorAccessorR my_input_accessor[MAX_NUM_INPUTS]; - GenericTensorAccessorR my_weight_accessor[MAX_NUM_WEIGHTS]; - GenericTensorAccessorW my_output_accessor[MAX_NUM_OUTPUTS]; + std::vector my_input_accessor; + std::vector my_weight_accessor; + std::vector my_output_accessor; for (int i = 0; i < fused->op_num_inputs[op]; i++) { int my_off = fused->op_input_idx[i + ioff]; if (fused->op_input_source[i + ioff] == SOURCE_INPUT) { - my_input_accessor[i] = input_accessor[my_off]; -#if 0 - printf("\tmy_input_accessor[%i] = input_accessor[%i]\n", i, my_off); -#endif + my_input_accessor.push_back(input_accessor[my_off]); } else if (fused->op_input_source[i + ioff] == SOURCE_OUTPUT) { - my_input_accessor[i] = output_accessor[my_off]; -#if 0 - printf("\tmy_input_accessor[%i] = output_accessor[%i]\n", i, my_off); -#endif + my_input_accessor.push_back(output_accessor[my_off]); } else { assert(false); } } for (int i = 0; i < fused->op_num_weights[op]; i++) { assert(fused->op_weight_source[i + woff] == SOURCE_WEIGHT); - my_weight_accessor[i] = weight_accessor[fused->op_weight_idx[i + woff]]; + my_weight_accessor.push_back( + weight_accessor[fused->op_weight_idx[i + woff]]); } for (int i = 0; i < fused->op_num_outputs[op]; i++) { int my_off = fused->op_output_idx[i + ooff]; assert(fused->op_output_source[i + ooff] == SOURCE_OUTPUT); - my_output_accessor[i] = output_accessor[my_off]; -#if 0 - printf("\tmy_output_accessor[%i] = output_accessor[%i]\n", i, my_off); -#endif + my_output_accessor.push_back(output_accessor[my_off]); } switch (fused->op_op_type[op]) { case OP_CONCAT: { @@ -195,7 +184,7 @@ __host__ void int num_inputs = fused->op_num_inputs[op]; Kernels::Concat::forward_kernel_wrapper(m, my_output_accessor[0], - my_input_accessor, + my_input_accessor.data(), num_inputs, m->legion_axis); break; @@ -1242,40 +1231,40 @@ __host__ void FusedOp::forward_task(Task const *task, assert(regions.size() == task->regions.size()); assert((int)regions.size() == fused->numInputs + fused->numWeights + fused->numOutputs); - GenericTensorAccessorR input_accessor[MAX_NUM_INPUTS]; - GenericTensorAccessorR weight_accessor[MAX_NUM_WEIGHTS]; - GenericTensorAccessorW output_accessor[MAX_NUM_OUTPUTS]; + std::vector input_accessor; + std::vector weight_accessor; + std::vector output_accessor; assert(fused->numInputs <= MAX_NUM_INPUTS); for (int i = 0; i < fused->numInputs; i++) { - input_accessor[i] = + input_accessor.push_back( helperGetGenericTensorAccessorRO(fused->input_data_types[i], regions[i], task->regions[i], FID_DATA, ctx, - runtime); + runtime)); } int roff = fused->numInputs; assert(fused->numWeights <= MAX_NUM_WEIGHTS); for (int i = 0; i < fused->numWeights; i++) { - weight_accessor[i] = + weight_accessor.push_back( helperGetGenericTensorAccessorRO(fused->weight_data_types[i], regions[i + roff], task->regions[i + roff], FID_DATA, ctx, - runtime); + runtime)); } roff += fused->numWeights; assert(fused->numOutputs <= MAX_NUM_OUTPUTS); for (int i = 0; i < fused->numOutputs; i++) { - output_accessor[i] = + output_accessor.push_back( helperGetGenericTensorAccessorWO(fused->output_data_types[i], regions[i + roff], task->regions[i + roff], FID_DATA, ctx, - runtime); + runtime)); } // Assert that all meta share the same dnn/blas handler int start = 0; @@ -1293,17 +1282,17 @@ __host__ void FusedOp::forward_task(Task const *task, int ioff = 0, woff = 0, ooff = 0; for (int op = 0; op < fused->numOperators; op++) { - GenericTensorAccessorR my_input_accessor[MAX_NUM_INPUTS]; - GenericTensorAccessorR my_weight_accessor[MAX_NUM_WEIGHTS]; - GenericTensorAccessorW my_output_accessor[MAX_NUM_OUTPUTS]; + std::vector my_input_accessor; + std::vector my_weight_accessor; + std::vector my_output_accessor; for (int i = 0; i < fused->op_num_inputs[op]; i++) { int my_off = fused->op_input_idx[i + ioff]; if (fused->op_input_source[i + ioff] == SOURCE_INPUT) { assert(my_off < fused->numInputs); - my_input_accessor[i] = input_accessor[my_off]; + my_input_accessor.push_back(input_accessor[my_off]); } else if (fused->op_input_source[i + ioff] == SOURCE_OUTPUT) { assert(my_off < fused->numOutputs); - my_input_accessor[i] = output_accessor[my_off]; + my_input_accessor.push_back(output_accessor[my_off]); } else { assert(false); } @@ -1311,13 +1300,14 @@ __host__ void FusedOp::forward_task(Task const *task, for (int i = 0; i < fused->op_num_weights[op]; i++) { assert(fused->op_weight_source[i + woff] == SOURCE_WEIGHT); assert(fused->op_weight_idx[i + woff] < fused->numWeights); - my_weight_accessor[i] = weight_accessor[fused->op_weight_idx[i + woff]]; + my_weight_accessor.push_back( + weight_accessor[fused->op_weight_idx[i + woff]]); } for (int i = 0; i < fused->op_num_outputs[op]; i++) { int my_off = fused->op_output_idx[i + ooff]; assert(fused->op_output_source[i + ooff] == SOURCE_OUTPUT); assert(my_off < fused->numOutputs); - my_output_accessor[i] = output_accessor[my_off]; + my_output_accessor.push_back(output_accessor[my_off]); } switch (fused->op_op_type[op]) { case OP_CONCAT: { @@ -1327,7 +1317,7 @@ __host__ void FusedOp::forward_task(Task const *task, int num_inputs = fused->op_num_inputs[op]; Kernels::Concat::forward_kernel_wrapper(m, my_output_accessor[0], - my_input_accessor, + my_input_accessor.data(), num_inputs, m->legion_axis); break;