From f2b6d2a2a53463390794b01050c01f23cc932b6c Mon Sep 17 00:00:00 2001 From: Victor Li Date: Wed, 30 Oct 2024 18:18:06 -0700 Subject: [PATCH] Changed ff_dim to ff_dim_t, added in nonnegative_int type --- lib/kernels/include/kernels/legion_dim.h | 2 +- lib/kernels/src/array_shape.cc | 4 +- lib/kernels/src/cuda/ops/transpose_kernels.cu | 2 +- lib/kernels/src/legion_dim.cc | 4 +- .../local-execution/legion_tensor_shape.h | 6 +- .../src/legion_tensor_shape.cc | 8 +- lib/local-execution/src/ops/gather.cc | 2 +- lib/local-execution/src/ops/layer_norm.cc | 2 +- lib/local-execution/src/ops/linear.cc | 2 +- .../op-attrs/dim_ordered/dim_ordered.h | 2 +- lib/op-attrs/include/op-attrs/ff_dim_t.h | 17 ++ .../include/op-attrs/ff_dim_t.struct.toml | 14 ++ .../op-attrs/ops/combine_attrs.struct.toml | 4 +- .../op-attrs/ops/concat_attrs.struct.toml | 4 +- .../op-attrs/ops/flat_attrs.struct.toml | 4 +- .../op-attrs/ops/gather_attrs.struct.toml | 4 +- .../op-attrs/ops/layer_norm_attrs.struct.toml | 4 +- .../op-attrs/ops/reduce_attrs.struct.toml | 4 +- .../ops/repartition_attrs.struct.toml | 4 +- .../op-attrs/ops/reverse_attrs.struct.toml | 4 +- .../op-attrs/ops/softmax_attrs.struct.toml | 4 +- .../op-attrs/ops/split_attrs.struct.toml | 4 +- .../op-attrs/ops/transpose_attrs.struct.toml | 4 +- .../parallel_tensor_dim_idx_t.variant.toml | 2 +- .../include/op-attrs/relative_ff_dim_t.h | 17 ++ .../op-attrs/relative_ff_dim_t.struct.toml | 14 ++ .../test/src/op-attrs/dim_ordered/zip.cc | 2 +- lib/runtime/src/parallel_op_info.h | 2 +- .../operator_attribute_value.variant.toml | 2 +- .../utils/nonnegative_int/nonnegative_int.h | 108 ++++++++++ .../utils/nonnegative_int/nonnegative_int.cc | 173 ++++++++++++++++ .../utils/nonnegative_int/nonnegative_int.cc | 190 ++++++++++++++++++ 32 files changed, 576 insertions(+), 43 deletions(-) create mode 100644 lib/op-attrs/include/op-attrs/ff_dim_t.h create mode 100644 lib/op-attrs/include/op-attrs/ff_dim_t.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/relative_ff_dim_t.h create mode 100644 lib/op-attrs/include/op-attrs/relative_ff_dim_t.struct.toml create mode 100644 lib/utils/include/utils/nonnegative_int/nonnegative_int.h create mode 100644 lib/utils/src/utils/nonnegative_int/nonnegative_int.cc create mode 100644 lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc diff --git a/lib/kernels/include/kernels/legion_dim.h b/lib/kernels/include/kernels/legion_dim.h index e4dd9723b8..22347c278e 100644 --- a/lib/kernels/include/kernels/legion_dim.h +++ b/lib/kernels/include/kernels/legion_dim.h @@ -8,7 +8,7 @@ namespace FlexFlow { legion_dim_t add_to_legion_dim(legion_dim_t legion_dim, int value); -legion_dim_t legion_dim_from_ff_dim(ff_dim_t, int num_dimensions); +legion_dim_t legion_dim_from_ff_dim_t(ff_dim_t, int num_dimensions); template using LegionOrdered = DimOrdered; diff --git a/lib/kernels/src/array_shape.cc b/lib/kernels/src/array_shape.cc index d5e2f1167d..6488a13ce2 100644 --- a/lib/kernels/src/array_shape.cc +++ b/lib/kernels/src/array_shape.cc @@ -47,7 +47,7 @@ std::size_t ArrayShape::at(legion_dim_t idx) const { } std::size_t ArrayShape::at(ff_dim_t idx) const { - return dims.at(legion_dim_from_ff_dim(idx, this->num_dims())); + return dims.at(legion_dim_from_ff_dim_t(idx, this->num_dims())); } ArrayShape ArrayShape::sub_shape( @@ -65,7 +65,7 @@ std::optional ArrayShape::at_maybe(legion_dim_t index) const { } std::optional ArrayShape::at_maybe(ff_dim_t index) const { - return this->at_maybe(legion_dim_from_ff_dim(index, this->num_dims())); + return this->at_maybe(legion_dim_from_ff_dim_t(index, this->num_dims())); } size_t get_volume(ArrayShape const &shape) { diff --git a/lib/kernels/src/cuda/ops/transpose_kernels.cu b/lib/kernels/src/cuda/ops/transpose_kernels.cu index 3b3f80944d..ce1c6b48e4 100644 --- a/lib/kernels/src/cuda/ops/transpose_kernels.cu +++ b/lib/kernels/src/cuda/ops/transpose_kernels.cu @@ -36,7 +36,7 @@ TransposePerDeviceState init_kernel(int num_dim, std::vector perm_vector; assert(length <= MAX_TENSOR_DIM); for (int i = 0; i < length; ++i) { - perm_vector.push_back(legion_dim_from_ff_dim(perm[i], num_dim)); + perm_vector.push_back(legion_dim_from_ff_dim_t(perm[i], num_dim)); } return {num_dim, perm_vector}; diff --git a/lib/kernels/src/legion_dim.cc b/lib/kernels/src/legion_dim.cc index 9ef47d40ae..b615b39647 100644 --- a/lib/kernels/src/legion_dim.cc +++ b/lib/kernels/src/legion_dim.cc @@ -6,8 +6,8 @@ legion_dim_t add_to_legion_dim(legion_dim_t legion_dim, int value) { return legion_dim_t(legion_dim.value + value); } -legion_dim_t legion_dim_from_ff_dim(ff_dim_t ff_dim, int num_dimensions) { - return legion_dim_t(num_dimensions - ff_dim.value - 1); +legion_dim_t legion_dim_from_ff_dim_t(ff_dim_t ff_dim_t, int num_dimensions) { + return legion_dim_t(num_dimensions - ff_dim_t.value - 1); } } // namespace FlexFlow diff --git a/lib/local-execution/include/local-execution/legion_tensor_shape.h b/lib/local-execution/include/local-execution/legion_tensor_shape.h index 2f2ed50d41..5bb5885844 100644 --- a/lib/local-execution/include/local-execution/legion_tensor_shape.h +++ b/lib/local-execution/include/local-execution/legion_tensor_shape.h @@ -3,7 +3,7 @@ #include "kernels/legion_dim.h" #include "op-attrs/datatype.h" -#include "op-attrs/ff_dim.h" +#include "op-attrs/ff_dim_t.h" #include "op-attrs/tensor_shape.dtg.h" #include "utils/stack_vector.h" #include "utils/visitable.h" @@ -30,10 +30,10 @@ struct LegionTensorShape : public use_visitable_cmp, }; ff_dim_t to_ff(legion_dim_t, size_t num_dims); -legion_dim_t legion_dim_from_ff_dim(ff_dim_t, size_t num_dims); +legion_dim_t legion_dim_from_ff_dim_t(ff_dim_t, size_t num_dims); ff_dim_t to_ff(legion_dim_t, TensorShape const &); -legion_dim_t legion_dim_from_ff_dim(ff_dim_t, TensorShape const &); +legion_dim_t legion_dim_from_ff_dim_t(ff_dim_t, TensorShape const &); } // namespace FlexFlow diff --git a/lib/local-execution/src/legion_tensor_shape.cc b/lib/local-execution/src/legion_tensor_shape.cc index bce29fafeb..2a660860c5 100644 --- a/lib/local-execution/src/legion_tensor_shape.cc +++ b/lib/local-execution/src/legion_tensor_shape.cc @@ -3,12 +3,12 @@ namespace FlexFlow { -legion_dim_t legion_dim_from_ff_dim(ff_dim_t ff_dim, size_t num_dims) { - return legion_dim_t(num_dims - ff_dim.value - 1); +legion_dim_t legion_dim_from_ff_dim_t(ff_dim_t ff_dim_t, size_t num_dims) { + return legion_dim_t(num_dims - ff_dim_t.value - 1); } -legion_dim_t legion_dim_from_ff_dim(ff_dim_t ff_dim, TensorShape const &shape) { - return legion_dim_t(num_dims(shape) - ff_dim.value - 1); +legion_dim_t legion_dim_from_ff_dim_t(ff_dim_t ff_dim_t, TensorShape const &shape) { + return legion_dim_t(num_dims(shape) - ff_dim_t.value - 1); } } // namespace FlexFlow diff --git a/lib/local-execution/src/ops/gather.cc b/lib/local-execution/src/ops/gather.cc index a015c64f4d..adcbee7d3a 100644 --- a/lib/local-execution/src/ops/gather.cc +++ b/lib/local-execution/src/ops/gather.cc @@ -67,7 +67,7 @@ static DeviceSpecificDeviceStates PerDeviceFFHandle handle = acc.get_argument(HANDLE); auto const &attrs = acc.get_argument(ATTRS); legion_dim_t legion_dim = - legion_dim_from_ff_dim(attrs.dim, input.shape.num_dims()); + legion_dim_from_ff_dim_t(attrs.dim, input.shape.num_dims()); assert(input.shape.get_dim() == index.shape.get_dim()); assert(output.shape.get_dim() == index.shape.get_dim()); diff --git a/lib/local-execution/src/ops/layer_norm.cc b/lib/local-execution/src/ops/layer_norm.cc index e99d27319c..1203286220 100644 --- a/lib/local-execution/src/ops/layer_norm.cc +++ b/lib/local-execution/src/ops/layer_norm.cc @@ -123,7 +123,7 @@ static DeviceSpecificDeviceStates int64_t effective_batch_size, effective_num_elements; int M = 1; for (int i = 0; i < attrs.axes.size(); i++) { - legion_dim_t legion_dim = legion_dim_from_ff_dim( + legion_dim_t legion_dim = legion_dim_from_ff_dim_t( attrs.axes[i], get_tensor_shape(input.shape, input.data_type)); M *= input.shape.at(legion_dim); } diff --git a/lib/local-execution/src/ops/linear.cc b/lib/local-execution/src/ops/linear.cc index 9934e2a45c..f3ef022175 100644 --- a/lib/local-execution/src/ops/linear.cc +++ b/lib/local-execution/src/ops/linear.cc @@ -1,7 +1,7 @@ #include "linear.h" #include "kernels/linear_kernels.h" #include "local-execution/task_argument_accessor.h" -#include "op-attrs/ff_dim.h" +#include "op-attrs/ff_dim_t.h" #include "op-attrs/get_output_shapes.h" #include "utils/exception.h" #include "utils/hash-utils.h" diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h index 6aa23d40fc..9a8d8adc0b 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_FF_STACK_VECTOR_H #define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_FF_STACK_VECTOR_H -#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim_t.dtg.h" #include "utils/fmt/vector.h" #include "utils/stack_vector.h" #include diff --git a/lib/op-attrs/include/op-attrs/ff_dim_t.h b/lib/op-attrs/include/op-attrs/ff_dim_t.h new file mode 100644 index 0000000000..8e3dca3bc7 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ff_dim_t.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_T_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_T_H + +#include "op-attrs/ff_dim_t.dtg.h" +#include "rapidcheck.h" + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::construct( + gen::inRange(0, MAX_TENSOR_DIM)); + } +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_T_H \ No newline at end of file diff --git a/lib/op-attrs/include/op-attrs/ff_dim_t.struct.toml b/lib/op-attrs/include/op-attrs/ff_dim_t.struct.toml new file mode 100644 index 0000000000..441f9826ca --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ff_dim_t.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "ff_dim_t" + +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +[[fields]] +name = "value" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml index 585295fe1c..e7eeedec06 100644 --- a/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml @@ -10,8 +10,8 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml index fab8132993..f3c66d0416 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml @@ -10,8 +10,8 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h" + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h" ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml index 7349e2a8c4..301df8bca4 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml @@ -11,14 +11,14 @@ features = [ includes = [ "", - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.dtg.h", ] src_includes = [ "utils/fmt/optional.h", "utils/json/optional.h", "utils/rapidcheck/optional.h", - "op-attrs/ff_dim.h", + "op-attrs/ff_dim_t.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml index c8bb88dcc7..66d475aa46 100644 --- a/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml @@ -10,8 +10,8 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h" + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h" ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml index ec60d39f7f..401eaeeec4 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml @@ -10,8 +10,8 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", "utils/stack_vector.h", ] diff --git a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml index 717e7954e8..88e57ef7c4 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml @@ -11,8 +11,8 @@ features = [ includes = [ "op-attrs/operator_type.dtg.h", - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", "utils/stack_vector.h", ] diff --git a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml index 25a33c0c15..69c4b7580f 100644 --- a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml @@ -10,8 +10,8 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml index 198346e5dd..2577ac1398 100644 --- a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml @@ -10,8 +10,8 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml index 8b839c122a..49172f44b0 100644 --- a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml @@ -10,8 +10,8 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml index 8cdf7728af..f10aa7c3fd 100644 --- a/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml @@ -11,8 +11,8 @@ features = [ includes = [ "utils/stack_vector.h", - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml index 0dc30d9a79..b1c5f60382 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml @@ -10,8 +10,8 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", "op-attrs/dim_ordered/dim_ordered.h", ] diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml index 9396cbcbe8..7e7356a5e7 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml @@ -9,7 +9,7 @@ features = [ ] includes = [ - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.dtg.h", "op-attrs/replica_type.dtg.h", ] diff --git a/lib/op-attrs/include/op-attrs/relative_ff_dim_t.h b/lib/op-attrs/include/op-attrs/relative_ff_dim_t.h new file mode 100644 index 0000000000..a208e42664 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/relative_ff_dim_t.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_RELATIVE_FF_DIM_T_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_RELATIVE_FF_DIM_T_H + +#include "op-attrs/relative_ff_dim_t.dtg.h" +#include "rapidcheck.h" + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::construct( + gen::inRange(0, MAX_TENSOR_DIM)); + } +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_RELATIVE_FF_DIM_T_H \ No newline at end of file diff --git a/lib/op-attrs/include/op-attrs/relative_ff_dim_t.struct.toml b/lib/op-attrs/include/op-attrs/relative_ff_dim_t.struct.toml new file mode 100644 index 0000000000..a93b649052 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/relative_ff_dim_t.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "relative_ff_dim_t" + +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +[[fields]] +name = "value" +type = "int" diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc index 8e3d0f1b80..b77bb8f71e 100644 --- a/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc @@ -1,5 +1,5 @@ #include "op-attrs/dim_ordered/zip.h" -#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim_t.dtg.h" #include "test/utils/doctest/fmt/pair.h" #include diff --git a/lib/runtime/src/parallel_op_info.h b/lib/runtime/src/parallel_op_info.h index ebd44f012b..49ad22be74 100644 --- a/lib/runtime/src/parallel_op_info.h +++ b/lib/runtime/src/parallel_op_info.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_PARALLEL_OPS_PARALLEL_OP_INFO_H #define _FLEXFLOW_PARALLEL_OPS_PARALLEL_OP_INFO_H -#include "op-attrs/ff_dim.h" +#include "op-attrs/ff_dim_t.h" #include "op-attrs/operator_type.h" #include "utils/visitable.h" #include diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml index 02a856f59a..8fe4a9494d 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml @@ -13,7 +13,7 @@ includes = [ "", "", "op-attrs/operator_type.dtg.h", - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.dtg.h", "op-attrs/activation.dtg.h", "op-attrs/aggregate_op.dtg.h", "op-attrs/regularizer_attrs.dtg.h", diff --git a/lib/utils/include/utils/nonnegative_int/nonnegative_int.h b/lib/utils/include/utils/nonnegative_int/nonnegative_int.h new file mode 100644 index 0000000000..b5d7551a8d --- /dev/null +++ b/lib/utils/include/utils/nonnegative_int/nonnegative_int.h @@ -0,0 +1,108 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONNEGATIVE_INT_NONNEGATIVE_INT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONNEGATIVE_INT_NONNEGATIVE_INT_H + +#include "rapidcheck.h" + +#include +#include +#include +#include +#include + +namespace FlexFlow { +class nonnegative_int { + public: + nonnegative_int() = delete; + explicit nonnegative_int(int const &value); + explicit nonnegative_int(int &&value); + + explicit operator int() const noexcept; + operator int&() noexcept; + operator int const&() const noexcept; + + template ::value && + !std::is_same::value), + bool>::type = true> + operator T() const { + return (this->value_); + } + + friend void swap(nonnegative_int &a, nonnegative_int &b) noexcept; + + bool operator<(nonnegative_int const &other); + bool operator==(nonnegative_int const &other); + bool operator>(nonnegative_int const &other); + bool operator<=(nonnegative_int const &other); + bool operator!=(nonnegative_int const &other); + bool operator>=(nonnegative_int const &other); + + nonnegative_int &operator+=(nonnegative_int const &other); + nonnegative_int &operator+=(int const &other); + nonnegative_int &operator++(); + nonnegative_int operator++(int); + nonnegative_int operator+(nonnegative_int const &other); + nonnegative_int operator+(int const &other); + friend nonnegative_int operator+(int const &lhs, nonnegative_int const &rhs); + + nonnegative_int &operator-=(nonnegative_int const &other); + nonnegative_int &operator-=(int const &other); + nonnegative_int &operator--(); + nonnegative_int operator--(int); + nonnegative_int operator-(nonnegative_int const &other); + nonnegative_int operator-(int const &other); + friend nonnegative_int operator-(int const &lhs, nonnegative_int const &rhs); + + nonnegative_int &operator*=(nonnegative_int const &other); + nonnegative_int &operator*=(int const &other); + nonnegative_int operator*(nonnegative_int const &other); + nonnegative_int operator*(int const &other); + friend nonnegative_int operator*(int const &lhs, nonnegative_int const &rhs); + + nonnegative_int &operator/=(nonnegative_int const &other); + nonnegative_int &operator/=(int const &other); + nonnegative_int operator/(nonnegative_int const &other); + nonnegative_int operator/(int const &other); + friend nonnegative_int operator/(int const &lhs, nonnegative_int const &rhs); + + template + nonnegative_int fmap(F const &f) { + static_assert( + std::is_same()(std::declval())), + int>::value, + "Function must return an value of the underlying type"); + + return nonnegative_int(f(this->value_)); + } + + friend std::ostream& operator<<(std::ostream& os, const nonnegative_int& n); + + int get_value() const; + + private: + int value_; +}; +}// namespace FlexFlow + +namespace nlohmann { + template <> + struct adl_serializer<::FlexFlow::nonnegative_int> { + static ::FlexFlow::nonnegative_int from_json(const json& j) { + return ::FlexFlow::nonnegative_int{j.template get()}; + } + static void to_json(json& j, ::FlexFlow::nonnegative_int t) { + j = t.get_value(); + } + }; +} //namepsace nlohmann + +namespace std { + template <> + struct hash { + std::size_t operator()(const FlexFlow::nonnegative_int& n) const noexcept { + return std::hash{}(n.get_value()); + } + }; +} // namespace std + +#endif diff --git a/lib/utils/src/utils/nonnegative_int/nonnegative_int.cc b/lib/utils/src/utils/nonnegative_int/nonnegative_int.cc new file mode 100644 index 0000000000..9d9b10ad10 --- /dev/null +++ b/lib/utils/src/utils/nonnegative_int/nonnegative_int.cc @@ -0,0 +1,173 @@ +#include "utils/nonnegative_int/nonnegative_int.h" + +namespace FlexFlow { + +nonnegative_int::nonnegative_int(int const &value) { + if (value < 0) { + throw std::invalid_argument("Value of nonnegative_int type must be nonnegative."); + } + this->value_ = value; +} +nonnegative_int::nonnegative_int(int &&value) { + if (std::move(value) < 0) { + throw std::invalid_argument("Value of nonnegative_int type must be nonnegative."); + } + this->value_ = std::move(value); +} + +void swap(nonnegative_int &a, nonnegative_int &b) noexcept { + using std::swap; + swap(static_cast(a), static_cast(b)); +} + +nonnegative_int::operator int() const noexcept { + return this->value_; +} +nonnegative_int::operator int&() noexcept { + return this->value_; +} +nonnegative_int::operator int const&() const noexcept { + return this->value_; +} + +bool nonnegative_int::operator<(nonnegative_int const &other) { + return this->value_ < other.value_; +} +bool nonnegative_int::operator==(nonnegative_int const &other) { + return this->value_ == other.value_; +} +bool nonnegative_int::operator>(nonnegative_int const &other) { + return this->value_ > other.value_; +} +bool nonnegative_int::operator<=(nonnegative_int const &other) { + return this->value_ <= other.value_; +} +bool nonnegative_int::operator!=(nonnegative_int const &other) { + return this->value_ != other.value_; +} +bool nonnegative_int::operator>=(nonnegative_int const &other) { + return this->value_ >= other.value_; +} + +nonnegative_int &nonnegative_int::operator+=(nonnegative_int const &other) { + static_cast(*this) += static_cast(other); + return *this; +} +nonnegative_int &nonnegative_int::operator+=(int const &other) { + static_cast(*this) += other; + return *this; +} +nonnegative_int &nonnegative_int::operator++() { + return *this += 1; +} +nonnegative_int nonnegative_int::operator++(int) { + nonnegative_int tmp = *this; + ++*this; + return tmp; +} +nonnegative_int nonnegative_int::operator+(nonnegative_int const &other) { + return nonnegative_int(this->value_ + other.value_); +} +nonnegative_int nonnegative_int::operator+(int const &other) { + return nonnegative_int(this->value_ + other); +} +nonnegative_int operator+(int const &lhs, nonnegative_int const &rhs) { + return nonnegative_int(lhs + rhs.value_); +} + +nonnegative_int &nonnegative_int::operator-=(nonnegative_int const &other) { + if (*this < other) { + throw std::out_of_range("Invalid subtraction"); + } + static_cast(*this) -= static_cast(other); + return *this; +} +nonnegative_int &nonnegative_int::operator-=(int const &other) { + if (*this < other) { + throw std::out_of_range("Invalid subtraction"); + } + static_cast(*this) -= other; + return *this; +} +nonnegative_int &nonnegative_int::operator--() { + return *this -= 1; +} +nonnegative_int nonnegative_int::operator--(int) { + nonnegative_int tmp = *this; + --*this; + return tmp; +} +nonnegative_int nonnegative_int::operator-(nonnegative_int const &other) { + if (*this < other) { + throw std::out_of_range("Invalid subtraction"); + } + return nonnegative_int(this->value_ - other.value_); +} +nonnegative_int nonnegative_int::operator-(int const &other) { + if (*this < other) { + throw std::out_of_range("Invalid subtraction"); + } + return nonnegative_int(this->value_ - other); +} +nonnegative_int operator-(int const &lhs, nonnegative_int const &rhs) { + if (lhs < rhs) { + throw std::out_of_range("Invalid subtraction"); + } + return nonnegative_int(lhs - rhs.value_); +} + +nonnegative_int &nonnegative_int::operator*=(nonnegative_int const &other) { + static_cast(*this) *= static_cast(other); + return *this; +} +nonnegative_int &nonnegative_int::operator*=(int const &other) { + static_cast(*this) *= other; + return *this; +} +nonnegative_int nonnegative_int::operator*(nonnegative_int const &other) { + return nonnegative_int(this->value_ * other.value_); +} +nonnegative_int nonnegative_int::operator*(int const &other) { + return nonnegative_int(this->value_ * other); +} +nonnegative_int operator*(int const &lhs, nonnegative_int const &rhs) { + return nonnegative_int(lhs * rhs.value_); +} + +nonnegative_int &nonnegative_int::operator/=(nonnegative_int const &other) { + if (other == 0) { + throw std::invalid_argument("Cannot divide by zero"); + } + static_cast(*this) /= static_cast(other); + return *this; +} +nonnegative_int &nonnegative_int::operator/=(int const &other) { + if (other == 0) { + throw std::invalid_argument("Cannot divide by zero"); + } + static_cast(*this) /= other; + return *this; +} +nonnegative_int nonnegative_int::operator/(nonnegative_int const &other) { + return (other != 0 ? nonnegative_int(this->value_ / other.value_) : throw std::invalid_argument("Cannot divide by zero")); +} +nonnegative_int nonnegative_int::operator/(int const &other) { + return (other != 0 ? nonnegative_int(this->value_ / other) : throw std::invalid_argument("Cannot divide by zero")); +} +nonnegative_int operator/(int const &lhs, nonnegative_int const &rhs) { + return (rhs != 0 ? nonnegative_int(lhs / rhs.value_) : throw std::invalid_argument("Cannot divide by zero")); +} + +std::ostream& operator<<(std::ostream& os, const nonnegative_int& n) { + os << n.value_; + return os; + } + +int nonnegative_int::get_value() const { + return this->value_; +} +} // namespace FlexFlow + +namespace std { + +} // namespace std diff --git a/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc b/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc new file mode 100644 index 0000000000..89e19f9c9b --- /dev/null +++ b/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc @@ -0,0 +1,190 @@ +#include "utils/nonnegative_int/nonnegative_int.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("nonnegative_int initialization") { + SUBCASE("positive int initialization") { + CHECK_NOTHROW(nonnegative_int(1)); + } + + SUBCASE("zero initialization") { + CHECK_NOTHROW(nonnegative_int(0)); + } + + SUBCASE("negative int initialization") { + CHECK_THROWS(nonnegative_int(-1)); + } + } + + TEST_CASE("nonnegative_int comparisons") { + SUBCASE("equality") { + nonnegative_int nn_int_1 = nonnegative_int(1); + nonnegative_int nn_int_2 = nonnegative_int(2); + CHECK((nn_int_1 == nn_int_2) == false); + CHECK((nn_int_1 == nn_int_1) == true); + CHECK((nn_int_1 == 2) == false); + CHECK((nn_int_1 == 1) == true); + CHECK((1 == nn_int_2) == false); + CHECK((1 == nn_int_1) == true); + } + + SUBCASE("not equals") { + nonnegative_int nn_int_1 = nonnegative_int(1); + nonnegative_int nn_int_2 = nonnegative_int(2); + CHECK((nn_int_1 != nn_int_2) == true); + CHECK((nn_int_1 != nn_int_1) == false); + CHECK((nn_int_1 != 2) == true); + CHECK((nn_int_1 != 1) == false); + CHECK((1 != nn_int_2) == true); + CHECK((1 != nn_int_1) == false); + } + + SUBCASE("less than") { + nonnegative_int nn_int_1 = nonnegative_int(1); + nonnegative_int nn_int_2 = nonnegative_int(2); + CHECK((nn_int_1 < nn_int_2) == true); + CHECK((nn_int_2 < nn_int_1) == false); + CHECK((nn_int_1 < 2) == true); + CHECK((nn_int_2 < 1) == false); + CHECK((1 < nn_int_2) == true); + CHECK((2 < nn_int_1) == false); + } + + SUBCASE("less than or equal to") { + nonnegative_int nn_int_1 = nonnegative_int(1); + nonnegative_int nn_int_2 = nonnegative_int(2); + CHECK((nn_int_1 <= nn_int_2) == true); + CHECK((nn_int_2 <= nn_int_1) == false); + CHECK((nn_int_1 <= 2) == true); + CHECK((nn_int_2 <= 1) == false); + CHECK((1 <= nn_int_2) == true); + CHECK((2 <= nn_int_1) == false); + } + + SUBCASE("greater than") { + nonnegative_int nn_int_1 = nonnegative_int(1); + nonnegative_int nn_int_2 = nonnegative_int(2); + CHECK((nn_int_1 > nn_int_2) == false); + CHECK((nn_int_2 > nn_int_1) == true); + CHECK((nn_int_1 > 2) == false); + CHECK((nn_int_2 > 1) == true); + CHECK((1 > nn_int_2) == false); + CHECK((2 > nn_int_1) == true); + } + + SUBCASE("greater than or equal to") { + nonnegative_int nn_int_1 = nonnegative_int(1); + nonnegative_int nn_int_2 = nonnegative_int(2); + CHECK((nn_int_1 >= nn_int_2) == false); + CHECK((nn_int_2 >= nn_int_1) == true); + CHECK((nn_int_1 >= 2) == false); + CHECK((nn_int_2 >= 1) == true); + CHECK((1 >= nn_int_2) == false); + CHECK((2 >= nn_int_1) == true); + } + } + + TEST_CASE("nonnegative_int arithmetic operations") { + SUBCASE("addition") { + nonnegative_int nn_int_1 = nonnegative_int(1); + nonnegative_int nn_int_2 = nonnegative_int(2); + CHECK((nn_int_2 == nn_int_1 + 1) == true); + CHECK((nn_int_2 == 1 + nn_int_1) == true); + CHECK((nn_int_2 == nn_int_1 + nn_int_1) == true); + nn_int_1 += nn_int_1; + CHECK((2 == nn_int_1) == true); + nn_int_1 += 1; + CHECK((3 == nn_int_1) == true); + CHECK((++nn_int_1) == 4); + nn_int_1++; + CHECK((5 == nn_int_1) == true); + } + + SUBCASE("subtraction") { + nonnegative_int nn_int_1 = nonnegative_int(1); + nonnegative_int nn_int_2 = nonnegative_int(2); + CHECK((nn_int_1 == nn_int_2 - 1) == true); + CHECK((nn_int_1 == 2 - nn_int_1) == true); + CHECK((nn_int_1 == nn_int_2 - nn_int_1) == true); + nonnegative_int nn_int_5 = nonnegative_int(5); + nn_int_5 -= nn_int_1; + CHECK((4 == nn_int_5) == true); + nn_int_5 -= 1; + CHECK((3 == nn_int_5) == true); + CHECK((--nn_int_5) == 2); + nn_int_5--; + CHECK((1 == nn_int_5) == true); + } + + SUBCASE("subtraction negative result") { + nonnegative_int nn_int_1 = nonnegative_int(1); + nonnegative_int nn_int_2 = nonnegative_int(2); + CHECK_THROWS(nn_int_1 - 2); + CHECK_THROWS(1 - nn_int_2); + CHECK_THROWS(nn_int_1 - nn_int_2); + CHECK_THROWS(nn_int_1 -= 2); + CHECK_THROWS(nn_int_1 -= nn_int_2); + nonnegative_int nn_int_0 = nonnegative_int(0); + CHECK_THROWS(nn_int_0--); + CHECK_THROWS(--nn_int_0); + } + + SUBCASE("multiplication") { + nonnegative_int nn_int_2 = nonnegative_int(2); + nonnegative_int nn_int_3 = nonnegative_int(3); + nonnegative_int nn_int_6 = nonnegative_int(6); + CHECK((nn_int_6 == (nn_int_2 * 3)) == true); + CHECK((nn_int_6 == (2 * nn_int_3)) == true); + CHECK((nn_int_6 == (nn_int_2 * nn_int_3)) == true); + nn_int_2 *= 3; + CHECK((nn_int_6 == nn_int_2) == true); + nn_int_2 *= nn_int_3; + CHECK((18 == nn_int_2) == true); + } + + SUBCASE("division") { + nonnegative_int nn_int_2 = nonnegative_int(2); + nonnegative_int nn_int_3 = nonnegative_int(3); + nonnegative_int nn_int_6 = nonnegative_int(6); + CHECK((nn_int_3 == (nn_int_6 / 2)) == true); + CHECK((nn_int_3 == (6 / nn_int_2)) == true); + CHECK((nn_int_3 == (nn_int_6 / nn_int_2)) == true); + nn_int_6 /= 3; + CHECK((nn_int_6 == nn_int_2) == true); + nn_int_6 /= nn_int_2; + CHECK((1 == nn_int_6) == true); + } + + SUBCASE("divide by 0") { + nonnegative_int nn_int_0 = nonnegative_int(0); + nonnegative_int nn_int_1 = nonnegative_int(1); + CHECK_THROWS(nn_int_1 / 0); + CHECK_THROWS(1 / nn_int_0); + CHECK_THROWS(nn_int_1 / nn_int_0); + CHECK_THROWS(nn_int_1 /= 0); + CHECK_THROWS(nn_int_1 /= nn_int_0); + } + } + + TEST_CASE("nonnegative_int adl_serializer") { + SUBCASE("to_json") { + nonnegative_int input = nonnegative_int{5}; + + nlohmann::json result = input; + nlohmann::json correct = 5; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + nlohmann::json input = 5; + + nonnegative_int result = input.template get(); + nonnegative_int correct = nonnegative_int{5}; + + CHECK(result == correct); + } + } +} \ No newline at end of file