diff --git a/src/app/domain/accounts/repositories.py b/src/app/domain/accounts/repositories.py index 18e32778..5c622bcd 100644 --- a/src/app/domain/accounts/repositories.py +++ b/src/app/domain/accounts/repositories.py @@ -1,15 +1,42 @@ from __future__ import annotations +from typing import TYPE_CHECKING, Any +from uuid import UUID from advanced_alchemy.repository import SQLAlchemyAsyncRepository, SQLAlchemyAsyncSlugRepository +from advanced_alchemy.repository import SQLAlchemyAsyncSlugRepository +from sqlalchemy import ColumnElement, select +from sqlalchemy.orm import joinedload, InstrumentedAttribute from app.db.models import Role, User, UserRole, Tenant +if TYPE_CHECKING: + from advanced_alchemy.filters import FilterTypes + from advanced_alchemy.repository._util import LoadSpec + class UserRepository(SQLAlchemyAsyncRepository[User]): """User SQLAlchemy Repository.""" model_type = User + async def get_user( + self, + user_id: UUID, + tenant_id: UUID, + *, + load: LoadSpec | None = None, + execution_options: dict[str, Any] | None = None, + auto_expunge: bool | None = None, + ) -> User: + """Get a user along with it's associated details.""" + return await self.get_one( + id=user_id, + auto_expunge=auto_expunge, + statement=select(User).where((User.id == user_id) & (User.tenant_id == tenant_id)).options(), + load=load, + execution_options=execution_options, + ) + class RoleRepository(SQLAlchemyAsyncSlugRepository[Role]): """User SQLAlchemy Repository.""" diff --git a/src/app/domain/accounts/services.py b/src/app/domain/accounts/services.py index 92727e2b..5671a9cf 100644 --- a/src/app/domain/accounts/services.py +++ b/src/app/domain/accounts/services.py @@ -36,6 +36,15 @@ def __init__(self, **repo_kwargs: Any) -> None: self.repository: UserRepository = self.repository_type(**repo_kwargs) self.model_type = self.repository.model_type + async def get_user( + self, + user_id: UUID, + tenant_id: UUID, + **kwargs: Any, + ) -> tuple[list[User], int]: + """Get user details.""" + return await self.repository.get_user(user_id=user_id, tenant_id=tenant_id, **kwargs) + async def create( self, data: ModelDictT[User], diff --git a/src/app/domain/opportunities/controllers/opportunities.py b/src/app/domain/opportunities/controllers/opportunities.py index 489a1772..90a66ac6 100644 --- a/src/app/domain/opportunities/controllers/opportunities.py +++ b/src/app/domain/opportunities/controllers/opportunities.py @@ -80,9 +80,7 @@ async def create_opportunity( # Verify is the owner exists in this tenant owner_id = obj.get("owner_id") if owner_id: - db_obj = await users_service.get_one( - (UserModel.tenant_id == current_user.tenant_id) & (UserModel.id == owner_id) - ) + db_obj = await users_service.get_user(owner_id, tenant_id=current_user.tenant_id) if not db_obj: raise ValidationException("Owner does not exist") @@ -120,7 +118,7 @@ async def get_opportunity( ], ) -> Opportunity: """Get details about a comapny.""" - db_obj = await opportunities_service.get(opportunity_id) + db_obj = await opportunities_service.get_opportunity(opportunity_id, tenant_id=current_user.tenant_id) return opportunities_service.to_schema(schema_type=Opportunity, data=db_obj) @patch( @@ -149,12 +147,12 @@ async def update_opportunity( # Verify is the owner exists for in tenant owner_id = obj.get("owner_id") if owner_id: - db_obj = await users_service.get_one(owner_id, tenant_id=current_user.tenant_id) + db_obj = await users_service.get_user(owner_id, tenant_id=current_user.tenant_id) if not db_obj: raise ValidationException("Owner does not exist") # Verify if the user is part of the same tenant as the opportunity - opportunity = await OpportunityService.get_one(opportunity_id) + opportunity = await opportunities_service.get_opportunity(opportunity_id, tenant_id=current_user.tenant_id) if not opportunity: raise ValidationException("Opportunity does not exist") diff --git a/src/app/domain/opportunities/repositories.py b/src/app/domain/opportunities/repositories.py index 76f297e2..ed9d80ef 100644 --- a/src/app/domain/opportunities/repositories.py +++ b/src/app/domain/opportunities/repositories.py @@ -13,10 +13,7 @@ from advanced_alchemy.filters import FilterTypes from advanced_alchemy.repository._util import LoadSpec -__all__ = ( - "OpportunityRepository", - "OpportunityAuditLogRepository" -) +__all__ = ("OpportunityRepository", "OpportunityAuditLogRepository") class OpportunityRepository(SQLAlchemyAsyncSlugRepository[Opportunity]): @@ -37,8 +34,8 @@ async def get_opportunities( return await self.list_and_count( *filters, statement=select(Opportunity) - .where(Opportunity.tenant_id == tenant_id), - #.order_by(Opportunity.score.desc(), Opportunity.created_at.desc()) + .where(Opportunity.tenant_id == tenant_id) + .order_by(Opportunity.created_at.desc()), auto_expunge=auto_expunge, force_basic_query_mode=force_basic_query_mode, **kwargs, @@ -49,28 +46,22 @@ async def get_opportunity( opportunity_id: UUID, tenant_id: UUID, *, - id_attribute: str | InstrumentedAttribute[Any] | None = None, load: LoadSpec | None = None, execution_options: dict[str, Any] | None = None, auto_expunge: bool | None = None, ) -> Opportunity: """Get an opportunity along with it's associated details.""" - return await self.repository.get_one( - item_id=opportunity_id, + return await self.get_one( + id=opportunity_id, auto_expunge=auto_expunge, statement=select(Opportunity) .where((Opportunity.id == opportunity_id) & (Opportunity.tenant_id == tenant_id)) - #.order_by(Opportunity.score.desc(), Opportunity.created_at.desc()) - .options( - joinedload(Opportunity.contacts, innerjoin=False), - joinedload(Opportunity.job_posts, innerjoin=False), - joinedload(Opportunity.logs, innerjoin=False), - ), - id_attribute=id_attribute, + .options(), load=load, execution_options=execution_options, ) + class OpportunityAuditLogRepository(SQLAlchemyAsyncSlugRepository[OpportunityAuditLog]): """OpportunityAuditLog Repository.""" diff --git a/src/app/domain/opportunities/services.py b/src/app/domain/opportunities/services.py index 213427b1..bb9f1e0c 100644 --- a/src/app/domain/opportunities/services.py +++ b/src/app/domain/opportunities/services.py @@ -68,10 +68,8 @@ async def get_opportunity( tenant_id: UUID, **kwargs: Any, ) -> tuple[list[Opportunity], int]: - """Get all opportunities for a tenant.""" - return await self.repository.get_opportunity( - opportunity_id=opportunity_id, tenant_id=tenant_id, **kwargs - ) + """Get opportunity details.""" + return await self.repository.get_opportunity(opportunity_id=opportunity_id, tenant_id=tenant_id, **kwargs) async def update( self,