Skip to content

Commit

Permalink
Add tests for enabling LORA on a Dense and Embedding layer (keras-tea…
Browse files Browse the repository at this point in the history
…m#19079)

* Add tests for enabling LORA on a Dense layer

* Add tests for enabling LORA on Embedding layer
  • Loading branch information
Faisal-Alsrheed authored Jan 22, 2024
1 parent 19187d8 commit dad5342
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
21 changes: 21 additions & 0 deletions keras/layers/core/dense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,24 @@ def test_lora_rank_argument(self):
expected_num_losses=2, # we have 2 regularizers.
supports_masking=True,
)

def test_enable_lora_with_kernel_constraint(self):
layer = layers.Dense(units=2, kernel_constraint="max_norm")
with self.assertRaisesRegex(
ValueError, "incompatible with kernel constraints"
):
layer.enable_lora(rank=2)

def test_enable_lora_on_unbuilt_layer(self):
layer = layers.Dense(units=2)
with self.assertRaisesRegex(
ValueError, "Cannot enable lora on a layer that isn't yet built"
):
layer.enable_lora(rank=2)

def test_enable_lora_when_already_enabled(self):
layer = layers.Dense(units=2)
layer.build((None, 2))
layer.enable_lora(rank=2)
with self.assertRaisesRegex(ValueError, "lora is already enabled"):
layer.enable_lora(rank=2)
23 changes: 23 additions & 0 deletions keras/layers/core/embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,26 @@ def test_lora_rank_argument(self):
expected_num_losses=0,
supports_masking=False,
)

def test_enable_lora_with_embeddings_constraint(self):
layer = layers.Embedding(
input_dim=10, output_dim=16, embeddings_constraint="max_norm"
)
with self.assertRaisesRegex(
ValueError, "incompatible with embedding constraints"
):
layer.enable_lora(rank=2)

def test_enable_lora_on_unbuilt_layer(self):
layer = layers.Embedding(input_dim=10, output_dim=16)
with self.assertRaisesRegex(
ValueError, "Cannot enable lora on a layer that isn't yet built"
):
layer.enable_lora(rank=2)

def test_enable_lora_when_already_enabled(self):
layer = layers.Embedding(input_dim=10, output_dim=16)
layer.build()
layer.enable_lora(rank=2)
with self.assertRaisesRegex(ValueError, "lora is already enabled"):
layer.enable_lora(rank=2)

0 comments on commit dad5342

Please sign in to comment.