Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(camelcase-api-field): Allow camelcase for API result #10

Merged
merged 7 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/vaultwarden/clients/bitwarden.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ def _refresh_connect_token(self):
)
self._connect_token = ConnectToken.model_validate_json(resp.text)

self._connect_token.master_key = make_master_key(
password=self.password,
salt=self.email,
iterations=self._connect_token.KdfIterations,
)

def _set_connect_token(self):
headers = {
"content-type": "application/x-www-form-urlencoded; charset=utf-8",
Expand Down
4 changes: 2 additions & 2 deletions src/vaultwarden/clients/vaultwarden.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import http
from http.cookiejar import Cookie
from typing import Any, Literal, Optional
from typing import Any, Literal
from uuid import UUID

from httpx import Client, HTTPStatusError, Response
Expand Down Expand Up @@ -43,7 +43,7 @@ def __init__(
if preload_users:
self._load_users()

def _get_admin_cookie(self) -> Optional[Cookie]:
def _get_admin_cookie(self) -> Cookie | None:
"""Get the session cookie, required to authenticate requests"""
bw_cookies = (
c for c in self._http_client.cookies.jar if c.name == "VW_ADMIN"
Expand Down
31 changes: 18 additions & 13 deletions src/vaultwarden/models/bitwarden.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
from typing import Generic, TypeVar
from uuid import UUID

from pydantic import BaseModel, Field, TypeAdapter, field_validator
from pydantic import AliasChoices, Field, TypeAdapter, field_validator
from pydantic_core.core_schema import FieldValidationInfo

from vaultwarden.clients.bitwarden import BitwardenAPIClient
from vaultwarden.models.enum import CipherType, OrganizationUserType
from vaultwarden.models.exception_models import BitwardenError
from vaultwarden.models.permissive_model import PermissiveBaseModel
from vaultwarden.utils.crypto import decrypt, encrypt

# Pydantic models for Bitwarden data structures

T = TypeVar("T", bound="BitwardenBaseModel")


class ResplistBitwarden(BaseModel, Generic[T]):
class ResplistBitwarden(PermissiveBaseModel, Generic[T]):
Data: list[T]


class BitwardenBaseModel(
BaseModel, extra="allow", arbitrary_types_allowed=True
):
class BitwardenBaseModel(PermissiveBaseModel):
bitwarden_client: BitwardenAPIClient | None = Field(
default=None, validate_default=True, exclude=True
)
Expand Down Expand Up @@ -105,7 +104,11 @@ class CollectionAccess(BitwardenBaseModel):

class CollectionUser(CollectionAccess):
CollectionId: UUID | None = Field(None, validate_default=True)
UserId: UUID | None = Field(None, alias="Id", serialization_alias="Id")
UserId: UUID | None = Field(
None,
validation_alias=AliasChoices("id", "Id"),
serialization_alias="id",
)

@field_validator("CollectionId")
@classmethod
Expand All @@ -117,7 +120,9 @@ def set_id(cls, v, info: FieldValidationInfo):

class UserCollection(CollectionAccess):
CollectionId: UUID | None = Field(
None, alias="Id", serialization_alias="Id"
None,
validation_alias=AliasChoices("id", "Id"),
serialization_alias="id",
)
UserId: UUID | None = Field(None, validate_default=True)

Expand All @@ -133,7 +138,7 @@ class OrganizationCollection(BitwardenBaseModel):
Id: UUID | None = None
OrganizationId: UUID | None = Field(None, validate_default=True)
Name: str
ExternalId: str | None
ExternalId: str | None = None

@field_validator("OrganizationId")
@classmethod
Expand Down Expand Up @@ -219,7 +224,7 @@ def add_collections(self, collections: list[UUID]):
if collection in _current_collections:
continue
user = UserCollection(
Id=collection,
CollectionId=collection,
UserId=self.Id,
ReadOnly=False,
HidePasswords=False,
Expand Down Expand Up @@ -284,7 +289,7 @@ def update_collection(self, collections: list[UUID]):
self.Collections = [
UserCollection(
UserId=self.Id,
Id=coll,
CollectionId=coll,
ReadOnly=False,
HidePasswords=False,
)
Expand Down Expand Up @@ -484,9 +489,9 @@ def collections(
def create_collection(self, name: str) -> OrganizationCollection:
org_key = self.key()
data = {
"Name": encrypt(2, name, self.key()),
"Groups": [],
"Users": [],
"name": encrypt(2, name, self.key()),
"groups": [],
"users": [],
}
resp = self.api_client.api_request(
"POST", f"api/organizations/{self.Id}/collections", json=data
Expand Down
13 changes: 13 additions & 0 deletions src/vaultwarden/models/permissive_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pydantic import BaseModel

from vaultwarden.utils.string_cases import pascal_case_to_camel_case


class PermissiveBaseModel(
BaseModel,
extra="allow",
alias_generator=pascal_case_to_camel_case,
Lowaiz marked this conversation as resolved.
Show resolved Hide resolved
populate_by_name=True,
arbitrary_types_allowed=True,
):
pass
37 changes: 19 additions & 18 deletions src/vaultwarden/models/sync.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import time
from uuid import UUID

from pydantic import BaseModel, Field, field_validator
from pydantic import AliasChoices, Field, field_validator

from vaultwarden.models.enum import VaultwardenUserStatus
from vaultwarden.models.permissive_model import PermissiveBaseModel
from vaultwarden.utils.crypto import decrypt


class ConnectToken(BaseModel, extra="allow"):
class ConnectToken(PermissiveBaseModel):
Kdf: int = 0
KdfIterations: int = 0
KdfMemory: int | None = None
Expand All @@ -31,9 +32,7 @@ def expires_in_to_time(cls, v):
def is_expired(self, now=None):
if now is None:
now = time.time()
if (self.expires_in is not None) and (self.expires_in <= now):
return True
return False
return (self.expires_in is not None) and (self.expires_in <= now)

@property
def user_key(self):
Expand All @@ -44,7 +43,7 @@ def orgs_key(self):
return decrypt(self.PrivateKey, self.user_key)


class ProfileOrganization(BaseModel, extra="allow"):
class ProfileOrganization(PermissiveBaseModel):
Id: UUID
Name: str
Key: str | None = None
Expand All @@ -67,7 +66,7 @@ class ProfileOrganization(BaseModel, extra="allow"):
UseTotp: bool


class UserProfile(BaseModel, extra="allow"):
class UserProfile(PermissiveBaseModel):
AvatarColor: str | None
Culture: str
Email: str
Expand All @@ -78,14 +77,16 @@ class UserProfile(BaseModel, extra="allow"):
MasterPasswordHint: str | None
Name: str
Object: str | None
Organizations: list[ProfileOrganization] = []
Organizations: list[ProfileOrganization]
Premium: bool
PrivateKey: str | None
ProviderOrganizations: list = []
Providers: list = []
ProviderOrganizations: list
Providers: list
SecurityStamp: str
TwoFactorEnabled: bool
status: VaultwardenUserStatus = Field(alias="_Status")
status: VaultwardenUserStatus = Field(
validation_alias=AliasChoices("_status", "_Status")
)


class VaultwardenUser(UserProfile):
Expand All @@ -95,11 +96,11 @@ class VaultwardenUser(UserProfile):


# TODO: add definition of attribute's types
class SyncData(BaseModel, extra="allow"):
Ciphers: list[dict] = []
Collections: list[dict] = []
Domains: dict = {}
Folders: list[dict] = []
Policies: list[dict] = []
class SyncData(PermissiveBaseModel):
Ciphers: list[dict]
Collections: list[dict]
Domains: dict | None
Folders: list[dict]
Policies: list[dict]
Profile: UserProfile
Sends: list[dict] = []
Sends: list[dict]
10 changes: 10 additions & 0 deletions src/vaultwarden/utils/string_cases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
def pascal_case_to_camel_case(pascal: str) -> str:
"""Convert a PascalCase string to camelCase.

Args:
pascal: The string to convert.

Returns:
The converted camelCase string.
"""
return pascal[0].lower() + pascal[1:]
Loading