Skip to content

Commit

Permalink
Add types and ignores to make mypy happy
Browse files Browse the repository at this point in the history
  • Loading branch information
felddy committed Aug 26, 2024
1 parent 0d9d163 commit 8b9fd65
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 15 deletions.
12 changes: 6 additions & 6 deletions src/cyhy_db/db.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Third-Party Libraries
from beanie import init_beanie
from motor.motor_asyncio import AsyncIOMotorClient
from beanie import Document, View, init_beanie
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase

from .models import *

ALL_MODELS = [
ALL_MODELS: list[type[Document] | type[View] | str] = [
CVE,
HostDoc,
HostScanDoc,
Expand All @@ -22,10 +22,10 @@
]


async def initialize_db(db_uri: str, db_name: str) -> None:
async def initialize_db(db_uri: str, db_name: str) -> AsyncIOMotorDatabase:
try:
client = AsyncIOMotorClient(db_uri)
db = client[db_name]
client: AsyncIOMotorClient = AsyncIOMotorClient(db_uri)
db: AsyncIOMotorDatabase = client[db_name]
await init_beanie(database=db, document_models=ALL_MODELS)
return db
except Exception as e:
Expand Down
3 changes: 2 additions & 1 deletion src/cyhy_db/models/cve.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ class CVE(Document):
# Validate on assignment so ip_int is recalculated as ip is set
model_config = ConfigDict(extra="forbid", validate_assignment=True)

id: str = Indexed(primary_field=True) # CVE ID
# CVE ID as a string
id: str = Indexed(primary_field=True) # type: ignore[assignment]
cvss_score: float = Field(ge=0.0, le=10.0)
cvss_version: CVSSVersion = Field(default=CVSSVersion.V3_1)
severity: int = Field(ge=1, le=4, default=1)
Expand Down
3 changes: 2 additions & 1 deletion src/cyhy_db/models/host_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ class State(BaseModel):
class HostDoc(Document):
model_config = ConfigDict(extra="forbid")

id: int = Field() # IP address as an integer
# IP address as an integer
id: int = Field(default_factory=int) # type: ignore[assignment]
ip: IPv4Address = Field(...)
owner: str = Field(...)
last_change: datetime = Field(default_factory=utcnow)
Expand Down
4 changes: 2 additions & 2 deletions src/cyhy_db/models/kev_doc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Third-Party Libraries
from beanie import Document
from pydantic import ConfigDict
from pydantic import ConfigDict, Field


class KEVDoc(Document):
model_config = ConfigDict(extra="forbid")

id: str # CVE
id: str = Field(default_factory=str) # type: ignore[assignment]
known_ransomware: bool

class Settings:
Expand Down
5 changes: 3 additions & 2 deletions src/cyhy_db/models/place_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from typing import Optional

# Third-Party Libraries
from beanie import Document
from beanie import Document, PydanticObjectId
from pydantic import ConfigDict, Field


class PlaceDoc(Document):
model_config = ConfigDict(extra="forbid")

id: int # GNIS FEATURE_ID (INCITS 446-2008) - https://geonames.usgs.gov/domestic/index.html
# GNIS FEATURE_ID (INCITS 446-2008) - https://geonames.usgs.gov/domestic/index.html
id: int = Field(default_factory=int) # type: ignore[assignment]
name: str
clazz: str = Field(alias="class") # 'class' is a reserved keyword in Python
state: str
Expand Down
12 changes: 10 additions & 2 deletions src/cyhy_db/models/request_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
from typing import List, Optional

# Third-Party Libraries
from beanie import Document, Insert, Link, Replace, ValidateOnSave, before_event
from beanie import (
Document,
Insert,
Link,
PydanticObjectId,
Replace,
ValidateOnSave,
before_event,
)
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator

from ..utils import utcnow
Expand Down Expand Up @@ -81,7 +89,7 @@ def validate_start(cls, v):
class RequestDoc(Document):
model_config = ConfigDict(extra="forbid")

id: str = Field(default=BOGUS_ID)
id: str = Field(default=BOGUS_ID) # type: ignore[assignment]
agency: Agency
children: List[Link["RequestDoc"]] = Field(default=[])
enrolled: datetime = Field(default_factory=utcnow)
Expand Down
2 changes: 1 addition & 1 deletion src/cyhy_db/models/system_control_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def wait_for_completion(cls, document_id, timeout: Optional[int] = None):
start_time = utcnow()
while True:
doc = await cls.get(document_id)
if doc.completed:
if doc and doc.completed:
return True
if timeout and (utcnow() - start_time).total_seconds() > timeout:
return False
Expand Down

0 comments on commit 8b9fd65

Please sign in to comment.