Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro authored and Gabriele Oliaro committed Aug 16, 2024
1 parent 440ad3d commit 5cbe1a4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
18 changes: 12 additions & 6 deletions src/parallel_ops/parallel_identity.cc
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ void ParallelIdentity::inference_task(
assert(regions.size() == 2);
assert(task->regions.size() == 2);

ParallelIdentityMeta const *m = *((ParallelIdentityMeta **)task->local_args);
ParallelIdentityMeta *m = *((ParallelIdentityMeta **)task->local_args);
BatchConfig const *bc = BatchConfig::from_future(task->futures[0]);
if (bc->num_active_tokens() == 0) {
return;
Expand All @@ -349,10 +349,13 @@ void ParallelIdentity::inference_task(
m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime);

assert(input.data_type == output.data_type);
inference_kernel_wrapper(m, bc, input, output);
if (m->inference_debugging) {
std::cout << "INF " << m->op_name << std::endl;
assert(task->index_point.get_dim() == 1);
int shard_id = task->index_point.point_data[0];
ParallelIdentity::save_inference_tensors_to_file(
m, shard_id, bc, {input}, {}, {output});
}
inference_kernel_wrapper(m, bc, input, output);
}

FutureMap
Expand Down Expand Up @@ -406,7 +409,7 @@ void ParallelIdentity::peft_bwd_task(Task const *task,
assert(regions.size() == 2);
assert(task->regions.size() == 2);

ParallelIdentityMeta const *m = *((ParallelIdentityMeta **)task->local_args);
ParallelIdentityMeta *m = *((ParallelIdentityMeta **)task->local_args);
BatchConfig const *bc = BatchConfig::from_future(task->futures[0]);
if (bc->num_active_peft_tokens() == 0) {
return;
Expand All @@ -417,10 +420,13 @@ void ParallelIdentity::peft_bwd_task(Task const *task,
m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime);

assert(input_grad.data_type == output_grad.data_type);
peft_bwd_kernel_wrapper(m, bc, input_grad, output_grad);
if (m->inference_debugging) {
std::cout << "BWD " << m->op_name << std::endl;
assert(task->index_point.get_dim() == 1);
int shard_id = task->index_point.point_data[0];
ParallelIdentity::save_inference_tensors_to_file(
m, shard_id, bc, {input_grad}, {}, {output_grad}, false);
}
peft_bwd_kernel_wrapper(m, bc, input_grad, output_grad);
}

bool ParallelIdentity::measure_operator_cost(Simulator *sim,
Expand Down
19 changes: 17 additions & 2 deletions src/runtime/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3540,8 +3540,23 @@ void FFModel::create_operators_from_layers() {
}
} else if (need_to_add_parallel_identity(layer_idx)) {
assert(op->numOutputs == 2);
ParallelIdentity *parallel_identity = new ParallelIdentity(
*this, op->outputs[1], op->outputs[1]->num_dims - 1);
size_t transformer_layer_id = op->layer_guid.transformer_layer_id;
if (transformer_layer_parallel_identity_count.find(
transformer_layer_id) ==
transformer_layer_parallel_identity_count.end()) {
transformer_layer_parallel_identity_count[transformer_layer_id] = 0;
}
std::string parallel_identity_name = std::string(
"layers." + std::to_string(transformer_layer_id) +
".parallel_identity." +
std::to_string(
transformer_layer_parallel_identity_count[transformer_layer_id]));
transformer_layer_parallel_identity_count[transformer_layer_id]++;
ParallelIdentity *parallel_identity =
new ParallelIdentity(*this,
op->outputs[1],
op->outputs[1]->num_dims - 1,
parallel_identity_name.c_str());
operators.push_back(parallel_identity);
assert(op->numOutputs == l->numOutputs);
// output 0 is taken from the residual rms norm
Expand Down

0 comments on commit 5cbe1a4

Please sign in to comment.