diff --git a/littlepay/api/__init__.py b/littlepay/api/__init__.py index 1a0627e..9650d97 100644 --- a/littlepay/api/__init__.py +++ b/littlepay/api/__init__.py @@ -99,3 +99,22 @@ def _post(self, endpoint: str, data: dict, response_cls: TResponse = dict, **kwa A TResponse instance of the JSON response. """ pass + + def _put(self, endpoint: str, data: dict, response_cls: TResponse = ListResponse, **kwargs) -> TResponse: + """Make a PUT request to a JSON endpoint. + + Args: + self (ClientProtocol): The current ClientProtocol reference. + + endpoint (str): The fully-formed endpoint where the PUT request should be made. + + data (dict): Data to send as JSON in the PUT body. + + response_cls (TResponse): A dataclass representing the JSON response to the PUT. By default, returns a ListResponse. # noqa + + Extra kwargs are passed to requests.put(...) + + Returns (TResponse): + A TResponse instance of the PUT response. + """ + pass diff --git a/littlepay/api/client.py b/littlepay/api/client.py index c5ee380..d37f62a 100644 --- a/littlepay/api/client.py +++ b/littlepay/api/client.py @@ -158,3 +158,13 @@ def _post(self, endpoint: str, data: dict, response_cls: TResponse = dict, **kwa except json.JSONDecodeError: data = {"status_code": response.status_code} return response_cls(**data) + + def _put(self, endpoint: str, data: dict, response_cls: TResponse = ListResponse, **kwargs) -> TResponse: + response = self.oauth.put(endpoint, headers=self.headers, json=data, **kwargs) + response.raise_for_status() + try: + # response body may be empty, cannot be decoded + data = response.json() + except json.JSONDecodeError: + data = {"status_code": response.status_code} + return response_cls(**data) diff --git a/littlepay/api/groups.py b/littlepay/api/groups.py index 5a925a4..b44da07 100644 --- a/littlepay/api/groups.py +++ b/littlepay/api/groups.py @@ -2,7 +2,7 @@ from datetime import datetime, timezone from typing import Generator -from littlepay.api import ClientProtocol +from littlepay.api import ClientProtocol, ListResponse from littlepay.api.funding_sources import FundingSourcesMixin @@ -25,6 +25,37 @@ def csv_header() -> str: return ",".join(vars(instance).keys()) +@dataclass +class GroupFundingSourceResponse: + id: str + participant_id: str + concession_expiry: datetime | None = None + concession_created_at: datetime | None = None + concession_updated_at: datetime | None = None + + def __post_init__(self): + """Parses any date parameters into Python datetime objects. + + Includes a workaround for Python 3.10 where datetime.fromisoformat() can only parse the format output + by datetime.isoformat(), i.e. without a trailing 'Z' offset character and with UTC offset expressed + as +/-HH:mm + + https://docs.python.org/3.11/library/datetime.html#datetime.datetime.fromisoformat + """ + if self.concession_expiry: + self.concession_expiry = datetime.fromisoformat(self.concession_expiry.replace("Z", "+00:00", 1)) + else: + self.concession_expiry = None + if self.concession_created_at: + self.concession_created_at = datetime.fromisoformat(self.concession_created_at.replace("Z", "+00:00", 1)) + else: + self.concession_created_at = None + if self.concession_updated_at: + self.concession_updated_at = datetime.fromisoformat(self.concession_updated_at.replace("Z", "+00:00", 1)) + else: + self.concession_updated_at = None + + class GroupsMixin(ClientProtocol): """Mixin implements APIs for concession groups.""" @@ -86,4 +117,13 @@ def link_concession_group_funding_source( return self._post(endpoint, data, dict) - return self._post(endpoint, data, dict) + def update_concession_group_funding_source_expiry( + self, group_id: str, funding_source_id: str, concession_expiry: datetime + ) -> GroupFundingSourceResponse: + """Update the expiry of a funding source already linked to a concession group.""" + endpoint = self.concession_group_funding_source_endpoint(group_id) + data = {"id": funding_source_id, "concession_expiry": self._format_concession_expiry(concession_expiry)} + + response = self._put(endpoint, data, ListResponse) + + return GroupFundingSourceResponse(**response.list[0]) diff --git a/tests/api/test_client.py b/tests/api/test_client.py index 064b44f..1101d27 100644 --- a/tests/api/test_client.py +++ b/tests/api/test_client.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +import dataclasses from json import JSONDecodeError import time from typing import Callable, Generator, TypeAlias @@ -42,7 +42,7 @@ def mock_active_Config(mocker, credentials, token, url): return config -@dataclass +@dataclasses.dataclass class SampleResponse: one: str two: str @@ -54,11 +54,6 @@ def SampleResponse_json(): return {"one": "single", "two": "double", "three": 3} -@pytest.fixture -def ListResponse_sample(): - return ListResponse(list=[{"one": 1}, {"two": 2}, {"three": 3}], total_count=3) - - @pytest.fixture def default_list_params(): return dict(page=1, perPage=100) @@ -372,3 +367,73 @@ def test_Client_post_error_status(mocker, make_client: ClientFunc, url): client._post(url, data, dict) req_spy.assert_called_once_with(url, headers=client.headers, json=data) + + +def test_Client_put(mocker, make_client: ClientFunc, url, SampleResponse_json): + client = make_client() + mock_response = mocker.Mock( + raise_for_status=mocker.Mock(return_value=False), json=mocker.Mock(return_value=SampleResponse_json) + ) + req_spy = mocker.patch.object(client.oauth, "put", return_value=mock_response) + + data = {"data": "123"} + result = client._put(url, data, SampleResponse) + + req_spy.assert_called_once_with(url, headers=client.headers, json=data) + assert isinstance(result, SampleResponse) + assert result.one == "single" + assert result.two == "double" + assert result.three == 3 + + +def test_Client_put_default_cls(mocker, make_client: ClientFunc, url, ListResponse_sample): + client = make_client() + mock_response = mocker.Mock( + raise_for_status=mocker.Mock(return_value=False), + json=mocker.Mock(return_value=dataclasses.asdict(ListResponse_sample)), + ) + req_spy = mocker.patch.object(client.oauth, "put", return_value=mock_response) + + data = {"data": "123"} + result = client._put(url, data) + + req_spy.assert_called_once_with(url, headers=client.headers, json=data) + assert isinstance(result, ListResponse) + assert result.total_count == ListResponse_sample.total_count + assert len(result.list) == len(ListResponse_sample.list) + + for list_item in result.list: + assert list_item == ListResponse_sample.list[result.list.index(list_item)] + + +def test_Client_put_empty_response(mocker, make_client: ClientFunc, url): + client = make_client() + mock_response = mocker.Mock( + # json() throws a JSONDecodeError, simulating an empty response + json=mocker.Mock(side_effect=JSONDecodeError("msg", "doc", 0)), + # raise_for_status() returns None + raise_for_status=mocker.Mock(return_value=False), + # fake a 201 status_code + status_code=201, + ) + req_spy = mocker.patch.object(client.oauth, "put", return_value=mock_response) + + data = {"data": "123"} + + result = client._put(url, data, dict) + + req_spy.assert_called_once_with(url, headers=client.headers, json=data) + assert result == {"status_code": 201} + + +def test_Client_put_error_status(mocker, make_client: ClientFunc, url): + client = make_client() + mock_response = mocker.Mock(raise_for_status=mocker.Mock(side_effect=HTTPError)) + req_spy = mocker.patch.object(client.oauth, "put", return_value=mock_response) + + data = {"data": "123"} + + with pytest.raises(HTTPError): + client._put(url, data, dict) + + req_spy.assert_called_once_with(url, headers=client.headers, json=data) diff --git a/tests/api/test_groups.py b/tests/api/test_groups.py index 8d7c696..efb873a 100644 --- a/tests/api/test_groups.py +++ b/tests/api/test_groups.py @@ -3,7 +3,30 @@ import pytest -from littlepay.api.groups import GroupResponse, GroupsMixin +from littlepay.api import ListResponse +from littlepay.api.groups import GroupFundingSourceResponse, GroupResponse, GroupsMixin + + +@pytest.fixture +def ListResponse_GroupFundingSources(): + items = [ + dict( + id="0", + participant_id="zero_0", + concession_expiry="2024-03-19T20:00:00Z", + concession_created_at="2024-03-19T20:00:00Z", + concession_updated_at="2024-03-19T20:00:00Z", + ), + dict( + id="1", + participant_id="one_1", + concession_expiry="2024-03-19T20:00:00Z", + concession_created_at="2024-03-19T20:00:00Z", + concession_updated_at="2024-03-19T20:00:00Z", + ), + dict(id="2", participant_id="two_2", concession_expiry="", concession_created_at=""), + ] + return ListResponse(list=items, total_count=3) @pytest.fixture @@ -28,6 +51,13 @@ def mock_ClientProtocol_post_link_concession_group_funding_source(mocker): return mocker.patch("littlepay.api.ClientProtocol._post", side_effect=lambda *args, **kwargs: response) +@pytest.fixture +def mock_ClientProtocol_put_update_concession_group_funding_source(mocker, ListResponse_GroupFundingSources): + return mocker.patch( + "littlepay.api.ClientProtocol._put", side_effect=lambda *args, **kwargs: ListResponse_GroupFundingSources + ) + + def test_GroupResponse_csv(): group = GroupResponse("id", "label", "participant") assert group.csv() == "id,label,participant" @@ -40,6 +70,42 @@ def test_GroupResponse_csv_header(): assert GroupResponse.csv_header() == "id,label,participant_id" +def test_GroupFundingSourceResponse_no_dates(): + response = GroupFundingSourceResponse("id", "participant_id") + + assert response.id == "id" + assert response.participant_id == "participant_id" + assert response.concession_expiry is None + assert response.concession_created_at is None + assert response.concession_updated_at is None + + +def test_GroupFundingSourceResponse_empty_dates(): + response = GroupFundingSourceResponse("id", "participant_id", "", "", "") + + assert response.id == "id" + assert response.participant_id == "participant_id" + assert response.concession_expiry is None + assert response.concession_created_at is None + assert response.concession_updated_at is None + + +def test_GroupFundingSourceResponse_with_dates(): + str_date = "2024-03-19T20:00:00Z" + expected_date = datetime(2024, 3, 19, 20, 0, 0, tzinfo=timezone.utc) + + response = GroupFundingSourceResponse("id", "participant_id", str_date, str_date, str_date) + + assert response.id == "id" + assert response.participant_id == "participant_id" + assert response.concession_expiry == expected_date + assert response.concession_expiry.tzinfo == timezone.utc + assert response.concession_created_at == expected_date + assert response.concession_created_at.tzinfo == timezone.utc + assert response.concession_updated_at == expected_date + assert response.concession_updated_at.tzinfo == timezone.utc + + def test_GroupsMixin_concession_groups_endpoint(url): client = GroupsMixin() @@ -171,3 +237,20 @@ def test_GroupsMixin_link_concession_group_funding_source_expiry( endpoint, {"id": "funding-source-1234", "concession_expiry": "formatted concession expiry"}, dict ) assert result == {"status_code": 201} + + +def test_GroupsMixin_update_concession_group_funding_source_expiry( + mock_ClientProtocol_put_update_concession_group_funding_source, ListResponse_GroupFundingSources, mocker +): + client = GroupsMixin() + mocker.patch.object(client, "_format_concession_expiry", return_value="formatted concession expiry") + + result = client.update_concession_group_funding_source_expiry("group-1234", "funding-source-1234", datetime.now()) + + endpoint = client.concession_group_funding_source_endpoint("group-1234") + mock_ClientProtocol_put_update_concession_group_funding_source.assert_called_once_with( + endpoint, {"id": "funding-source-1234", "concession_expiry": "formatted concession expiry"}, ListResponse + ) + + expected = GroupFundingSourceResponse(**ListResponse_GroupFundingSources.list[0]) + assert result == expected diff --git a/tests/conftest.py b/tests/conftest.py index a334424..8954812 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ from pytest_socket import disable_socket from littlepay import __version__ +from littlepay.api import ListResponse import littlepay.config from littlepay.commands import RESULT_SUCCESS @@ -140,3 +141,8 @@ def mock_ClientProtocol_make_endpoint(mocker, url): mocker.patch( "littlepay.api.ClientProtocol._make_endpoint", side_effect=lambda *args: f"{url}/{'/'.join([a for a in args if a])}" ) + + +@pytest.fixture +def ListResponse_sample(): + return ListResponse(list=[{"one": 1}, {"two": 2}, {"three": 3}], total_count=3)