From 063bd049b19fb6c2910d1dbd76e6a39a34011b41 Mon Sep 17 00:00:00 2001 From: AaronLieb Date: Mon, 11 Nov 2024 20:33:09 -0800 Subject: [PATCH] Mypy --- Client.py | 4 ++-- SmsClient.py | 2 +- TextTestClient.py | 2 +- database.py | 8 ++++---- main.py | 22 ++++++++------------ models.py | 52 ++++++++++++++++++++++++++++++----------------- mypy.ini | 2 ++ 7 files changed, 52 insertions(+), 40 deletions(-) create mode 100644 mypy.ini diff --git a/Client.py b/Client.py index 141710e..68e212a 100644 --- a/Client.py +++ b/Client.py @@ -65,7 +65,7 @@ def handle_message(self, from_: str, text: str) -> None: def reg_state_0(self, from_: str, text: str) -> None: if contains(text, START_KEYWORDS): - _ = self.db.update_reg(from_, 1) + self.db.update_reg(from_, 1) self.send_message(from_, ENTER_USERNAME) else: self.send_message(from_, HOW_TO_START) @@ -75,7 +75,7 @@ def reg_state_1(self, from_: str, text: str) -> None: self.send_message(from_, BAD_USERNAME) return - _ = self.db.update_reg(from_, 2, text) + self.db.update_reg(from_, 2, text) self.send_message(from_, CONFIRM_USERNAME.format(text)) def reg_state_2(self, from_: str, text: str, reg: Registration) -> None: diff --git a/SmsClient.py b/SmsClient.py index 2516738..0bb5c25 100644 --- a/SmsClient.py +++ b/SmsClient.py @@ -22,7 +22,7 @@ def __init__(self, settings: Settings, db: Database | None = None): @override def send_message(self, to: str, text: str) -> None: - self.twilio_client.messages.create( # pyright: ignore [reportUnknownMemberType] + self.twilio_client.messages.create( to=to, from_=settings.twilio_phone_number, body=text ) diff --git a/TextTestClient.py b/TextTestClient.py index 98da56b..e80a068 100644 --- a/TextTestClient.py +++ b/TextTestClient.py @@ -24,7 +24,7 @@ def handle_beta_message(self, from_: str, text: str): if __name__ == "__main__": - client = TextTestClient() + client = TextTestClient(settings) while True: text = input() diff --git a/database.py b/database.py index bc20286..6135fff 100644 --- a/database.py +++ b/database.py @@ -25,7 +25,7 @@ def get_db() -> Generator[Session]: class Database: - def __init__(self): + def __init__(self) -> None: db = next(get_db(), None) if db is None: raise Exception("Could not connect to database") @@ -86,9 +86,9 @@ def update_reg( if db_reg is None: return None - db_reg.state = state # pyright: ignore [reportAttributeAccessIssue] + db_reg.state = state if username is not None: - db_reg.username = username # pyright: ignore [reportAttributeAccessIssue] + db_reg.username = username self.db.commit() return db_reg @@ -162,6 +162,6 @@ def set_winner(self, pic_id: int): if pic is None: return None - pic.winner = not pic.winner # pyright: ignore [reportAttributeAccessIssue] + pic.winner = not pic.winner self.db.commit() return pic diff --git a/main.py b/main.py index 2edc200..9a5b357 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,3 @@ -# pyright: reportCallInDefaultInitializer=false - import logging from typing import Annotated @@ -25,7 +23,7 @@ logger = logging.getLogger(__name__) -models.Base.metadata.create_all(bind=engine) # pyright: ignore [reportAny] +models.Base.metadata.create_all(bind=engine) app = FastAPI() twilio_client = SmsClient(settings) @@ -61,7 +59,7 @@ async def receive_message( validator = RequestValidator(settings.twilio_auth_token) form_ = await request.form() - if not validator.validate( # pyright: ignore [reportUnknownMemberType] + if not validator.validate( str(request.url), form_, request.headers.get("X-Twilio-Signature", "") ): raise HTTPException(status_code=400, detail="Error in Twilio Signature") @@ -101,21 +99,19 @@ def images_page( if prompt is None: raise HTTPException(status_code=404, detail="Unable to find prompt") - if not db.get_submission_status( - user_hash, prompt.id # pyright: ignore [reportArgumentType] - ): + if not db.get_submission_status(user_hash, prompt.id): raise HTTPException(status_code=401, detail="No submission for this prompt") - pics = db.get_pics_by_prompt(prompt.id) # pyright: ignore [reportArgumentType] + pics = db.get_pics_by_prompt(prompt.id) pics = [vars(pic) for pic in pics] for pic in pics: pic["click_url"] = pic["url"] - date_str: str = prompt.date.strftime("%b %-d, %Y") # pyright: ignore [reportAny] + date_str: str = prompt.date.strftime("%b %-d, %Y") og = {"display": False} - winner = db.get_winner_by_prompt(prompt.id) # pyright: ignore [reportArgumentType] + winner = db.get_winner_by_prompt(prompt.id) if winner is not None: og["display"] = True og["url"] = winner.url @@ -132,7 +128,7 @@ def history_page(user_hash: str): pics = db.get_pics_by_hash(user_hash) html_list: list[str] = [] for pic in pics: - prompt = db.get_prompt(pic.prompt) # pyright: ignore [reportArgumentType] + prompt = db.get_prompt(pic.prompt) assert prompt is not None url = BASE_URL + "{}?n={}".format(user_hash, prompt.id) html_list.append('
  • {}
  • '.format(url, prompt.prompt)) @@ -187,12 +183,12 @@ def winner_admin_page( if prompt is None: raise HTTPException(status_code=404, detail="Unable to find prompt") - pics = db.get_pics_by_prompt(prompt.id) # pyright: ignore [reportArgumentType] + pics = db.get_pics_by_prompt(prompt.id) pics = [vars(pic) for pic in pics] for pic in pics: pic["click_url"] = BASE_URL + "win/{}".format(pic["id"]) - date_str: str = prompt.date.strftime("%b %-d, %Y") # pyright: ignore [reportAny] + date_str: str = prompt.date.strftime("%b %-d, %Y") return templates.TemplateResponse( request=request, diff --git a/models.py b/models.py index feaebe8..a5c5912 100644 --- a/models.py +++ b/models.py @@ -1,18 +1,32 @@ from typing import Optional -from sqlalchemy import Boolean, Column, Date, ForeignKey, Integer, String -from sqlalchemy.orm import Mapped, declarative_base, relationship - -Base = declarative_base() +from sqlalchemy import ( + Boolean, + Date, + ForeignKey, + Integer, + String, +) +from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + mapped_column, + declarative_base, + relationship, +) + + +class Base(DeclarativeBase): + __allow_unmapped__ = True class User(Base): __tablename__ = "users" - username = Column(String, primary_key=True) - phone = Column(String, unique=True) - active = Column(Boolean, default=True) - hash = Column(String, unique=True) + username = mapped_column(String, primary_key=True) + phone = mapped_column(String, unique=True) + active = mapped_column(Boolean, default=True) + hash = mapped_column(String, unique=True) pics: Mapped[Optional["Pic"]] = relationship("Pic", back_populates="uploader") @@ -20,17 +34,17 @@ class User(Base): class Registration(Base): __tablename__ = "registrations" - phone = Column(String, primary_key=True) - username = Column(String) - state = Column(Integer) + phone = mapped_column(String, primary_key=True) + username = mapped_column(String) + state = mapped_column(Integer) class Prompt(Base): __tablename__ = "prompts" - id = Column(Integer, primary_key=True) - prompt = Column(String, nullable=False) - date = Column(Date, nullable=False) + id = mapped_column(Integer, primary_key=True) + prompt = mapped_column(String, nullable=False) + date = mapped_column(Date, nullable=False) pics: Mapped[Optional["Pic"]] = relationship("Pic", back_populates="parent") @@ -38,11 +52,11 @@ class Prompt(Base): class Pic(Base): __tablename__ = "pics" - id = Column(Integer, primary_key=True) - url = Column(String) - prompt = Column(Integer, ForeignKey("prompts.id")) - user = Column(String, ForeignKey("users.username")) - winner = Column(Boolean, default=False) + id = mapped_column(Integer, primary_key=True) + url = mapped_column(String) + prompt = mapped_column(Integer, ForeignKey("prompts.id")) + user = mapped_column(String, ForeignKey("users.username")) + winner = mapped_column(Boolean, default=False) parent = relationship("Prompt", back_populates="pics") uploader = relationship("User", back_populates="pics") diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..b0eebaf --- /dev/null +++ b/mypy.ini @@ -0,0 +1,2 @@ +[mypy] +python_executable=venv/bin/python