diff --git a/industry_game/handlers/games/lobby/read_lobby.py b/industry_game/handlers/games/lobby/read_lobby.py index 5c5b585..31c86e9 100644 --- a/industry_game/handlers/games/lobby/read_lobby.py +++ b/industry_game/handlers/games/lobby/read_lobby.py @@ -1,4 +1,4 @@ -from aiohttp.web import HTTPNotFound, Response, View +from aiohttp.web import Response, View from industry_game.utils.http.auth.base import ( AuthMixin, @@ -7,6 +7,7 @@ from industry_game.utils.http.deps import DependenciesMixin from industry_game.utils.http.params import parse_path_param from industry_game.utils.http.response import msgspec_json_response +from industry_game.utils.lobby.models import LobbyStatus, LobbyStatusType class ReadGameUserLobbyHandler(View, DependenciesMixin, AuthMixin): @@ -14,7 +15,12 @@ class ReadGameUserLobbyHandler(View, DependenciesMixin, AuthMixin): async def get(self) -> Response: game_id = parse_path_param(self.request, "game_id", int) - game = await self.game_storage.read_by_id(game_id=game_id) - if game is None: - raise HTTPNotFound - return msgspec_json_response(game) + lobby = await self.lobby_storage.read_by_id( + game_id=game_id, + user_id=self.user.id, + ) + if lobby is None: + status = LobbyStatusType.NOT_CHECKED_IN + else: + status = LobbyStatusType.CHECKED_IN + return msgspec_json_response(LobbyStatus(status=status)) diff --git a/industry_game/utils/lobby/models.py b/industry_game/utils/lobby/models.py index 892647c..6935446 100644 --- a/industry_game/utils/lobby/models.py +++ b/industry_game/utils/lobby/models.py @@ -1,3 +1,5 @@ +from enum import StrEnum, unique + import msgspec from industry_game.db.models import UserGameLobby as UserGameLobbyDb @@ -20,3 +22,13 @@ def from_model(cls, obj: UserGameLobbyDb) -> "Lobby": class LobbyPagination(msgspec.Struct, frozen=True): meta: MetaPagination items: list[ShortUser] + + +@unique +class LobbyStatusType(StrEnum): + CHECKED_IN = "CHECKED_IN" + NOT_CHECKED_IN = "NOT_CHECKED_IN" + + +class LobbyStatus(msgspec.Struct, frozen=True): + status: StrEnum diff --git a/industry_game/utils/users/storage.py b/industry_game/utils/users/storage.py index 8d42e6d..4b5e9ec 100644 --- a/industry_game/utils/users/storage.py +++ b/industry_game/utils/users/storage.py @@ -53,10 +53,7 @@ async def read_by_username( session: AsyncSession, username: str, ) -> FullUser | None: - stmt = select(UserDb).where( - UserDb.type == UserType.PLAYER, - UserDb.username == username, - ) + stmt = select(UserDb).where(UserDb.username == username) obj = (await session.scalars(stmt)).first() return FullUser.from_model(obj) if obj else None diff --git a/tests/api/players/test_player_register.py b/tests/api/players/test_player_register.py index 0ea9d18..4a53d8b 100644 --- a/tests/api/players/test_player_register.py +++ b/tests/api/players/test_player_register.py @@ -1,3 +1,6 @@ +from http import HTTPStatus + +import pytest from aiohttp.test_utils import TestClient from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -21,7 +24,7 @@ async def test_player_register_successful_status_created( "name": "your name", }, ) - assert response.status == 201 + assert response.status == HTTPStatus.CREATED async def test_player_register_successful_check_db( @@ -43,3 +46,21 @@ async def test_player_register_successful_check_db( assert user.username == "username" assert user.type == UserType.PLAYER + + +@pytest.mark.parametrize("user_type", (UserType.ADMIN, UserType.PLAYER)) +async def test_player_register_same_username_error_conflict( + api_client: TestClient, create_user, user_type +): + user = await create_user(type=user_type) + + response = await api_client.post( + API_URL, + json={ + "username": user.username, + "password": "password", + "telegram": "telegram", + "name": "your name", + }, + ) + assert response.status == HTTPStatus.BAD_REQUEST