Skip to content

Commit

Permalink
Merge pull request #28 from ROCmSoftwarePlatform/mixed-dim-new
Browse files Browse the repository at this point in the history
Support Mixed-dimension Tables in The Forward Pass
  • Loading branch information
liligwu authored Oct 20, 2022
2 parents ad123f2 + 87dcbf9 commit 9d181e5
Show file tree
Hide file tree
Showing 4 changed files with 361 additions and 205 deletions.
23 changes: 11 additions & 12 deletions fbgemm_gpu/codegen/embedding_forward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -526,19 +526,16 @@ Tensor {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else ""
/*
* current limitations
1. sparse, and bag
2. yet to support mixed embedding dims (loosely guarded below)
3. yet to support non-uniform table locations (all be on devs)
4. yet to support duplicate tables from some cases in exact optim (fbgemm_gpu/split_embedding_configs.py)
2. yet to support non-uniform table locations (all be on devs)
3. yet to support duplicate tables from some cases in exact optim (fbgemm_gpu/split_embedding_configs.py)
*/
{% if not nobag %}
{% if not dense %}

// weight param cnt
int64_t wcnts = dev_weights.numel();
// mixed hypothesis
bool mixed_ls = (total_D != (max_D * T));
// execution guards
bool guard_ex = (wcnts > 0 && !mixed_ls);
bool guard_ex = (wcnts > 0);

// all Ts on device
std::vector<int32_t> wplas(weights_placements.data_ptr<int32_t>(), weights_placements.data_ptr<int32_t>() + weights_placements.numel());
Expand Down Expand Up @@ -568,11 +565,12 @@ Tensor {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else ""
void *emb_table;
const int64_t *indices;
const int64_t *offsets;
const int32_t* D_offsets;
const int64_t* weights_offsets;
int64_t pooling_mode;
{% if weighted %}
float *indice_weights;
{% endif %}
uint32_t emb_dim;
uint32_t batch;
uint32_t num_rows;
uint32_t num_tables;
Expand All @@ -585,11 +583,12 @@ Tensor {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else ""
args.emb_table = dev_weights.packed_accessor64<float, 1, at::RestrictPtrTraits>().data();
args.indices = indices.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>().data();
args.offsets = offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>().data();
args.D_offsets = D_offsets.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>().data();
args.weights_offsets = weights_offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>().data();
args.pooling_mode = pooling_mode;
{% if weighted %}
args.indice_weights = indice_weights.packed_accessor32<float, 1, at::RestrictPtrTraits>().data();
{% endif %}
args.emb_dim = (uint32_t) max_D;
args.batch = (uint32_t) B;
args.num_rows = E;
args.num_tables = (uint32_t) T;
Expand All @@ -601,21 +600,21 @@ Tensor {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else ""
dim3(grids[0], grids[1], grids[2]),
dim3(blocks[0], blocks[1], blocks[2]),
0, 0,
(float *)args.output, (const half *)args.emb_table, args.indices, args.offsets, args.pooling_mode,
(float *)args.output, (const half *)args.emb_table, args.indices, args.offsets, args.D_offsets, args.weights_offsets, args.pooling_mode,
{% if weighted %}
args.indice_weights,
{% endif %}
args.emb_dim, args.batch, args.num_rows, args.num_tables);
args.batch, args.num_rows, args.num_tables);
} else { // only 2 emb_t: fp16, fp32 for now
hipLaunchKernelGGL(split_tbe_fwd_{{ wdesc }}_hip_kernel_fp32_e{{ kDimSize }},
dim3(grids[0], grids[1], grids[2]),
dim3(blocks[0], blocks[1], blocks[2]),
0, 0,
(float *)args.output, (const float *)args.emb_table, args.indices, args.offsets, args.pooling_mode,
(float *)args.output, (const float *)args.emb_table, args.indices, args.offsets, args.D_offsets, args.weights_offsets, args.pooling_mode,
{% if weighted %}
args.indice_weights,
{% endif %}
args.emb_dim, args.batch, args.num_rows, args.num_tables);
args.batch, args.num_rows, args.num_tables);
}
return output;
}
Expand Down
6 changes: 4 additions & 2 deletions fbgemm_gpu/hip_kernel/split_tbe_fwd.hip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@
const emb_type * p_emb_table, \
const int64_t * p_indices, \
const int64_t * p_offsets, \
const int32_t * D_offsets, \
const int64_t * weights_offsets, \
const int64_t pooling_mode, \
uint32_t emb_dim, \
uint32_t batch, \
uint32_t num_rows, \
uint32_t num_tables); \
Expand All @@ -42,9 +43,10 @@
const emb_type * p_emb_table, \
const int64_t * p_indices, \
const int64_t * p_offsets, \
const int32_t * D_offsets, \
const int64_t * weights_offsets, \
const int64_t pooling_mode, \
const float * p_indice_weights,\
uint32_t emb_dim, \
uint32_t batch, \
uint32_t num_rows, \
uint32_t num_tables)
Expand Down
Loading

0 comments on commit 9d181e5

Please sign in to comment.