Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite: APNs from the ground up #100

Merged
merged 30 commits into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
44073df
initial structure
JJTech0130 May 11, 2024
b3374d5
restructure to use namespaces
JJTech0130 May 11, 2024
bbb7fbb
found a better way that doesn't require .vscode config
JJTech0130 May 11, 2024
c789465
implement basics of APNs with asyncio
JJTech0130 May 13, 2024
4b33e5e
add test for async with API
JJTech0130 May 13, 2024
ada147e
add folder structure for planned packages
JJTech0130 May 13, 2024
844c015
restructure under single pip package
JJTech0130 May 13, 2024
f0a99e7
readme: update installation instructions
JJTech0130 May 13, 2024
8119f54
add setuptools scm
JJTech0130 May 13, 2024
25e15cc
packaging: make setuptools generate _version.py
JJTech0130 May 13, 2024
e8c8928
apns: start anyio refactor
JJTech0130 May 14, 2024
037b477
wip: forwarding cli tool
JJTech0130 May 14, 2024
49b38ea
more refactoring
JJTech0130 May 15, 2024
866c3c4
protocol: automatically generate packet conversion
JJTech0130 May 15, 2024
354a402
depend on rich
JJTech0130 May 15, 2024
9239fc0
apns: implement more high level commands
JJTech0130 May 15, 2024
a7e59be
apns: wrap with CommandStream
JJTech0130 May 15, 2024
4dca8b6
apns: proxy at raw packet level, so can recover if parsing fails
JJTech0130 May 15, 2024
4905023
apns: try to decode topic when parsing SendMessage packet
JJTech0130 May 15, 2024
67b2a4c
cli: refactor
JJTech0130 May 16, 2024
bb727af
proxy: use SNI rather than localhost addresses
JJTech0130 May 16, 2024
897180f
apns: refactoring, new lifecycle management
JJTech0130 May 17, 2024
121c530
apns: lifecycle improvements
JJTech0130 May 17, 2024
8221d04
apns: refactor new api out of new
JJTech0130 May 17, 2024
468a65b
test: remove dep on aioapns
JJTech0130 May 17, 2024
33cc7e5
apns: _protocol.py: clean up and rename @auto_packet -> @command
JJTech0130 May 18, 2024
c1c061b
apns: protocol.py: handle unknown Type values better
JJTech0130 May 18, 2024
1973d5c
apns: fix minor type checking error
JJTech0130 May 18, 2024
72ee9a6
apns: protocol.py: suppress __repr__ for 29, 30, and 32 PubSub comman…
JJTech0130 May 18, 2024
1c10e01
bump minimum Python to 3.9 to reflect actual testing
JJTech0130 May 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
apns: start anyio refactor
  • Loading branch information
JJTech0130 committed May 14, 2024
commit e8c8928add74c1e0a86764d871a9264b9dd44714
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ build-backend = "setuptools.build_meta"
name = "pypush"
dynamic = ["version"]
dependencies = [
"anyio",
"httpx",
"cryptography",
'importlib_metadata; python_version>="3.8"',
Expand Down
321 changes: 213 additions & 108 deletions pypush/apns/connection.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import asyncio
import typing
import random
import plistlib
import anyio.streams
import anyio.streams.tls
import httpx
import ssl
import time
import logging
import base64
import anyio
#from anyio.
from anyio.abc import TaskGroup

from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization, hashes
Expand Down Expand Up @@ -34,44 +38,53 @@ def __init__(
private_key: rsa.RSAPrivateKey,
token: typing.Union[bytes, None] = None,
):
self._incoming_queue: typing.List[payload.Payload] = []
self._queue_event = asyncio.Event()

self._tasks: typing.List[asyncio.Task] = []

self.connected = False
"""
Create a new APNs connection
Please use the async context manager to manage the connection

:param certificate: activation certificate from Albert to authenticate with APNs
:param private_key: private key for the activation certificate

:param token: optional token root token to use when connecting
"""
self.certificate = certificate
self.private_key = private_key
self.token = token

self._incoming_queue: typing.List[payload.Payload] = []
self._queue_event = anyio.Event()

async def _send(self, payload: payload.Payload):
await payload.write_to_stream(self._writer)

async def _receive(
self, id: int, filter: typing.Callable[[payload.Payload], bool] = lambda x: True
):
while True:
# TODO: Come up with a more efficient way to search for messages
for message in self._incoming_queue:
if message.id == id and filter(message):
# remove the message from the queue and return it
self._incoming_queue.remove(message)
return message

# If no messages were found, wait for the queue to be updated
await self._queue_event.wait()
self._filters: dict[str, int] = {}
self._socket = None

async def _queue_messages(self):
while True:
self._incoming_queue.append(
await payload.Payload.read_from_stream(self._reader)
)
assert self._socket is not None
try:
self._incoming_queue.append(
await payload.Payload.read_from_stream(self._socket)
)
except:
# Reconnect if the connection is dropped
await self._connect(None)
continue
self._queue_event.set()
self._queue_event.clear()
self._queue_event = anyio.Event()

async def connect(self, reconnect: bool = True):
# TODO: Implement auto reconnection
self._event_loop = asyncio.get_event_loop()
async def _ping(self):
while True:
await anyio.sleep(60)
try:
await self._send_packet(payload.Payload(0x01, [])) # TODO: Is this correct
except:
# Reconnect if the connection is dropped
await self._connect(None)
continue

async def _connect(self, task_group: typing.Union[TaskGroup, None]):
# If task_group is None, don't spawn background tasks, assume they will continue from a previous connection
# Must be able to call this function repeatedly to reconnect the socket if dropped
assert self._socket is None or task_group is None # Either this is a fresh connection (socket is None) or we are reconnecting (task_group is None)

context = ssl.create_default_context()
context.set_alpn_protocols(ALPN)
Expand All @@ -80,92 +93,184 @@ async def connect(self, reconnect: bool = True):
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE

self._reader, self._writer = await asyncio.open_connection(
COURIER_HOST, COURIER_PORT, ssl=context
)
self.connected = True
self._socket = await anyio.connect_tcp(COURIER_HOST, COURIER_PORT, ssl_context=context)

self._tasks.append(self._event_loop.create_task(self._queue_messages()))
# TODO: Send authenticated connect packet
# TODO: Send set state packet
await self._update_filters()

await self._connect_pkt(self.certificate, self.private_key, self.token)
await self._state_pkt(0x01)
if task_group is not None:
task_group.start_soon(self._queue_messages)

async def aclose(self):
self.connected = False
for task in self._tasks:
task.cancel()
self._writer.close()
await self._writer.wait_closed()

async def _aclose(self):
if self._socket is not None:
await self._socket.aclose()
self._socket = None

async def __aenter__(self):
await self.connect()
return self

async def __aexit__(self, exc_type, exc, tb):
await self.aclose()

async def _state_pkt(self, state: int):
log.debug(f"Sending state message with state {state}")
await self._send(
payload.Payload(
0x14,
[
payload.Field(1, state.to_bytes(1, "big")),
payload.Field(2, 0x7FFFFFFF.to_bytes(4, "big")),
],
)
)

async def _connect_pkt(
self,
certificate: x509.Certificate,
private_key: rsa.RSAPrivateKey,
token: typing.Union[bytes, None],
# self._tg must be managed using __aenter__ and __aexit__ to ensure it is properly closed
self._tg = anyio.create_task_group()
await self._tg.__aenter__()

await self._connect(self._tg)

async def __aexit__(self, exc_type, exc_value, traceback):
await self._tg.__aexit__(exc_type, exc_value, traceback)

await self._aclose()

async def _send_packet(self, payload: payload.Payload):
if self._socket is None:
raise Exception("Not connected")
await payload.write_to_stream(self._socket)

async def _receive_packet(
self, filter: typing.Callable[[payload.Payload], bool] = lambda x: True
):
flags = 0b01000001 # TODO: Root/sub-connection flags

cert = certificate.public_bytes(serialization.Encoding.DER)
nonce = (
b"\x00" + int(time.time() * 1000).to_bytes(8, "big") + random.randbytes(8)
)
signature = b"\x01\x01" + private_key.sign(
nonce, padding.PKCS1v15(), hashes.SHA1()
)

p = payload.Payload(
7,
[
payload.Field(0x2, b"\x01"),
payload.Field(0x5, flags.to_bytes(4, "big")),
payload.Field(0xC, cert),
payload.Field(0xD, nonce),
payload.Field(0xE, signature),
# TODO: Metrics/optional device info fields
],
)

if token:
p.fields.insert(0, payload.Field(0x1, token))

await self._send(p)

resp = await self._receive(8)

if resp.fields_with_id(1)[0].value != b"\x00":
raise Exception("Failed to connect")

if len(resp.fields_with_id(3)) > 0:
new_token = resp.fields_with_id(3)[0].value
else:
if token is None:
raise Exception("No token received")
new_token = token
while True:
# TODO: Come up with a more efficient way to search for messages
for message in self._incoming_queue:
if message.id == id and filter(message):
# remove the message from the queue and return it
self._incoming_queue.remove(message)
return message

log.debug(
f"Received connect response with token {base64.b64encode(new_token).decode()}"
)
# If no messages were found, wait for the queue to be updated
await self._queue_event.wait()

return new_token
async def _update_filters(self):
# TODO: Send filter packet with filter list
pass

async def _add_filter(self, filter: str):
if filter in self._filters:
self._filters[filter] += 1
else:
self._filters[filter] = 1
await self._update_filters()

async def _remove_filter(self, filter: str):
if filter in self._filters:
self._filters[filter] -= 1
if self._filters[filter] == 0:
del self._filters[filter]
await self._update_filters()

async def


#@property
#def connected(self):
# return self._socket is not None

# async def __aenter__(self, reconnect: bool = True):

# async with httpx.AsyncClient() as client:
# response = await client.get("http://init-p01st.push.apple.com/bag")
# APNS_CONFIG = plistlib.loads(plistlib.loads(response.content)["bag"])
# # TODO: Implement auto reconnection
# self._event_loop = asyncio.get_event_loop()

# context = ssl.create_default_context()
# context.set_alpn_protocols(ALPN)

# # TODO: Verify courier certificate
# context.check_hostname = False
# context.verify_mode = ssl.CERT_NONE

# self._socket = await anyio.connect_tcp(COURIER_HOST, COURIER_PORT, ssl_context=context)
# self._socket.
# self._socket.__aenter__
# self._tg = anyio.create_task_group()
# await self._tg.__aenter__()
# # Create a task group to manage the queue and keepalive tasks
# async with anyio.create_task_group() as tg:
# tg.start_soon(self._queue_messages)

# await self._connect_pkt(self.certificate, self.private_key, self.token)
# await self._state_pkt(0x01)

# self._tasks.append(self._event_loop.create_task(self._queue_messages()))

# await self._connect_pkt(self.certificate, self.private_key, self.token)
# await self._state_pkt(0x01)






# async def _queue_messages(self):
# while True:
# self._incoming_queue.append(
# await payload.Payload.read_from_stream(self._reader)
# )
# self._queue_event.set()
# self._queue_event.clear()

# async def _state_pkt(self, state: int):
# log.debug(f"Sending state message with state {state}")
# await self._send(
# payload.Payload(
# 0x14,
# [
# payload.Field(1, state.to_bytes(1, "big")),
# payload.Field(2, 0x7FFFFFFF.to_bytes(4, "big")),
# ],
# )
# )

# async def _connect_pkt(
# self,
# certificate: x509.Certificate,
# private_key: rsa.RSAPrivateKey,
# token: typing.Union[bytes, None],
# ):
# flags = 0b01000001 # TODO: Root/sub-connection flags

# cert = certificate.public_bytes(serialization.Encoding.DER)
# nonce = (
# b"\x00" + int(time.time() * 1000).to_bytes(8, "big") + random.randbytes(8)
# )
# signature = b"\x01\x01" + private_key.sign(
# nonce, padding.PKCS1v15(), hashes.SHA1()
# )

# p = payload.Payload(
# 7,
# [
# payload.Field(0x2, b"\x01"),
# payload.Field(0x5, flags.to_bytes(4, "big")),
# payload.Field(0xC, cert),
# payload.Field(0xD, nonce),
# payload.Field(0xE, signature),
# # TODO: Metrics/optional device info fields
# ],
# )

# if token:
# p.fields.insert(0, payload.Field(0x1, token))

# await self._send(p)

# resp = await self._receive(8)

# if resp.fields_with_id(1)[0].value != b"\x00":
# raise Exception("Failed to connect")

# if len(resp.fields_with_id(3)) > 0:
# new_token = resp.fields_with_id(3)[0].value
# else:
# if token is None:
# raise Exception("No token received")
# new_token = token

# log.debug(
# f"Received connect response with token {base64.b64encode(new_token).decode()}"
# )

# return new_token


# TODO: Implement sub-connections
Loading