-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stochastic Rounding Optimizers #17
base: master
Are you sure you want to change the base?
Changes from 1 commit
eca05da
71a0c29
e2112f0
88b9850
50cbbe4
46750ee
17696a2
2e130eb
907f568
31bfb75
ae652ea
b765059
6f7a93a
9159b03
59523c8
8ea3245
ccab446
31bd573
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,17 +44,24 @@ __device__ __forceinline__ float get_delta_fp16(float x) { | |
} | ||
|
||
// Natalia magic | ||
template <typename scalar_t> | ||
__device__ __forceinline__ scalar_t round_stochastically(float x, float random_value) { | ||
if (x == 0.0) { | ||
return scalar_t(0.0); | ||
} | ||
float delta = get_delta_fp16(x); | ||
float val; | ||
if (x < 0.0) { | ||
val = x - random_value * delta; | ||
} else { | ||
val = x + random_value * delta; | ||
template <typename out_type, typename in_type, typename round_to_prec=at::Half> | ||
struct round_stochastically { | ||
static_assert(std::is_same<round_to_prec, at::Half>::value, "round_stochastically only supports round_to_prec=at::Half"); | ||
}; | ||
|
||
template <typename out_type, typename in_type> | ||
struct round_stochastically<out_type, in_type, at::Half> { | ||
__device__ __forceinline__ out_type operator()(in_type x, float random_value) { | ||
if (x == 0.0) { | ||
return out_type(0.0); | ||
} | ||
float delta = get_delta_fp16(static_cast<float>(x)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hardcoding float here is probably fine IMO, but natalia may ask you to change this to |
||
float val; | ||
if (x < 0.0) { | ||
val = x - random_value * delta; | ||
} else { | ||
val = x + random_value * delta; | ||
} | ||
return maybe_upcast<out_type>(__float2half_rz(val)); | ||
} | ||
return maybe_upcast<scalar_t>(__float2half_rz(val)); | ||
} | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep this comment.