Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stochastic Rounding Optimizers #17

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions aten/src/ATen/native/cuda/StochasticRounding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ __global__ void stochastic_rounding_kernel(
curandStatePhilox4_32_10_t state;
curand_init(seed_and_offset.first, tid, seed_and_offset.second, &state);

round_stochastically<output_t, input_t, at::Half> rounder;

for (int64_t i = tid; i < numel; i += blockDim.x * gridDim.x) {
float inp = static_cast<float>(input[i]);
output[i] = round_stochastically<output_t>(inp, curand_uniform(&state));
output[i] = rounder(input[i], curand_uniform(&state));
}
}

Expand Down
10 changes: 6 additions & 4 deletions aten/src/ATen/native/cuda/StochasticRoundingAdam.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ __global__ void stochastic_rounding_adam_step_kernel(
curandStatePhilox4_32_10_t state;
curand_init(seeds.first, tid, seeds.second, &state);

round_stochastically<scalar_t, float, at::Half> rounder;

float m_correction = 1.0 - powf(beta1, step);
float v_correction = 1.0 - powf(beta2, step);

Expand Down Expand Up @@ -54,11 +56,11 @@ __global__ void stochastic_rounding_adam_step_kernel(

weight -= (lr / m_correction) * m / (sqrtf(max_v / v_correction) + eps);

weights[i] = round_stochastically<scalar_t>(weight, random_values.x);
exp_avg[i] = round_stochastically<scalar_t>(m, random_values.y);
exp_avg_sq[i] = round_stochastically<scalar_t>(sqrtf(v), random_values.z);
weights[i] = rounder(weight, random_values.x);
exp_avg[i] = rounder(m, random_values.y);
exp_avg_sq[i] = rounder(sqrtf(v), random_values.z);
if (is_amsgrad) {
max_exp_avg_sq[i] = round_stochastically<scalar_t>(sqrtf(max_v), random_values.w);
max_exp_avg_sq[i] = rounder(sqrtf(max_v), random_values.w);
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions aten/src/ATen/native/cuda/StochasticRoundingSGD.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ __global__ void stochastic_rounding_sgd_step_kernel(
curandStatePhilox4_32_10_t state;
curand_init(seeds.first, tid, seeds.second, &state);

round_stochastically<scalar_t, float, at::Half> rounder;

for (int i = tid; i < numel; i += blockDim.x * gridDim.x) {
float weight = static_cast<float>(weights[i]);
float gradient = static_cast<float>(gradients[i]) * (*inv_scale);
Expand All @@ -42,9 +44,9 @@ __global__ void stochastic_rounding_sgd_step_kernel(

weight -= lr * gradient;

weights[i] = round_stochastically<scalar_t>(weight, random_values.x);
weights[i] = rounder(weight, random_values.x);
if (momentum != 0.0f)
momentum_buffer[i] = round_stochastically<scalar_t>(velocity, random_values.y);
momentum_buffer[i] = rounder(velocity, random_values.y);
}
}

Expand Down
33 changes: 20 additions & 13 deletions aten/src/ATen/native/cuda/stochastic_rounding.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,24 @@ __device__ __forceinline__ float get_delta_fp16(float x) {
}

// Natalia magic

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep this comment.

template <typename scalar_t>
__device__ __forceinline__ scalar_t round_stochastically(float x, float random_value) {
if (x == 0.0) {
return scalar_t(0.0);
}
float delta = get_delta_fp16(x);
float val;
if (x < 0.0) {
val = x - random_value * delta;
} else {
val = x + random_value * delta;
template <typename out_type, typename in_type, typename round_to_prec=at::Half>
struct round_stochastically {
static_assert(std::is_same<round_to_prec, at::Half>::value, "round_stochastically only supports round_to_prec=at::Half");
};

template <typename out_type, typename in_type>
struct round_stochastically<out_type, in_type, at::Half> {
__device__ __forceinline__ out_type operator()(in_type x, float random_value) {
if (x == 0.0) {
return out_type(0.0);
}
float delta = get_delta_fp16(static_cast<float>(x));
Copy link

@mcarilli mcarilli May 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoding float here is probably fine IMO, but natalia may ask you to change this to in_type (which might require making get_delta_fp16 a template), and replace the __float2half_rz call with a wrapper function that has several overloads and the float overload calls __float2half_rz.

float val;
if (x < 0.0) {
val = x - random_value * delta;
} else {
val = x + random_value * delta;
}
return maybe_upcast<out_type>(__float2half_rz(val));
}
return maybe_upcast<scalar_t>(__float2half_rz(val));
}
};