Skip to content

Commit

Permalink
cpu_kernel's refactor, generic tensor accessor indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
oOTigger committed Oct 8, 2024
1 parent 366bd94 commit c9c33fd
Show file tree
Hide file tree
Showing 68 changed files with 619 additions and 556 deletions.
1 change: 1 addition & 0 deletions lib/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ target_link_libraries(
cudnn
nccl
utils
pcg
)

define_ff_vars(${project_target})
Expand Down
112 changes: 94 additions & 18 deletions lib/kernels/include/kernels/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "device.h"
#include "kernels/ff_handle.h"
#include "op-attrs/datatype.h"
#include "pcg/device_type.dtg.h"
#include "utils/exception.h"
#include "utils/required.h"
#include "utils/variant.h"
Expand All @@ -29,20 +30,65 @@ class GenericTensorAccessorW {
double *get_double_ptr() const;
half *get_half_ptr() const;

GenericTensorAccessorW(DataType dt,
ArrayShape sh,
req<void *> p,
bool on_dev = true)
: data_type(dt), shape(sh), ptr(p), on_device(on_dev) {}
GenericTensorAccessorW() = delete;

GenericTensorAccessorW(DataType data_type, ArrayShape const &shape, void *ptr, DeviceType device_type);

bool operator==(GenericTensorAccessorW const &) const;
bool operator!=(GenericTensorAccessorW const &) const;

template <DataType DT, typename... Indices>
real_type_t<DT> &at(Indices... indices) {
if (this->device_type != DeviceType::CPU) {
throw mk_runtime_error("Calling at() on non-CPU allocated tensor");
}
if (this->data_type != DT) {
throw mk_runtime_error(
"Invalid access data type ({} != {})", this->data_type, DT);
}

using T = real_type_t<DT>;

T *data_ptr = static_cast<T *>(this->ptr);
size_t offset = calculate_index_offset({static_cast<size_t>(indices)...});

return data_ptr[offset];
}

template <DataType DT, typename... Indices>
real_type_t<DT> const &at(Indices... indices) const {
if (this->device_type != DeviceType::CPU) {
throw mk_runtime_error("Calling at() on non-CPU allocated tensor");
}
if (this->data_type != DT) {
throw mk_runtime_error(
"Invalid access data type ({} != {})", this->data_type, DT);
}

using T = real_type_t<DT>;

T const *data_ptr = static_cast<T const *>(this->ptr);
size_t offset = calculate_index_offset({static_cast<size_t>(indices)...});

return data_ptr[offset];
}

public:
DataType data_type;
ArrayShape shape;
req<void *> ptr;
bool on_device;
void *ptr;
DeviceType device_type;

private:
std::tuple<decltype(data_type) const &,
decltype(shape) const &,
decltype(ptr) const &,
decltype(device_type) const &>
tie() const;

size_t calculate_index_offset(
std::initializer_list<size_t> const &indices) const;
};
FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(
GenericTensorAccessorW, data_type, shape, ptr, on_device);

std::string format_as(GenericTensorAccessorW const &);
std::ostream &operator<<(std::ostream &, GenericTensorAccessorW const &);
Expand All @@ -65,20 +111,50 @@ class GenericTensorAccessorR {
double const *get_double_ptr() const;
half const *get_half_ptr() const;

GenericTensorAccessorR(DataType dt,
ArrayShape sh,
req<void const *> p,
bool on_dev = true)
: data_type(dt), shape(sh), ptr(p), on_device(on_dev) {}
GenericTensorAccessorR() = delete;

GenericTensorAccessorR(DataType data_type,
ArrayShape const &shape,
void const *ptr,
DeviceType device_type);

bool operator==(GenericTensorAccessorR const &) const;
bool operator!=(GenericTensorAccessorR const &) const;

template <DataType DT, typename... Indices>
real_type_t<DT> const &at(Indices... indices) const {
if (this->device_type != DeviceType::CPU) {
throw mk_runtime_error("Calling at() on non-CPU allocated tensor");
}
if (this->data_type != DT) {
throw mk_runtime_error(
"Invalid access data type ({} != {})", this->data_type, DT);
}

using T = real_type_t<DT>;

T const *data_ptr = static_cast<T const *>(this->ptr);
size_t offset = calculate_index_offset({static_cast<size_t>(indices)...});

return data_ptr[offset];
}

public:
DataType data_type;
ArrayShape shape;
req<void const *> ptr;
bool on_device;
void const *ptr;
DeviceType device_type;

private:
std::tuple<decltype(data_type) const &,
decltype(shape) const &,
decltype(ptr) const &,
decltype(device_type) const &>
tie() const;

size_t calculate_index_offset(
std::initializer_list<size_t> const &indices) const;
};
FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(
GenericTensorAccessorR, data_type, shape, ptr, on_device);

std::string format_as(GenericTensorAccessorR const &);
std::ostream &operator<<(std::ostream &, GenericTensorAccessorR const &);
Expand Down
12 changes: 4 additions & 8 deletions lib/kernels/include/kernels/allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,27 @@
#include <cstddef>
#include <memory>

enum class AllocLocation { HOST, DEVICE };

namespace FlexFlow {

struct IAllocator {
virtual void *allocate(size_t) = 0;
virtual void *allocate_and_zero(size_t) = 0;
virtual void deallocate(void *) = 0;

virtual DeviceType get_allocation_device_type() const = 0;

virtual ~IAllocator() = default;
};

struct Allocator {
Allocator() = delete;

GenericTensorAccessorW allocate_tensor(TensorShape const &tensor_shape);
GenericTensorAccessorW
allocate_tensor_and_zero(TensorShape const &tensor_shape);

void *allocate(size_t mem_size);
void *allocate_and_zero(size_t mem_size);
void deallocate(void *ptr);

DeviceType get_allocation_device_type() const;

template <typename T, typename... Args>
static typename std::enable_if<std::is_base_of<IAllocator, T>::value,
Allocator>::type
Expand All @@ -37,8 +35,6 @@ struct Allocator {

Allocator(std::shared_ptr<IAllocator> ptr) : i_allocator(ptr){};

AllocLocation alloc_location;

private:
std::shared_ptr<IAllocator> i_allocator;
};
Expand Down
6 changes: 2 additions & 4 deletions lib/kernels/include/kernels/attention_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ FF_VISITABLE_STRUCT_NO_EQ(MHAPerDeviceState,
std::string format_as(MHAPerDeviceState const &x);
std::ostream &operator<<(std::ostream &s, MHAPerDeviceState const &x);

namespace Kernels {
namespace MultiHeadAttention {
namespace Kernels::MultiHeadAttention {

MHAPerDeviceState init_kernel(PerDeviceFFHandle const &,
Allocator &,
Expand Down Expand Up @@ -105,8 +104,7 @@ void backward_kernel(ffStream_t stream,
void cleanup_kernel(Allocator &allocator,
MHAPerDeviceState const &device_state);

} // namespace MultiHeadAttention
} // namespace Kernels
} // namespace Kernels::MultiHeadAttention
} // namespace FlexFlow

#endif
8 changes: 2 additions & 6 deletions lib/kernels/include/kernels/batch_matmul_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
#include "kernels/allocation.h"
#include "kernels/ff_handle.h"

namespace FlexFlow {
namespace Kernels {
namespace BatchMatmul {
namespace FlexFlow::Kernels::BatchMatmul {

void forward_kernel(ffStream_t stream,
PerDeviceFFHandle const &handle,
Expand Down Expand Up @@ -35,8 +33,6 @@ void backward_kernel(ffStream_t stream,
int k,
int batch);

} // namespace BatchMatmul
} // namespace Kernels
} // namespace FlexFlow
} // namespace FlexFlow::Kernels::BatchMatmul

#endif
6 changes: 2 additions & 4 deletions lib/kernels/include/kernels/batch_norm_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(BatchNormPerDeviceState,
output_w,
relu);

namespace Kernels {
namespace BatchNorm {
namespace Kernels::BatchNorm {

BatchNormPerDeviceState init_kernel(PerDeviceFFHandle handle,
Allocator allocator,
Expand Down Expand Up @@ -81,8 +80,7 @@ void cleanup_kernel(Allocator allocator,
bool relu,
float *runningMean);

} // namespace BatchNorm
} // namespace Kernels
} // namespace Kernels::BatchNorm
} // namespace FlexFlow

#endif
8 changes: 2 additions & 6 deletions lib/kernels/include/kernels/cast_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
#include "device.h"
#include "kernels/accessor.h"

namespace FlexFlow {
namespace Kernels {
namespace Cast {
namespace FlexFlow::Kernels::Cast {

void forward_kernel(ffStream_t stream,
GenericTensorAccessorR const &input,
Expand All @@ -20,8 +18,6 @@ void backward_kernel(ffStream_t stream,
DataType input_type,
DataType output_type);

} // namespace Cast
} // namespace Kernels
} // namespace FlexFlow
} // namespace FlexFlow::Kernels::Cast

#endif
8 changes: 2 additions & 6 deletions lib/kernels/include/kernels/cast_kernels_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
#include "device.h"
#include "kernels/accessor.h"

namespace FlexFlow {
namespace Kernels {
namespace Cast {
namespace FlexFlow::Kernels::Cast {

void cpu_forward_kernel(GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output,
Expand All @@ -18,8 +16,6 @@ void cpu_backward_kernel(GenericTensorAccessorR const &input,
DataType input_type,
DataType output_type);

} // namespace Cast
} // namespace Kernels
} // namespace FlexFlow
} // namespace FlexFlow::Kernels::Cast

#endif
8 changes: 2 additions & 6 deletions lib/kernels/include/kernels/combine_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
#include "device.h"
#include "kernels/accessor.h"

namespace FlexFlow {
namespace Kernels {
namespace Combine {
namespace FlexFlow::Kernels::Combine {

void forward_kernel(ffStream_t stream,
GenericTensorAccessorR const &input,
Expand All @@ -16,8 +14,6 @@ void backward_kernel(ffStream_t stream,
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorW const &input_grad);

} // namespace Combine
} // namespace Kernels
} // namespace FlexFlow
} // namespace FlexFlow::Kernels::Combine

#endif // _FLEXFLOW_OPS_KERNELS_COMBINE_KERNELS_H
8 changes: 2 additions & 6 deletions lib/kernels/include/kernels/combine_kernels_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,14 @@
#include "device.h"
#include "kernels/accessor.h"

namespace FlexFlow {
namespace Kernels {
namespace Combine {
namespace FlexFlow::Kernels::Combine {

void cpu_forward_kernel(GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output);

void cpu_backward_kernel(GenericTensorAccessorR const &output_grad,
GenericTensorAccessorW const &input_grad);

} // namespace Combine
} // namespace Kernels
} // namespace FlexFlow
} // namespace FlexFlow::Kernels::Combine

#endif // _FLEXFLOW_OPS_KERNELS_COMBINE_KERNELS_CPU_H
8 changes: 2 additions & 6 deletions lib/kernels/include/kernels/concat_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
#include "device.h"
#include "kernels/accessor.h"

namespace FlexFlow {
namespace Kernels {
namespace Concat {
namespace FlexFlow::Kernels::Concat {

void forward_kernel(ffStream_t stream,
GenericTensorAccessorW const &output,
Expand All @@ -18,8 +16,6 @@ void backward_kernel(ffStream_t stream,
std::vector<GenericTensorAccessorW> const &input_grads,
ff_dim_t axis);

} // namespace Concat
} // namespace Kernels
} // namespace FlexFlow
} // namespace FlexFlow::Kernels::Concat

#endif
6 changes: 2 additions & 4 deletions lib/kernels/include/kernels/conv_2d_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(Conv2DPerDeviceState,
bwdFilterAlgo,
bwdDataAlgo);

namespace Kernels {
namespace Conv2D {
namespace Kernels::Conv2D {

Conv2DPerDeviceState init_kernel(PerDeviceFFHandle handle,
std::optional<Activation> activation,
Expand Down Expand Up @@ -70,8 +69,7 @@ void backward_kernel(ffStream_t stream,
float *bias_grad_ptr,
std::optional<Activation> activation);

} // namespace Conv2D
} // namespace Kernels
} // namespace Kernels::Conv2D
} // namespace FlexFlow

#endif // _FLEXFLOW_OPS_KERNELS_CONV_2D_KERNELS_H
3 changes: 2 additions & 1 deletion lib/kernels/include/kernels/datatype_dispatch.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#ifndef _FLEXFLOW_KERNELS_DATATYPE_DISPATCH_H
#define _FLEXFLOW_KERNELS_DATATYPE_DISPATCH_H

#include "accessor.h"
#include "op-attrs/datatype.h"
#include "utils/exception.h"

namespace FlexFlow {

Expand Down
Loading

0 comments on commit c9c33fd

Please sign in to comment.