Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add login spam checker API (#15838)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston authored Jun 26, 2023
1 parent 52d8131 commit 25c55a9
Show file tree
Hide file tree
Showing 7 changed files with 285 additions and 6 deletions.
1 change: 1 addition & 0 deletions changelog.d/15838.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add spam checker module API for logins.
36 changes: 36 additions & 0 deletions docs/modules/spam_checker_callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,42 @@ callback returns `False`, Synapse falls through to the next one. The value of th
callback that does not return `False` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback.


### `check_login_for_spam`

_First introduced in Synapse v1.87.0_

```python
async def check_login_for_spam(
user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
request_info: Collection[Tuple[Optional[str], str]],
auth_provider_id: Optional[str] = None,
) -> Union["synapse.module_api.NOT_SPAM", "synapse.module_api.errors.Codes"]
```

Called when a user logs in.

The arguments passed to this callback are:

* `user_id`: The user ID the user is logging in with
* `device_id`: The device ID the user is re-logging into.
* `initial_display_name`: The device display name, if any.
* `request_info`: A collection of tuples, which first item is a user agent, and which
second item is an IP address. These user agents and IP addresses are the ones that were
used during the login process.
* `auth_provider_id`: The identifier of the SSO authentication provider, if any.

If multiple modules implement this callback, they will be considered in order. If a
callback returns `synapse.module_api.NOT_SPAM`, Synapse falls through to the next one.
The value of the first callback that does not return `synapse.module_api.NOT_SPAM` will
be used. If this happens, Synapse will not call any of the subsequent implementations of
this callback.

*Note:* This will not be called when a user registers.


## Example

The example below is a module that implements the spam checker callback
Expand Down
11 changes: 11 additions & 0 deletions synapse/http/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,11 @@ def get_client_ip_if_available(self) -> str:
else:
return self.getClientAddress().host

def request_info(self) -> "RequestInfo":
h = self.getHeader(b"User-Agent")
user_agent = h.decode("ascii", "replace") if h else None
return RequestInfo(user_agent=user_agent, ip=self.get_client_ip_if_available())


class XForwardedForRequest(SynapseRequest):
"""Request object which honours proxy headers
Expand Down Expand Up @@ -661,3 +666,9 @@ def request_factory(channel: HTTPChannel, queued: bool) -> Request:

def log(self, request: SynapseRequest) -> None:
pass


@attr.s(auto_attribs=True, frozen=True, slots=True)
class RequestInfo:
user_agent: Optional[str]
ip: str
3 changes: 3 additions & 0 deletions synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
)
from synapse.module_api.callbacks.spamchecker_callbacks import (
CHECK_EVENT_FOR_SPAM_CALLBACK,
CHECK_LOGIN_FOR_SPAM_CALLBACK,
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK,
CHECK_REGISTRATION_FOR_SPAM_CALLBACK,
CHECK_USERNAME_FOR_SPAM_CALLBACK,
Expand Down Expand Up @@ -302,6 +303,7 @@ def register_spam_checker_callbacks(
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = None,
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
) -> None:
"""Registers callbacks for spam checking capabilities.
Expand All @@ -319,6 +321,7 @@ def register_spam_checker_callbacks(
check_username_for_spam=check_username_for_spam,
check_registration_for_spam=check_registration_for_spam,
check_media_file_for_spam=check_media_file_for_spam,
check_login_for_spam=check_login_for_spam,
)

def register_account_validity_callbacks(
Expand Down
80 changes: 80 additions & 0 deletions synapse/module_api/callbacks/spamchecker_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,26 @@
]
],
]
CHECK_LOGIN_FOR_SPAM_CALLBACK = Callable[
[
str,
Optional[str],
Optional[str],
Collection[Tuple[Optional[str], str]],
Optional[str],
],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
]
],
]


def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
Expand Down Expand Up @@ -315,6 +335,7 @@ def __init__(self, hs: "synapse.server.HomeServer") -> None:
self._check_media_file_for_spam_callbacks: List[
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK
] = []
self._check_login_for_spam_callbacks: List[CHECK_LOGIN_FOR_SPAM_CALLBACK] = []

def register_callbacks(
self,
Expand All @@ -335,6 +356,7 @@ def register_callbacks(
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = None,
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
) -> None:
"""Register callbacks from module for each hook."""
if check_event_for_spam is not None:
Expand Down Expand Up @@ -378,6 +400,9 @@ def register_callbacks(
if check_media_file_for_spam is not None:
self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam)

if check_login_for_spam is not None:
self._check_login_for_spam_callbacks.append(check_login_for_spam)

@trace
async def check_event_for_spam(
self, event: "synapse.events.EventBase"
Expand Down Expand Up @@ -819,3 +844,58 @@ async def check_media_file_for_spam(
return synapse.api.errors.Codes.FORBIDDEN, {}

return self.NOT_SPAM

async def check_login_for_spam(
self,
user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
request_info: Collection[Tuple[Optional[str], str]],
auth_provider_id: Optional[str] = None,
) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
"""Checks if we should allow the given registration request.
Args:
user_id: The request user ID
request_info: List of tuples of user agent and IP that
were used during the registration process.
auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml",
"cas". If any. Note this does not include users registered
via a password provider.
Returns:
Enum for how the request should be handled
"""

for callback in self._check_login_for_spam_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation(
callback(
user_id,
device_id,
initial_display_name,
request_info,
auth_provider_id,
)
)
# Normalize return values to `Codes` or `"NOT_SPAM"`.
if res is self.NOT_SPAM:
continue
elif isinstance(res, synapse.api.errors.Codes):
return res, {}
elif (
isinstance(res, tuple)
and len(res) == 2
and isinstance(res[0], synapse.api.errors.Codes)
and isinstance(res[1], dict)
):
return res
else:
logger.warning(
"Module returned invalid value, rejecting login as spam"
)
return synapse.api.errors.Codes.FORBIDDEN, {}

return self.NOT_SPAM
52 changes: 48 additions & 4 deletions synapse/rest/client/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
parse_json_object_from_request,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.http.site import RequestInfo, SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
Expand Down Expand Up @@ -114,6 +114,7 @@ def __init__(self, hs: "HomeServer"):
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self._sso_handler = hs.get_sso_handler()
self._spam_checker = hs.get_module_api_callbacks().spam_checker

self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter(
Expand Down Expand Up @@ -197,6 +198,8 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
self._refresh_tokens_enabled and client_requested_refresh_token
)

request_info = request.request_info()

try:
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
requester = await self.auth.get_user_by_req(request)
Expand All @@ -216,6 +219,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
login_submission,
appservice,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
elif (
self.jwt_enabled
Expand All @@ -227,6 +231,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
result = await self._do_jwt_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
await self._address_ratelimiter.ratelimit(
Expand All @@ -235,6 +240,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
result = await self._do_token_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
else:
await self._address_ratelimiter.ratelimit(
Expand All @@ -243,6 +249,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
result = await self._do_other_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
Expand All @@ -265,6 +272,8 @@ async def _do_appservice_login(
login_submission: JsonDict,
appservice: ApplicationService,
should_issue_refresh_token: bool = False,
*,
request_info: RequestInfo,
) -> LoginResponse:
identifier = login_submission.get("identifier")
logger.info("Got appservice login request with identifier: %r", identifier)
Expand Down Expand Up @@ -300,10 +309,15 @@ async def _do_appservice_login(
# The user represented by an appservice's configured sender_localpart
# is not actually created in Synapse.
should_check_deactivated=qualified_user_id != appservice.sender,
request_info=request_info,
)

async def _do_other_login(
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
self,
login_submission: JsonDict,
should_issue_refresh_token: bool = False,
*,
request_info: RequestInfo,
) -> LoginResponse:
"""Handle non-token/saml/jwt logins
Expand Down Expand Up @@ -333,6 +347,7 @@ async def _do_other_login(
login_submission,
callback,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
return result

Expand All @@ -347,6 +362,8 @@ async def _complete_login(
should_issue_refresh_token: bool = False,
auth_provider_session_id: Optional[str] = None,
should_check_deactivated: bool = True,
*,
request_info: RequestInfo,
) -> LoginResponse:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
Expand All @@ -371,6 +388,7 @@ async def _complete_login(
This exists purely for appservice's configured sender_localpart
which doesn't have an associated user in the database.
request_info: The user agent/IP address of the user.
Returns:
Dictionary of account information after successful login.
Expand Down Expand Up @@ -417,6 +435,22 @@ async def _complete_login(
)

initial_display_name = login_submission.get("initial_device_display_name")
spam_check = await self._spam_checker.check_login_for_spam(
user_id,
device_id=device_id,
initial_display_name=initial_display_name,
request_info=[(request_info.user_agent, request_info.ip)],
auth_provider_id=auth_provider_id,
)
if spam_check != self._spam_checker.NOT_SPAM:
logger.info("Blocking login due to spam checker")
raise SynapseError(
403,
msg="Login was blocked by the server",
errcode=spam_check[0],
additional_fields=spam_check[1],
)

(
device_id,
access_token,
Expand Down Expand Up @@ -451,7 +485,11 @@ async def _complete_login(
return result

async def _do_token_login(
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
self,
login_submission: JsonDict,
should_issue_refresh_token: bool = False,
*,
request_info: RequestInfo,
) -> LoginResponse:
"""
Handle token login.
Expand All @@ -474,10 +512,15 @@ async def _do_token_login(
auth_provider_id=res.auth_provider_id,
should_issue_refresh_token=should_issue_refresh_token,
auth_provider_session_id=res.auth_provider_session_id,
request_info=request_info,
)

async def _do_jwt_login(
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
self,
login_submission: JsonDict,
should_issue_refresh_token: bool = False,
*,
request_info: RequestInfo,
) -> LoginResponse:
"""
Handle the custom JWT login.
Expand All @@ -496,6 +539,7 @@ async def _do_jwt_login(
login_submission,
create_non_existent_users=True,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)


Expand Down
Loading

0 comments on commit 25c55a9

Please sign in to comment.