diff --git a/.github/workflows/pyright.yml b/.github/workflows/pyright.yml
new file mode 100644
index 0000000..bfad777
--- /dev/null
+++ b/.github/workflows/pyright.yml
@@ -0,0 +1,18 @@
+name: Pyright
+on: [push, pull_request]
+jobs:
+ pyright:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v4
+ with:
+ cache: 'pip'
+
+ - run: |
+ python -m venv .venv
+ source .venv/bin/activate
+ pip install -e '.[test,cli]'
+
+ - run: echo "$PWD/.venv/bin" >> $GITHUB_PATH
+ - uses: jakebailey/pyright-action@v2
diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml
new file mode 100644
index 0000000..b268138
--- /dev/null
+++ b/.github/workflows/ruff.yml
@@ -0,0 +1,8 @@
+name: Ruff
+on: [push, pull_request]
+jobs:
+ ruff:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - uses: chartboost/ruff-action@v1
diff --git a/pyproject.toml b/pyproject.toml
index 05602b2..287497e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -34,4 +34,8 @@ version_file = "pypush/_version.py"
[tool.pytest.ini_options]
minversion = "6.0"
addopts = ["-ra", "-q"]
-testpaths = ["tests"]
\ No newline at end of file
+testpaths = ["tests"]
+
+[tool.ruff.lint]
+select = ["E", "F", "B", "SIM", "I"]
+ignore = ["E501", "B010"]
\ No newline at end of file
diff --git a/pypush/apns/__init__.py b/pypush/apns/__init__.py
index ff6398a..3c954b8 100644
--- a/pypush/apns/__init__.py
+++ b/pypush/apns/__init__.py
@@ -1,5 +1,5 @@
-__all__ = ["protocol", "create_apns_connection", "activate"]
+__all__ = ["protocol", "create_apns_connection", "activate", "filters"]
-from . import protocol
-from .lifecycle import create_apns_connection
+from . import filters, protocol
from .albert import activate
+from .lifecycle import create_apns_connection
diff --git a/pypush/apns/_protocol.py b/pypush/apns/_protocol.py
index 140d9e3..bd6c4b5 100644
--- a/pypush/apns/_protocol.py
+++ b/pypush/apns/_protocol.py
@@ -3,7 +3,7 @@
import logging
from dataclasses import MISSING, field
from dataclasses import fields as dataclass_fields
-from typing import Any, TypeVar, get_origin, get_args, Union
+from typing import Any, TypeVar, Union, get_args, get_origin
from pypush.apns.transport import Packet
@@ -67,14 +67,14 @@ def from_packet(cls, packet: Packet):
)
# Check for extra fields
- for field in packet.fields:
- if field.id not in [
+ for current_field in packet.fields:
+ if current_field.id not in [
f.metadata["packet_id"]
for f in dataclass_fields(cls)
if f.metadata is not None and "packet_id" in f.metadata
]:
logging.warning(
- f"Unexpected field with packet ID {field.id} in packet {packet}"
+ f"Unexpected field with packet ID {current_field.id} in packet {packet}"
)
return cls(**field_values)
@@ -122,15 +122,15 @@ def fid(
:param byte_len: The length of the field in bytes (for int fields)
:param default: The default value of the field
"""
- if not default == MISSING and not default_factory == MISSING:
+ if default != MISSING and default_factory != MISSING:
raise ValueError("Cannot specify both default and default_factory")
- if not default == MISSING:
+ if default != MISSING:
return field(
metadata={"packet_id": packet_id, "packet_bytes": byte_len},
default=default,
repr=repr,
)
- if not default_factory == MISSING:
+ if default_factory != MISSING:
return field(
metadata={"packet_id": packet_id, "packet_bytes": byte_len},
default_factory=default_factory,
diff --git a/pypush/apns/_util.py b/pypush/apns/_util.py
index 09e9574..3564892 100644
--- a/pypush/apns/_util.py
+++ b/pypush/apns/_util.py
@@ -3,25 +3,40 @@
from typing import Generic, TypeVar
import anyio
-from anyio.abc import ObjectSendStream
+from anyio.abc import ObjectReceiveStream, ObjectSendStream
+
+from . import filters
T = TypeVar("T")
class BroadcastStream(Generic[T]):
- def __init__(self):
+ def __init__(self, backlog: int = 50):
self.streams: list[ObjectSendStream[T]] = []
+ self.backlog: list[T] = []
+ self._backlog_size = backlog
async def broadcast(self, packet):
+ logging.debug(f"Broadcasting {packet} to {len(self.streams)} streams")
for stream in self.streams:
try:
await stream.send(packet)
except anyio.BrokenResourceError:
- self.streams.remove(stream)
+ logging.error("Broken resource error")
+ # self.streams.remove(stream)
+ # If we have a backlog, add the packet to it
+ if len(self.backlog) >= self._backlog_size:
+ self.backlog.pop(0)
+ self.backlog.append(packet)
@asynccontextmanager
- async def open_stream(self):
- send, recv = anyio.create_memory_object_stream[T]()
+ async def open_stream(self, backlog: bool = True):
+ # 1000 seems like a reasonable number, if more than 1000 messages come in before someone deals with them it will
+ # start stalling the APNs connection itself
+ send, recv = anyio.create_memory_object_stream[T](max_buffer_size=1000)
+ if backlog:
+ for packet in self.backlog:
+ await send.send(packet)
self.streams.append(send)
async with recv:
yield recv
@@ -29,6 +44,31 @@ async def open_stream(self):
await send.aclose()
+W = TypeVar("W")
+F = TypeVar("F")
+
+
+class FilteredStream(ObjectReceiveStream[F]):
+ """
+ A stream that filters out unwanted items
+
+ filter should return None if the item should be filtered out, otherwise it should return the item or a modified version of it
+ """
+
+ def __init__(self, source: ObjectReceiveStream[W], filter: filters.Filter[W, F]):
+ self.source = source
+ self.filter = filter
+
+ async def receive(self) -> F:
+ async for item in self.source:
+ if (filtered := self.filter(item)) is not None:
+ return filtered
+ raise anyio.EndOfStream
+
+ async def aclose(self):
+ await self.source.aclose()
+
+
def exponential_backoff(f):
async def wrapper(*args, **kwargs):
backoff = 1
diff --git a/pypush/apns/albert.py b/pypush/apns/albert.py
index 024e449..3706807 100644
--- a/pypush/apns/albert.py
+++ b/pypush/apns/albert.py
@@ -4,7 +4,7 @@
import re
import uuid
from base64 import b64decode
-from typing import Tuple, Optional
+from typing import Optional, Tuple
import httpx
from cryptography import x509
@@ -96,10 +96,10 @@ async def activate(
try:
protocol = re.search("(.*)", resp.text).group(1) # type: ignore
- except AttributeError:
+ except AttributeError as e:
# Search for error text between and
error = re.search("(.*)", resp.text).group(1) # type: ignore
- raise Exception(f"Failed to get certificate from Albert: {error}")
+ raise Exception(f"Failed to get certificate from Albert: {error}") from e
protocol = plistlib.loads(protocol.encode("utf-8"))
diff --git a/pypush/apns/filters.py b/pypush/apns/filters.py
new file mode 100644
index 0000000..63bb784
--- /dev/null
+++ b/pypush/apns/filters.py
@@ -0,0 +1,44 @@
+import logging
+from typing import Callable, Optional, Type, TypeVar
+
+from pypush.apns import protocol
+
+T1 = TypeVar("T1")
+T2 = TypeVar("T2")
+Filter = Callable[[T1], Optional[T2]]
+
+# Chain with proper types so that subsequent filters only need to take output type of previous filter
+T_IN = TypeVar("T_IN", bound=protocol.Command)
+T_MIDDLE = TypeVar("T_MIDDLE", bound=protocol.Command)
+T_OUT = TypeVar("T_OUT", bound=protocol.Command)
+
+
+def chain(first: Filter[T_IN, T_MIDDLE], second: Filter[T_MIDDLE, T_OUT]):
+ def filter(command: T_IN) -> Optional[T_OUT]:
+ logging.debug(f"Filtering {command} with {first} and {second}")
+ filtered = first(command)
+ if filtered is None:
+ return None
+ return second(filtered)
+
+ return filter
+
+
+T = TypeVar("T", bound=protocol.Command)
+
+
+def cmd(type: Type[T]) -> Filter[protocol.Command, T]:
+ def filter(command: protocol.Command) -> Optional[T]:
+ if isinstance(command, type):
+ return command
+ return None
+
+ return filter
+
+
+def ALL(c):
+ return c
+
+
+def NONE(_):
+ return None
diff --git a/pypush/apns/lifecycle.py b/pypush/apns/lifecycle.py
index 49b4fcf..23d3f94 100644
--- a/pypush/apns/lifecycle.py
+++ b/pypush/apns/lifecycle.py
@@ -6,6 +6,7 @@
import time
import typing
from contextlib import asynccontextmanager
+from hashlib import sha1
import anyio
from anyio.abc import TaskGroup
@@ -13,7 +14,7 @@
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa
-from . import protocol, transport, _util
+from . import _util, filters, protocol, transport
@asynccontextmanager
@@ -21,13 +22,18 @@ async def create_apns_connection(
certificate: x509.Certificate,
private_key: rsa.RSAPrivateKey,
token: typing.Optional[bytes] = None,
+ sandbox: bool = False,
courier: typing.Optional[str] = None,
):
async with anyio.create_task_group() as tg:
- conn = Connection(tg, certificate, private_key, token, courier)
+ conn = Connection(
+ tg, certificate, private_key, token, sandbox, courier
+ ) # Await connected for first time here, so that base token is set
yield conn
tg.cancel_scope.cancel() # Cancel the task group when the context manager exits
- await conn.aclose() # Make sure to close the connection after the task group is cancelled
+ await (
+ conn.aclose()
+ ) # Make sure to close the connection after the task group is cancelled
class Connection:
@@ -37,26 +43,44 @@ def __init__(
certificate: x509.Certificate,
private_key: rsa.RSAPrivateKey,
token: typing.Optional[bytes] = None,
+ sandbox: bool = False,
courier: typing.Optional[str] = None,
):
-
self.certificate = certificate
self.private_key = private_key
- self.base_token = token
+ self._base_token = token
+
+ self._filters: dict[str, int] = {} # topic -> use count
+
+ self._connected = anyio.Event() # Only use for base_token property
self._conn = None
self._tg = task_group
self._broadcast = _util.BroadcastStream[protocol.Command]()
self._reconnect_lock = anyio.Lock()
+ self._send_lock = anyio.Lock()
+ self.sandbox = sandbox
if courier is None:
# Pick a random courier server from 1 to 50
- courier = f"{random.randint(1, 50)}-courier.push.apple.com"
+ courier = (
+ f"{random.randint(1, 50)}-courier.push.apple.com"
+ if not sandbox
+ else f"{random.randint(1, 10)}-courier.sandbox.push.apple.com"
+ )
+ logging.debug(f"Using courier: {courier}")
self.courier = courier
self._tg.start_soon(self.reconnect)
self._tg.start_soon(self._ping_task)
+ @property
+ async def base_token(self) -> bytes:
+ if self._base_token is None:
+ await self._connected.wait()
+ assert self._base_token is not None
+ return self._base_token
+
async def _receive_task(self):
assert self._conn is not None
async for command in self._conn:
@@ -68,8 +92,10 @@ async def _ping_task(self):
while True:
await anyio.sleep(30)
logging.debug("Sending keepalive")
- await self.send(protocol.KeepAliveCommand())
- await self.receive(protocol.KeepAliveAck)
+ await self._send(protocol.KeepAliveCommand())
+ await self._receive(
+ filters.cmd(protocol.KeepAliveAck), backlog=False
+ ) # Explicitly disable the backlog since we don't want to receive old acks
@_util.exponential_backoff
async def reconnect(self):
@@ -77,8 +103,11 @@ async def reconnect(self):
if self._conn is not None:
logging.warning("Closing existing connection")
await self._conn.aclose()
- self._conn = protocol.CommandStream(
- await transport.create_courier_connection(courier=self.courier)
+
+ self._broadcast.backlog = [] # Clear the backlog
+
+ conn = protocol.CommandStream(
+ await transport.create_courier_connection(self.sandbox, self.courier)
)
cert = self.certificate.public_bytes(serialization.Encoding.DER)
nonce = (
@@ -89,53 +118,150 @@ async def reconnect(self):
signature = b"\x01\x01" + self.private_key.sign(
nonce, padding.PKCS1v15(), hashes.SHA1()
)
- await self._conn.send(
+ await conn.send(
protocol.ConnectCommand(
- push_token=self.base_token,
+ push_token=self._base_token,
state=1,
- flags=69,
+ flags=65, # 69
certificate=cert,
nonce=nonce,
signature=signature,
)
)
+
+ # Don't set self._conn until we've sent the connect command
+ self._conn = conn
+
self._tg.start_soon(self._receive_task)
- ack = await self.receive(protocol.ConnectAck)
+ ack = await self._receive(
+ filters.chain(
+ filters.cmd(protocol.ConnectAck),
+ lambda c: (
+ c
+ if (
+ c.token == self._base_token
+ if self._base_token is not None
+ else True
+ )
+ else None
+ ),
+ )
+ )
logging.debug(f"Connected with ack: {ack}")
assert ack.status == 0
- if self.base_token is None:
- self.base_token = ack.token
+ if self._base_token is None:
+ self._base_token = ack.token
else:
- assert ack.token == self.base_token
+ assert ack.token == self._base_token
+ if not self._connected.is_set():
+ self._connected.set()
+
+ await self._update_filter()
async def aclose(self):
if self._conn is not None:
await self._conn.aclose()
# Note: Will be reopened if task group is still running and ping task is still running
- T = typing.TypeVar("T", bound=protocol.Command)
+ T = typing.TypeVar("T")
- async def receive_stream(
- self, filter: typing.Type[T], max: int = -1
- ) -> typing.AsyncIterator[T]:
- async with self._broadcast.open_stream() as stream:
+ @asynccontextmanager
+ async def _receive_stream(
+ self,
+ filter: filters.Filter[protocol.Command, T] = lambda c: c,
+ backlog: bool = True,
+ ):
+ async with self._broadcast.open_stream(backlog) as stream:
+ yield _util.FilteredStream(stream, filter)
+
+ async def _receive(
+ self, filter: filters.Filter[protocol.Command, T], backlog: bool = True
+ ):
+ async with self._receive_stream(filter, backlog) as stream:
async for command in stream:
- if isinstance(command, filter):
- yield command
- max -= 1
- if max == 0:
- break
-
- async def receive(self, filter: typing.Type[T]) -> T:
- async for command in self.receive_stream(filter, 1):
- return command
- raise ValueError("No matching command received")
+ return command
+ raise ValueError("Did not receive expected command")
- async def send(self, command: protocol.Command):
+ async def _send(self, command: protocol.Command):
try:
- assert self._conn is not None
- await self._conn.send(command)
- except Exception as e:
- logging.warning(f"Error sending command, reconnecting")
+ async with self._send_lock:
+ assert self._conn is not None
+ await self._conn.send(command)
+ except Exception:
+ logging.warning("Error sending command, reconnecting")
await self.reconnect()
- await self.send(command)
+ await self._send(command)
+
+ async def _update_filter(self):
+ await self._send(
+ protocol.FilterCommand(
+ token=await self.base_token,
+ enabled_topic_hashes=[
+ sha1(topic.encode()).digest() for topic in self._filters
+ ],
+ )
+ )
+
+ @asynccontextmanager
+ async def _filter(self, topics: list[str]):
+ for topic in topics:
+ self._filters[topic] = self._filters.get(topic, 0) + 1
+ await self._update_filter()
+ yield
+ for topic in topics:
+ self._filters[topic] -= 1
+ if self._filters[topic] == 0:
+ del self._filters[topic]
+ await self._update_filter()
+
+ async def mint_scoped_token(self, topic: str) -> bytes:
+ topic_hash = sha1(topic.encode()).digest()
+ await self._send(
+ protocol.ScopedTokenCommand(token=await self.base_token, topic=topic_hash)
+ )
+ ack = await self._receive(filters.cmd(protocol.ScopedTokenAck))
+ assert ack.status == 0
+ return ack.scoped_token
+
+ @asynccontextmanager
+ async def notification_stream(
+ self,
+ topic: str,
+ token: typing.Optional[bytes] = None,
+ filter: filters.Filter[
+ protocol.SendMessageCommand, protocol.SendMessageCommand
+ ] = filters.ALL,
+ ):
+ if token is None:
+ token = await self.base_token
+ async with self._filter([topic]), self._receive_stream(
+ filters.chain(
+ filters.chain(
+ filters.chain(
+ filters.cmd(protocol.SendMessageCommand),
+ lambda c: c if c.token == token else None,
+ ),
+ lambda c: (c if c.topic == topic else None),
+ ),
+ filter,
+ )
+ ) as stream:
+ yield stream
+
+ async def ack(self, command: protocol.SendMessageCommand, status: int = 0):
+ await self._send(
+ protocol.SendMessageAck(status=status, token=command.token, id=command.id)
+ )
+
+ async def expect_notification(
+ self,
+ topic: str,
+ token: typing.Optional[bytes] = None,
+ filter: filters.Filter[
+ protocol.SendMessageCommand, protocol.SendMessageCommand
+ ] = filters.ALL,
+ ) -> protocol.SendMessageCommand:
+ async with self.notification_stream(topic, token, filter) as stream:
+ command = await stream.receive()
+ await self.ack(command)
+ return command
diff --git a/pypush/apns/protocol.py b/pypush/apns/protocol.py
index ea0f7d3..147119c 100644
--- a/pypush/apns/protocol.py
+++ b/pypush/apns/protocol.py
@@ -2,7 +2,7 @@
from hashlib import sha1
from typing import Optional, Union
-from anyio.abc import ByteStream, ObjectStream
+from anyio.abc import ObjectStream
from pypush.apns._protocol import command, fid
from pypush.apns.transport import Packet
@@ -87,12 +87,7 @@ class FilterCommand(Command):
def _lookup_hashes(self, hashes: Optional[list[bytes]]):
return (
- [
- KNOWN_TOPICS_LOOKUP[hash] if hash in KNOWN_TOPICS_LOOKUP else hash
- for hash in hashes
- ]
- if hashes
- else []
+ [KNOWN_TOPICS_LOOKUP.get(hash, hash) for hash in hashes] if hashes else []
)
@property
@@ -140,6 +135,7 @@ class KeepAliveAck(Command):
PacketType = Packet.Type.KeepAliveAck
unknown: Optional[int] = fid(1)
+
@command
@dataclass
class SetStateCommand(Command):
@@ -182,7 +178,7 @@ def __post_init__(self):
) and not (self._token_topic_1 is not None and self._token_topic_2 is not None):
raise ValueError("topic, token, and outgoing must be set.")
- if self.outgoing == True:
+ if self.outgoing is True:
assert self.topic and self.token
self._token_topic_1 = (
sha1(self.topic.encode()).digest()
@@ -190,7 +186,7 @@ def __post_init__(self):
else self.topic
)
self._token_topic_2 = self.token
- elif self.outgoing == False:
+ elif self.outgoing is False:
assert self.topic and self.token
self._token_topic_1 = self.token
self._token_topic_2 = (
@@ -201,18 +197,14 @@ def __post_init__(self):
else:
assert self._token_topic_1 and self._token_topic_2
if len(self._token_topic_1) == 20: # SHA1 hash, topic
- self.topic = (
- KNOWN_TOPICS_LOOKUP[self._token_topic_1]
- if self._token_topic_1 in KNOWN_TOPICS_LOOKUP
- else self._token_topic_1
+ self.topic = KNOWN_TOPICS_LOOKUP.get(
+ self._token_topic_1, self._token_topic_1
)
self.token = self._token_topic_2
self.outgoing = True
else:
- self.topic = (
- KNOWN_TOPICS_LOOKUP[self._token_topic_2]
- if self._token_topic_2 in KNOWN_TOPICS_LOOKUP
- else self._token_topic_2
+ self.topic = KNOWN_TOPICS_LOOKUP.get(
+ self._token_topic_2, self._token_topic_2
)
self.token = self._token_topic_1
self.outgoing = False
@@ -229,6 +221,27 @@ class SendMessageAck(Command):
unknown6: Optional[bytes] = fid(6, default=None)
+@command
+@dataclass
+class ScopedTokenCommand(Command):
+ PacketType = Packet.Type.ScopedToken
+
+ token: bytes = fid(1)
+ topic: bytes = fid(2)
+ app_id: Optional[bytes] = fid(3, default=None)
+
+
+@command
+@dataclass
+class ScopedTokenAck(Command):
+ PacketType = Packet.Type.ScopedTokenAck
+
+ status: int = fid(1)
+ scoped_token: bytes = fid(2)
+ topic: bytes = fid(3)
+ app_id: Optional[bytes] = fid(4, default=None)
+
+
@dataclass
class UnknownCommand(Command):
id: Packet.Type
@@ -240,7 +253,7 @@ def from_packet(cls, packet: Packet):
def to_packet(self) -> Packet:
return Packet(id=self.id, fields=self.fields)
-
+
def __repr__(self):
if self.id.value in [29, 30, 32]:
return f"UnknownCommand(id={self.id}, fields=[SUPPRESSED])"
@@ -259,6 +272,8 @@ def command_from_packet(packet: Packet) -> Command:
Packet.Type.SetState: SetStateCommand,
Packet.Type.SendMessage: SendMessageCommand,
Packet.Type.SendMessageAck: SendMessageAck,
+ Packet.Type.ScopedToken: ScopedTokenCommand,
+ Packet.Type.ScopedTokenAck: ScopedTokenAck,
# Add other mappings here...
}
command_class = command_classes.get(packet.id, None)
diff --git a/pypush/apns/transport.py b/pypush/apns/transport.py
index 864f3eb..d0de86b 100644
--- a/pypush/apns/transport.py
+++ b/pypush/apns/transport.py
@@ -30,6 +30,8 @@ class Type(Enum):
KeepAlive = 12
KeepAliveAck = 13
NoStorage = 14
+ ScopedToken = 17
+ ScopedTokenAck = 18
SetState = 20
UNKNOWN = "Unknown"
@@ -38,20 +40,19 @@ def __new__(cls, value):
obj = object.__new__(cls)
obj._value_ = value
return obj
-
+
@classmethod
def _missing_(cls, value):
# Handle unknown values
instance = cls.UNKNOWN
instance._value_ = value # Assign the unknown value
return instance
-
+
def __str__(self):
if self is Packet.Type.UNKNOWN:
return f"Unknown({self._value_})"
return self.name
-
id: Type
fields: list[Field]
@@ -60,18 +61,25 @@ def fields_for_id(self, id: int) -> list[bytes]:
async def create_courier_connection(
+ sandbox: bool = False,
courier: str = "1-courier.push.apple.com",
) -> PacketStream:
context = ssl.create_default_context()
context.set_alpn_protocols(ALPN)
+ sni = "courier.sandbox.push.apple.com" if sandbox else "courier.push.apple.com"
+
# TODO: Verify courier certificate
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
return PacketStream(
await anyio.connect_tcp(
- courier, COURIER_PORT, ssl_context=context, tls_standard_compatible=False
+ courier,
+ COURIER_PORT,
+ ssl_context=context,
+ tls_standard_compatible=False,
+ tls_hostname=sni,
)
)
diff --git a/pypush/cli/__init__.py b/pypush/cli/__init__.py
index 83e70a0..1495dd0 100644
--- a/pypush/cli/__init__.py
+++ b/pypush/cli/__init__.py
@@ -1,12 +1,17 @@
+import contextlib
import logging
+from asyncio import CancelledError
+import anyio
import typer
from rich.logging import RichHandler
from typing_extensions import Annotated
+from pypush import apns
+
from . import proxy as _proxy
-logging.basicConfig(level=logging.DEBUG, handlers=[RichHandler()], format="%(message)s")
+logging.basicConfig(level=logging.INFO, handlers=[RichHandler()], format="%(message)s")
app = typer.Typer()
@@ -22,12 +27,12 @@ def proxy(
Attach requires SIP to be disabled and to be running as root
"""
-
- _proxy.main(attach)
+ with contextlib.suppress(CancelledError):
+ _proxy.main(attach)
@app.command()
-def client(
+def notifications(
topic: Annotated[str, typer.Argument(help="app topic to listen on")],
sandbox: Annotated[
bool, typer.Option("--sandbox/--production", help="APNs courier to use")
@@ -36,8 +41,29 @@ def client(
"""
Connect to the APNs courier and listen for app notifications on the given topic
"""
- typer.echo("Running APNs client")
- raise NotImplementedError("Not implemented yet")
+ logging.getLogger("httpx").setLevel(logging.WARNING)
+ with contextlib.suppress(CancelledError):
+ anyio.run(notifications_async, topic, sandbox)
+
+
+async def notifications_async(topic: str, sandbox: bool):
+ async with apns.create_apns_connection(
+ *await apns.activate(),
+ courier="1-courier.sandbox.push.apple.com"
+ if sandbox
+ else "1-courier.push.apple.com",
+ ) as connection:
+ token = await connection.mint_scoped_token(topic)
+
+ async with connection.notification_stream(topic, token) as stream:
+ logging.info(
+ f"Listening for notifications on topic {topic} ({'sandbox' if sandbox else 'production'})"
+ )
+ logging.info(f"Token: {token.hex()}")
+
+ async for notification in stream:
+ await connection.ack(notification)
+ logging.info(notification.payload.decode())
def main():
diff --git a/pypush/cli/_frida.py b/pypush/cli/_frida.py
index dc30ce5..3a71ae4 100644
--- a/pypush/cli/_frida.py
+++ b/pypush/cli/_frida.py
@@ -1,6 +1,7 @@
-import frida
import logging
+import frida
+
def attach_to_apsd() -> frida.core.Session:
frida.kill("apsd")
diff --git a/pypush/cli/proxy.py b/pypush/cli/proxy.py
index b801d43..8c43bc4 100644
--- a/pypush/cli/proxy.py
+++ b/pypush/cli/proxy.py
@@ -2,7 +2,6 @@
import logging
import ssl
import tempfile
-from typing import Optional
import anyio
import anyio.abc
@@ -12,11 +11,10 @@
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.hashes import SHA256
-from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
+from cryptography.hazmat.primitives.serialization import Encoding
# from pypush import apns
-from pypush.apns import transport
-from pypush.apns import protocol
+from pypush.apns import protocol, transport
from . import _frida
@@ -71,7 +69,7 @@ async def handle(client: TLSStream):
else "1-courier.sandbox.push.apple.com"
)
name = f"prod-{connection_cnt}" if not sandbox else f"sandbox-{connection_cnt}"
- async with await transport.create_courier_connection(forward) as conn:
+ async with await transport.create_courier_connection(sandbox, forward) as conn:
logging.debug("Connected to courier")
async with anyio.create_task_group() as tg:
tg.start_soon(forward_packets, client_pkt, conn, f"client-{name}")
diff --git a/pypush/cli/pushclient.py b/pypush/cli/pushclient.py
deleted file mode 100644
index e69de29..0000000
diff --git a/tests/assets/dev.jjtech.pypush.tests.pem b/tests/assets/dev.jjtech.pypush.tests.pem
new file mode 100644
index 0000000..0188045
--- /dev/null
+++ b/tests/assets/dev.jjtech.pypush.tests.pem
@@ -0,0 +1,75 @@
+Bag Attributes
+ friendlyName: Apple Sandbox Push Services: dev.jjtech.pypush.tests
+ localKeyID: 0A C9 4D 65 F1 39 44 73 5F A8 05 BC B9 00 47 14 2C 12 9A F3
+subject=UID=dev.jjtech.pypush.tests, CN=Apple Sandbox Push Services: dev.jjtech.pypush.tests, OU=C4492JYJR3, C=US
+issuer=CN=Apple Worldwide Developer Relations Certification Authority, OU=G4, O=Apple Inc., C=US
+-----BEGIN CERTIFICATE-----
+MIIGnzCCBYegAwIBAgIQRLQgelpeA0ozi3PDbx2ZmTANBgkqhkiG9w0BAQsFADB1
+MUQwQgYDVQQDDDtBcHBsZSBXb3JsZHdpZGUgRGV2ZWxvcGVyIFJlbGF0aW9ucyBD
+ZXJ0aWZpY2F0aW9uIEF1dGhvcml0eTELMAkGA1UECwwCRzQxEzARBgNVBAoMCkFw
+cGxlIEluYy4xCzAJBgNVBAYTAlVTMB4XDTI0MDUxNjAwMTUwM1oXDTI1MDYxNTAw
+MTUwMlowgYoxJzAlBgoJkiaJk/IsZAEBDBdkZXYuamp0ZWNoLnB5cHVzaC50ZXN0
+czE9MDsGA1UEAww0QXBwbGUgU2FuZGJveCBQdXNoIFNlcnZpY2VzOiBkZXYuamp0
+ZWNoLnB5cHVzaC50ZXN0czETMBEGA1UECwwKQzQ0OTJKWUpSMzELMAkGA1UEBhMC
+VVMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQD3BvhGnrBtXpVLVvdi
+HFHYeu58MKBD/vyw3A+a4PXnCXskSdEZDydBXJnKa1OeIqn/7TG5/6iiWGR+pcYa
+XK6kCka8fxpuWgk4/H7C2EN9Atv/XgJit3RSUFdKVN1dvG5cDX5yvFcu7xSt8J+Y
+RHuqM2YGwor1bZNUCi46n144dntB9rEV2ZgLwrHc2ofo/STbdstGKMJHkhg0GVcI
+0IzGderz1Ga1UXB8yhr+CvQthjcm74G+aQJZfuMsGwXI06wbKOJQPtCPdAD0taBW
+rdHivETxRw3WhPzmiwQLUruOmXEo5+bgl1NhnPCLJn374LWaxQEzpnW2HhP6p8mC
+TzZhAgMBAAGjggMTMIIDDzAMBgNVHRMBAf8EAjAAMB8GA1UdIwQYMBaAFFvZ+h3n
+mhoLo5l2IlCGPpHIW3eoMHAGCCsGAQUFBwEBBGQwYjAtBggrBgEFBQcwAoYhaHR0
+cDovL2NlcnRzLmFwcGxlLmNvbS93d2RyZzQuZGVyMDEGCCsGAQUFBzABhiVodHRw
+Oi8vb2NzcC5hcHBsZS5jb20vb2NzcDAzLXd3ZHJnNDAzMIIBHgYDVR0gBIIBFTCC
+AREwggENBgkqhkiG92NkBQEwgf8wgcMGCCsGAQUFBwICMIG2DIGzUmVsaWFuY2Ug
+b24gdGhpcyBjZXJ0aWZpY2F0ZSBieSBhbnkgcGFydHkgYXNzdW1lcyBhY2NlcHRh
+bmNlIG9mIHRoZSB0aGVuIGFwcGxpY2FibGUgc3RhbmRhcmQgdGVybXMgYW5kIGNv
+bmRpdGlvbnMgb2YgdXNlLCBjZXJ0aWZpY2F0ZSBwb2xpY3kgYW5kIGNlcnRpZmlj
+YXRpb24gcHJhY3RpY2Ugc3RhdGVtZW50cy4wNwYIKwYBBQUHAgEWK2h0dHBzOi8v
+d3d3LmFwcGxlLmNvbS9jZXJ0aWZpY2F0ZWF1dGhvcml0eS8wEwYDVR0lBAwwCgYI
+KwYBBQUHAwIwMgYDVR0fBCswKTAnoCWgI4YhaHR0cDovL2NybC5hcHBsZS5jb20v
+d3dkcmc0LTMuY3JsMB0GA1UdDgQWBBQKyU1l8TlEc1+oBby5AEcULBKa8zAOBgNV
+HQ8BAf8EBAMCB4Awgb8GCiqGSIb3Y2QGAwYEgbAwga0MF2Rldi5qanRlY2gucHlw
+dXNoLnRlc3RzMAcMBXRvcGljDBxkZXYuamp0ZWNoLnB5cHVzaC50ZXN0cy52b2lw
+MAYMBHZvaXAMJGRldi5qanRlY2gucHlwdXNoLnRlc3RzLmNvbXBsaWNhdGlvbjAO
+DAxjb21wbGljYXRpb24MIGRldi5qanRlY2gucHlwdXNoLnRlc3RzLnZvaXAtcHR0
+MAsMCS52b2lwLXB0dDAQBgoqhkiG92NkBgMBBAIFADANBgkqhkiG9w0BAQsFAAOC
+AQEAwQac2q1BMnAH1vdZgfDunc+b7SKO6rJIG6w/wl4211YyNBBS5oabQnQDfB8y
+8iOeWnoWXry60gI2fwWN/rRaQn4QCy72jNeTGz/T/s2jwoGj89114JjcBhRAHvQl
+/HN4QjSt5rWVRcxTE4cKKbJIqVCm7Uq9VROgbxXrmsZsRnyk1ASvLGboibtGbmty
+wmXZWns5NXNDbv1wP+PF5HSFXtDWodPYnhvzJe0s9lRvo4yGAt1KL5mNaZM3kKp0
+74kdzKK/iT7954EQK4ZWPQbDnS1A+/BzHQjK0rWTwjDQkbKvNE9bb+KJbNHH3+DX
+5s0ybZYoG5meGKUplwu7A2bfFw==
+-----END CERTIFICATE-----
+Bag Attributes
+ friendlyName: Apple Sandbox Push Services: dev.jjtech.pypush.tests Private Key
+ localKeyID: 0A C9 4D 65 F1 39 44 73 5F A8 05 BC B9 00 47 14 2C 12 9A F3
+Key Attributes:
+-----BEGIN PRIVATE KEY-----
+MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQD3BvhGnrBtXpVL
+VvdiHFHYeu58MKBD/vyw3A+a4PXnCXskSdEZDydBXJnKa1OeIqn/7TG5/6iiWGR+
+pcYaXK6kCka8fxpuWgk4/H7C2EN9Atv/XgJit3RSUFdKVN1dvG5cDX5yvFcu7xSt
+8J+YRHuqM2YGwor1bZNUCi46n144dntB9rEV2ZgLwrHc2ofo/STbdstGKMJHkhg0
+GVcI0IzGderz1Ga1UXB8yhr+CvQthjcm74G+aQJZfuMsGwXI06wbKOJQPtCPdAD0
+taBWrdHivETxRw3WhPzmiwQLUruOmXEo5+bgl1NhnPCLJn374LWaxQEzpnW2HhP6
+p8mCTzZhAgMBAAECggEBAKADb8eu+3GdFvAagVyYI5wq5Vik1uu0vFKD+cfFeQQT
+bCTxe/TTkAYSybwJEb0Zjy0spE1rgfzHbTFsiIqDBs1TqsZnPuPEhrzXMfVcyTqt
+I3yjlMAFPeAkEqcfmdUiPgp64zHHNmI8lBSoDXlAwypY6PnwArtAI3MItTFcElhX
+gWB44xVGuJRjRP4UVqXg0ML/Ic2yuYT9DRsDRilYhm8RGRSHkdZKdzCicMZcLtC7
+bs6/evmIrk9V5AzF6YiXlfT0dOp6yy9mFwhLljXF3Z2/LdrOTAmhLPQRMbUrJrcW
+ZPd0kMybGIlEoprQEA/6nZkdtIiDo2OJtufCs8g+nJECgYEA/+v4uTJzEI1igKOB
+myJtADECZAsJUaJaKSAM7VHn1hNOKgNLhUHOuroWvIWEhEomWeMvCbZIG42eOwNW
+BXGtG7ruT79E6655dljU6E/029FaxONqXXCTD9ZPh031R293KcydMwgBJJ0pvFJE
+14HWmMRAG0auPygMRhXubXU1ndMCgYEA9xpNWrl9poTjsZDNqvu60nYcq0W1escw
+ovmb87uxZ5u8fC8T1F3AVMYj4v0dTyA4F0mZenY+nri/hJBuanWVxa5Liu0fGnBr
+tEa2rzCMaajoDTNMKSygFz6CIMZbbZhozy0+9DHcRcC6b2UtIgB/+/ZQtrTvQ8Ea
+i6viarkq1nsCgYBznYAM8mynEqhoYvV/RyslBf8FgTLhjU3b/F26rODmhmwucLSi
+a9tf4ge5fTwjo3f17btnUND8mZrdICGxbex9dZKJtmgFbRn0TCdLGCwPTmIKRo7b
+zaqyYeglwSNI9WNJH+X4kuopR1L+f9AX59ExzJ8Fc4XuhEIfO3MuQeBJ/wKBgQDa
+8AgH0X/+EZJ42rcPvxiprxL5wbrpPSHf1M+T5gJqrXcUhNXJ/QMTWbekP+Y/HGn2
+YDTHZ4tWMJUoTJw4YVTBoQu33R8I2wDi6yCkGpzeZVStlXzuomZ6Ed1UUsvhT//V
+SN6VmLP1ba0CVB/oF49OXNDpAWlZm/f8NuBW9Rd6jwKBgQDi495IOjLJ8SvWRJLT
+c9AUmO7IVgipWvr51cF9IYxkzXIVIQIh1usy2NsrBxshAD+FbbWFVBfoptdKBZVK
+J8u+Ou4gTxs8SdGKGZWZpUMEKJbPsq8lE2aU3mBXiWcFRxYpu+n7nKap0Lla/xBD
+v77FY1M3FxGR6rNqPJQ9rRLFbA==
+-----END PRIVATE KEY-----
diff --git a/tests/test_apns.py b/tests/test_apns.py
index 3b24508..501e862 100644
--- a/tests/test_apns.py
+++ b/tests/test_apns.py
@@ -1,18 +1,13 @@
-import pytest
-from pypush import apns
-import asyncio
-
-# from aioapns import *
+import logging
import uuid
-import anyio
-
-# from pypush.apns import _util
-# from pypush.apns import albert, lifecycle, protocol
-from pypush import apns
+from pathlib import Path
-import logging
+import httpx
+import pytest
from rich.logging import RichHandler
+from pypush import apns
+
logging.basicConfig(level=logging.DEBUG, handlers=[RichHandler()], format="%(message)s")
@@ -26,17 +21,43 @@ async def test_activate():
@pytest.mark.asyncio
async def test_lifecycle_2():
- async with apns.create_apns_connection(
- certificate, key, courier="localhost"
- ) as connection:
- await connection.receive(
- apns.protocol.ConnectAck
- ) # Just wait until the initial connection is established. Don't do this in real code plz.
+ async with apns.create_apns_connection(certificate, key) as _:
+ pass
+
+
+ASSETS_DIR = Path(__file__).parent / "assets"
+
+
+async def send_test_notification(device_token, payload=b"hello, world"):
+ async with httpx.AsyncClient(
+ cert=str(ASSETS_DIR / "dev.jjtech.pypush.tests.pem"), http2=True
+ ) as client:
+ # Use the certificate and key from above
+ response = await client.post(
+ f"https://api.sandbox.push.apple.com/3/device/{device_token}",
+ content=payload,
+ headers={
+ "apns-topic": "dev.jjtech.pypush.tests",
+ "apns-push-type": "alert",
+ "apns-priority": "10",
+ },
+ )
+ assert response.status_code == 200
@pytest.mark.asyncio
-async def test_shorthand():
+async def test_scoped_token():
async with apns.create_apns_connection(
- *await apns.activate(), courier="localhost"
+ *await apns.activate(), sandbox=True
) as connection:
- await connection.receive(apns.protocol.ConnectAck)
+ token = await connection.mint_scoped_token("dev.jjtech.pypush.tests")
+
+ test_message = f"test-message-{uuid.uuid4().hex}"
+
+ await send_test_notification(token.hex(), test_message.encode())
+
+ await connection.expect_notification(
+ "dev.jjtech.pypush.tests",
+ token,
+ lambda c: c if c.payload == test_message.encode() else None,
+ )