Skip to content

Commit

Permalink
Add test to verify setting of SemaphoreKey and MutexName fields in KF…
Browse files Browse the repository at this point in the history
…P DSL
  • Loading branch information
DharmitD committed Nov 14, 2024
1 parent 8db319b commit 3d73418
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3846,6 +3846,44 @@ def outer():
task = inner()
foo_platform_set_bar_feature(task, 12)

class TestPipelineSemaphoreMutex(unittest.TestCase):

def test_pipeline_with_semaphore_and_mutex(self):
from kfp import dsl
from kfp import compiler
from kfp.dsl.pipeline_config import PipelineConfig

config = PipelineConfig()
config.set_semaphore_key("semaphore")
config.set_mutex_name("mutex")

@dsl.pipeline(pipeline_config=config)
def my_pipeline():
task = comp()

with tempfile.TemporaryDirectory() as tempdir:
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=output_yaml)

with open(output_yaml, 'r') as f:
pipeline_docs = list(yaml.safe_load_all(f))

pipeline_spec = None
for doc in pipeline_docs:
if 'platforms' in doc:
pipeline_spec = doc
break

if pipeline_spec:
try:
kubernetes_spec = pipeline_spec['platforms']['kubernetes']['pipelineConfig']
assert kubernetes_spec['semaphoreKey'] == "semaphore"
assert kubernetes_spec['mutexName'] == "mutex"
except KeyError:
print("platforms or expected keys not found in the compiled pipeline spec.")
else:
print("No document with 'platforms' found in the compiled pipeline spec.")

class ExtractInputOutputDescription(unittest.TestCase):

Expand Down

0 comments on commit 3d73418

Please sign in to comment.