Skip to content

Commit

Permalink
Fix user and opportunity get method to include tenancy filters
Browse files Browse the repository at this point in the history
  • Loading branch information
shri committed Aug 26, 2024
1 parent a1d1ad0 commit 2dac7a6
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 26 deletions.
27 changes: 27 additions & 0 deletions src/app/domain/accounts/repositories.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,42 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from uuid import UUID

from advanced_alchemy.repository import SQLAlchemyAsyncRepository, SQLAlchemyAsyncSlugRepository
from advanced_alchemy.repository import SQLAlchemyAsyncSlugRepository
from sqlalchemy import ColumnElement, select
from sqlalchemy.orm import joinedload, InstrumentedAttribute

from app.db.models import Role, User, UserRole, Tenant

if TYPE_CHECKING:
from advanced_alchemy.filters import FilterTypes
from advanced_alchemy.repository._util import LoadSpec


class UserRepository(SQLAlchemyAsyncRepository[User]):
"""User SQLAlchemy Repository."""

model_type = User

async def get_user(
self,
user_id: UUID,
tenant_id: UUID,
*,
load: LoadSpec | None = None,
execution_options: dict[str, Any] | None = None,
auto_expunge: bool | None = None,
) -> User:
"""Get a user along with it's associated details."""
return await self.get_one(
id=user_id,
auto_expunge=auto_expunge,
statement=select(User).where((User.id == user_id) & (User.tenant_id == tenant_id)).options(),
load=load,
execution_options=execution_options,
)


class RoleRepository(SQLAlchemyAsyncSlugRepository[Role]):
"""User SQLAlchemy Repository."""
Expand Down
9 changes: 9 additions & 0 deletions src/app/domain/accounts/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ def __init__(self, **repo_kwargs: Any) -> None:
self.repository: UserRepository = self.repository_type(**repo_kwargs)
self.model_type = self.repository.model_type

async def get_user(
self,
user_id: UUID,
tenant_id: UUID,
**kwargs: Any,
) -> tuple[list[User], int]:
"""Get user details."""
return await self.repository.get_user(user_id=user_id, tenant_id=tenant_id, **kwargs)

async def create(
self,
data: ModelDictT[User],
Expand Down
10 changes: 4 additions & 6 deletions src/app/domain/opportunities/controllers/opportunities.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ async def create_opportunity(
# Verify is the owner exists in this tenant
owner_id = obj.get("owner_id")
if owner_id:
db_obj = await users_service.get_one(
(UserModel.tenant_id == current_user.tenant_id) & (UserModel.id == owner_id)
)
db_obj = await users_service.get_user(owner_id, tenant_id=current_user.tenant_id)
if not db_obj:
raise ValidationException("Owner does not exist")

Expand Down Expand Up @@ -120,7 +118,7 @@ async def get_opportunity(
],
) -> Opportunity:
"""Get details about a comapny."""
db_obj = await opportunities_service.get(opportunity_id)
db_obj = await opportunities_service.get_opportunity(opportunity_id, tenant_id=current_user.tenant_id)
return opportunities_service.to_schema(schema_type=Opportunity, data=db_obj)

@patch(
Expand Down Expand Up @@ -149,12 +147,12 @@ async def update_opportunity(
# Verify is the owner exists for in tenant
owner_id = obj.get("owner_id")
if owner_id:
db_obj = await users_service.get_one(owner_id, tenant_id=current_user.tenant_id)
db_obj = await users_service.get_user(owner_id, tenant_id=current_user.tenant_id)
if not db_obj:
raise ValidationException("Owner does not exist")

# Verify if the user is part of the same tenant as the opportunity
opportunity = await OpportunityService.get_one(opportunity_id)
opportunity = await opportunities_service.get_opportunity(opportunity_id, tenant_id=current_user.tenant_id)
if not opportunity:
raise ValidationException("Opportunity does not exist")

Expand Down
23 changes: 7 additions & 16 deletions src/app/domain/opportunities/repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@
from advanced_alchemy.filters import FilterTypes
from advanced_alchemy.repository._util import LoadSpec

__all__ = (
"OpportunityRepository",
"OpportunityAuditLogRepository"
)
__all__ = ("OpportunityRepository", "OpportunityAuditLogRepository")


class OpportunityRepository(SQLAlchemyAsyncSlugRepository[Opportunity]):
Expand All @@ -37,8 +34,8 @@ async def get_opportunities(
return await self.list_and_count(
*filters,
statement=select(Opportunity)
.where(Opportunity.tenant_id == tenant_id),
#.order_by(Opportunity.score.desc(), Opportunity.created_at.desc())
.where(Opportunity.tenant_id == tenant_id)
.order_by(Opportunity.created_at.desc()),
auto_expunge=auto_expunge,
force_basic_query_mode=force_basic_query_mode,
**kwargs,
Expand All @@ -49,28 +46,22 @@ async def get_opportunity(
opportunity_id: UUID,
tenant_id: UUID,
*,
id_attribute: str | InstrumentedAttribute[Any] | None = None,
load: LoadSpec | None = None,
execution_options: dict[str, Any] | None = None,
auto_expunge: bool | None = None,
) -> Opportunity:
"""Get an opportunity along with it's associated details."""
return await self.repository.get_one(
item_id=opportunity_id,
return await self.get_one(
id=opportunity_id,
auto_expunge=auto_expunge,
statement=select(Opportunity)
.where((Opportunity.id == opportunity_id) & (Opportunity.tenant_id == tenant_id))
#.order_by(Opportunity.score.desc(), Opportunity.created_at.desc())
.options(
joinedload(Opportunity.contacts, innerjoin=False),
joinedload(Opportunity.job_posts, innerjoin=False),
joinedload(Opportunity.logs, innerjoin=False),
),
id_attribute=id_attribute,
.options(),
load=load,
execution_options=execution_options,
)


class OpportunityAuditLogRepository(SQLAlchemyAsyncSlugRepository[OpportunityAuditLog]):
"""OpportunityAuditLog Repository."""

Expand Down
6 changes: 2 additions & 4 deletions src/app/domain/opportunities/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,8 @@ async def get_opportunity(
tenant_id: UUID,
**kwargs: Any,
) -> tuple[list[Opportunity], int]:
"""Get all opportunities for a tenant."""
return await self.repository.get_opportunity(
opportunity_id=opportunity_id, tenant_id=tenant_id, **kwargs
)
"""Get opportunity details."""
return await self.repository.get_opportunity(opportunity_id=opportunity_id, tenant_id=tenant_id, **kwargs)

async def update(
self,
Expand Down

0 comments on commit 2dac7a6

Please sign in to comment.