Add some internal documenation
This commit is contained in:
parent
44dc2d06c6
commit
464dc23ff0
@ -46,7 +46,7 @@ class StringMessage(RNS.MessageBase):
|
|||||||
# message arrives over the channel.
|
# message arrives over the channel.
|
||||||
#
|
#
|
||||||
# MSGTYPE must be unique across all message types we
|
# MSGTYPE must be unique across all message types we
|
||||||
# register with the channel. MSGTYPEs >= 0xff00 are
|
# register with the channel. MSGTYPEs >= 0xf000 are
|
||||||
# reserved for the system.
|
# reserved for the system.
|
||||||
MSGTYPE = 0x0101
|
MSGTYPE = 0x0101
|
||||||
|
|
||||||
@ -159,17 +159,36 @@ def client_disconnected(link):
|
|||||||
RNS.log("Client disconnected")
|
RNS.log("Client disconnected")
|
||||||
|
|
||||||
def server_message_received(message):
|
def server_message_received(message):
|
||||||
|
"""
|
||||||
|
A message handler
|
||||||
|
@param message: An instance of a subclass of MessageBase
|
||||||
|
@return: True if message was handled
|
||||||
|
"""
|
||||||
global latest_client_link
|
global latest_client_link
|
||||||
|
|
||||||
# When a message is received over any active link,
|
# When a message is received over any active link,
|
||||||
# the replies will all be directed to the last client
|
# the replies will all be directed to the last client
|
||||||
# that connected.
|
# that connected.
|
||||||
|
|
||||||
|
# In a message handler, any deserializable message
|
||||||
|
# that arrives over the link's channel will be passed
|
||||||
|
# to all message handlers, unless a preceding handler indicates it
|
||||||
|
# has handled the message.
|
||||||
|
#
|
||||||
|
#
|
||||||
if isinstance(message, StringMessage):
|
if isinstance(message, StringMessage):
|
||||||
RNS.log("Received data on the link: " + message.data + " (message created at " + str(message.timestamp) + ")")
|
RNS.log("Received data on the link: " + message.data + " (message created at " + str(message.timestamp) + ")")
|
||||||
|
|
||||||
reply_message = StringMessage("I received \""+message.data+"\" over the link")
|
reply_message = StringMessage("I received \""+message.data+"\" over the link")
|
||||||
latest_client_link.get_channel().send(reply_message)
|
latest_client_link.get_channel().send(reply_message)
|
||||||
|
|
||||||
|
# Incoming messages are sent to each message
|
||||||
|
# handler added to the channel, in the order they
|
||||||
|
# were added.
|
||||||
|
# If any message handler returns True, the message
|
||||||
|
# is considered handled and any subsequent
|
||||||
|
# handlers are skipped.
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
##########################################################
|
##########################################################
|
||||||
#### Client Part #########################################
|
#### Client Part #########################################
|
||||||
|
109
RNS/Channel.py
109
RNS/Channel.py
@ -14,6 +14,13 @@ TPacket = TypeVar("TPacket")
|
|||||||
|
|
||||||
|
|
||||||
class ChannelOutletBase(ABC, Generic[TPacket]):
|
class ChannelOutletBase(ABC, Generic[TPacket]):
|
||||||
|
"""
|
||||||
|
An abstract transport layer interface used by Channel.
|
||||||
|
|
||||||
|
DEPRECATED: This was created for testing; eventually
|
||||||
|
Channel will use Link or a LinkBase interface
|
||||||
|
directly.
|
||||||
|
"""
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def send(self, raw: bytes) -> TPacket:
|
def send(self, raw: bytes) -> TPacket:
|
||||||
raise NotImplemented()
|
raise NotImplemented()
|
||||||
@ -64,6 +71,9 @@ class ChannelOutletBase(ABC, Generic[TPacket]):
|
|||||||
|
|
||||||
|
|
||||||
class CEType(enum.IntEnum):
|
class CEType(enum.IntEnum):
|
||||||
|
"""
|
||||||
|
ChannelException type codes
|
||||||
|
"""
|
||||||
ME_NO_MSG_TYPE = 0
|
ME_NO_MSG_TYPE = 0
|
||||||
ME_INVALID_MSG_TYPE = 1
|
ME_INVALID_MSG_TYPE = 1
|
||||||
ME_NOT_REGISTERED = 2
|
ME_NOT_REGISTERED = 2
|
||||||
@ -73,12 +83,18 @@ class CEType(enum.IntEnum):
|
|||||||
|
|
||||||
|
|
||||||
class ChannelException(Exception):
|
class ChannelException(Exception):
|
||||||
|
"""
|
||||||
|
An exception thrown by Channel, with a type code.
|
||||||
|
"""
|
||||||
def __init__(self, ce_type: CEType, *args):
|
def __init__(self, ce_type: CEType, *args):
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
self.type = ce_type
|
self.type = ce_type
|
||||||
|
|
||||||
|
|
||||||
class MessageState(enum.IntEnum):
|
class MessageState(enum.IntEnum):
|
||||||
|
"""
|
||||||
|
Set of possible states for a Message
|
||||||
|
"""
|
||||||
MSGSTATE_NEW = 0
|
MSGSTATE_NEW = 0
|
||||||
MSGSTATE_SENT = 1
|
MSGSTATE_SENT = 1
|
||||||
MSGSTATE_DELIVERED = 2
|
MSGSTATE_DELIVERED = 2
|
||||||
@ -86,14 +102,29 @@ class MessageState(enum.IntEnum):
|
|||||||
|
|
||||||
|
|
||||||
class MessageBase(abc.ABC):
|
class MessageBase(abc.ABC):
|
||||||
|
"""
|
||||||
|
Base type for any messages sent or received on a Channel.
|
||||||
|
Subclasses must define the two abstract methods as well as
|
||||||
|
the MSGTYPE class variable.
|
||||||
|
"""
|
||||||
|
# MSGTYPE must be unique within all classes sent over a
|
||||||
|
# channel. Additionally, MSGTYPE > 0xf000 are reserved.
|
||||||
MSGTYPE = None
|
MSGTYPE = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def pack(self) -> bytes:
|
def pack(self) -> bytes:
|
||||||
|
"""
|
||||||
|
Create and return the binary representation of the message
|
||||||
|
@return: binary representation of message
|
||||||
|
"""
|
||||||
raise NotImplemented()
|
raise NotImplemented()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def unpack(self, raw):
|
def unpack(self, raw):
|
||||||
|
"""
|
||||||
|
Populate message from binary representation
|
||||||
|
@param raw: binary representation
|
||||||
|
"""
|
||||||
raise NotImplemented()
|
raise NotImplemented()
|
||||||
|
|
||||||
|
|
||||||
@ -101,6 +132,10 @@ MessageCallbackType = NewType("MessageCallbackType", Callable[[MessageBase], boo
|
|||||||
|
|
||||||
|
|
||||||
class Envelope:
|
class Envelope:
|
||||||
|
"""
|
||||||
|
Internal wrapper used to transport messages over a channel and
|
||||||
|
track its state within the channel framework.
|
||||||
|
"""
|
||||||
def unpack(self, message_factories: dict[int, Type]) -> MessageBase:
|
def unpack(self, message_factories: dict[int, Type]) -> MessageBase:
|
||||||
msgtype, self.sequence, length = struct.unpack(">HHH", self.raw[:6])
|
msgtype, self.sequence, length = struct.unpack(">HHH", self.raw[:6])
|
||||||
raw = self.raw[6:]
|
raw = self.raw[6:]
|
||||||
@ -131,6 +166,12 @@ class Envelope:
|
|||||||
|
|
||||||
|
|
||||||
class Channel(contextlib.AbstractContextManager):
|
class Channel(contextlib.AbstractContextManager):
|
||||||
|
"""
|
||||||
|
Channel provides reliable delivery of messages over
|
||||||
|
a link. Channel is not meant to be instantiated
|
||||||
|
directly, but rather obtained from a Link using the
|
||||||
|
get_channel() function.
|
||||||
|
"""
|
||||||
def __init__(self, outlet: ChannelOutletBase):
|
def __init__(self, outlet: ChannelOutletBase):
|
||||||
self._outlet = outlet
|
self._outlet = outlet
|
||||||
self._lock = threading.RLock()
|
self._lock = threading.RLock()
|
||||||
@ -146,10 +187,14 @@ class Channel(contextlib.AbstractContextManager):
|
|||||||
|
|
||||||
def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None,
|
def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None,
|
||||||
__traceback: TracebackType | None) -> bool | None:
|
__traceback: TracebackType | None) -> bool | None:
|
||||||
self.shutdown()
|
self._shutdown()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def register_message_type(self, message_class: Type[MessageBase], *, is_system_type: bool = False):
|
def register_message_type(self, message_class: Type[MessageBase], *, is_system_type: bool = False):
|
||||||
|
"""
|
||||||
|
Register a message class for reception over a channel.
|
||||||
|
@param message_class: Class to register. Must extend MessageBase.
|
||||||
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if not issubclass(message_class, MessageBase):
|
if not issubclass(message_class, MessageBase):
|
||||||
raise ChannelException(CEType.ME_INVALID_MSG_TYPE,
|
raise ChannelException(CEType.ME_INVALID_MSG_TYPE,
|
||||||
@ -157,7 +202,7 @@ class Channel(contextlib.AbstractContextManager):
|
|||||||
if message_class.MSGTYPE is None:
|
if message_class.MSGTYPE is None:
|
||||||
raise ChannelException(CEType.ME_INVALID_MSG_TYPE,
|
raise ChannelException(CEType.ME_INVALID_MSG_TYPE,
|
||||||
f"{message_class} has invalid MSGTYPE class attribute.")
|
f"{message_class} has invalid MSGTYPE class attribute.")
|
||||||
if message_class.MSGTYPE >= 0xff00 and not is_system_type:
|
if message_class.MSGTYPE >= 0xf000 and not is_system_type:
|
||||||
raise ChannelException(CEType.ME_INVALID_MSG_TYPE,
|
raise ChannelException(CEType.ME_INVALID_MSG_TYPE,
|
||||||
f"{message_class} has system-reserved message type.")
|
f"{message_class} has system-reserved message type.")
|
||||||
try:
|
try:
|
||||||
@ -169,20 +214,34 @@ class Channel(contextlib.AbstractContextManager):
|
|||||||
self._message_factories[message_class.MSGTYPE] = message_class
|
self._message_factories[message_class.MSGTYPE] = message_class
|
||||||
|
|
||||||
def add_message_handler(self, callback: MessageCallbackType):
|
def add_message_handler(self, callback: MessageCallbackType):
|
||||||
|
"""
|
||||||
|
Add a handler for incoming messages. A handler
|
||||||
|
has the signature (message: MessageBase) -> bool.
|
||||||
|
Handlers are processed in the order they are
|
||||||
|
added. If any handler returns True, processing
|
||||||
|
of the message stops; handlers after the
|
||||||
|
returning handler will not be called.
|
||||||
|
@param callback: Function to call
|
||||||
|
@return:
|
||||||
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if callback not in self._message_callbacks:
|
if callback not in self._message_callbacks:
|
||||||
self._message_callbacks.append(callback)
|
self._message_callbacks.append(callback)
|
||||||
|
|
||||||
def remove_message_handler(self, callback: MessageCallbackType):
|
def remove_message_handler(self, callback: MessageCallbackType):
|
||||||
|
"""
|
||||||
|
Remove a handler
|
||||||
|
@param callback: handler to remove
|
||||||
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._message_callbacks.remove(callback)
|
self._message_callbacks.remove(callback)
|
||||||
|
|
||||||
def shutdown(self):
|
def _shutdown(self):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._message_callbacks.clear()
|
self._message_callbacks.clear()
|
||||||
self.clear_rings()
|
self._clear_rings()
|
||||||
|
|
||||||
def clear_rings(self):
|
def _clear_rings(self):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
for envelope in self._tx_ring:
|
for envelope in self._tx_ring:
|
||||||
if envelope.packet is not None:
|
if envelope.packet is not None:
|
||||||
@ -191,14 +250,15 @@ class Channel(contextlib.AbstractContextManager):
|
|||||||
self._tx_ring.clear()
|
self._tx_ring.clear()
|
||||||
self._rx_ring.clear()
|
self._rx_ring.clear()
|
||||||
|
|
||||||
def emplace_envelope(self, envelope: Envelope, ring: collections.deque[Envelope]) -> bool:
|
def _emplace_envelope(self, envelope: Envelope, ring: collections.deque[Envelope]) -> bool:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
i = 0
|
i = 0
|
||||||
for env in ring:
|
for existing in ring:
|
||||||
if env.sequence < envelope.sequence:
|
if existing.sequence > envelope.sequence \
|
||||||
|
and not existing.sequence // 2 > envelope.sequence: # account for overflow
|
||||||
ring.insert(i, envelope)
|
ring.insert(i, envelope)
|
||||||
return True
|
return True
|
||||||
if env.sequence == envelope.sequence:
|
if existing.sequence == envelope.sequence:
|
||||||
RNS.log(f"Envelope: Emplacement of duplicate envelope sequence.", RNS.LOG_EXTREME)
|
RNS.log(f"Envelope: Emplacement of duplicate envelope sequence.", RNS.LOG_EXTREME)
|
||||||
return False
|
return False
|
||||||
i += 1
|
i += 1
|
||||||
@ -206,7 +266,7 @@ class Channel(contextlib.AbstractContextManager):
|
|||||||
ring.append(envelope)
|
ring.append(envelope)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def prune_rx_ring(self):
|
def _prune_rx_ring(self):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
# Implementation for fixed window = 1
|
# Implementation for fixed window = 1
|
||||||
stale = list(sorted(self._rx_ring, key=lambda env: env.sequence, reverse=True))[1:]
|
stale = list(sorted(self._rx_ring, key=lambda env: env.sequence, reverse=True))[1:]
|
||||||
@ -225,13 +285,13 @@ class Channel(contextlib.AbstractContextManager):
|
|||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
RNS.log(f"Channel: Error running message callback: {ex}", RNS.LOG_ERROR)
|
RNS.log(f"Channel: Error running message callback: {ex}", RNS.LOG_ERROR)
|
||||||
|
|
||||||
def receive(self, raw: bytes):
|
def _receive(self, raw: bytes):
|
||||||
try:
|
try:
|
||||||
envelope = Envelope(outlet=self._outlet, raw=raw)
|
envelope = Envelope(outlet=self._outlet, raw=raw)
|
||||||
with self._lock:
|
with self._lock:
|
||||||
message = envelope.unpack(self._message_factories)
|
message = envelope.unpack(self._message_factories)
|
||||||
is_new = self.emplace_envelope(envelope, self._rx_ring)
|
is_new = self._emplace_envelope(envelope, self._rx_ring)
|
||||||
self.prune_rx_ring()
|
self._prune_rx_ring()
|
||||||
if not is_new:
|
if not is_new:
|
||||||
RNS.log("Channel: Duplicate message received", RNS.LOG_DEBUG)
|
RNS.log("Channel: Duplicate message received", RNS.LOG_DEBUG)
|
||||||
return
|
return
|
||||||
@ -241,6 +301,10 @@ class Channel(contextlib.AbstractContextManager):
|
|||||||
RNS.log(f"Channel: Error receiving data: {ex}")
|
RNS.log(f"Channel: Error receiving data: {ex}")
|
||||||
|
|
||||||
def is_ready_to_send(self) -> bool:
|
def is_ready_to_send(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if Channel is ready to send.
|
||||||
|
@return: True if ready
|
||||||
|
"""
|
||||||
if not self._outlet.is_usable:
|
if not self._outlet.is_usable:
|
||||||
RNS.log("Channel: Link is not usable.", RNS.LOG_EXTREME)
|
RNS.log("Channel: Link is not usable.", RNS.LOG_EXTREME)
|
||||||
return False
|
return False
|
||||||
@ -273,7 +337,7 @@ class Channel(contextlib.AbstractContextManager):
|
|||||||
def retry_envelope(envelope: Envelope) -> bool:
|
def retry_envelope(envelope: Envelope) -> bool:
|
||||||
if envelope.tries >= self._max_tries:
|
if envelope.tries >= self._max_tries:
|
||||||
RNS.log("Channel: Retry count exceeded, tearing down Link.", RNS.LOG_ERROR)
|
RNS.log("Channel: Retry count exceeded, tearing down Link.", RNS.LOG_ERROR)
|
||||||
self.shutdown() # start on separate thread?
|
self._shutdown() # start on separate thread?
|
||||||
self._outlet.timed_out()
|
self._outlet.timed_out()
|
||||||
return True
|
return True
|
||||||
envelope.tries += 1
|
envelope.tries += 1
|
||||||
@ -283,13 +347,18 @@ class Channel(contextlib.AbstractContextManager):
|
|||||||
self._packet_tx_op(packet, retry_envelope)
|
self._packet_tx_op(packet, retry_envelope)
|
||||||
|
|
||||||
def send(self, message: MessageBase) -> Envelope:
|
def send(self, message: MessageBase) -> Envelope:
|
||||||
|
"""
|
||||||
|
Send a message. If a message send is attempted and
|
||||||
|
Channel is not ready, an exception is thrown.
|
||||||
|
@param message: an instance of a MessageBase subclass to send on the Channel
|
||||||
|
"""
|
||||||
envelope: Envelope | None = None
|
envelope: Envelope | None = None
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if not self.is_ready_to_send():
|
if not self.is_ready_to_send():
|
||||||
raise ChannelException(CEType.ME_LINK_NOT_READY, f"Link is not ready")
|
raise ChannelException(CEType.ME_LINK_NOT_READY, f"Link is not ready")
|
||||||
envelope = Envelope(self._outlet, message=message, sequence=self._next_sequence)
|
envelope = Envelope(self._outlet, message=message, sequence=self._next_sequence)
|
||||||
self._next_sequence = (self._next_sequence + 1) % 0x10000
|
self._next_sequence = (self._next_sequence + 1) % 0x10000
|
||||||
self.emplace_envelope(envelope, self._tx_ring)
|
self._emplace_envelope(envelope, self._tx_ring)
|
||||||
if envelope is None:
|
if envelope is None:
|
||||||
raise BlockingIOError()
|
raise BlockingIOError()
|
||||||
|
|
||||||
@ -304,10 +373,20 @@ class Channel(contextlib.AbstractContextManager):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def MDU(self):
|
def MDU(self):
|
||||||
|
"""
|
||||||
|
Maximum Data Unit: the number of bytes available
|
||||||
|
for a message to consume in a single send.
|
||||||
|
@return: number of bytes available
|
||||||
|
"""
|
||||||
return self._outlet.mdu - 6 # sizeof(msgtype) + sizeof(length) + sizeof(sequence)
|
return self._outlet.mdu - 6 # sizeof(msgtype) + sizeof(length) + sizeof(sequence)
|
||||||
|
|
||||||
|
|
||||||
class LinkChannelOutlet(ChannelOutletBase):
|
class LinkChannelOutlet(ChannelOutletBase):
|
||||||
|
"""
|
||||||
|
An implementation of ChannelOutletBase for RNS.Link.
|
||||||
|
Allows Channel to send packets over an RNS Link with
|
||||||
|
Packets.
|
||||||
|
"""
|
||||||
def __init__(self, link: RNS.Link):
|
def __init__(self, link: RNS.Link):
|
||||||
self.link = link
|
self.link = link
|
||||||
|
|
||||||
|
@ -464,7 +464,7 @@ class Link:
|
|||||||
for resource in self.outgoing_resources:
|
for resource in self.outgoing_resources:
|
||||||
resource.cancel()
|
resource.cancel()
|
||||||
if self._channel:
|
if self._channel:
|
||||||
self._channel.shutdown()
|
self._channel._shutdown()
|
||||||
|
|
||||||
self.prv = None
|
self.prv = None
|
||||||
self.pub = None
|
self.pub = None
|
||||||
@ -801,7 +801,7 @@ class Link:
|
|||||||
RNS.log(f"Channel data received without open channel", RNS.LOG_DEBUG)
|
RNS.log(f"Channel data received without open channel", RNS.LOG_DEBUG)
|
||||||
else:
|
else:
|
||||||
plaintext = self.decrypt(packet.data)
|
plaintext = self.decrypt(packet.data)
|
||||||
self._channel.receive(plaintext)
|
self._channel._receive(plaintext)
|
||||||
packet.prove()
|
packet.prove()
|
||||||
|
|
||||||
elif packet.packet_type == RNS.Packet.PROOF:
|
elif packet.packet_type == RNS.Packet.PROOF:
|
||||||
|
@ -153,7 +153,7 @@ class ProtocolHarness(contextlib.AbstractContextManager):
|
|||||||
self.channel = Channel(self.outlet)
|
self.channel = Channel(self.outlet)
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
self.channel.shutdown()
|
self.channel._shutdown()
|
||||||
|
|
||||||
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
|
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
|
||||||
__traceback: types.TracebackType) -> bool:
|
__traceback: types.TracebackType) -> bool:
|
||||||
@ -282,7 +282,7 @@ class TestChannel(unittest.TestCase):
|
|||||||
self.h.channel.add_message_handler(handler2)
|
self.h.channel.add_message_handler(handler2)
|
||||||
envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=0)
|
envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=0)
|
||||||
raw = envelope.pack()
|
raw = envelope.pack()
|
||||||
self.h.channel.receive(raw)
|
self.h.channel._receive(raw)
|
||||||
|
|
||||||
self.assertEqual(1, handler1_called)
|
self.assertEqual(1, handler1_called)
|
||||||
self.assertEqual(0, handler2_called)
|
self.assertEqual(0, handler2_called)
|
||||||
@ -290,7 +290,7 @@ class TestChannel(unittest.TestCase):
|
|||||||
handler1_return = False
|
handler1_return = False
|
||||||
envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=1)
|
envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=1)
|
||||||
raw = envelope.pack()
|
raw = envelope.pack()
|
||||||
self.h.channel.receive(raw)
|
self.h.channel._receive(raw)
|
||||||
|
|
||||||
self.assertEqual(2, handler1_called)
|
self.assertEqual(2, handler1_called)
|
||||||
self.assertEqual(1, handler2_called)
|
self.assertEqual(1, handler2_called)
|
||||||
@ -348,7 +348,7 @@ class TestChannel(unittest.TestCase):
|
|||||||
self.assertFalse(envelope.tracked)
|
self.assertFalse(envelope.tracked)
|
||||||
self.assertEqual(0, len(decoded))
|
self.assertEqual(0, len(decoded))
|
||||||
|
|
||||||
self.h.channel.receive(packet.raw)
|
self.h.channel._receive(packet.raw)
|
||||||
|
|
||||||
self.assertEqual(1, len(decoded))
|
self.assertEqual(1, len(decoded))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user