-
Notifications
You must be signed in to change notification settings - Fork 200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update the Gaudi trainer with transformers 4.45.2 #1398
base: main
Are you sure you want to change the base?
Conversation
- Add the description - remove _is_peft_model function and import it from transformers - remove the unnecessary import datasets and keep the first one
- Update the _get_train_sampler, by replacing computed var num_samples - Update the _inner_training_loop - num_update_steps_per_epoch used when possible - unused line was removed - add a new _should_compute_grad_norm var to remove extra conditional check - Update the _load_best_model, enabled the if self.is_deepspeed_enabled - Update _maybe_log_save_evaluate, _grad_norm.item() -> _grad_norm.detach.item()
- Remove the _save_checkpoint, exactly the same as transformers - Update autocast_smart_context_manager, updateing the old interface - Update save_model, add accelerate version check - Update evaluation_loop - introduce a new var _should_update_inputs for llama, qwen2, starcoder2 and gemma, to avoid double if condition in the loop - set the logits_dtype to float32 before the loop - re-order the conditionals for losses, logits, labels and match with transformers
…mandatory for Gaudi
…in train any more" This reverts commit e3f316f.
The env for all tests, export RUN_SLOW=true
export GAUDI2_CI=1 Some tests which are finished - slow_tests_fsdp -> 2 passed in 689.82s (0:11:29) | the same accuracy and loss on both tests with main
- test_trainer_distributed -> 2 passed, 6 warnings in 21.19s | the same as main
- test_trainer -> 75 passed, 8 skipped, 37 warnings in 55.23s | the same as main
- test_trainer_seq2seq -> 2 passed, 6 warnings in 21.19s | the same as main Other tests are in progress and will be updated here |
- introduce a new var _should_update_inputs for llama, qwen2, starcoder2 and gemma, to avoid double if condition in the loop
The env for all tests, export RUN_SLOW=true
export GAUDI2_CI=1 test_examples finished - test_examples -> 11 failed, 50 passed, 2 warnings in 21761.68s (6:02:41) | 11 failed, 50 passed, 2 warnings in 23254.39s (6:27:34) on main
=========================== short test summary info ============================
FAILED tests/test_examples.py::MultiCardQuestionAnsweringExampleTester::test_run_qa_roberta-large_multi_card
FAILED tests/test_examples.py::MultiCardCausalLanguageModelingExampleTester::test_run_clm_gpt2_multi_card
FAILED tests/test_examples.py::DeepspeedCausalLanguageModelingExampleTester::test_run_clm_CodeLlama-13b-Instruct-hf_deepspeed
FAILED tests/test_examples.py::MultiCardSummarizationExampleTester::test_run_summarization_t5-small_multi_card
FAILED tests/test_examples.py::DeepspeedSummarizationExampleTester::test_run_summarization_flan-t5-xxl_deepspeed
FAILED tests/test_examples.py::MultiCardCausalLanguageModelingLORAExampleTester2::test_run_lora_clm_falcon-40b_multi_card
FAILED tests/test_examples.py::MultiCardSeq2SeqSpeechRecognitionExampleTester::test_run_speech_recognition_seq2seq_whisper-small_multi_card
FAILED tests/test_examples.py::DeepspeedSFTExampleTester::test_sft_Qwen2-72B_deepspeed
FAILED tests/test_examples.py::MultiCardCausalLanguageModelingPrefixTuningExampleTester::test_run_prompt_tuning_clm_llama-7b_multi_card
FAILED tests/test_examples.py::MultiCardMultiTastPromptPeftExampleTester::test_run_multitask_prompt_tuning_t5-small_multi_card
FAILED tests/test_examples.py::MultiCardCausalLanguageModelingVeraExampleTester::test_run_lora_clm_llama-7b_multi_card
=========== 11 failed, 50 passed, 2 warnings in 23254.39s (6:27:34) ============
=========================== short test summary info ============================
FAILED tests/test_examples.py::MultiCardQuestionAnsweringExampleTester::test_run_qa_roberta-large_multi_card
FAILED tests/test_examples.py::MultiCardCausalLanguageModelingExampleTester::test_run_clm_gpt2_multi_card
FAILED tests/test_examples.py::DeepspeedCausalLanguageModelingExampleTester::test_run_clm_CodeLlama-13b-Instruct-hf_deepspeed
FAILED tests/test_examples.py::MultiCardSummarizationExampleTester::test_run_summarization_t5-small_multi_card
FAILED tests/test_examples.py::DeepspeedSummarizationExampleTester::test_run_summarization_flan-t5-xxl_deepspeed
FAILED tests/test_examples.py::MultiCardCausalLanguageModelingLORAExampleTester2::test_run_lora_clm_falcon-40b_multi_card
FAILED tests/test_examples.py::MultiCardSeq2SeqSpeechRecognitionExampleTester::test_run_speech_recognition_seq2seq_whisper-small_multi_card
FAILED tests/test_examples.py::DeepspeedSFTExampleTester::test_sft_Qwen2-72B_deepspeed
FAILED tests/test_examples.py::MultiCardCausalLanguageModelingPrefixTuningExampleTester::test_run_prompt_tuning_clm_llama-7b_multi_card
FAILED tests/test_examples.py::MultiCardMultiTastPromptPeftExampleTester::test_run_multitask_prompt_tuning_t5-small_multi_card
FAILED tests/test_examples.py::MultiCardCausalLanguageModelingVeraExampleTester::test_run_lora_clm_llama-7b_multi_card
=========== 11 failed, 50 passed, 2 warnings in 21761.68s (6:02:41) ============ |
@regisss would you please check this PR, if we want to move towards v4.46.0, this PR will help. There are updates in the Trainer for the new version |
Please make the title more descriptive to say what PR is doing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yafshar , Some comments added for code changes. Thanks for doing the port.
- Created `_get_input_update_settings` function to encapsulate the logic for determining input updates. - The function returns both `_should_update_inputs` boolean and `_inputs_update` dictionary. - This refactor improves maintainability and prepares for potential future expansion of the model list.
@yafshar , Please redo make style, and a test run if you haven't already done so and confirm that no new test failures are seen. Thanks. |
This PR, >>>python -m pytest tests/test_gaudi_configuration.py tests/test_trainer_distributed.py tests/test_trainer.py tests/test_trainer_seq2seq.py -v -s
81 passed, 8 skipped, 140 warnings in 139.73s (0:02:19)
>>> python -m pytest tests/test_examples.py -v -s -k single_card
12 passed, 50 deselected in 2917.12s (0:48:37)
>>> python -m pytest tests/test_fsdp_examples.py -v -s
2 passed in 799.36s (0:13:19)
>>> python -m pytest tests/test_examples.py -v -s -k multi_card
10 failed, 34 passed, 18 deselected, 2 warnings in 17206.15s (4:46:46)
>>> python -m pytest tests/test_examples.py -v -s -k deepspeed
3 failed, 3 passed, 56 deselected in 5636.71s (1:33:56) main >>>python -m pytest tests/test_gaudi_configuration.py tests/test_trainer_distributed.py tests/test_trainer.py tests/test_trainer_seq2seq.py -v -s
81 passed, 8 skipped, 140 warnings in 181.19s (0:03:01)
>>> python -m pytest tests/test_examples.py -v -s -k single_card
12 passed, 50 deselected in 2805.12s (0:46:45)
>>> python -m pytest tests/test_fsdp_examples.py -v -s
2 passed in 804.29s (0:13:24)
>>> python -m pytest tests/test_examples.py -v -s -k multi_card
10 failed, 34 passed, 18 deselected, 2 warnings in 16397.55s (4:33:17)
>>> python -m pytest tests/test_examples.py -v -s -k deepspeed
3 failed, 3 passed, 56 deselected in 4832.87s (1:20:32) |
What does this PR do?
Update the Gaudi trainer with transformers 4.45.2
transformers
Before submitting