diff --git a/Client.py b/Client.py index 494002d..6ff0771 100644 --- a/Client.py +++ b/Client.py @@ -32,8 +32,14 @@ def handle_beta_message(self, from_: str, text: str): def handle_message(self, from_: str, text: str) -> None: logger.info("handle_message %s %s", from_, text) - if self.settings.beta_code in text: - text = self.handle_beta_message(from_, text) + # If the beta code is detected in a message, it will be rerouted to the beta environment + # by throwing a 501. The backup webhook (beta env) will instead handle the message + text = self.handle_beta_message(from_, text) + + # This value is used by handle_beta_message, or can be used for testing purposes + if text == IGNORE_MESSAGE: + logger.info("ignoring message...") + return if self.settings.admin_pass in text: self.handle_admin_message(text) @@ -87,6 +93,8 @@ def reg_state_2(self, from_: str, text: str, reg: Registration) -> None: def handle_image(self, from_: str, url: str) -> None: logger.info("handle_image %s %s", from_, url) + self.handle_beta_message(from_, "") + user = self.db.get_user_by_phone(from_) if user is None: diff --git a/SmsClient.py b/SmsClient.py index 37bd186..4d22486 100644 --- a/SmsClient.py +++ b/SmsClient.py @@ -7,7 +7,7 @@ from Client import Client from config import Settings, settings -from constants import PROD_ENV +from constants import PROD_ENV, IGNORE_MESSAGE from database import Database logger = logging.getLogger(__name__) @@ -19,6 +19,7 @@ def __init__(self, settings: Settings, db: Database | None = None): self.twilio_client = TwilioClient( settings.twilio_account_sid, settings.twilio_auth_token ) + self.reroute_next_msg_users = set() @override def send_message(self, to: str, text: str) -> None: @@ -32,8 +33,22 @@ def receive_message(self, from_: str, text: str) -> None: @override def handle_beta_message(self, from_: str, text: str) -> str: - prompt_text = " ".join(text.split(" ")[1:]) - logger.info("handle_beta_message %s", prompt_text) + if text == self.settings.beta_code: + logger.info("handle_beta_reroute_next_msg") + self.reroute_next_msg_users.add(from_) + return IGNORE_MESSAGE + elif self.settings.beta_code in text: + prompt_text = " ".join(text.split(" ")[1:]) + logger.info("handle_beta_message %s", prompt_text) + self.reroute_to_beta() + return prompt_text + elif from_ in self.reroute_next_msg_users: + logger.info("handle_beta_image") + self.reroute_next_msg_users.remove(_from) + self.reroute_to_beta() + return text + + def reroute_to_beta(self, from_: str): if self.settings.environment == PROD_ENV: logger.info("text routed to beta env") raise HTTPException(status_code=501, detail="beta code detected.") @@ -42,4 +57,3 @@ def handle_beta_message(self, from_: str, text: str) -> str: "%s attempted to use BETA environment but was not allowlisted.", from_ ) raise HTTPException(status_code=401, detail="Number not allowlisted.") - return prompt_text diff --git a/constants.py b/constants.py index 19085ea..b6dc6db 100644 --- a/constants.py +++ b/constants.py @@ -38,3 +38,5 @@ FAILED_PIC_SAVE = SNAPSHOT + "Couldn't save your pic, oops!" VIEW_SUBMISSIONS = SNAPSHOT + "Thanks for submitting! View all submissions here:\n{}" + +IGNORE_MESSAGE = "IGNORE_MESSAGE"