From 9a2f3cb853fbd824a2ded3241719b7f8fdcf8f58 Mon Sep 17 00:00:00 2001 From: Shahriar Date: Mon, 17 Jun 2024 00:38:38 +0200 Subject: [PATCH] feat: Move queries in files + Add tests for creating, updating, and deleting labels --- omnivoreql/omnivoreql.py | 407 ++++++++------------------------------- tests/test_omnivoreql.py | 66 ++++++- 2 files changed, 144 insertions(+), 329 deletions(-) diff --git a/omnivoreql/omnivoreql.py b/omnivoreql/omnivoreql.py index 470fd4e..f1ec5f6 100644 --- a/omnivoreql/omnivoreql.py +++ b/omnivoreql/omnivoreql.py @@ -1,7 +1,11 @@ from gql import gql, Client from gql.transport.requests import RequestsHTTPTransport import uuid -from typing import List +from models import CreateLabelInput +from dataclasses import asdict +import os +import glob +from typing import List, Optional class OmnivoreQL: @@ -22,36 +26,38 @@ def __init__( use_json=True, ) self.client = Client(transport=transport, fetch_schema_from_transport=False) - - def save_url(self, url: str, labels: List[str] = None): + self.queries = self._load_queries("queries") + + def _load_queries(self, queries_path: str): + current_dir = os.path.dirname(os.path.abspath(__file__)) + queries_path = os.path.join(current_dir, queries_path) + queries = {} + for file in glob.glob(f"{queries_path}/*.graphql"): + with open(file, "r") as f: + queries[os.path.basename(file).replace(".graphql", "")] = ( + f.read().replace("\n", " ") + ) + return queries + + def save_url( + self, + url: str, + labels: List[str] = None, + clientRequestId: str = str(uuid.uuid4()), + ): """ Save a URL to Omnivore. :param url: The URL to save. :param labels: The labels to assign to the item. + :param clientRequestId: The client request ID. """ labels = [] if labels is None else [{"name": x} for x in labels] - mutation = gql( - """ - mutation SaveUrl($input: SaveUrlInput!) { - saveUrl(input: $input) { - ... on SaveSuccess { - url - clientRequestId - } - ... on SaveError { - errorCodes - message - } - } - } - """ - ) return self.client.execute( - mutation, + gql(self.queries["SaveUrl"]), variable_values={ "input": { - "clientRequestId": str(uuid.uuid4()), + "clientRequestId": clientRequestId, "source": "api", "url": url, "labels": labels, @@ -68,24 +74,8 @@ def save_page(self, url: str, original_content: str, labels: List[str] = None): :param labels: The labels to assign to the item. """ labels = [] if labels is None else [{"name": x} for x in labels] - mutation = gql( - """ - mutation SavePage($input: SavePageInput!) { - savePage(input: $input) { - ... on SaveSuccess { - url - clientRequestId - } - ... on SaveError { - errorCodes - message - } - } - } - """ - ) return self.client.execute( - mutation, + gql(self.queries["SavePage"]), variable_values={ "input": { "clientRequestId": str(uuid.uuid4()), @@ -101,83 +91,19 @@ def get_profile(self): """ Get the profile of the current user. """ - query = gql( - """ - query Viewer { - me { - id - name - isFullUser - profile { - id - username - pictureUrl - bio - } - } - } - """ - ) - return self.client.execute(query) + return self.client.execute(gql(self.queries["Viewer"])) def get_labels(self): """ Get the labels of the current user. """ - query = gql( - """ - query GetLabels { - labels { - ... on LabelsSuccess { - labels { - ...LabelFields - } - } - ... on LabelsError { - errorCodes - } - } - } - fragment LabelFields on Label { - id - name - color - description - createdAt - }""" - ) - return self.client.execute(query) + return self.client.execute(gql(self.queries["Labels"])) def get_subscriptions(self): """ Get the subscriptions of the current user. """ - query = gql( - """ - query GetSubscriptions { - subscriptions(sort: { by: UPDATED_TIME }) { - ... on SubscriptionsSuccess { - subscriptions { - id - name - newsletterEmail - url - description - status - unsubscribeMailTo - unsubscribeHttpUrl - createdAt - updatedAt - } - } - ... on SubscriptionsError { - errorCodes - } - } - } - """ - ) - return self.client.execute(query) + return self.client.execute(gql(self.queries["GetSubscriptions"])) def get_articles( self, @@ -196,104 +122,8 @@ def get_articles( :param query: The query to use for filtering articles. Example of query by date: 'in:inbox published:2024-03-01..*'. See https://docs.omnivore.app/using/search.html#filtering-by-save-publish-dates for more information. :param include_content: Whether to include the content of the articles. """ - q = gql( - """ - query Search($after: String, $first: Int, $query: String, $format: String, $includeContent: Boolean) { - search(after: $after, first: $first, query: $query, format: $format, includeContent: $includeContent) { - ... on SearchSuccess { - edges { - cursor - node { - id - title - slug - url - pageType - contentReader - createdAt - isArchived - readingProgressPercent - readingProgressTopPercent - readingProgressAnchorIndex - author - image - description - publishedAt - ownedByViewer - originalArticleUrl - uploadFileId - labels { - id - name - color - } - pageId - shortId - quote - annotation - state - siteName - subscription - readAt - savedAt - wordsCount - recommendations { - id - name - note - user { - userId - name - username - profileImageURL - } - recommendedAt - } - highlights { - ...HighlightFields - } - } - } - pageInfo { - hasNextPage - hasPreviousPage - startCursor - endCursor - totalCount - } - } - ... on SearchError { - errorCodes - } - } - } - - fragment HighlightFields on Highlight { - id - type - shortId - quote - prefix - suffix - patch - annotation - createdByMe - createdAt - updatedAt - sharedAt - highlightPositionPercent - highlightPositionAnchorIndex - labels { - id - name - color - createdAt - } - } - """ - ) return self.client.execute( - q, + gql(self.queries["Search"]), variable_values={ "first": limit, "after": cursor, @@ -311,100 +141,8 @@ def get_article(self, username: str, slug: str, format: str = None): :param slug: The slug of the article. :param format: The format of the article to return. """ - query = gql( - """ - query GetArticle($username: String!, $slug: String!, $format: String, $includeFriendsHighlights: Boolean) { - article(username: $username, slug: $slug, format: $format) { - ... on ArticleSuccess { - article { - ...ArticleFields - content - highlights(input: { includeFriends: $includeFriendsHighlights }) { - ...HighlightFields - } - labels { - ...LabelFields - } - recommendations { - ...RecommendationFields - } - } - } - ... on ArticleError { - errorCodes - } - } - } - - fragment ArticleFields on Article { - id - title - url - author - image - savedAt - createdAt - publishedAt - contentReader - originalArticleUrl - readingProgressPercent - readingProgressTopPercent - readingProgressAnchorIndex - slug - isArchived - description - linkId - state - wordsCount - } - - fragment HighlightFields on Highlight { - id - type - shortId - quote - prefix - suffix - patch - annotation - createdByMe - createdAt - updatedAt - sharedAt - highlightPositionPercent - highlightPositionAnchorIndex - labels { - id - name - color - createdAt - } - } - - fragment LabelFields on Label { - id - name - color - description - createdAt - } - - fragment RecommendationFields on Recommendation { - id - name - note - user { - userId - name - username - profileImageURL - } - recommendedAt - } - """ - ) return self.client.execute( - query, + gql(self.queries["ArticleContent"]), variable_values={ "username": username, "slug": slug, @@ -419,24 +157,8 @@ def archive_article(self, article_id: str, to_archive: bool = True): :param article_id: The ID of the article to archive. :param to_archive: Whether to archive or unarchive the article. """ - mutation = gql( - """ - mutation SetLinkArchived($input: ArchiveLinkInput!) { - setLinkArchived(input: $input) { - ... on ArchiveLinkSuccess { - linkId - message - } - ... on ArchiveLinkError { - errorCodes - message - } - } - } - """ - ) return self.client.execute( - mutation, + gql(self.queries["ArchiveSavedItem"]), variable_values={"input": {"linkId": article_id, "archived": to_archive}}, ) @@ -454,22 +176,53 @@ def delete_article(self, article_id: str): :param article_id: The ID of the article to delete. """ - mutation = gql( - """ - mutation SetBookmarkArticle($input: SetBookmarkArticleInput!) { - setBookmarkArticle(input: $input) { - ... on SetBookmarkArticleSuccess { - bookmarkedArticle { - id - } - } - ... on SetBookmarkArticleError { - errorCodes - } + q = self.queries["DeleteSavedItem"] + return self.client.execute( + gql(q), + variable_values={"input": {"articleID": article_id, "bookmark": False}}, + ) + + def create_label(self, label: CreateLabelInput): + """ + Create a new label using a dataclass for input. + + :param label: An instance of LabelInput with the label data. + """ + return self.client.execute( + gql(self.queries["CreateLabel"]), + variable_values={"input": asdict(label)}, + ) + + def update_label( + self, label_id: str, name: str, color: str, description: str = None + ): + """ + Update a label. + + :param label_id: The ID of the label to update. + :param name: The name of the label. + :param color: The color of the label. + :param description: The description of the label. + """ + return self.client.execute( + gql(self.queries["UpdateLabel"]), + variable_values={ + "input": { + "labelId": label_id, + "name": name, + "color": color, + "description": description, } - }""" + }, ) + + def delete_label_by_id(self, label_id: str): + """ + Delete a label. + + :param label_id: The ID of the label to delete. + """ return self.client.execute( - mutation, - variable_values={"input": {"articleID": article_id, "bookmark": False}}, + gql(self.queries["DeleteLabel"]), + variable_values={"input": label_id}, ) diff --git a/tests/test_omnivoreql.py b/tests/test_omnivoreql.py index be81720..1e760ee 100644 --- a/tests/test_omnivoreql.py +++ b/tests/test_omnivoreql.py @@ -7,7 +7,8 @@ current_dir = os.path.dirname(os.path.abspath(__file__)) omnivoreql_dir = os.path.join(current_dir, "..", "omnivoreql") sys.path.insert(0, omnivoreql_dir) -from omnivoreql import OmnivoreQL +from omnivoreql import OmnivoreQL, CreateLabelInput + class TestOmnivoreQL(unittest.TestCase): client = None @@ -68,7 +69,7 @@ def test_save_url_with_labels(self): result = self.client.save_url("https://www.google.com", ["test", "google"]) # Then self.assertIsNotNone(result) - self.assertFalse('errorCodes' in result['saveUrl']) + self.assertFalse("errorCodes" in result["saveUrl"]) def test_get_articles(self): # When @@ -125,6 +126,67 @@ def test_delete_article(self): self.assertIsNotNone(result) self.assertIsNotNone(result["setBookmarkArticle"]["bookmarkedArticle"]["id"]) + def test_create_label(self): + # Given + label_name = hash("TestLabel") # create random label name to avoid conflicts + label_input = CreateLabelInput( + name=str(label_name), color="#FF0000", description="label description" + ) + # When + result = self.client.create_label(label_input) + # Then + self.assertIsNotNone(result) + self.assertNotIn("errorCodes", result["createLabel"]) + self.assertEqual(result["createLabel"]["label"]["name"], label_input.name) + self.assertEqual(result["createLabel"]["label"]["color"], label_input.color) + self.assertEqual( + result["createLabel"]["label"]["description"], label_input.description + ) + + def test_update_label(self): + # Given + label_input = CreateLabelInput(name=hash("TestLabel"), color="#FF0000") + created_label = self.client.create_label(label_input) + # When + new_label_name = f"UpdatedLabel-{label_input.name}" + result = self.client.update_label( + label_id=created_label["createLabel"]["label"]["id"], + name=new_label_name, + color="#0000FF", + description="An updated TestLabel", + ) + # Then + self.assertIsNotNone(result) + self.assertNotIn("errorCodes", result["updateLabel"]) + self.assertEqual(result["updateLabel"]["label"]["name"], new_label_name) + self.assertEqual(result["updateLabel"]["label"]["color"], "#0000FF") + self.assertEqual( + result["updateLabel"]["label"]["description"], "An updated TestLabel" + ) + + def test_delete_label_by_id(self): + # Given + label_input = CreateLabelInput(name=hash("TestLabel"), color="#FF0000") + created_label = self.client.create_label(label_input) + # When + result = self.client.delete_label_by_id( + created_label["createLabel"]["label"]["id"] + ) + # Then + self.assertIsNotNone(result) + self.assertNotIn("errorCodes", result["deleteLabel"]) + self.assertEqual( + result["deleteLabel"]["label"]["id"], + created_label["createLabel"]["label"]["id"], + ) + + def test_get_all_labels_and_delete_them(self): + # Given + labels = self.client.get_labels()["labels"] + # When + for label in labels["labels"]: + self.client.delete_label_by_id(label["id"]) + if __name__ == "__main__": unittest.main()