diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py index cc3acf51e12..f8a95607844 100644 --- a/synapse/media/thumbnailer.py +++ b/synapse/media/thumbnailer.py @@ -359,9 +359,10 @@ async def select_or_generate_remote_thumbnail( desired_method: str, desired_type: str, max_timeout_ms: int, + ip_address: str, ) -> None: media_info = await self.media_repo.get_remote_media_info( - server_name, media_id, max_timeout_ms + server_name, media_id, max_timeout_ms, ip_address ) if not media_info: respond_404(request) @@ -422,12 +423,13 @@ async def respond_remote_thumbnail( method: str, m_type: str, max_timeout_ms: int, + ip_address: str, ) -> None: # TODO: Don't download the whole remote file # We should proxy the thumbnail from the remote server instead of # downloading the remote file and generating our own thumbnails. media_info = await self.media_repo.get_remote_media_info( - server_name, media_id, max_timeout_ms + server_name, media_id, max_timeout_ms, ip_address ) if not media_info: return diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py index 172d2407838..0c089163c13 100644 --- a/synapse/rest/client/media.py +++ b/synapse/rest/client/media.py @@ -174,6 +174,7 @@ async def on_GET( respond_404(request) return + ip_address = request.getClientAddress().host remote_resp_function = ( self.thumbnailer.select_or_generate_remote_thumbnail if self.dynamic_thumbnails @@ -188,6 +189,7 @@ async def on_GET( method, m_type, max_timeout_ms, + ip_address, ) self.media_repo.mark_recently_accessed(server_name, media_id) diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py index 13705c87b4d..5e8d327e8d4 100644 --- a/synapse/rest/media/thumbnail_resource.py +++ b/synapse/rest/media/thumbnail_resource.py @@ -72,6 +72,7 @@ async def on_GET( ) -> None: # Validate the server name, raising if invalid parse_and_validate_server_name(server_name) + set_cors_headers(request) set_corp_headers(request) width = parse_integer(request, "width", required=True) @@ -121,4 +122,3 @@ async def on_GET( ip_address, ) self.media_repo.mark_recently_accessed(server_name, media_id) -