Skip to content

Commit

Permalink
type safety
Browse files Browse the repository at this point in the history
  • Loading branch information
AaronLieb committed Oct 5, 2024
1 parent 6cab421 commit 2fb9b1d
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 77 deletions.
8 changes: 3 additions & 5 deletions Client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def receive_message(self, from_: str, text: str):
pass

@abstractmethod
def handle_beta_message(self, from_: str, text: str):
def handle_beta_message(self, from_: str, text: str) -> str:
pass

def handle_message(self, from_: str, text: str) -> None:
Expand Down Expand Up @@ -65,7 +65,7 @@ def handle_message(self, from_: str, text: str) -> None:

def reg_state_0(self, from_: str, text: str) -> None:
if contains(text, START_KEYWORDS):
self.db.update_reg(from_, 1)
_ = self.db.update_reg(from_, 1)
self.send_message(from_, ENTER_USERNAME)
else:
self.send_message(from_, HOW_TO_START)
Expand All @@ -75,7 +75,7 @@ def reg_state_1(self, from_: str, text: str) -> None:
self.send_message(from_, BAD_USERNAME)
return

self.db.update_reg(from_, 2, text)
_ = self.db.update_reg(from_, 2, text)
self.send_message(from_, CONFIRM_USERNAME.format(text))

def reg_state_2(self, from_: str, text: str, reg: Registration) -> None:
Expand All @@ -93,8 +93,6 @@ def reg_state_2(self, from_: str, text: str, reg: Registration) -> None:
def handle_image(self, from_: str, url: str) -> None:
logger.info("handle_image %s %s", from_, url)

self.handle_beta_message(from_, "")

user = self.db.get_user_by_phone(from_)

if user is None:
Expand Down
7 changes: 3 additions & 4 deletions SmsClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import override

from fastapi import HTTPException
from sqlalchemy.orm import Session
from twilio.rest import Client as TwilioClient

from Client import Client
Expand All @@ -19,11 +18,11 @@ def __init__(self, settings: Settings, db: Database | None = None):
self.twilio_client = TwilioClient(
settings.twilio_account_sid, settings.twilio_auth_token
)
self.reroute_next_msg_users = set()
self.reroute_next_msg_users: set[str] = set()

@override
def send_message(self, to: str, text: str) -> None:
self.twilio_client.messages.create(
self.twilio_client.messages.create( # pyright: ignore [reportUnknownMemberType]
to=to, from_=settings.twilio_phone_number, body=text
)

Expand All @@ -40,7 +39,7 @@ def handle_beta_message(self, from_: str, text: str) -> str:
elif self.settings.beta_code in text:
prompt_text = " ".join(text.split(" ")[1:])
logger.info("handle_beta_message %s", prompt_text)
self.reroute_to_beta()
self.reroute_to_beta(from_)
return prompt_text
elif from_ in self.reroute_next_msg_users:
logger.info("handle_beta_image")
Expand Down
10 changes: 6 additions & 4 deletions constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
BETA_ENV = "BETA"
DEV_ENV = "DEV"

BASE_URL = "https://snapshot.lieber.men/"
if settings.environment == BETA_ENV:
BASE_URL = "https://dev.snapshot.lieber.men/"
if settings.environment == DEV_ENV:
BASE_URL = "http://localhost:8000/"
base_domain = "dev.snapshot.lieber.men/"
elif settings.environment == DEV_ENV:
base_domain = "localhost:8000/"
else:
base_domain = "snapshot.lieber.men/"
BASE_URL = "https://" + base_domain

NAME = "Snapshot"
EMOJI = "📸" if settings.environment == PROD_ENV else "📷"
Expand Down
81 changes: 41 additions & 40 deletions database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Generator
from datetime import datetime
from typing import List

from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
Expand All @@ -16,7 +16,7 @@
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


def get_db():
def get_db() -> Generator[Session]:
db = SessionLocal()
try:
yield db
Expand All @@ -26,9 +26,10 @@ def get_db():

class Database:
def __init__(self):
self.db: Session = next(get_db(), None)
if self.db is None:
db = next(get_db(), None)
if db is None:
raise Exception("Could not connect to database")
self.db: Session = db

def get_user(self, username: str) -> User | None:
return self.db.query(User).filter(User.username == username).first()
Expand All @@ -39,7 +40,7 @@ def get_user_by_phone(self, phone: str) -> User | None:
def get_user_by_hash(self, hash: str) -> User | None:
return self.db.query(User).filter(User.hash == hash).first()

def get_users(self, skip: int = 0, limit: int = 100) -> List[User]:
def get_users(self, skip: int = 0, limit: int = 100) -> list[User]:
return self.db.query(User).offset(skip).limit(limit).all()

def create_user(self, phone: str, username: str) -> User:
Expand All @@ -56,43 +57,41 @@ def create_user(self, phone: str, username: str) -> User:
return user

def update_user(self, user: User) -> User | None:
self.db_user = self.db.query(User).filter(User.phone == user.phone).first()
db_user = self.db.query(User).filter(User.phone == user.phone).first()

if self.db_user is None:
if db_user is None:
return None

self.db_user.username = user.username
self.db_user.active = user.active
db_user.username = user.username
db_user.active = user.active
self.db.commit()

return self.db_user
return db_user

def get_reg(self, phone: str) -> Registration | None:
return self.db.query(Registration).filter(Registration.phone == phone).first()

def create_reg(self, phone: str) -> Registration:
self.db_reg = Registration(phone=phone, state=0)
self.db.add(self.db_reg)
db_reg = Registration(phone=phone, state=0)
self.db.add(db_reg)
self.db.commit()
self.db.refresh(self.db_reg)
return self.db_reg
self.db.refresh(db_reg)
return db_reg

def update_reg(
self, phone: str, state: int, username: Optional[int] = None
self, phone: str, state: int, username: str | None = None
) -> Registration | None:
self.db_reg = (
self.db.query(Registration).filter(Registration.phone == phone).first()
)
db_reg = self.db.query(Registration).filter(Registration.phone == phone).first()

if self.db_reg is None:
if db_reg is None:
return None

self.db_reg.state = state
db_reg.state = state # pyright: ignore [reportAttributeAccessIssue]
if username is not None:
self.db_reg.username = username
db_reg.username = username # pyright: ignore [reportAttributeAccessIssue]
self.db.commit()

return self.db_reg
return db_reg

def get_prompt(self, prompt_id: int) -> Prompt | None:
return self.db.query(Prompt).filter(Prompt.id == prompt_id).first()
Expand All @@ -101,13 +100,13 @@ def get_current_prompt(self) -> Prompt | None:
return self.db.query(Prompt).order_by(Prompt.id.desc()).first()

def create_prompt(self, prompt_text: str):
self.db_prompt = Prompt(prompt=prompt_text, date=datetime.now().date())
self.db.add(self.db_prompt)
db_prompt = Prompt(prompt=prompt_text, date=datetime.now().date())
self.db.add(db_prompt)
self.db.commit()
self.db.refresh(self.db_prompt)
return self.db_prompt
self.db.refresh(db_prompt)
return db_prompt

def get_all_prompts(self) -> List[Prompt]:
def get_all_prompts(self) -> list[Prompt]:
return self.db.query(Prompt).all()

def get_pic(self, username: str, prompt_id: int):
Expand All @@ -117,18 +116,20 @@ def get_pic(self, username: str, prompt_id: int):
.first()
)

def get_pics_by_hash(self, user_hash: str) -> List[Pic]:
user: User = self.db.query(User).filter(User.hash == user_hash).first()
return self.db.query(Pic).filter(Pic.user == user.username)
def get_pics_by_hash(self, user_hash: str) -> list[Pic]:
user = self.db.query(User).filter(User.hash == user_hash).first()
if user is None:
return []
return list(self.db.query(Pic).filter(Pic.user == user.username))

def get_pics_by_prompt(self, prompt_id: int) -> List[Pic]:
return (
def get_pics_by_prompt(self, prompt_id: int) -> list[Pic]:
return list(
self.db.query(Pic)
.filter(Pic.prompt == prompt_id)
.order_by(Pic.winner.desc())
)

def get_winner_by_prompt(self, prompt_id: int) -> List[Pic]:
def get_winner_by_prompt(self, prompt_id: int) -> Pic | None:
return (
self.db.query(Pic)
.filter(Pic.prompt == prompt_id, Pic.winner == True)
Expand All @@ -140,28 +141,28 @@ def get_submission_status(self, user_hash: str, prompt_id: int) -> bool:
print(user_hash)
if user is None:
return False
pic = self.get_pic(user.username, prompt_id)
pic = self.get_pic(str(user.username), prompt_id)

return pic is not None

def create_pic(self, url: str, prompt_id: int, username: str) -> Pic:
def create_pic(self, url: str, prompt_id: int, username: str) -> Pic | None:
if self.get_pic(username, prompt_id) is not None:
return None

picModel = Pic(url=url, prompt=prompt_id, user=username)
pic = Pic(url=url, prompt=prompt_id, user=username)

self.db.add(picModel)
self.db.add(pic)
self.db.commit()
self.db.refresh(picModel)
self.db.refresh(pic)

return picModel
return pic

def set_winner(self, pic_id: int):
pic = self.db.query(Pic).filter(Pic.id == pic_id).first()

if pic is None:
return None

pic.winner = not pic.winner
pic.winner = ~pic.winner # pyright: ignore [reportAttributeAccessIssue]
self.db.commit()
return pic
Loading

0 comments on commit 2fb9b1d

Please sign in to comment.