Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 7th PPSCI No.12】Adam、AdamW 优化器支持 amsgrad -part #68079

Open
wants to merge 45 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
b45f2c4
[init] amsgrad
megemini Aug 29, 2024
640be9b
[update] refer.h
megemini Aug 29, 2024
2028825
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Sep 4, 2024
caf919a
[Add] amsgrad gpu
megemini Sep 4, 2024
aa289ad
[Add] amsgrad for adamw and fused
megemini Sep 4, 2024
106f817
[Fix] adamw gpu kernel
megemini Sep 5, 2024
fddb46a
[Update] fused adam kernel for gpu
megemini Sep 5, 2024
d206442
[Update] xpu adam/adamw param list
megemini Sep 5, 2024
8cc9b5b
[Update] tests for amsgrad
megemini Sep 6, 2024
eb5de54
[Fix] moment2 max out settting values without amsgrad
megemini Sep 7, 2024
7aa9d60
[Update] unittest passed for adam and adamw
megemini Sep 7, 2024
96216e4
[Update] unittest passed for merged and fused amda
megemini Sep 8, 2024
7398a2f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Sep 9, 2024
98abe71
[Update] make moment2_max optional
megemini Sep 10, 2024
e159b70
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Sep 10, 2024
4564d32
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Sep 11, 2024
e2d2f9b
[Update] test_adamw_op.py with new test cast
megemini Sep 11, 2024
7d7ddb1
[Update] adam adamw with amsgrad formula
megemini Sep 12, 2024
d8d97ed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Sep 18, 2024
fc6204f
[Update] adam/adamw for test.cc
megemini Sep 18, 2024
0144890
[Fix] xpu param name
megemini Sep 18, 2024
c6942c0
[Fix] xpu param name & unittest
megemini Sep 18, 2024
f9cb32e
[Fix] xpu param type
megemini Sep 18, 2024
92ad89d
[Fix] xpu unittest
megemini Sep 18, 2024
8e026cd
[Fix] xpu unittest
megemini Sep 19, 2024
56d26df
[Fix] xpu unittest
megemini Sep 19, 2024
26c7e63
[Fix] merged_adam_ op_compat.yaml
megemini Sep 19, 2024
5aa6c40
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Sep 19, 2024
ddb2035
[Fix] remove UNUSED
megemini Sep 19, 2024
e41b66b
[Fix] remove UNUSED
megemini Sep 19, 2024
a751804
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Sep 20, 2024
1f2831a
[Update] unittest adam op
megemini Sep 20, 2024
cfbd173
[Fix] op_compat.yaml
megemini Sep 21, 2024
d371c41
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Sep 26, 2024
9f977ac
[Update] assembly for adam adamw
megemini Sep 29, 2024
1f74eb8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Oct 2, 2024
d6e2652
[Fix] adamw.cc for assembly jit gen
megemini Oct 10, 2024
da2e743
[Update] adam with old ir test
megemini Oct 10, 2024
1b9a6bf
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Oct 11, 2024
d157301
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Oct 12, 2024
1c05064
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Oct 20, 2024
6544a48
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Oct 28, 2024
f17d737
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
megemini Nov 8, 2024
af27337
[Update] codestyle
megemini Nov 8, 2024
d7bb19a
[Update] npu test rtol adamw
megemini Nov 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions paddle/fluid/operators/fused/fused_adam_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class FusedAdamOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("LearningRate", "(Tensor, default Tensor<float>) Learning rate");
AddInput("Moments1", "(Tensor) Input first moments").AsDuplicable();
AddInput("Moments2", "(Tensor) Input second moments").AsDuplicable();
AddInput("Moments2Max", "(Tensor) Input second moments max for amsgrad")
.AsDispensable()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里改变了原有算子的签名,将会是不兼容升级。不知道是否符合我们的期望。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有些模型可能已经按照原来的协议保存下来了,如果这里修改后,原来save的模型可能无法加载。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,我看 https://github.com/PaddlePaddle/Paddle/wiki/OP-Input-Output-Attribute-Compatibility-Modification 是要求写 AsDispensable ~ 这样也不行?

.AsDuplicable();
AddInput("Beta1Pows",
"(Tensor, default Tensor<float>) Input beta1 power accumulator")
.AsDuplicable();
Expand All @@ -72,6 +75,10 @@ class FusedAdamOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("ParamsOut", "(Tensor) Output parameters").AsDuplicable();
AddOutput("Moments1Out", "(Tensor) Output first moments").AsDuplicable();
AddOutput("Moments2Out", "(Tensor) Output second moments").AsDuplicable();
AddOutput("Moments2MaxOut",
"(Tensor) Output second moments max for amsgrad")
.AsDispensable()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

.AsDuplicable();
AddOutput("Beta1PowsOut", "(Tensor) Output beta1 power accumulator")
.AsDuplicable();
AddOutput("Beta2PowsOut", "(Tensor) Output beta2 power accumulator")
Expand Down Expand Up @@ -122,6 +129,10 @@ class FusedAdamOpMaker : public framework::OpProtoAndCheckerMaker {
"Whether to use global beta_pow for whole model instead of "
"creating beta_pow for each parameter.")
.SetDefault(false);
AddAttr<bool>("amsgrad",
"(bool, default false) "
"Whether to use the AMSGrad of this algorithm.")
.SetDefault(false);

AddComment(R"DOC(
Adam Optimizer.
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/operators/ops_signature/adam_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ KernelSignature AdamOpArgumentMapping(const ArgumentMappingContext& ctx) {
"LearningRate",
"Moment1",
"Moment2",
"Moment2Max",
"Beta1Pow",
"Beta2Pow",
"MasterParam",
"SkipUpdate"};
paddle::small_vector<const char*> out_names = {"ParamOut",
"Moment1Out",
"Moment2Out",
"Moment2MaxOut",
"Beta1PowOut",
"Beta2PowOut",
"MasterParamOut"};
Expand All @@ -46,6 +48,7 @@ KernelSignature AdamOpArgumentMapping(const ArgumentMappingContext& ctx) {
attr_names.emplace_back("min_row_size_to_use_multithread");
attr_names.emplace_back("multi_precision");
attr_names.emplace_back("use_global_beta_pow");
attr_names.emplace_back("amsgrad");

if (ctx.IsSelectedRowsInput("Grad")) {
return KernelSignature("adam_dense_param_sparse_grad",
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/operators/ops_signature/fused_adam_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@ KernelSignature FusedAdamOpArgumentMapping(
"LearningRate",
"Moments1",
"Moments2",
"Moments2Max",
"Beta1Pows",
"Beta2Pows",
"MasterParams",
"SkipUpdate"};
paddle::small_vector<const char*> out_names = {"ParamsOut",
"Moments1Out",
"Moments2Out",
"Moments2MaxOut",
"Beta1PowsOut",
"Beta2PowsOut",
"MasterParamsOut"};
Expand All @@ -42,7 +44,8 @@ KernelSignature FusedAdamOpArgumentMapping(
"weight_decay",
"use_adamw",
"multi_precision",
"use_global_beta_pow"};
"use_global_beta_pow",
"amsgrad"};

return KernelSignature("fused_adam",
std::move(in_names),
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/pybind/eager_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3344,27 +3344,31 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"ParamOut",
"Moment1Out",
"Moment2Out",
"Moment2MaxOut",
"Beta1PowOut",
"Beta2PowOut",
"MasterParamOut"}},
{"merged_adam",
{"ParamOut",
"Moment1Out",
"Moment2Out",
"Moment2MaxOut",
"Beta1PowOut",
"Beta2PowOut",
"MasterParamOut"}},
{"fused_adam",
{"ParamsOut",
"Moments1Out",
"Moments2Out",
"Moments2MaxOut",
"Beta1PowsOut",
"Beta2PowsOut",
"MasterParamsOut"}},
{"adamw",
{"ParamOut",
"Moment1Out",
"Moment2Out",
"Moment2MaxOut",
"Beta1PowOut",
"Beta2PowOut",
"MasterParamOut"}},
Expand Down Expand Up @@ -3544,6 +3548,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"LearningRate",
"Moment1",
"Moment2",
"Moment2Max",
"Beta1Pow",
"Beta2Pow",
"MasterParam"}},
Expand All @@ -3553,6 +3558,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"LearningRate",
"Moment1",
"Moment2",
"Moment2Max",
"Beta1Pow",
"Beta2Pow",
"MasterParam"}},
Expand All @@ -3562,6 +3568,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"LearningRate",
"Moments1",
"Moments2",
"Moments2Max",
"Beta1Pows",
"Beta2Pows",
"MasterParams",
Expand All @@ -3572,6 +3579,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"LearningRate",
"Moment1",
"Moment2",
"Moment2Max",
"Beta1Pow",
"Beta2Pow",
"MasterParam"}},
Expand Down Expand Up @@ -3723,27 +3731,31 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"ParamOut",
"Moment1Out",
"Moment2Out",
"Moment2MaxOut",
"Beta1PowOut",
"Beta2PowOut",
"MasterParamOut"}},
{"merged_adam",
{"ParamOut",
"Moment1Out",
"Moment2Out",
"Moment2MaxOut",
"Beta1PowOut",
"Beta2PowOut",
"MasterParamOut"}},
{"fused_adam",
{"ParamsOut",
"Moments1Out",
"Moments2Out",
"Moments2MaxOut",
"Beta1PowsOut",
"Beta2PowsOut",
"MasterParamsOut"}},
{"adamw",
{"ParamOut",
"Moment1Out",
"Moment2Out",
"Moment2MaxOut",
"Beta1PowOut",
"Beta2PowOut",
"MasterParamOut"}},
Expand Down
23 changes: 23 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ void AdamInferMeta(const MetaTensor& param,
const MetaTensor& learning_rate,
const MetaTensor& moment1,
const MetaTensor& moment2,
const MetaTensor& moment2_max,
const MetaTensor& beta1_pow,
const MetaTensor& beta2_pow,
const MetaTensor& master_param,
Expand All @@ -163,9 +164,11 @@ void AdamInferMeta(const MetaTensor& param,
int64_t min_row_size_to_use_multithread,
bool multi_precision,
bool use_global_beta_pow,
bool amsgrad,
MetaTensor* param_out,
MetaTensor* moment1_out,
MetaTensor* moment2_out,
MetaTensor* moment2_max_out,
MetaTensor* beta1_pow_out,
MetaTensor* beta2_pow_out,
MetaTensor* master_param_outs) {
Expand Down Expand Up @@ -232,6 +235,10 @@ void AdamInferMeta(const MetaTensor& param,
moment1_out->set_dtype(moment1.dtype());
moment2_out->set_dims(param_dims);
moment2_out->set_dtype(moment2.dtype());
if (amsgrad) {
moment2_max_out->set_dims(param_dims);
moment2_max_out->set_dtype(moment2.dtype());
}

beta1_pow_out->set_dims(beta1_pow_dims);
beta1_pow_out->set_dtype(beta1_pow.dtype());
Expand Down Expand Up @@ -328,6 +335,7 @@ void AdamwInferMeta(const MetaTensor& param,
const MetaTensor& learning_rate,
const MetaTensor& moment1,
const MetaTensor& moment2,
const MetaTensor& moment2_max,
const MetaTensor& beta1_pow,
const MetaTensor& beta2_pow,
const MetaTensor& master_param,
Expand All @@ -342,9 +350,11 @@ void AdamwInferMeta(const MetaTensor& param,
int64_t min_row_size_to_use_multithread,
bool multi_precision,
bool use_global_beta_pow,
bool amsgrad,
MetaTensor* param_out,
MetaTensor* moment1_out,
MetaTensor* moment2_out,
MetaTensor* moment2_max_out,
MetaTensor* beta1_pow_out,
MetaTensor* beta2_pow_out,
MetaTensor* master_param_outs) {
Expand All @@ -353,6 +363,7 @@ void AdamwInferMeta(const MetaTensor& param,
learning_rate,
moment1,
moment2,
moment2_max,
beta1_pow,
beta2_pow,
master_param,
Expand All @@ -364,9 +375,11 @@ void AdamwInferMeta(const MetaTensor& param,
min_row_size_to_use_multithread,
multi_precision,
use_global_beta_pow,
amsgrad,
param_out,
moment1_out,
moment2_out,
moment2_max_out,
beta1_pow_out,
beta2_pow_out,
master_param_outs);
Expand Down Expand Up @@ -3856,6 +3869,7 @@ void MergedAdamInferMeta(
const std::vector<const MetaTensor*>& learning_rate,
const std::vector<const MetaTensor*>& moment1,
const std::vector<const MetaTensor*>& moment2,
const paddle::optional<std::vector<const MetaTensor*>>& moment2_max,
const std::vector<const MetaTensor*>& beta1_pow,
const std::vector<const MetaTensor*>& beta2_pow,
const paddle::optional<std::vector<const MetaTensor*>>& master_param,
Expand All @@ -3864,9 +3878,11 @@ void MergedAdamInferMeta(
const Scalar& epsilon,
bool multi_precision,
bool use_global_beta_pow,
bool amsgrad,
std::vector<MetaTensor*> param_out,
std::vector<MetaTensor*> moment1_out,
std::vector<MetaTensor*> moment2_out,
std::vector<MetaTensor*> moment2_max_out,
std::vector<MetaTensor*> beta1_pow_out,
std::vector<MetaTensor*> beta2_pow_out,
std::vector<MetaTensor*> master_param_out) {}
Expand Down Expand Up @@ -5784,6 +5800,7 @@ void FusedAdamInferMeta(
const MetaTensor& learning_rate,
const std::vector<const MetaTensor*>& moments1,
const std::vector<const MetaTensor*>& moments2,
const paddle::optional<std::vector<const MetaTensor*>>& moments2_max,
const std::vector<const MetaTensor*>& beta1_pows,
const std::vector<const MetaTensor*>& beta2_pows,
const paddle::optional<std::vector<const MetaTensor*>>& master_params,
Expand All @@ -5796,9 +5813,11 @@ void FusedAdamInferMeta(
bool use_adamw,
bool multi_precision,
bool use_global_beta_pow,
bool amsgrad,
std::vector<MetaTensor*> params_out,
std::vector<MetaTensor*> moments1_out,
std::vector<MetaTensor*> moments2_out,
std::vector<MetaTensor*> moments2_max_out,
std::vector<MetaTensor*> beta1_pows_out,
std::vector<MetaTensor*> beta2_pows_out,
std::vector<MetaTensor*> master_params_out) {
Expand All @@ -5810,6 +5829,10 @@ void FusedAdamInferMeta(
moments1_out[i]->set_dtype(moments1[i]->dtype());
moments2_out[i]->set_dims(moments2[i]->dims());
moments2_out[i]->set_dtype(moments2[i]->dtype());
if (amsgrad) {
moments2_max_out[i]->set_dims(moments2_max.get()[i]->dims());
moments2_max_out[i]->set_dtype(moments2_max.get()[i]->dtype());
}
beta1_pows_out[i]->set_dims(beta1_pows[i]->dims());
beta1_pows_out[i]->set_dtype(beta1_pows[i]->dtype());
beta2_pows_out[i]->set_dims(beta2_pows[i]->dims());
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ void AdamInferMeta(const MetaTensor& param,
const MetaTensor& learning_rate,
const MetaTensor& moment1,
const MetaTensor& moment2,
const MetaTensor& moment2_max,
const MetaTensor& beta1_pow,
const MetaTensor& beta2_pow,
const MetaTensor& master_param,
Expand All @@ -97,9 +98,11 @@ void AdamInferMeta(const MetaTensor& param,
int64_t min_row_size_to_use_multithread,
bool multi_precision,
bool use_global_beta_pow,
bool amsgrad,
MetaTensor* param_out,
MetaTensor* moment1_out,
MetaTensor* moment2_out,
MetaTensor* moment2_max_out,
MetaTensor* beta1_pow_out,
MetaTensor* beta2_pow_out,
MetaTensor* master_param_outs);
Expand All @@ -109,6 +112,7 @@ void AdamwInferMeta(const MetaTensor& param,
const MetaTensor& learning_rate,
const MetaTensor& moment1,
const MetaTensor& moment2,
const MetaTensor& moment2_max,
const MetaTensor& beta1_pow,
const MetaTensor& beta2_pow,
const MetaTensor& master_param,
Expand All @@ -123,9 +127,11 @@ void AdamwInferMeta(const MetaTensor& param,
int64_t min_row_size_to_use_multithread,
bool multi_precision,
bool use_global_beta_pow,
bool amsgrad,
MetaTensor* param_out,
MetaTensor* moment1_out,
MetaTensor* moment2_out,
MetaTensor* moment2_max_out,
MetaTensor* beta1_pow_out,
MetaTensor* beta2_pow_out,
MetaTensor* master_param_outs);
Expand Down Expand Up @@ -711,6 +717,7 @@ void MergedAdamInferMeta(
const std::vector<const MetaTensor*>& learning_rate,
const std::vector<const MetaTensor*>& moment1,
const std::vector<const MetaTensor*>& moment2,
const paddle::optional<std::vector<const MetaTensor*>>& moment2_max,
const std::vector<const MetaTensor*>& beta1_pow,
const std::vector<const MetaTensor*>& beta2_pow,
const paddle::optional<std::vector<const MetaTensor*>>& master_param,
Expand All @@ -719,9 +726,11 @@ void MergedAdamInferMeta(
const Scalar& epsilon,
bool multi_precision,
bool use_global_beta_pow,
bool amsgrad,
std::vector<MetaTensor*> param_out,
std::vector<MetaTensor*> moment1_out,
std::vector<MetaTensor*> moment2_out,
std::vector<MetaTensor*> moment2_max_out,
std::vector<MetaTensor*> beta1_pow_out,
std::vector<MetaTensor*> beta2_pow_out,
std::vector<MetaTensor*> master_param_out);
Expand Down Expand Up @@ -1117,6 +1126,7 @@ void FusedAdamInferMeta(
const MetaTensor& learning_rate,
const std::vector<const MetaTensor*>& moments1,
const std::vector<const MetaTensor*>& moments2,
const paddle::optional<std::vector<const MetaTensor*>>& moments2_max,
const std::vector<const MetaTensor*>& beta1_pows,
const std::vector<const MetaTensor*>& beta2_pows,
const paddle::optional<std::vector<const MetaTensor*>>& master_params,
Expand All @@ -1129,9 +1139,11 @@ void FusedAdamInferMeta(
bool use_adamw,
bool multi_precision,
bool use_global_beta_pow,
bool amsgrad,
std::vector<MetaTensor*> params_out,
std::vector<MetaTensor*> moments1_out,
std::vector<MetaTensor*> moments2_out,
std::vector<MetaTensor*> moments2_max_out,
std::vector<MetaTensor*> beta1_pows_out,
std::vector<MetaTensor*> beta2_pows_out,
std::vector<MetaTensor*> master_params_out);
Expand Down
Loading