-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(components): Create Data Labeling kube-flow container component
PiperOrigin-RevId: 617930970
- Loading branch information
Googler
committed
Apr 16, 2024
1 parent
60a443e
commit 3103f62
Showing
6 changed files
with
375 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
14 changes: 14 additions & 0 deletions
14
...ogle_cloud_pipeline_components/_implementation/model_evaluation/data_labeling/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Copyright 2023 The Kubeflow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Google Cloud Pipeline Evaluation Data Labeling Component.""" |
83 changes: 83 additions & 0 deletions
83
...gle_cloud_pipeline_components/_implementation/model_evaluation/data_labeling/component.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Copyright 2024 The Kubeflow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Data Labeling Evaluation component.""" | ||
|
||
|
||
from google_cloud_pipeline_components import _image | ||
from kfp import dsl | ||
|
||
|
||
@dsl.container_component | ||
def evaluation_data_labeling( | ||
project: str, | ||
location: str, | ||
gcp_resources: dsl.OutputPath(str), # pytype: disable=invalid-annotation | ||
job_display_name: str, | ||
dataset_name: str, | ||
instruction_uri: str, | ||
inputs_schema_uri: str, | ||
annotation_spec: str, | ||
labeler_count: int, | ||
annotation_label: str, | ||
): | ||
"""Builds a container spec that launches a data labeling job. | ||
Args: | ||
project: Project to run the job in. | ||
location: Location to run the job in. | ||
gcp_resources: GCP resources that can be used to track the job. | ||
job_display_name: Display name of the data labeling job. | ||
dataset_name: Name of the dataset to use for the data labeling job. | ||
instruction_uri: URI of the instruction for the data labeling job. | ||
inputs_schema_uri: URI of the inputs schema for the data labeling job. | ||
annotation_spec: Annotation spec for the data labeling job. | ||
labeler_count: Number of labelers to use for the data labeling job. | ||
annotation_label: Label of the data labeling job. | ||
Returns: | ||
Container spec that launches a data labeling job with the specified payload. | ||
""" | ||
return dsl.ContainerSpec( | ||
image=_image.GCPC_IMAGE_TAG, | ||
command=[ | ||
'python3', | ||
'-u', | ||
'-m', | ||
'google_cloud_pipeline_components.container._implementation.model_evaluation.data_labeling_job.launcher', | ||
], | ||
args=[ | ||
'--type', | ||
'DataLabelingJob', | ||
'--project', | ||
project, | ||
'--location', | ||
location, | ||
'--gcp_resources', | ||
gcp_resources, | ||
'--job_display_name', | ||
job_display_name, | ||
'--dataset_name', | ||
dataset_name, | ||
'--instruction_uri', | ||
instruction_uri, | ||
'--inputs_schema_uri', | ||
inputs_schema_uri, | ||
'--annotation_spec', | ||
annotation_spec, | ||
'--labeler_count', | ||
labeler_count, | ||
'--annotation_label', | ||
annotation_label, | ||
], | ||
) |
14 changes: 14 additions & 0 deletions
14
...eline_components/container/_implementation/model_evaluation/data_labeling_job/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Copyright 2024 The Kubeflow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Google Cloud Pipeline Components - Data Labeling Job Launcher and Remote Runner.""" |
154 changes: 154 additions & 0 deletions
154
...eline_components/container/_implementation/model_evaluation/data_labeling_job/launcher.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# Copyright 2024 The Kubeflow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""GCP launcher for data labeling jobs based on the AI Platform SDK.""" | ||
|
||
import argparse | ||
import logging | ||
import os | ||
import sys | ||
from typing import Any, Dict | ||
|
||
from google_cloud_pipeline_components.container._implementation.model_evaluation.data_labeling_job import remote_runner | ||
|
||
|
||
def _make_parent_dirs_and_return_path(file_path: str): | ||
os.makedirs(os.path.dirname(file_path), exist_ok=True) | ||
return file_path | ||
|
||
|
||
def _parse_args(args) -> Dict[str, Any]: | ||
"""Parse command line arguments. | ||
Args: | ||
args: A list of arguments. | ||
Returns: | ||
A tuple containing an argparse.Namespace class instance holding parsed args, | ||
and a list containing all unknonw args. | ||
""" | ||
parser = argparse.ArgumentParser( | ||
prog='Dataflow python job Pipelines service launcher', description='' | ||
) | ||
parser.add_argument( | ||
'--type', | ||
dest='type', | ||
type=str, | ||
required=True, | ||
default=argparse.SUPPRESS, | ||
) | ||
parser.add_argument( | ||
'--project', | ||
dest='project', | ||
type=str, | ||
required=True, | ||
default=argparse.SUPPRESS, | ||
) | ||
parser.add_argument( | ||
'--location', | ||
dest='location', | ||
type=str, | ||
required=True, | ||
default=argparse.SUPPRESS, | ||
) | ||
parser.add_argument( | ||
'--gcp_resources', | ||
dest='gcp_resources', | ||
type=_make_parent_dirs_and_return_path, | ||
required=True, | ||
default=argparse.SUPPRESS, | ||
) | ||
parser.add_argument( | ||
'--job_display_name', | ||
dest='job_display_name', | ||
type=str, | ||
required=True, | ||
default=argparse.SUPPRESS, | ||
) | ||
parser.add_argument( | ||
'--dataset_name', | ||
dest='dataset_name', | ||
type=str, | ||
required=True, | ||
default=argparse.SUPPRESS, | ||
) | ||
parser.add_argument( | ||
'--instruction_uri', | ||
dest='instruction_uri', | ||
type=str, | ||
required=True, | ||
default=argparse.SUPPRESS, | ||
) | ||
parser.add_argument( | ||
'--inputs_schema_uri', | ||
dest='inputs_schema_uri', | ||
type=str, | ||
required=True, | ||
default=argparse.SUPPRESS, | ||
) | ||
parser.add_argument( | ||
'--annotation_spec', | ||
dest='annotation_spec', | ||
type=str, | ||
required=True, | ||
default=argparse.SUPPRESS, | ||
) | ||
parser.add_argument( | ||
'--labeler_count', | ||
dest='labeler_count', | ||
type=int, | ||
required=True, | ||
default=argparse.SUPPRESS, | ||
) | ||
parser.add_argument( | ||
'--annotation_label', | ||
dest='annotation_label', | ||
type=str, | ||
required=True, | ||
default=argparse.SUPPRESS, | ||
) | ||
parsed_args, _ = parser.parse_known_args(args) | ||
return vars(parsed_args) | ||
|
||
|
||
def main(argv): | ||
"""Main entry. | ||
Expected input args are as follows: | ||
Project - Required. The project of which the resource will be launched. | ||
Region - Required. The region of which the resource will be launched. | ||
Type - Required. GCP launcher is a single container. This Enum will | ||
specify which resource to be launched. | ||
Request payload - Required. The full serialized json of the resource spec. | ||
Note this can contain the Pipeline Placeholders. | ||
gcp_resources - placeholder output for returning job_id. | ||
Args: | ||
argv: A list of system arguments. | ||
""" | ||
parsed_args = _parse_args(argv) | ||
job_type = parsed_args['type'] | ||
|
||
if job_type != 'DataLabelingJob': | ||
raise ValueError('Incorrect job type: ' + job_type) | ||
|
||
logging.info( | ||
'Starting DataLabelingJob using the following arguments: %s', | ||
parsed_args, | ||
) | ||
|
||
remote_runner.create_data_labeling_job(**parsed_args) | ||
|
||
|
||
if __name__ == '__main__': | ||
main(sys.argv[1:]) |
108 changes: 108 additions & 0 deletions
108
..._components/container/_implementation/model_evaluation/data_labeling_job/remote_runner.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Copyright 2024 The Kubeflow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""GCP launcher for data labeling jobs based on the AI Platform SDK.""" | ||
|
||
import json | ||
|
||
from google.api_core import retry | ||
from google_cloud_pipeline_components.container.v1.gcp_launcher import job_remote_runner | ||
from google_cloud_pipeline_components.container.v1.gcp_launcher.utils import error_util | ||
|
||
_DATA_LABELING_JOB_RETRY_DEADLINE_SECONDS = 10.0 * 60.0 | ||
|
||
|
||
def create_data_labeling_job_with_client(job_client, parent, job_spec): | ||
create_data_labeling_job_fn = None | ||
try: | ||
create_data_labeling_job_fn = job_client.create_data_labeling_job( | ||
parent=parent, data_labeling_job=job_spec | ||
) | ||
except (ConnectionError, RuntimeError) as err: | ||
error_util.exit_with_internal_error(err.args[0]) | ||
return create_data_labeling_job_fn | ||
|
||
|
||
def get_data_labeling_job_with_client(job_client, job_name): | ||
get_data_labeling_job_fn = None | ||
try: | ||
get_data_labeling_job_fn = job_client.get_data_labeling_job( | ||
name=job_name, | ||
retry=retry.Retry(deadline=_DATA_LABELING_JOB_RETRY_DEADLINE_SECONDS), | ||
) | ||
except (ConnectionError, RuntimeError) as err: | ||
error_util.exit_with_internal_error(err.args[0]) | ||
return get_data_labeling_job_fn | ||
|
||
|
||
def create_data_labeling_job( | ||
type, | ||
project, | ||
location, | ||
gcp_resources, | ||
job_display_name, | ||
dataset_name, | ||
instruction_uri, | ||
inputs_schema_uri, | ||
annotation_spec, | ||
labeler_count, | ||
annotation_label, | ||
): | ||
"""Create data labeling job. | ||
This follows the typical launching logic: | ||
1. Read if the data labeling job already exists in gcp_resources | ||
- If already exists, jump to step 3 and poll the job status. This happens | ||
if the launcher container experienced unexpected termination, such as | ||
preemption | ||
2. Deserialize the params into the job spec and create the data labeling | ||
job | ||
3. Poll the data labeling job status every | ||
job_remote_runner._POLLING_INTERVAL_IN_SECONDS seconds | ||
- If the data labeling job is succeeded, return succeeded | ||
- If the data labeling job is cancelled/paused, it's an unexpected | ||
scenario so return failed | ||
- If the data labeling job is running, continue polling the status | ||
Also retry on ConnectionError up to | ||
job_remote_runner._CONNECTION_ERROR_RETRY_LIMIT times during the poll. | ||
""" | ||
remote_runner = job_remote_runner.JobRemoteRunner( | ||
type, project, location, gcp_resources | ||
) | ||
|
||
job_spec = { | ||
'display_name': job_display_name, | ||
'datasets': [dataset_name], | ||
'instruction_uri': instruction_uri, | ||
'inputs_schema_uri': inputs_schema_uri, | ||
'inputs': annotation_spec, | ||
'annotation_labels': { | ||
'aiplatform.googleapis.com/annotation_set_name': annotation_label | ||
}, | ||
'labeler_count': labeler_count, | ||
} | ||
|
||
try: | ||
# Create data labeling job if it does not exist | ||
job_name = remote_runner.check_if_job_exists() | ||
if job_name is None: | ||
job_name = remote_runner.create_job( | ||
create_data_labeling_job_with_client, | ||
json.dumps(job_spec), | ||
) | ||
|
||
# Poll data labeling job status until "JobState.JOB_STATE_SUCCEEDED" | ||
remote_runner.poll_job(get_data_labeling_job_with_client, job_name) | ||
except (ConnectionError, RuntimeError) as err: | ||
error_util.exit_with_internal_error(err.args[0]) |