Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the Gaudi trainer with transformers 4.45.2 #1398

Open
wants to merge 26 commits into
base: main
Choose a base branch
from

Conversation

yafshar
Copy link
Contributor

@yafshar yafshar commented Oct 4, 2024

What does this PR do?

Update the Gaudi trainer with transformers 4.45.2

  • Add the description
  • remove _is_peft_model function and import it from transformers
  • remove the unnecessary import datasets and keep the first one
  • Update the _get_train_sampler, by replacing computed var num_samples
  • Update the _inner_training_loop
    • num_update_steps_per_epoch used when possible
    • unused line was removed
    • add a new _should_compute_grad_norm var to remove extra conditional check
  • Update the _load_best_model, enabled the if self.is_deepspeed_enabled
  • Update _maybe_log_save_evaluate, _grad_norm.item() -> _grad_norm.detach.item()
  • Remove the _save_checkpoint, exactly the same as transformers
  • Update autocast_smart_context_manager, updateing the old interface
  • Update save_model, add accelerate version check
  • Update evaluation_loop
    • introduce a new var _should_update_inputs for llama, qwen2, starcoder2 and gemma, to avoid double if condition in the loop
    • set the logits_dtype to float32 before the loop
    • re-order the conditionals for losses, logits, labels and match with
      transformers
  • Update _inner_training_loop
    • introduce a new var _should_update_inputs for llama, qwen2, starcoder2 and gemma, to avoid double if condition in the loop

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

- Add the description
- remove _is_peft_model function and import it from transformers
- remove the unnecessary import datasets and keep the first one
- Update the _get_train_sampler, by replacing computed var num_samples
- Update the _inner_training_loop
  - num_update_steps_per_epoch used when possible
  - unused line was removed
  - add a new _should_compute_grad_norm var to remove extra conditional
    check
- Update the _load_best_model, enabled the if self.is_deepspeed_enabled
- Update _maybe_log_save_evaluate, _grad_norm.item() -> _grad_norm.detach.item()
- Remove the _save_checkpoint, exactly the same as transformers
- Update autocast_smart_context_manager, updateing the old interface
- Update save_model, add accelerate version check
- Update evaluation_loop
  - introduce a new var _should_update_inputs  for llama, qwen2,
    starcoder2 and gemma, to avoid double if condition in the loop
  - set the logits_dtype to float32 before the loop
  - re-order the conditionals for losses, logits, labels and match with
    transformers
@yafshar yafshar marked this pull request as ready for review October 10, 2024 11:45
@yafshar yafshar requested a review from regisss as a code owner October 10, 2024 11:45
@yafshar
Copy link
Contributor Author

yafshar commented Oct 11, 2024

The env for all tests,

export RUN_SLOW=true
export GAUDI2_CI=1

Some tests which are finished

- slow_tests_fsdp -> 2 passed in 689.82s (0:11:29) | the same accuracy and loss on both tests with main

- test_trainer_distributed -> 2 passed, 6 warnings in 21.19s | the same as main

- test_trainer -> 75 passed, 8 skipped, 37 warnings in 55.23s | the same as main

- test_trainer_seq2seq -> 2 passed, 6 warnings in 21.19s | the same as main

Other tests are in progress and will be updated here

- introduce a new var _should_update_inputs for llama, qwen2,
  starcoder2 and gemma, to avoid double if condition in the loop
@yafshar
Copy link
Contributor Author

yafshar commented Oct 14, 2024

The env for all tests,

export RUN_SLOW=true
export GAUDI2_CI=1

test_examples finished

- test_examples -> 11 failed, 50 passed, 2 warnings in 21761.68s (6:02:41) | 11 failed, 50 passed, 2 warnings in 23254.39s (6:27:34) on main
  • main branch
=========================== short test summary info ============================
FAILED tests/test_examples.py::MultiCardQuestionAnsweringExampleTester::test_run_qa_roberta-large_multi_card
FAILED tests/test_examples.py::MultiCardCausalLanguageModelingExampleTester::test_run_clm_gpt2_multi_card
FAILED tests/test_examples.py::DeepspeedCausalLanguageModelingExampleTester::test_run_clm_CodeLlama-13b-Instruct-hf_deepspeed
FAILED tests/test_examples.py::MultiCardSummarizationExampleTester::test_run_summarization_t5-small_multi_card
FAILED tests/test_examples.py::DeepspeedSummarizationExampleTester::test_run_summarization_flan-t5-xxl_deepspeed
FAILED tests/test_examples.py::MultiCardCausalLanguageModelingLORAExampleTester2::test_run_lora_clm_falcon-40b_multi_card
FAILED tests/test_examples.py::MultiCardSeq2SeqSpeechRecognitionExampleTester::test_run_speech_recognition_seq2seq_whisper-small_multi_card
FAILED tests/test_examples.py::DeepspeedSFTExampleTester::test_sft_Qwen2-72B_deepspeed
FAILED tests/test_examples.py::MultiCardCausalLanguageModelingPrefixTuningExampleTester::test_run_prompt_tuning_clm_llama-7b_multi_card
FAILED tests/test_examples.py::MultiCardMultiTastPromptPeftExampleTester::test_run_multitask_prompt_tuning_t5-small_multi_card
FAILED tests/test_examples.py::MultiCardCausalLanguageModelingVeraExampleTester::test_run_lora_clm_llama-7b_multi_card
=========== 11 failed, 50 passed, 2 warnings in 23254.39s (6:27:34) ============
  • Current PR
=========================== short test summary info ============================
FAILED tests/test_examples.py::MultiCardQuestionAnsweringExampleTester::test_run_qa_roberta-large_multi_card
FAILED tests/test_examples.py::MultiCardCausalLanguageModelingExampleTester::test_run_clm_gpt2_multi_card
FAILED tests/test_examples.py::DeepspeedCausalLanguageModelingExampleTester::test_run_clm_CodeLlama-13b-Instruct-hf_deepspeed
FAILED tests/test_examples.py::MultiCardSummarizationExampleTester::test_run_summarization_t5-small_multi_card
FAILED tests/test_examples.py::DeepspeedSummarizationExampleTester::test_run_summarization_flan-t5-xxl_deepspeed
FAILED tests/test_examples.py::MultiCardCausalLanguageModelingLORAExampleTester2::test_run_lora_clm_falcon-40b_multi_card
FAILED tests/test_examples.py::MultiCardSeq2SeqSpeechRecognitionExampleTester::test_run_speech_recognition_seq2seq_whisper-small_multi_card
FAILED tests/test_examples.py::DeepspeedSFTExampleTester::test_sft_Qwen2-72B_deepspeed
FAILED tests/test_examples.py::MultiCardCausalLanguageModelingPrefixTuningExampleTester::test_run_prompt_tuning_clm_llama-7b_multi_card
FAILED tests/test_examples.py::MultiCardMultiTastPromptPeftExampleTester::test_run_multitask_prompt_tuning_t5-small_multi_card
FAILED tests/test_examples.py::MultiCardCausalLanguageModelingVeraExampleTester::test_run_lora_clm_llama-7b_multi_card
=========== 11 failed, 50 passed, 2 warnings in 21761.68s (6:02:41) ============

@yafshar
Copy link
Contributor Author

yafshar commented Oct 24, 2024

@regisss would you please check this PR, if we want to move towards v4.46.0, this PR will help. There are updates in the Trainer for the new version

@emascarenhas
Copy link
Contributor

Please make the title more descriptive to say what PR is doing.

@yafshar yafshar changed the title Trainer Update the Gaudi trainer with transformers 4.45.0 Nov 5, 2024
@yafshar yafshar changed the title Update the Gaudi trainer with transformers 4.45.0 Update the Gaudi trainer with transformers 4.45.2 Nov 5, 2024
Copy link
Contributor

@emascarenhas emascarenhas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yafshar , Some comments added for code changes. Thanks for doing the port.

optimum/habana/transformers/trainer.py Show resolved Hide resolved
optimum/habana/transformers/trainer.py Show resolved Hide resolved
optimum/habana/transformers/trainer.py Show resolved Hide resolved
optimum/habana/transformers/trainer.py Outdated Show resolved Hide resolved
- Created `_get_input_update_settings` function to encapsulate the
  logic for determining input updates.
- The function returns both `_should_update_inputs` boolean and
  `_inputs_update` dictionary.
- This refactor improves maintainability and prepares for potential
  future expansion of the model list.
@emascarenhas
Copy link
Contributor

@yafshar , Please redo make style, and a test run if you haven't already done so and confirm that no new test failures are seen. Thanks.

@yafshar
Copy link
Contributor Author

yafshar commented Nov 10, 2024

This PR,

>>>python -m pytest tests/test_gaudi_configuration.py tests/test_trainer_distributed.py tests/test_trainer.py tests/test_trainer_seq2seq.py -v -s
81 passed, 8 skipped, 140 warnings in 139.73s (0:02:19)

>>> python -m pytest tests/test_examples.py -v -s -k single_card
12 passed, 50 deselected in 2917.12s (0:48:37)

>>> python -m pytest tests/test_fsdp_examples.py -v -s
2 passed in 799.36s (0:13:19)

>>>  python -m pytest tests/test_examples.py -v -s -k multi_card
10 failed, 34 passed, 18 deselected, 2 warnings in 17206.15s (4:46:46)

>>> python -m pytest tests/test_examples.py -v -s -k deepspeed
3 failed, 3 passed, 56 deselected in 5636.71s (1:33:56)

main

>>>python -m pytest tests/test_gaudi_configuration.py tests/test_trainer_distributed.py tests/test_trainer.py tests/test_trainer_seq2seq.py -v -s
81 passed, 8 skipped, 140 warnings in 181.19s (0:03:01)

>>> python -m pytest tests/test_examples.py -v -s -k single_card
12 passed, 50 deselected in 2805.12s (0:46:45)

>>> python -m pytest tests/test_fsdp_examples.py -v -s
2 passed in 804.29s (0:13:24)

>>>  python -m pytest tests/test_examples.py -v -s -k multi_card
10 failed, 34 passed, 18 deselected, 2 warnings in 16397.55s (4:33:17)

>>> python -m pytest tests/test_examples.py -v -s -k deepspeed
3 failed, 3 passed, 56 deselected in 4832.87s (1:20:32)

@emascarenhas
Copy link
Contributor

@libinta , Please add run-test and other labels to allow this to merge into OH.
@regiss, This is ready for review and merge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants