Skip to content

Commit

Permalink
feat: finally impl set_page_labels_by_labels
Browse files Browse the repository at this point in the history
  • Loading branch information
yazdipour committed Jun 18, 2024
1 parent e6f741e commit fb2cec0
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 29 deletions.
57 changes: 37 additions & 20 deletions omnivoreql/omnivoreql.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(
self.client = Client(transport=transport, fetch_schema_from_transport=False)
self.queries = {}


def _get_query(self, query_name: str) -> str:
if query_name not in self.queries:
current_dir = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -223,41 +222,59 @@ def delete_label(self, label_id: str):
variable_values={"id": label_id},
)

def set_labels(
self,
page_id: str,
labels: Optional[List[CreateLabelInput]] = None,
label_ids: Optional[List[str]] = None,
source: str = "api",
def set_page_labels_by_create_label_inputs(
self, page_id: str, labels: List[CreateLabelInput]
) -> dict:
"""
Set labels for a page.
:param page_id: The ID of the page to set labels for.
:param label_ids: The IDs of the labels to set.
:param labels: The labels to set.
:param source: The source of the call.
"""
parsed_labels = (
[
return self.set_page_labels_by_labels(page_id, parsed_labels)

def set_page_labels_by_labels(self, page_id: str, labels: List[dict]) -> dict:
"""
Set labels for a page.
:param page_id: The ID of the page to set labels for.
:param labels: The labels to set.
"""
parsed_labels = []
for label in labels:
if isinstance(label, CreateLabelInput):
label = asdict(label)
parsed_labels.append(
{
"name": label.name,
"color": label.color,
"description": label.description,
"name": label["name"],
"color": label["color"],
"description": label["description"],
}
for label in labels
]
if labels
else None
)

return self.client.execute(
self._get_query("ApplyLabels"),
variable_values={
"input": {
"pageId": page_id,
"labels": parsed_labels,
}
},
)

def set_page_labels_by_label_ids(self, page_id: str, label_ids: List[str]) -> dict:
"""
Set labels for a page.
:param page_id: The ID of the page to set labels for.
:param label_ids: The IDs of the labels to set.
"""
return self.client.execute(
self._get_query("ApplyLabels"),
variable_values={
"input": {
"pageId": page_id,
"labelIds": label_ids,
"labels": parsed_labels,
"source": source,
}
},
)
48 changes: 39 additions & 9 deletions tests/test_omnivoreql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
To run the tests, execute the following command:
python -m unittest discover -s tests
"""


class TestOmnivoreQL(unittest.TestCase):
client = None

Expand Down Expand Up @@ -169,9 +171,7 @@ def test_delete_label(self):
label_input = CreateLabelInput(name=hash("TestLabel"), color="#FF0000")
created_label = self.client.create_label(label_input)
# When
result = self.client.delete_label(
created_label["createLabel"]["label"]["id"]
)
result = self.client.delete_label(created_label["createLabel"]["label"]["id"])
# Then
self.assertIsNotNone(result)
self.assertNotIn("errorCodes", result["deleteLabel"])
Expand All @@ -188,18 +188,48 @@ def test_clean_up_created_labels(self):
except Exception as e:
print(f"Error cleaning up labels: {e}")

def test_set_labels(self):
def test_set_page_labels_by_labels(self):
# Given
page = self.client.get_articles(limit=1)["search"]["edges"][0]["node"]
label_input = CreateLabelInput(name=hash("TestLabel"), color="#FF0000")
created_label = self.client.create_label(label_input)["createLabel"]["label"]
# When
result = self.client.set_page_labels_by_labels(page["id"], [created_label])
# Then
self.assertIsNotNone(result)
self.assertNotIn("errorCodes", result["setLabels"])
self.assertEqual(result["setLabels"]["labels"][0]["id"], created_label["id"])

def test_set_page_labels_by_create_label_inputs(self):
# Given
page = self.client.get_articles(limit=1)["search"]["edges"][0]["node"]
page_id = page["id"]
label_ids = [self.client.get_labels()["labels"]["labels"][0]["id"]]
label_input = CreateLabelInput(name=hash("TestLabel"), color="#FF0000")
created_label = self.client.create_label(label_input)["createLabel"]["label"]
created_label_input = CreateLabelInput(
created_label["name"],
created_label["color"],
created_label["description"],
)
# When
result = self.client.set_labels(page_id, label_ids=label_ids)
result = self.client.set_page_labels_by_labels(page["id"], [created_label])
# Then
self.assertIsNotNone(result)
self.assertNotIn("errorCodes", result["setLabels"])
self.assertEqual(result["setLabels"]["labels"][0]["id"], created_label["id"])

def test_set_page_labels_by_label_ids(self):
# Given
page = self.client.get_articles(limit=1)["search"]["edges"][0]["node"]
label_input = CreateLabelInput(name=hash("TestLabel"), color="#FF0000")
created_label = self.client.create_label(label_input)["createLabel"]["label"]
# When
result = self.client.set_page_labels_by_label_ids(
page["id"], label_ids=[created_label["id"]]
)
# Then
self.assertIsNotNone(result)
self.assertNotIn("errorCodes", result["setLabels"])
self.assertEqual(result["setLabels"]["pageId"], page_id)
self.assertListEqual(result["setLabels"]["labelIds"], label_ids)
self.assertEqual(result["setLabels"]["labels"][0]["id"], created_label["id"])


if __name__ == "__main__":
Expand Down

0 comments on commit fb2cec0

Please sign in to comment.