Skip to content

Commit

Permalink
update sdv2 to ms2.3 (#626)
Browse files Browse the repository at this point in the history
* update

* update readme

* fix ema bug

* update readme

* update

* update readme

* update typo

* update flash attention

* update benchmarking

* update

---------

Co-authored-by: songyuanwei <song.yuanwei@huawei.com>
  • Loading branch information
Songyuanwei and songyuanwei authored Sep 12, 2024
1 parent 5831703 commit 7159716
Show file tree
Hide file tree
Showing 27 changed files with 279 additions and 148 deletions.
3 changes: 1 addition & 2 deletions examples/stable_diffusion_v2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ The compatible framework versions that are well-tested are listed as follows.
| 910 | 2.0 | 6.3 RC1 | 23.0.rc1 | 3.7.16 | master (4c33849) |
| 910 | 2.1 | 6.3 RC2 | 23.0.rc2 | 3.9.18 | master (4c33849) |
| 910* | 2.2.1 (20231124) | 7.1 | 23.0.rc3.6 | 3.7.16 | master (4c33849) |
| 910* | 2.3.0 | 7.3 | 23.0.3 | 3.8.8 | master |

</div>

Expand Down Expand Up @@ -281,7 +282,6 @@ To run vanilla fine-tuning, we will use the `train_text_to_image.py` script foll
--output_path {path to output directory} \
--pretrained_model_path {path to pretrained checkpoint file}
```
> Please enable INFNAN mode by `export MS_ASCEND_CHECK_OVERFLOW_MODE="INFNAN_MODE"` for Ascend 910* if overflow found.

Take fine-tuning SD1.5 on the Pokemon dataset as an example:

Expand Down Expand Up @@ -319,7 +319,6 @@ For parallel training on multiple Ascend NPUs, please refer to the instructions
```shell
bash scripts/run_train_distributed.sh
```
> Please enable INFNAN mode by `export MS_ASCEND_CHECK_OVERFLOW_MODE="INFNAN_MODE"` for Ascend 910* if overflow found.

After launched, the training process can be traced by running `tail -f ouputs/train_txt2img/rank_0/train.log`.

Expand Down
85 changes: 47 additions & 38 deletions examples/stable_diffusion_v2/benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,45 @@

### Training

| SD Model | Context | Method | Global Batch Size x Grad. Accu. | Resolution | Acceleration | FPS (img/s) |
|---------------|---------------|--------------|:-------------------:|:------------------:|:----------------:|:----------------:|
| 1.5 | D910x1-MS2.1 | Vanilla | 3x1 | 512x512 | Graph, DS, FP16, | 5.98 |
| 1.5 | D910x8-MS2.1 | Vanilla | 24x1 | 512x512 | Graph, DS, FP16, | 31.18 |
| 1.5 | D910x1-MS2.1 | LoRA | 4x1 | 512x512 | Graph, DS, FP16, | 8.25 |
| 1.5 | D910x8-MS2.1 | LoRA | 32x1 | 512x512 | Graph, DS, FP16, | 63.85 |
| 1.5 | D910x1-MS2.1 | Dreambooth | 1x1 | 512x512 | Graph, DS, FP16, | 2.09 |
| 2.0 | D910x1-MS2.1 | Vanilla | 3x1 | 512x512 | Graph, DS, FP16, | 6.19 |
| 2.0 | D910x8-MS2.1 | Vanilla | 24x1 | 512x512 | Graph, DS, FP16, | 33.50 |
| 2.0 | D910x1-MS2.1 | LoRA | 4x1 | 512x512 | Graph, DS, FP16, | 9.46 |
| 2.0 | D910x8-MS2.1 | LoRA | 32x1 | 512x512 | Graph, DS, FP16, | 73.51 |
| 2.0 | D910x1-MS2.1 | Dreambooth | 1x1 | 512x512 | Graph, DS, FP16, | 2.18 |
| 2.1-v | D910x1-MS2.1 | Vanilla | 3x1 | 768x768 | Graph, DS, FP16, FA | 3.16 |
| 2.1-v | D910x8-MS2.1 | Vanilla | 24x1 | 768x768 | Graph, DS, FP16, FA | 18.98 |
| 2.1-v | D910x1-MS2.1 | LoRA | 4x1 | 768x768 | Graph, DS, FP16, FA | 3.39 |
| 2.1-v | D910x8-MS2.1 | LoRA | 32x1 | 768x768 | Graph, DS, FP16, FA | 23.45 |
| 1.5 | D910*x1-MS2.2.10 | Vanilla | 3x1 | 512x512 | Graph, DS, FP16, | 9.22 |
| 1.5 | D910*x8-MS2.2.10 | Vanilla | 24x1 | 512x512 | Graph, DS, FP16, | 52.30 |
| 1.5 | D910*x1-MS2.2.10 | LoRA | 4x1 | 512x512 | Graph, DS, FP16, | 13.58 |
| 1.5 | D910*x8-MS2.2.10 | LoRA | 32x1 | 512x512 | Graph, DS, FP16, | 105.08 |
| 1.5 | D910*x1-MS2.2.10 | Dreambooth | 1x1 | 512x512 | Graph, DS, FP16, | 2.92 |
| 2.0 | D910*x1-MS2.2.10 | Vanilla | 3x1 | 512x512 | Graph, DS, FP16, | 10.03 |
| 2.0 | D910*x8-MS2.2.10 | Vanilla | 24x1 | 512x512 | Graph, DS, FP16, | 55.69 |
| 2.0 | D910*x1-MS2.2.10 | LoRA | 4x1 | 512x512 | Graph, DS, FP16, | 15.88 |
| 2.0 | D910*x8-MS2.2.10 | LoRA | 32x1 | 512x512 | Graph, DS, FP16, | 119.74 |
| 2.0 | D910*x1-MS2.2.10 | Dreambooth | 1x1 | 512x512 | Graph, DS, FP16, | 2.93 |
| 2.1-v | D910*x1-MS2.2.10 | Vanilla | 3x1 | 768x768 | Graph, DS, FP16, | 5.80 |
| 2.1-v | D910*x1-MS2.2.10 | Vanilla | 24x1 | 768x768 | Graph, DS, FP16, | 46.02 |
| 2.1-v | D910*x1-MS2.2.10 | LoRA | 4x1 | 768x768 | Graph, DS, FP16, | 6.65 |
| 2.1-v | D910*x8-MS2.2.10 | LoRA | 32x1 | 768x768 | Graph, DS, FP16, | 52.57 |
| SD Model | Context | Method | Global Batch Size x Grad. Accu. | Resolution | Acceleration | jit_level |FPS (img/s) |
|---------------|---------------|--------------|:-------------------:|:------------------:|:----------------:|:----------------:|----------:|
| 1.5 | D910x1-MS2.1 | Vanilla | 3x1 | 512x512 | Graph, DS, FP16, | N/A | 5.98 |
| 1.5 | D910x8-MS2.1 | Vanilla | 24x1 | 512x512 | Graph, DS, FP16, | N/A | 31.18 |
| 1.5 | D910x1-MS2.1 | LoRA | 4x1 | 512x512 | Graph, DS, FP16, | N/A | 8.25 |
| 1.5 | D910x8-MS2.1 | LoRA | 32x1 | 512x512 | Graph, DS, FP16, | N/A | 63.85 |
| 1.5 | D910x1-MS2.1 | Dreambooth | 1x1 | 512x512 | Graph, DS, FP16, | N/A | 2.09 |
| 2.0 | D910x1-MS2.1 | Vanilla | 3x1 | 512x512 | Graph, DS, FP16, | N/A | 6.19 |
| 2.0 | D910x8-MS2.1 | Vanilla | 24x1 | 512x512 | Graph, DS, FP16, | N/A | 33.50 |
| 2.0 | D910x1-MS2.1 | LoRA | 4x1 | 512x512 | Graph, DS, FP16, | N/A | 9.46 |
| 2.0 | D910x8-MS2.1 | LoRA | 32x1 | 512x512 | Graph, DS, FP16, | N/A | 73.51 |
| 2.0 | D910x1-MS2.1 | Dreambooth | 1x1 | 512x512 | Graph, DS, FP16, | N/A | 2.18 |
| 2.1-v | D910x1-MS2.1 | Vanilla | 3x1 | 768x768 | Graph, DS, FP16, FA | N/A | 3.16 |
| 2.1-v | D910x8-MS2.1 | Vanilla | 24x1 | 768x768 | Graph, DS, FP16, FA | N/A | 18.98 |
| 2.1-v | D910x1-MS2.1 | LoRA | 4x1 | 768x768 | Graph, DS, FP16, FA | N/A | 3.39 |
| 2.1-v | D910x8-MS2.1 | LoRA | 32x1 | 768x768 | Graph, DS, FP16, FA | N/A | 23.45 |
| 1.5 | D910*x1-MS2.3.0 | Vanilla | 3x1 | 512x512 | Graph, DS, FP16, | O2 | 11.86 |
| 1.5 | D910*x8-MS2.3.0 | Vanilla | 24x1 | 512x512 | Graph, DS, FP16, | O2 | 75.53 |
| 1.5 | D910*x1-MS2.3.0 | LoRA | 4x1 | 512x512 | Graph, DS, FP16, | O2 | 15.27 |
| 1.5 | D910*x8-MS2.3.0 | LoRA | 32x1 | 512x512 | Graph, DS, FP16, | O2 | 119.94 |
| 1.5 | D910*x1-MS2.3.0 | Dreambooth | 1x1 | 512x512 | Graph, DS, FP16, | O2 | 3.86 |
| 2.0 | D910*x1-MS2.3.0 | Vanilla | 3x1 | 512x512 | Graph, DS, FP16, | O2 | 12.75 |
| 2.0 | D910*x8-MS2.3.0 | Vanilla | 24x1 | 512x512 | Graph, DS, FP16, | O2 | 79.67 |
| 2.0 | D910*x1-MS2.3.0 | LoRA | 4x1 | 512x512 | Graph, DS, FP16, | O2 | 16.53 |
| 2.0 | D910*x8-MS2.3.0 | LoRA | 32x1 | 512x512 | Graph, DS, FP16, | O2 | 129.70 |
| 2.0 | D910*x1-MS2.3.0 | Dreambooth | 1x1 | 512x512 | Graph, DS, FP16, | O2 | 3.76 |
| 2.1-v | D910*x1-MS2.3.0 | Vanilla | 3x1 | 768x768 | Graph, DS, FP16, FA | 02 | 7.16 |
| 2.1-v | D910*x8-MS2.3.0 | Vanilla | 24x1 | 768x768 | Graph, DS, FP16, FA | 02 | 49.27 |
| 2.1-v | D910*x1-MS2.3.0 | LoRA | 4x1 | 768x768 | Graph, DS, FP16, FA | 02 | 9.51 |
| 2.1-v | D910*x8-MS2.3.0 | LoRA | 32x1 | 768x768 | Graph, DS, FP16, FA | 02 | 71.51 |
> Context: {Ascend chip}-{number of NPUs}-{mindspore version}.
>
> Acceleration: DS: data sink mode, FP16: float16 computation. FA: flash attention.
>
>FPS: images per second during training. average training time (s/step) = batch_size / FPS
>
>jie_level: Used to control the compilation optimization level. N/A means that the current MindSpore version does not support setting jit_level.
Note that the jit_level only can be used for MindSpore 2.3.

Note that the performance of SD2.1 should be similar to SD2.0 since they have the same network architecture.

Expand All @@ -57,20 +61,25 @@ Flash Attention,

### Inference

| SD Model | Context | Scheduler | Steps | Resolution | Batch Size | Speed (step/s) | FPS (img/s) |
|---------------|:-----------|:------------:|:------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
| 1.5 | D910x1-MS2.2.10 | DDIM | 30 | 512x512 | 4 | 3.58 | 0.44 |
| 2.0 | D910x1-MS2.2.10 | DDIM | 30 | 512x512 | 4 | 4.12 | 0.49 |
| 2.1-v | D910x1-MS2.2.10 | DDIM | 30 | 768x768 | 4 | 1.14 | 0.14 |
| 1.5 | D910*x1-MS2.2.10 | DDIM | 30 | 512x512 | 4 | 6.19 | 0.71 |
| 2.0 | D910*x1-MS2.2.10 | DDIM | 30 | 512x512 | 4 | 7.65 | 0.83 |
| 2.1-v | D910*x1-MS2.2.10 | DDIM | 30 | 768x768 | 4 | 2.79 | 0.32 |
| SD Model | Context | Scheduler | Steps | Resolution | Batch Size | jit_level | Speed (step/s) | FPS (img/s) |
|---------------|------------|--------------|:-------------------:|:-------------:|:----------------:|:----------------:|----------:|----------|
| 1.5 | D910x1-MS2.2.10 | DDIM | 30 | 512x512 | 4 | N/A | 3.58 | 0.44 |
| 2.0 | D910x1-MS2.2.10 | DDIM | 30 | 512x512 | 4 | N/A | 4.12 | 0.49 |
| 2.1-v | D910x1-MS2.2.10 | DDIM | 30 | 768x768 | 4 | N/A | 1.14 | 0.14 |
| 1.5 | D910*x1-MS2.3.0 | DDIM | 30 | 512x512 | 4 | O2 | 6.69 | 0.77 |
| 2.0 | D910*x1-MS2.3.0 | DDIM | 30 | 512x512 | 4 | O2 | 8.30 | 0.91 |
| 2.1-v | D910*x1-MS2.3.0 | DDIM | 30 | 768x768 | 4 | O2 | 2.91 | 0.36 |


> Context: {Ascend chip}-{number of NPUs}-{mindspore version}.
>
> Speed (step/s): sampling speed measured in the number of sampling steps per second.
>
> FPS (img/s): image generation throughput measured in the number of image generated per second.
>
>jie_level: Used to control the compilation optimization level. N/A means that the current MindSpore version does not support setting jit_level.
Note that the jit_level only can be used for MindSpore 2.3.
Note that the performance of SD2.1 should be similar to SD2.0 since they have the same network architecture. Performance per NPU in multi-NPU parallel mode is the same as performance of single NPU mode.


Expand Down
18 changes: 18 additions & 0 deletions examples/stable_diffusion_v2/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ def init_env(
seed: int = 42,
distributed: bool = False,
device_target: Optional[str] = "Ascend",
jit_level: str = "O2",
enable_modelarts: bool = False,
num_workers: int = 1,
json_data_path: Optional[str] = None,
max_device_memory: Optional[str] = "1024GB",
) -> Tuple[int, int, int]:
"""
Initialize MindSpore environment.
Expand All @@ -40,6 +42,20 @@ def init_env(
A tuple containing the device ID, rank ID and number of devices.
"""
set_random_seed(seed)
if mode == ms.GRAPH_MODE:
try:
if jit_level in ["O0", "O1", "O2"]:
ms.set_context(jit_config={"jit_level": jit_level})
_logger.info(f"set jit_level: {jit_level}.")
else:
_logger.warning(
f"Unsupport jit_level: {jit_level}. The framework automatically selects the execution method"
)
except Exception:
_logger.warning(
"The current jit_level is not suitable because current MindSpore version does not match,"
"please ensure the MindSpore version >= ms2.3.0."
)

if debug and mode == ms.GRAPH_MODE: # force PyNative mode when debugging
_logger.warning("Debug mode is on, switching execution mode to PyNative.")
Expand All @@ -52,6 +68,7 @@ def init_env(
device_target=device_target,
device_id=device_id,
ascend_config={"precision_mode": "allow_fp32_to_fp16"}, # Only effective on Ascend 910*
max_device_memory=max_device_memory,
)
init()
device_num = get_group_size()
Expand Down Expand Up @@ -80,6 +97,7 @@ def init_env(
device_id=device_id,
ascend_config={"precision_mode": "allow_fp32_to_fp16"}, # Only effective on Ascend 910*
pynative_synchronize=debug,
max_device_memory=max_device_memory,
)

return device_id, rank_id, device_num
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ ckpt_save_interval: 1
epochs: 20
use_ema: True
clip_grad: False
enable_flash_attention: True

# lr scheduler
scheduler: "cosine_decay"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ model:
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64
use_spatial_transformer: True
enable_flash_attention: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
Expand Down
25 changes: 24 additions & 1 deletion examples/stable_diffusion_v2/depth_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ def main(args):
rank=0,
log_level=eval(args.log_level),
)

# init
device_id = int(os.getenv("DEVICE_ID", 0))
ms.context.set_context(
Expand All @@ -268,6 +267,20 @@ def main(args):
device_id=device_id,
max_device_memory="30GB",
)
if args.ms_mode == ms.GRAPH_MODE:
try:
if args.jit_level in ["O0", "O1", "O2"]:
ms.set_context(jit_config={"jit_level": args.jit_level})
logger.info(f"set jit_level: {args.jit_level}.")
else:
logger.warning(
f"Unsupport jit_level: {args.jit_level}. The framework automatically selects the execution method"
)
except Exception:
logger.warning(
"The current jit_level is not suitable because current MindSpore version does not match,"
"please ensure the MindSpore version >= ms2.3.0."
)

if args.save_graph:
save_graphs_path = "graph"
Expand Down Expand Up @@ -416,6 +429,16 @@ def load_model_from_config(config, ckpt, verbose=False):
parser.add_argument(
"--ms_mode", type=int, default=0, help="Running in GRAPH_MODE(0) or PYNATIVE_MODE(1) (default=0)"
)
parser.add_argument(
"--jit_level",
default="O2",
type=str,
choices=["O0", "O1", "O2"],
help="Used to control the compilation optimization level. Supports [“O0”, “O1”, “O2”]."
"O0: Except for optimizations that may affect functionality, all other optimizations are turned off, adopt KernelByKernel execution mode."
"O1: Using commonly used optimizations and automatic operator fusion optimizations, adopt KernelByKernel execution mode."
"O2: Ultimate performance optimization, adopt Sink execution mode.",
)
parser.add_argument("--num_samples", type=int, default=4, help="num of total samples")
parser.add_argument(
"--img_size",
Expand Down
Loading

0 comments on commit 7159716

Please sign in to comment.