Skip to content

Commit

Permalink
added test for rest_client
Browse files Browse the repository at this point in the history
  • Loading branch information
ZohebShaikh committed Oct 16, 2024
1 parent 73e4030 commit be76510
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 22 deletions.
38 changes: 18 additions & 20 deletions src/blueapi/client/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ def __init__(
cliAuthConfig: CLIAuthConfig | None = None,
) -> None:
self._config = config or RestConfig()
self._oauth_config = authConfig
self._cli_auth_config = cliAuthConfig
self._tokenHandler = None
if authConfig and cliAuthConfig:
self._tokenHandler = TokenManager(authConfig, cliAuthConfig)

def get_plans(self) -> PlanResponse:
return self._request_and_deserialize("/plans", PlanResponse)
Expand Down Expand Up @@ -139,24 +140,21 @@ def _request_and_deserialize(
headers = {
"content-type": "application/json; charset=UTF-8",
}
if self._oauth_config and self._cli_auth_config:
jwt_token_manager = TokenManager(self._oauth_config, self._cli_auth_config)
if jwt_token_manager.token and jwt_token_manager.token["access_token"]:
try:
valid_token = jwt_token_manager.authenticator.verify_token(
jwt_token_manager.token["access_token"]
)
if valid_token:
access_token = jwt_token_manager.token["access_token"]
headers["Authorization"] = f"Bearer {access_token}"
else:
raise jwt.ExpiredSignatureError
except jwt.ExpiredSignatureError:
if jwt_token_manager.refresh_auth_token():
access_token = jwt_token_manager.token["access_token"]
headers["Authorization"] = f"Bearer {access_token}"
except Exception:
pass
if (
self._tokenHandler
and self._tokenHandler.token
and self._tokenHandler.token["access_token"]
):
try:
access_token = self._tokenHandler.token["access_token"]
self._tokenHandler.authenticator.verify_token(access_token)
headers["Authorization"] = f"Bearer {access_token}"
except jwt.ExpiredSignatureError:
if self._tokenHandler.refresh_auth_token():
access_token = self._tokenHandler.token["access_token"]
headers["Authorization"] = f"Bearer {access_token}"
except Exception:
pass
if data:
response = requests.request(method, url, json=data, headers=headers)
else:
Expand Down
90 changes: 90 additions & 0 deletions tests/unit_tests/client/test_rest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,43 @@
from unittest.mock import Mock, patch

import jwt
import pytest
import responses
from pydantic import BaseModel

from blueapi.client.rest import BlueapiRestClient, BlueskyRemoteControlError
from blueapi.config import CLIAuthConfig, OauthConfig
from blueapi.core.bluesky_types import Plan
from blueapi.service.model import PlanModel, PlanResponse


@pytest.fixture
def rest() -> BlueapiRestClient:
return BlueapiRestClient()


@pytest.fixture
@responses.activate
def rest_with_auth() -> BlueapiRestClient:
responses.add(
responses.GET,
"http://example.com",
json={
"device_authorization_endpoint": "https://example.com/device_authorization",
"authorization_endpoint": "https://example.com/authorization",
"token_endpoint": "https://example.com/token",
"issuer": "https://example.com",
"jwks_uri": "https://example.com/realms/master/protocol/openid-connect/certs",
"end_session_endpoint": "https://example.com/logout",
},
status=200,
)
return BlueapiRestClient(
cliAuthConfig=CLIAuthConfig(client_id="foo", client_audience="bar"),
authConfig=OauthConfig(oidc_config_url="http://example.com"),
)


@pytest.mark.parametrize(
"code,expected_exception",
[
Expand All @@ -30,3 +58,65 @@ def test_rest_error_code(
mock_request.return_value = response
with pytest.raises(expected_exception):
rest.get_plans()


class MyModel(BaseModel):
id: str


@responses.activate
def test_auth_request_functionality(rest_with_auth: BlueapiRestClient):
plan = Plan(name="my-plan", model=MyModel)
responses.add(
responses.GET,
"http://localhost:8000/plans",
json=PlanResponse(plans=[PlanModel.from_plan(plan)]).model_dump(),
status=200,
)
with patch("blueapi.service.Authenticator.verify_token") as mock_verify_token:
# Mock the verify_token function to return True (indicating a valid token)
mock_verify_token.return_value = True

result = rest_with_auth.get_plans()
# Add assertions as needed
assert result == PlanResponse(plans=[PlanModel.from_plan(plan)])


@responses.activate
def test_refresh_if_signature_expired(rest_with_auth: BlueapiRestClient):
plan = Plan(name="my-plan", model=MyModel)
responses.add(
responses.GET,
"http://localhost:8000/plans",
json=PlanResponse(plans=[PlanModel.from_plan(plan)]).model_dump(),
status=200,
)
with (
patch("blueapi.service.Authenticator.verify_token") as mock_verify_token,
patch("blueapi.service.TokenManager.refresh_auth_token") as mock_refresh_token,
):
# Mock the verify_token function to return True (indicating a valid token)
mock_verify_token.side_effect = jwt.ExpiredSignatureError
mock_refresh_token.return_value = True
result = rest_with_auth.get_plans()
# Add assertions as needed
assert result == PlanResponse(plans=[PlanModel.from_plan(plan)])


@responses.activate
def test_verify_token_ignore_other_exceptions(rest_with_auth: BlueapiRestClient):
plan = Plan(name="my-plan", model=MyModel)
responses.add(
responses.GET,
"http://localhost:8000/plans",
json=PlanResponse(plans=[PlanModel.from_plan(plan)]).model_dump(),
status=200,
)
with (
patch("blueapi.service.Authenticator.verify_token") as mock_verify_token,
):
# Mock the verify_token function to return True (indicating a valid token)
mock_verify_token.side_effect = Exception
result = rest_with_auth.get_plans()
# Add assertions as needed
assert result == PlanResponse(plans=[PlanModel.from_plan(plan)])
39 changes: 37 additions & 2 deletions tests/unit_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,13 +721,20 @@ def test_login_success(runner: CliRunner, valid_auth_config: str):

@responses.activate
def test_logout_success(runner: CliRunner, valid_auth_config: str, tmp_path: Path):
with open(tmp_path / "token", "w+") as token_file:
with open(tmp_path / "token", "w") as token_file:
# base64 encoded token
token_file.write(base64.b64encode(b'{"access_token":"token"}').decode("utf-8"))
response = responses.add(
responses.GET,
"https://auth.example.com/realms/sample/.well-known/openid-configuration",
json=EnvironmentResponse(initialized=False).model_dump(),
json={
"device_authorization_endpoint": "https://example.com/device_authorization",
"authorization_endpoint": "https://example.com/authorization",
"token_endpoint": "https://example.com/token",
"issuer": "https://example.com",
"jwks_uri": "https://example.com/realms/master/protocol/openid-connect/certs",
"end_session_endpoint": "https://example.com/logout",
},
status=200,
)
assert tmp_path.joinpath("token").exists() is True
Expand All @@ -736,3 +743,31 @@ def test_logout_success(runner: CliRunner, valid_auth_config: str, tmp_path: Pat
assert result.exit_code == 0
assert response.call_count == 1
assert tmp_path.joinpath("token").exists() is False


@responses.activate
def test_controller_plan_with_auth(runner: CliRunner, valid_auth_config):
with responses.RequestsMock(assert_all_requests_are_fired=True) as requests_mock:
plan = Plan(name="my-plan", model=MyModel)
requests_mock.add(
responses.GET,
"https://auth.example.com/realms/sample/.well-known/openid-configuration",
json={
"device_authorization_endpoint": "https://example.com/device_authorization",
"authorization_endpoint": "https://example.com/authorization",
"token_endpoint": "https://example.com/token",
"issuer": "https://example.com",
"jwks_uri": "https://example.com/realms/master/protocol/openid-connect/certs",
"end_session_endpoint": "https://example.com/logout",
},
status=200,
)
requests_mock.add(
responses.GET,
"http://localhost:8000/plans",
json=PlanResponse(plans=[PlanModel.from_plan(plan)]).model_dump(),
status=200,
)
result = runner.invoke(main, ["-c", valid_auth_config, "controller", "plans"])
assert result.exit_code == 0
assert "Please login to access the plans" in result.output

0 comments on commit be76510

Please sign in to comment.