Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FusionExecutorCache determines index type by taking all tensors into …
…consideration (#108) Fixes #79 Stacked on top of #159 Currently we determine the index type by only looking at the sizes and strides of fusion inputs. That's not correct, for example, when a fusion has an output that's larger than inputs, for example, matmul of `{M, K } * {K, N}` where both inputs are small enough to use 32-bit indexing, but its output, `{M, N}` is not. This turned out to require a lot of changes in the code around fusion executor and kernel cache as we're assuming `KernelArgumentHolder` with only fusion inputs is enough to determine the index type. This isn't a valid assumption, so one of the major changes in this PR is that the index type is no longer dictated by `KernelArgumentHolder`. Another major change is that we can no longer create valid `TensorArgCodegen` for `KernelArgumentHolder` unless the fusion index type is determined since `TensorArgCodegen` is templated with the index type. However, `KernelArgumentHolder` is also used before index type is determined, so we need to have some form of placeholders for tensor arguments until the index type is determined. This placeholder object is replaced with a TensorArg with the resolved index type. Getting the placeholder pointer, i.e., `KernelArgumentHolder::at`, is invalid as it is not a valid `TensorArg`. Resolution of the index type is currently based on the complete, unsegmented original fusion. Fusion inputs and outputs may be non-contiguous, so if we have strides for them, they are also considered. If not, conservatively falls back to 64-bit indexing. For non-input, non-output intermediate tensors, rfactor root domains are used to find the required smallest index type. Note that contiguous tensors are assumed as they are intermediates. This analysis is not ideal at all, though. We should only need to check global memory tensors as shared and local tensors should be smaller than what 32-bit indexing can represent. However, when determining the index type, the fusion is not yet scheduled nor segmented, so we don't know exactly which tensors are global memory tensors. Furthermore, none of the tensors are inlined, so we need to look at rfactor root domains, which is likely an overestimate for many of shared and local tensors. For example, in a fusion with a matmul-like pattern, where two 2D tensors are broadcast and fed into a binary expression, producing a 3D tensor, which is then reduced to a 2D tensor and is returned as a fusion output. As long as the 3D intermediate tensor is inlined into the 2D output, we do not instantiate the whole 3D tensor, so that should not affect the fusion index type, but the current analysis makes a conservative assumption that it may be indeed instantiated and decides to use 64-bit indexing. As a temporary workaround, `FusionExecutorCache::runFusionWithInputs` has an optional parameter of `forced_index_type`. When given, no analysis of actual sizes is done, and the given type is just used with no validation. No guarantee of correctness is made when this option is used. To completely fix the conservative analysis, we would need to do this index-type analysis after segmentation and scheduling, but a problem is there's a circular dependency of index type and scheduling since the index type is used by the schedulers as well. I haven't looked into if the uses do require the actual final type or not, but have no idea how easy this circular dependency could be resolved. Maybe we could do something like, segment and schedule the original fusion optimistically with 32-bit indexing without considering intermediates, and then after scheduling is done, check the intermediates and if there's any scheduled tensor that actually requires 64-bit indexing, scrap the scheduled fusion and restart with 64-bit indexing. For now, I suppose the above workaround is sufficient. I'm still doing some final tests, but nothing seems to be failing. The change of the logic doesn't seem to affect any of the existing benchmarks, at least on A100-80G. Note that this index type analysis is only done when a fusion is executed through `FusionExecutorCache`. When a fusion is manually scheduled and executed with `FusionExecutor::compileFusion` and `FusionExecutor::runFusion`, the default type is 64-bit but can be overridden by `CompileParams` --------- Co-authored-by: Ryan Spring <rdspring1@gmail.com> Co-authored-by: Naoya Maruyama <nmaruyama@nvidia.com>
- Loading branch information