Skip to content

Commit

Permalink
Add mergekit-moe script (arcee-ai#141)
Browse files Browse the repository at this point in the history
Also incidentally adds support for merging Mixtral models.
  • Loading branch information
cg123 authored Jan 30, 2024
1 parent 508348a commit 031a3c2
Show file tree
Hide file tree
Showing 6 changed files with 561 additions and 15 deletions.
38 changes: 38 additions & 0 deletions docs/moe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# mergekit-moe

`mergekit-moe` is a script for combining Mistral or Llama models of the same size into Mixtral Mixture of Experts models. The script will combine the self-attention and layer normalization parameters from a "base" model with the MLP parameters from a set of "expert" models. `mergekit-moe` uses its own YML configuration syntax, which looks like so:

```yml
base_model: path/to/self_attn_donor
gate_mode: hidden # one of "hidden", "cheap_embed", or "random"
dtype: bfloat16 # output dtype (float32, float16, or bfloat16)
## (optional)
# experts_per_token: 2
experts:
- source_model: expert_model_1
positive_prompts:
- "This is a prompt that is demonstrative of what expert_model_1 excels at"
## (optional)
# negative_prompts:
# - "This is a prompt expert_model_1 should not be used for"
- source_model: expert_model_2
# ... and so on
```

The script takes two arguments, an input config and an output path: `mergekit-moe ./config.yml ./my-clowncar-moe-12x180B`

## Gate Modes

There are three methods for populating the MoE gates implemented.

### "hidden"

Uses the hidden state representations of the positive/negative prompts for MoE gate parameters. Best quality and most effective option; the default. Requires evaluating each prompt using the base model so you might not be able to use this on constrained hardware (depending on the model). You can use `--load-in-8bit` or `--load-in-4bit` to reduce VRAM usage.

### "cheap_embed"

Uses only the raw token embedding of the prompts, using the same gate parameters for every layer. Distinctly less effective than "hidden". Can be run on much, much lower end hardware.

### "random"

Randomly initializes the MoE gates. Good for if you are going to fine tune the model afterwards, or maybe if you want something a little unhinged? I won't judge.
60 changes: 46 additions & 14 deletions mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# along with this program. If not, see http://www.gnu.org/licenses/.

from abc import ABC, abstractmethod
from typing import List, Optional
from typing import ClassVar, List, Optional

from pydantic import BaseModel
from transformers import PretrainedConfig
Expand Down Expand Up @@ -117,6 +117,40 @@ def all_weights(self, config: PretrainedConfig) -> List[str]:
)


class MixtralTensorNames(ArchitectureInfo, BaseModel):
ARCHITECTURE_NAME: ClassVar[str] = "MixtralForCausalLM"
num_local_experts: int

@classmethod
def from_config(cls, config: PretrainedConfig):
return MixtralTensorNames(num_local_experts=config.num_local_experts)

def pre_weights(self) -> List[str]:
return MISTRAL_INFO.pre_weights()

def post_weights(self) -> List[str]:
return MISTRAL_INFO.post_weights()

def embed_weights(self) -> List[str]:
return MISTRAL_INFO.embed_weights()

def num_layers_config_key(self) -> str:
return MISTRAL_INFO.num_layers_config_key()

def layer_weight_formats(self) -> List[str]:
num_experts = self.num_local_experts
res = [fmt for fmt in MISTRAL_INFO.layer_weight_formats() if ".mlp." not in fmt]
for expert_idx in range(num_experts):
for param in ("w1", "w2", "w3"):
fmt = (
MISTRAL_INFO.layer_prefix_format
+ f".block_sparse_moe.experts.{expert_idx}.{param}.weight"
)
res.append(fmt)
res.append(MISTRAL_INFO.layer_prefix_format + ".block_sparse_moe.gate.weight")
return res


STABLELM_INFO = StaticTensorNames(
name="StableLMEpochForCausalLM",
post_weight_names=LLAMA_INFO.post_weight_names + ["model.norm.bias"],
Expand Down Expand Up @@ -289,29 +323,25 @@ def all_weights(self, config: PretrainedConfig) -> List[str]:
)


class PhiTensorNames(ArchitectureInfo):
architecture_name: str = "MixFormerSequentialForCausalLM"

def __init__(self, config: PretrainedConfig):
self.config = config
class PhiTensorNames(ArchitectureInfo, BaseModel):
ARCHITECTURE_NAME: ClassVar[str] = "MixFormerSequentialForCausalLM"
n_layer: int

def __eq__(self, rhs: "PhiTensorNames"):
if not isinstance(rhs, PhiTensorNames):
return False
return self.num_layers() == rhs.num_layers()
def from_config(cls, config: PretrainedConfig):
return PhiTensorNames(n_layer=config.n_layer)

def pre_weights(self) -> List[str]:
return ["layers.0.wte.weight"]

def post_weights(self) -> List[str]:
fake_layer_idx = self.config.n_layer + 1
fake_layer_idx = self.n_layer
return [
f"layers.{fake_layer_idx}.{suffix}"
for suffix in ["linear.bias", "linear.weight", "ln.bias", "ln.weight"]
]

def embed_weights(self) -> List[str]:
fake_layer_idx = self.config.n_layer + 1
fake_layer_idx = self.n_layer
return [
"layers.0.wte.weight",
f"layers.{fake_layer_idx}.linear.weight",
Expand Down Expand Up @@ -423,8 +453,10 @@ def get_architecture_info(config: PretrainedConfig) -> StaticTensorNames:
raise RuntimeError("More than one architecture in config?")

arch_name = config.architectures[0]
if arch_name == PhiTensorNames.architecture_name:
return PhiTensorNames(config)
if arch_name == PhiTensorNames.ARCHITECTURE_NAME:
return PhiTensorNames.from_config(config)
if arch_name == MixtralTensorNames.ARCHITECTURE_NAME:
return MixtralTensorNames.from_config(config)

if arch_name == PHI2_INFO.name:
if config.model_type == "phi-msft":
Expand Down
3 changes: 3 additions & 0 deletions mergekit/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ def __str__(self) -> str:


def dtype_from_name(name: Optional[str]) -> torch.dtype:
if name.startswith("torch."):
name = name[len("torch.") :]

if name == "bfloat16":
return torch.bfloat16
elif name == "float16":
Expand Down
2 changes: 1 addition & 1 deletion mergekit/io/lazy_tensor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def from_disk(cls, base_path: str) -> "ShardedTensorIndex":
tensor_paths = {key: shard_name for key in st.keys()}
else:
# this is ugly but not much else can be done
shard = torch.load(model_path)
shard = torch.load(model_path, map_location="meta")
if "state_dict" in shard:
shard = shard["state_dict"]

Expand Down
Loading

0 comments on commit 031a3c2

Please sign in to comment.