Skip to content

Commit

Permalink
deciLM support (#1133)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
  • Loading branch information
sywangyi authored Aug 2, 2024
1 parent abc4f0e commit 7738595
Show file tree
Hide file tree
Showing 7 changed files with 564 additions and 0 deletions.
1 change: 1 addition & 0 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
"llava_next",
"stablelm",
"mamba",
"deci",
]


Expand Down
5 changes: 5 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
gaudi_StoppingCriteriaList_call,
)
from .models import (
DeciLMConfig,
DeciLMForCausalLM,
GaudiBloomForCausalLM,
GaudiBloomMLP,
GaudiCLIPAttention,
Expand Down Expand Up @@ -553,3 +555,6 @@ def adapt_transformers_to_gaudi():
transformers.models.mamba.modeling_mamba.MambaForCausalLM._update_model_kwargs_for_generation = (
gaudi_MambaForCausalLM_update_model_kwargs_for_generation
)

transformers.AutoConfig.register("deci", DeciLMConfig)
transformers.AutoModelForCausalLM.register(DeciLMConfig, DeciLMForCausalLM)
4 changes: 4 additions & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
gaudi_codegen_block_forward,
gaudi_codegen_model_forward,
)
from .decilm import (
DeciLMConfig,
DeciLMForCausalLM,
)
from .detr import gaudi_DetrConvModel_forward
from .esm import (
gaudi_esm_for_protein_folding_forward,
Expand Down
4 changes: 4 additions & 0 deletions optimum/habana/transformers/models/decilm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .configuration_decilm import DeciLMConfig
from .modeling_decilm import (
DeciLMForCausalLM,
)
24 changes: 24 additions & 0 deletions optimum/habana/transformers/models/decilm/configuration_decilm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Adapted from the following sources:
https://huggingface.co/Deci/DeciLM-7B/blob/main/configuration_decilm.py
"""

from transformers.models.llama.configuration_llama import LlamaConfig


class DeciLMConfig(LlamaConfig):
r"""
Args:
num_key_value_heads_per_layer (`List[int]`):
The number of key-value heads per layer.
"""

model_type = "deci"

def __init__(
self,
num_key_value_heads_per_layer: list = None,
**kwargs,
):
self.num_key_value_heads_per_layer = num_key_value_heads_per_layer
super().__init__(**kwargs)
Loading

0 comments on commit 7738595

Please sign in to comment.