Skip to content

Commit

Permalink
Add support for authenticated media downloads
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Jul 9, 2024
1 parent 55c53e0 commit 1dbdc3f
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 6 deletions.
8 changes: 7 additions & 1 deletion mautrix/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ def get_download_url(
mxc_uri: str,
download_type: Literal["download", "thumbnail"] = "download",
file_name: str | None = None,
authenticated: bool = False,
) -> URL:
"""
Get the full HTTP URL to download a ``mxc://`` URI.
Expand All @@ -470,6 +471,7 @@ def get_download_url(
mxc_uri: The MXC URI whose full URL to get.
download_type: The type of download ("download" or "thumbnail").
file_name: Optionally, a file name to include in the download URL.
authenticated: Whether to use the new authenticated download endpoint in Matrix v1.11.
Returns:
The full HTTP URL.
Expand All @@ -485,7 +487,11 @@ def get_download_url(
"https://matrix-client.matrix.org/_matrix/media/v3/download/matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6/hello.png"
"""
server_name, media_id = self.parse_mxc_uri(mxc_uri)
url = self.base_url / str(APIPath.MEDIA) / "v3" / download_type / server_name / media_id
if authenticated:
url = self.base_url / str(APIPath.CLIENT) / "v1" / "media"
else:
url = self.base_url / str(APIPath.MEDIA) / "v3"
url = url / download_type / server_name / media_id
if file_name:
url /= file_name
return url
Expand Down
2 changes: 2 additions & 0 deletions mautrix/appservice/api/intent.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def __init__(
) -> None:
super().__init__(mxid=mxid, api=api, state_store=state_store)
self.bot = bot
if bot is not None:
self.versions_cache = bot.versions_cache
self.log = api.base_log.getChild("intent")

for method in ENSURE_REGISTERED_METHODS:
Expand Down
2 changes: 1 addition & 1 deletion mautrix/appservice/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from aiohttp import web
import aiohttp

from mautrix.types import JSON, RoomAlias, UserID
from mautrix.types import JSON, RoomAlias, UserID, VersionsResponse
from mautrix.util.logging import TraceLogger

from ..api import HTTPAPI
Expand Down
21 changes: 17 additions & 4 deletions mautrix/client/api/modules/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import asyncio
import time

from yarl import URL

from mautrix import __optional_imports__
from mautrix.api import MediaPath, Method
from mautrix.errors import MatrixResponseError, make_request_error
Expand All @@ -19,6 +21,7 @@
MediaRepoConfig,
MXOpenGraph,
SerializerError,
SpecVersions,
)
from mautrix.util import background_task
from mautrix.util.async_body import async_iter_bytes
Expand Down Expand Up @@ -178,13 +181,17 @@ async def download_media(self, url: ContentURI, timeout_ms: int | None = None) -
Returns:
The raw downloaded data.
"""
url = self.api.get_download_url(url)
authenticated = (await self.versions()).supports(SpecVersions.V111)
url = self.api.get_download_url(url, authenticated=authenticated)
query_params: dict[str, Any] = {"allow_redirect": "true"}
if timeout_ms is not None:
query_params["timeout_ms"] = timeout_ms
headers: dict[str, str] = {}
if authenticated:
headers["Authorization"] = f"Bearer {self.api.token}"
req_id = self.api.log_download_request(url, query_params)
start = time.monotonic()
async with self.api.session.get(url, params=query_params) as response:
async with self.api.session.get(url, params=query_params, headers=headers) as response:
try:
response.raise_for_status()
return await response.read()
Expand Down Expand Up @@ -223,7 +230,10 @@ async def download_thumbnail(
Returns:
The raw downloaded data.
"""
url = self.api.get_download_url(url, download_type="thumbnail")
authenticated = (await self.versions()).supports(SpecVersions.V111)
url = self.api.get_download_url(
url, download_type="thumbnail", authenticated=authenticated
)
query_params: dict[str, Any] = {"allow_redirect": "true"}
if width is not None:
query_params["width"] = width
Expand All @@ -235,9 +245,12 @@ async def download_thumbnail(
query_params["allow_remote"] = str(allow_remote).lower()
if timeout_ms is not None:
query_params["timeout_ms"] = timeout_ms
headers: dict[str, str] = {}
if authenticated:
headers["Authorization"] = f"Bearer {self.api.token}"
req_id = self.api.log_download_request(url, query_params)
start = time.monotonic()
async with self.api.session.get(url, params=query_params) as response:
async with self.api.session.get(url, params=query_params, headers=headers) as response:
try:
response.raise_for_status()
return await response.read()
Expand Down
4 changes: 4 additions & 0 deletions mautrix/types/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class SpecVersions:
V15 = Version.deserialize("v1.5")
V16 = Version.deserialize("v1.6")
V17 = Version.deserialize("v1.7")
V18 = Version.deserialize("v1.8")
V19 = Version.deserialize("v1.9")
V110 = Version.deserialize("v1.10")
V111 = Version.deserialize("v1.11")


@dataclass
Expand Down

0 comments on commit 1dbdc3f

Please sign in to comment.