Skip to content

Commit

Permalink
fix readme typo
Browse files Browse the repository at this point in the history
  • Loading branch information
Fzilan committed Sep 23, 2024
1 parent 6c5baa6 commit 6c98961
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions examples/kohya_sd_scripts/Limitations.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ We've tried to provide a consistent implementation with the torch Kohya SD train

The conception `network` in Kohya means the respective LoRA network of the target tuning net, such as SD's UNet.

In torch Kohya, the `LoRANetwork` is made of `LoRAModule`. Once the `apply_to` method is called, torch Kohya replaces forward method of the original Linear from the target modules of `UNet` , instead of dircectly replacing the original Linear module. So basically the `UNet` and the `LoRANetwork` are two independent `nn.Module`, and some linear layers of `UNet` use the respective forward methods of `LoRAModule` in `LoRANetwork` when computing.
In torch Kohya, the `LoRANetwork` is made of `LoRAModule`. Once the `apply_to` method is called, torch Kohya replaces forward method of the original Linear from the target modules of `UNet`, instead of dircectly replacing the original Linear module. So basically the `UNet` and the `LoRANetwork` are two independent `nn.Module`, and some linear layers of `UNet` use the respective forward methods of `LoRAModule` in `LoRANetwork` when computing.

In MindSpore graph mode, forward method replacements raise errors in graph compilation. The respective linear layers from original net and the added lora layers must in the same subgraph. Thus the MindSpore implementation replaces the original linear module of `UNet` when creating the`LoRANetwork`, and the `UNet` as a `nn.Cell` is packaged by the `LoRANetwork` as a bigger `nn.Cell`.
In MindSpore graph mode, forward method replacements raise errors in graph construction. The respective linear layers from original net and the added lora layers must in the same subgraph. Thus the MindSpore implementation replaces the original linear module of `UNet` when creating the`LoRANetwork`, and the `UNet` as a `nn.Cell` is packaged by the `LoRANetwork` as a bigger `nn.Cell`.


### class - `LoRANetwork`
Expand Down Expand Up @@ -55,7 +55,7 @@ class LoRANetwork(nn.Cell):
self.text_encoders = text_encoders
self.text_encoders = unet

# modified from `create_modules` method, but more than creation, once we find the traget we replace the module
# modified from `create_modules` method, but more than creation, once we find the target we replace the module
def replace_modules(...):
...
# search unet and text encoders TARGET_REPLACE_MODULE and replace respective LoRA modules
Expand All @@ -68,7 +68,7 @@ class LoRANetwork(nn.Cell):

### APIs - `create_network` and `create_network_from_weights`

The two APIs are for lora network creation in training and inference scripts. Torch use create_network to initalize a LoRANetwork and call LoRANetwork.apply_to to replace the forward method as said above. MindSpore creates and replaces the modules in initailztion and no need to have the apply method. Similarly in inference, Torch creates the LoRANetwork for lora modules, called the LoRANetwork.merge_to for lora weights loading and merging. MindSpore directly loads and merges the weights to unet or text encoders by create_network_from_weights API, without any network creation.
The two APIs are for lora network creation in training and inference scripts. Torch use create_network to initalize a `LoRANetwork` and call `LoRANetwork.apply_to` to replace the forward method as said above. MindSpore creates and replaces the modules in initailztion and no need to have the apply method. Similarly in inference, Torch creates the `LoRANetwork` first and call the `LoRANetwork.merge_to` for lora weights loading and merging. MindSpore directly loads and merges the weights to unet or text encoders by create_network_from_weights API, without any network creation.

*Torch*

Expand Down Expand Up @@ -141,9 +141,9 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):

## 3. forward and backward wrapper

Kohya designs a base lora trainer in `NetworkTrainer` in `train_network.py`, with `load_target_model`, `load_tokenizer`, `get_text_cond` , `call_unet` and other methods. Most of these methods are designed for SD training and are overridden in`SDXLNetworkTrainer`, except the `train` method. The `train` implements the whole training process including dataset, network, optimizer, forward and backward.
Kohya designs a base lora trainer in `NetworkTrainer` in `train_network.py`, with `load_target_model`, `load_tokenizer`, `get_text_cond`, `call_unet`, and other methods. Most of these methods are designed for SD training and are overridden in`SDXLNetworkTrainer`, except the `train` method. The `NetworkTrainer.train` in torch kohya implements a general lora training for the whole process including dataset, network, optimizer, forward and backward.

Different from Torch, automatic differentiation in MindSpore graph mode based on the graph structure with forward graph and backward graph. So we define a `TrainStepForSDXLLoRA` to wrap the forward and backward process, inherit the `TrainStep` from `mindone.diffusers.training_utils`, a base class for training steps in MindSpore.
Different from Torch, automatic differentiation in MindSpore graph mode based on the graph structure of forward graphs and backward graphs. So we define a `TrainStepForSDXLLoRA` to wrap the forward and backward process, and the `train` method needs to choose the trainstep wrapper for any specified model in training loop. That means another model like SD1.5 or Flux needs another trainstep wrapper definition. The wrapper inherits the `TrainStep` from `mindone.diffusers.training_utils`, a base class for training steps implemented in MindSpore.

*Torch*

Expand Down

0 comments on commit 6c98961

Please sign in to comment.