diff --git a/omnivoreql/omnivoreql.py b/omnivoreql/omnivoreql.py index 7c6bc30..c028c85 100644 --- a/omnivoreql/omnivoreql.py +++ b/omnivoreql/omnivoreql.py @@ -27,7 +27,6 @@ def __init__( self.client = Client(transport=transport, fetch_schema_from_transport=False) self.queries = {} - def _get_query(self, query_name: str) -> str: if query_name not in self.queries: current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -223,41 +222,59 @@ def delete_label(self, label_id: str): variable_values={"id": label_id}, ) - def set_labels( - self, - page_id: str, - labels: Optional[List[CreateLabelInput]] = None, - label_ids: Optional[List[str]] = None, - source: str = "api", + def set_page_labels_by_create_label_inputs( + self, page_id: str, labels: List[CreateLabelInput] ) -> dict: """ Set labels for a page. :param page_id: The ID of the page to set labels for. - :param label_ids: The IDs of the labels to set. :param labels: The labels to set. - :param source: The source of the call. """ - parsed_labels = ( - [ + return self.set_page_labels_by_labels(page_id, parsed_labels) + + def set_page_labels_by_labels(self, page_id: str, labels: List[dict]) -> dict: + """ + Set labels for a page. + + :param page_id: The ID of the page to set labels for. + :param labels: The labels to set. + """ + parsed_labels = [] + for label in labels: + if isinstance(label, CreateLabelInput): + label = asdict(label) + parsed_labels.append( { - "name": label.name, - "color": label.color, - "description": label.description, + "name": label["name"], + "color": label["color"], + "description": label["description"], } - for label in labels - ] - if labels - else None + ) + + return self.client.execute( + self._get_query("ApplyLabels"), + variable_values={ + "input": { + "pageId": page_id, + "labels": parsed_labels, + } + }, ) + + def set_page_labels_by_label_ids(self, page_id: str, label_ids: List[str]) -> dict: + """ + Set labels for a page. + + :param page_id: The ID of the page to set labels for. + :param label_ids: The IDs of the labels to set. + """ return self.client.execute( self._get_query("ApplyLabels"), variable_values={ "input": { "pageId": page_id, "labelIds": label_ids, - "labels": parsed_labels, - "source": source, } }, ) diff --git a/tests/test_omnivoreql.py b/tests/test_omnivoreql.py index 3f1f4a5..7831983 100644 --- a/tests/test_omnivoreql.py +++ b/tests/test_omnivoreql.py @@ -10,6 +10,8 @@ To run the tests, execute the following command: python -m unittest discover -s tests """ + + class TestOmnivoreQL(unittest.TestCase): client = None @@ -169,9 +171,7 @@ def test_delete_label(self): label_input = CreateLabelInput(name=hash("TestLabel"), color="#FF0000") created_label = self.client.create_label(label_input) # When - result = self.client.delete_label( - created_label["createLabel"]["label"]["id"] - ) + result = self.client.delete_label(created_label["createLabel"]["label"]["id"]) # Then self.assertIsNotNone(result) self.assertNotIn("errorCodes", result["deleteLabel"]) @@ -188,18 +188,48 @@ def test_clean_up_created_labels(self): except Exception as e: print(f"Error cleaning up labels: {e}") - def test_set_labels(self): + def test_set_page_labels_by_labels(self): + # Given + page = self.client.get_articles(limit=1)["search"]["edges"][0]["node"] + label_input = CreateLabelInput(name=hash("TestLabel"), color="#FF0000") + created_label = self.client.create_label(label_input)["createLabel"]["label"] + # When + result = self.client.set_page_labels_by_labels(page["id"], [created_label]) + # Then + self.assertIsNotNone(result) + self.assertNotIn("errorCodes", result["setLabels"]) + self.assertEqual(result["setLabels"]["labels"][0]["id"], created_label["id"]) + + def test_set_page_labels_by_create_label_inputs(self): # Given page = self.client.get_articles(limit=1)["search"]["edges"][0]["node"] - page_id = page["id"] - label_ids = [self.client.get_labels()["labels"]["labels"][0]["id"]] + label_input = CreateLabelInput(name=hash("TestLabel"), color="#FF0000") + created_label = self.client.create_label(label_input)["createLabel"]["label"] + created_label_input = CreateLabelInput( + created_label["name"], + created_label["color"], + created_label["description"], + ) # When - result = self.client.set_labels(page_id, label_ids=label_ids) + result = self.client.set_page_labels_by_labels(page["id"], [created_label]) + # Then + self.assertIsNotNone(result) + self.assertNotIn("errorCodes", result["setLabels"]) + self.assertEqual(result["setLabels"]["labels"][0]["id"], created_label["id"]) + + def test_set_page_labels_by_label_ids(self): + # Given + page = self.client.get_articles(limit=1)["search"]["edges"][0]["node"] + label_input = CreateLabelInput(name=hash("TestLabel"), color="#FF0000") + created_label = self.client.create_label(label_input)["createLabel"]["label"] + # When + result = self.client.set_page_labels_by_label_ids( + page["id"], label_ids=[created_label["id"]] + ) # Then self.assertIsNotNone(result) self.assertNotIn("errorCodes", result["setLabels"]) - self.assertEqual(result["setLabels"]["pageId"], page_id) - self.assertListEqual(result["setLabels"]["labelIds"], label_ids) + self.assertEqual(result["setLabels"]["labels"][0]["id"], created_label["id"]) if __name__ == "__main__":