Initial work on Channel

This commit is contained in:
Aaron Heise 2023-02-25 18:23:25 -06:00
parent b381a61be8
commit 68cb4a6740
No known key found for this signature in database
GPG Key ID: 6BA54088C41DE8BF
5 changed files with 759 additions and 1 deletions

350
RNS/Channel.py Normal file
View File

@ -0,0 +1,350 @@
from __future__ import annotations
import collections
import enum
import threading
import time
from types import TracebackType
from typing import Type, Callable, TypeVar
import abc
import contextlib
import struct
import RNS
from abc import ABC, abstractmethod
_TPacket = TypeVar("_TPacket")
class ChannelOutletBase(ABC):
@abstractmethod
def send(self, raw: bytes) -> _TPacket:
raise NotImplemented()
@abstractmethod
def resend(self, packet: _TPacket) -> _TPacket:
raise NotImplemented()
@property
@abstractmethod
def mdu(self):
raise NotImplemented()
@property
@abstractmethod
def rtt(self):
raise NotImplemented()
@property
@abstractmethod
def is_usable(self):
raise NotImplemented()
@abstractmethod
def get_packet_state(self, packet: _TPacket) -> MessageState:
raise NotImplemented()
@abstractmethod
def timed_out(self):
raise NotImplemented()
@abstractmethod
def __str__(self):
raise NotImplemented()
@abstractmethod
def set_packet_timeout_callback(self, packet: _TPacket, callback: Callable[[_TPacket], None] | None,
timeout: float | None = None):
raise NotImplemented()
@abstractmethod
def set_packet_delivered_callback(self, packet: _TPacket, callback: Callable[[_TPacket], None] | None):
raise NotImplemented()
@abstractmethod
def get_packet_id(self, packet: _TPacket) -> any:
raise NotImplemented()
class CEType(enum.IntEnum):
ME_NO_MSG_TYPE = 0
ME_INVALID_MSG_TYPE = 1
ME_NOT_REGISTERED = 2
ME_LINK_NOT_READY = 3
ME_ALREADY_SENT = 4
ME_TOO_BIG = 5
class ChannelException(Exception):
def __init__(self, ce_type: CEType, *args):
super().__init__(args)
self.type = ce_type
class MessageState(enum.IntEnum):
MSGSTATE_NEW = 0
MSGSTATE_SENT = 1
MSGSTATE_DELIVERED = 2
MSGSTATE_FAILED = 3
class MessageBase(abc.ABC):
MSGTYPE = None
@abstractmethod
def pack(self) -> bytes:
raise NotImplemented()
@abstractmethod
def unpack(self, raw):
raise NotImplemented()
class Envelope:
def unpack(self, message_factories: dict[int, Type]) -> MessageBase:
msgtype, self.sequence, length = struct.unpack(">HHH", self.raw[:6])
raw = self.raw[6:]
ctor = message_factories.get(msgtype, None)
if ctor is None:
raise ChannelException(CEType.ME_NOT_REGISTERED, f"Unable to find constructor for Channel MSGTYPE {hex(msgtype)}")
message = ctor()
message.unpack(raw)
return message
def pack(self) -> bytes:
if self.message.__class__.MSGTYPE is None:
raise ChannelException(CEType.ME_NO_MSG_TYPE, f"{self.message.__class__} lacks MSGTYPE")
data = self.message.pack()
self.raw = struct.pack(">HHH", self.message.MSGTYPE, self.sequence, len(data)) + data
return self.raw
def __init__(self, outlet: ChannelOutletBase, message: MessageBase = None, raw: bytes = None, sequence: int = None):
self.ts = time.time()
self.id = id(self)
self.message = message
self.raw = raw
self.packet: _TPacket = None
self.sequence = sequence
self.outlet = outlet
self.tries = 0
self.tracked = False
class Channel(contextlib.AbstractContextManager):
def __init__(self, outlet: ChannelOutletBase):
self._outlet = outlet
self._lock = threading.RLock()
self._tx_ring: collections.deque[Envelope] = collections.deque()
self._rx_ring: collections.deque[Envelope] = collections.deque()
self._message_callback: Callable[[MessageBase], None] | None = None
self._next_sequence = 0
self._message_factories: dict[int, Type[MessageBase]] = self._get_msg_constructors()
self._max_tries = 5
def __enter__(self) -> Channel:
return self
def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None,
__traceback: TracebackType | None) -> bool | None:
self.shutdown()
return False
@staticmethod
def _get_msg_constructors() -> (int, Type[MessageBase]):
subclass_tuples = []
for subclass in MessageBase.__subclasses__():
with contextlib.suppress(Exception):
subclass() # verify constructor works with no arguments, needed for unpacking
subclass_tuples.append((subclass.MSGTYPE, subclass))
message_factories = dict(subclass_tuples)
return message_factories
def register_message_type(self, message_class: Type[MessageBase]):
if not issubclass(message_class, MessageBase):
raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} is not a subclass of {MessageBase}.")
if message_class.MSGTYPE is None:
raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} has invalid MSGTYPE class attribute.")
try:
message_class()
except Exception as ex:
raise ChannelException(CEType.ME_INVALID_MSG_TYPE,
f"{message_class} raised an exception when constructed with no arguments: {ex}")
self._message_factories[message_class.MSGTYPE] = message_class
def set_message_callback(self, callback: Callable[[MessageBase], None]):
self._message_callback = callback
def shutdown(self):
self.clear_rings()
def clear_rings(self):
with self._lock:
for envelope in self._tx_ring:
if envelope.packet is not None:
self._outlet.set_packet_timeout_callback(envelope.packet, None)
self._outlet.set_packet_delivered_callback(envelope.packet, None)
self._tx_ring.clear()
self._rx_ring.clear()
def emplace_envelope(self, envelope: Envelope, ring: collections.deque[Envelope]) -> bool:
with self._lock:
i = 0
for env in ring:
if env.sequence < envelope.sequence:
ring.insert(i, envelope)
return True
if env.sequence == envelope.sequence:
RNS.log(f"Envelope: Emplacement of duplicate envelope sequence.", RNS.LOG_EXTREME)
return False
i += 1
envelope.tracked = True
ring.append(envelope)
return True
def prune_rx_ring(self):
with self._lock:
# Implementation for fixed window = 1
stale = list(sorted(self._rx_ring, key=lambda env: env.sequence, reverse=True))[1:]
for env in stale:
env.tracked = False
self._rx_ring.remove(env)
def receive(self, raw: bytes):
try:
envelope = Envelope(outlet=self._outlet, raw=raw)
message = envelope.unpack(self._message_factories)
with self._lock:
is_new = self.emplace_envelope(envelope, self._rx_ring)
self.prune_rx_ring()
if not is_new:
RNS.log("Channel: Duplicate message received", RNS.LOG_DEBUG)
return
RNS.log(f"Message received: {message}", RNS.LOG_DEBUG)
if self._message_callback:
threading.Thread(target=self._message_callback, name="Message Callback", args=[message], daemon=True)\
.start()
except Exception as ex:
RNS.log(f"Channel: Error receiving data: {ex}")
def is_ready_to_send(self) -> bool:
if not self._outlet.is_usable:
RNS.log("Channel: Link is not usable.", RNS.LOG_EXTREME)
return False
with self._lock:
for envelope in self._tx_ring:
if envelope.outlet == self._outlet and (not envelope.packet
or self._outlet.get_packet_state(envelope.packet) == MessageState.MSGSTATE_SENT):
RNS.log("Channel: Link has a pending message.", RNS.LOG_EXTREME)
return False
return True
def _packet_tx_op(self, packet: _TPacket, op: Callable[[_TPacket], bool]):
with self._lock:
envelope = next(filter(lambda e: self._outlet.get_packet_id(e.packet) == self._outlet.get_packet_id(packet),
self._tx_ring), None)
if envelope and op(envelope):
envelope.tracked = False
if envelope in self._tx_ring:
self._tx_ring.remove(envelope)
else:
RNS.log("Channel: Envelope not found in TX ring", RNS.LOG_DEBUG)
if not envelope:
RNS.log("Channel: Spurious message received.", RNS.LOG_EXTREME)
def _packet_delivered(self, packet: _TPacket):
self._packet_tx_op(packet, lambda env: True)
def _packet_timeout(self, packet: _TPacket):
def retry_envelope(envelope: Envelope) -> bool:
if envelope.tries >= self._max_tries:
RNS.log("Channel: Retry count exceeded, tearing down Link.", RNS.LOG_ERROR)
self.shutdown() # start on separate thread?
self._outlet.timed_out()
return True
envelope.tries += 1
self._outlet.resend(envelope.packet)
return False
self._packet_tx_op(packet, retry_envelope)
def send(self, message: MessageBase) -> Envelope:
envelope: Envelope | None = None
with self._lock:
if not self.is_ready_to_send():
raise ChannelException(CEType.ME_LINK_NOT_READY, f"Link is not ready")
envelope = Envelope(self._outlet, message=message, sequence=self._next_sequence)
self._next_sequence = (self._next_sequence + 1) % 0x10000
self.emplace_envelope(envelope, self._tx_ring)
if envelope is None:
raise BlockingIOError()
envelope.pack()
if len(envelope.raw) > self._outlet.mdu:
raise ChannelException(CEType.ME_TOO_BIG, f"Packed message too big for packet: {len(envelope.raw)} > {self._outlet.mdu}")
envelope.packet = self._outlet.send(envelope.raw)
envelope.tries += 1
self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered)
self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout)
return envelope
class LinkChannelOutlet(ChannelOutletBase):
def __init__(self, link: RNS.Link):
self.link = link
def send(self, raw: bytes) -> RNS.Packet:
packet = RNS.Packet(self.link, raw, context=RNS.Packet.CHANNEL)
packet.send()
return packet
def resend(self, packet: RNS.Packet) -> RNS.Packet:
if not packet.resend():
RNS.log("Failed to resend packet", RNS.LOG_ERROR)
return packet
@property
def mdu(self):
return self.link.MDU
@property
def rtt(self):
return self.link.rtt
@property
def is_usable(self):
return True # had issues looking at Link.status
def get_packet_state(self, packet: _TPacket) -> MessageState:
status = packet.receipt.get_status()
if status == RNS.PacketReceipt.SENT:
return MessageState.MSGSTATE_SENT
if status == RNS.PacketReceipt.DELIVERED:
return MessageState.MSGSTATE_DELIVERED
if status == RNS.PacketReceipt.FAILED:
return MessageState.MSGSTATE_FAILED
else:
raise Exception(f"Unexpected receipt state: {status}")
def timed_out(self):
self.link.teardown()
def __str__(self):
return f"{self.__class__.__name__}({self.link})"
def set_packet_timeout_callback(self, packet: RNS.Packet, callback: Callable[[RNS.Packet], None] | None,
timeout: float | None = None):
if timeout:
packet.receipt.set_timeout(timeout)
def inner(receipt: RNS.PacketReceipt):
callback(packet)
packet.receipt.set_timeout_callback(inner if callback else None)
def set_packet_delivered_callback(self, packet: RNS.Packet, callback: Callable[[RNS.Packet], None] | None):
def inner(receipt: RNS.PacketReceipt):
callback(packet)
packet.receipt.set_delivery_callback(inner if callback else None)
def get_packet_id(self, packet: RNS.Packet) -> any:
return packet.get_hash()

View File

@ -22,7 +22,7 @@
from RNS.Cryptography import X25519PrivateKey, X25519PublicKey, Ed25519PrivateKey, Ed25519PublicKey from RNS.Cryptography import X25519PrivateKey, X25519PublicKey, Ed25519PrivateKey, Ed25519PublicKey
from RNS.Cryptography import Fernet from RNS.Cryptography import Fernet
from RNS.Channel import Channel, LinkChannelOutlet
from time import sleep from time import sleep
from .vendor import umsgpack as umsgpack from .vendor import umsgpack as umsgpack
import threading import threading
@ -163,6 +163,7 @@ class Link:
self.destination = destination self.destination = destination
self.attached_interface = None self.attached_interface = None
self.__remote_identity = None self.__remote_identity = None
self._channel = None
if self.destination == None: if self.destination == None:
self.initiator = False self.initiator = False
self.prv = X25519PrivateKey.generate() self.prv = X25519PrivateKey.generate()
@ -462,6 +463,8 @@ class Link:
resource.cancel() resource.cancel()
for resource in self.outgoing_resources: for resource in self.outgoing_resources:
resource.cancel() resource.cancel()
if self._channel:
self._channel.shutdown()
self.prv = None self.prv = None
self.pub = None self.pub = None
@ -642,6 +645,27 @@ class Link:
if pending_request.request_id == resource.request_id: if pending_request.request_id == resource.request_id:
pending_request.request_timed_out(None) pending_request.request_timed_out(None)
def _ensure_channel(self):
if self._channel is None:
self._channel = Channel(LinkChannelOutlet(self))
return self._channel
def set_message_callback(self, callback, message_types=None):
if not callback:
if self._channel:
self._channel.set_message_callback(None)
return
self._ensure_channel()
if message_types:
for msg_type in message_types:
self._channel.register_message_type(msg_type)
self._channel.set_message_callback(callback)
def send_message(self, message: RNS.Channel.MessageBase):
self._ensure_channel().send(message)
def receive(self, packet): def receive(self, packet):
self.watchdog_lock = True self.watchdog_lock = True
if not self.status == Link.CLOSED and not (self.initiator and packet.context == RNS.Packet.KEEPALIVE and packet.data == bytes([0xFF])): if not self.status == Link.CLOSED and not (self.initiator and packet.context == RNS.Packet.KEEPALIVE and packet.data == bytes([0xFF])):
@ -788,6 +812,14 @@ class Link:
for resource in self.incoming_resources: for resource in self.incoming_resources:
resource.receive_part(packet) resource.receive_part(packet)
elif packet.context == RNS.Packet.CHANNEL:
if not self._channel:
RNS.log(f"Channel data received without open channel", RNS.LOG_DEBUG)
else:
plaintext = self.decrypt(packet.data)
self._channel.receive(plaintext)
packet.prove()
elif packet.packet_type == RNS.Packet.PROOF: elif packet.packet_type == RNS.Packet.PROOF:
if packet.context == RNS.Packet.RESOURCE_PRF: if packet.context == RNS.Packet.RESOURCE_PRF:
resource_hash = packet.data[0:RNS.Identity.HASHLENGTH//8] resource_hash = packet.data[0:RNS.Identity.HASHLENGTH//8]

View File

@ -75,6 +75,7 @@ class Packet:
PATH_RESPONSE = 0x0B # Packet is a response to a path request PATH_RESPONSE = 0x0B # Packet is a response to a path request
COMMAND = 0x0C # Packet is a command COMMAND = 0x0C # Packet is a command
COMMAND_STATUS = 0x0D # Packet is a status of an executed command COMMAND_STATUS = 0x0D # Packet is a status of an executed command
CHANNEL = 0x0E # Packet contains link channel data
KEEPALIVE = 0xFA # Packet is a keepalive packet KEEPALIVE = 0xFA # Packet is a keepalive packet
LINKIDENTIFY = 0xFB # Packet is a link peer identification proof LINKIDENTIFY = 0xFB # Packet is a link peer identification proof
LINKCLOSE = 0xFC # Packet is a link close message LINKCLOSE = 0xFC # Packet is a link close message

316
tests/channel.py Normal file
View File

@ -0,0 +1,316 @@
from __future__ import annotations
import threading
import RNS
from RNS.Channel import MessageState, ChannelOutletBase, Channel, MessageBase
from RNS.vendor import umsgpack
from typing import Callable
import contextlib
import typing
import types
import time
import uuid
import unittest
class Packet:
timeout = 1.0
def __init__(self, raw: bytes):
self.state = MessageState.MSGSTATE_NEW
self.raw = raw
self.packet_id = uuid.uuid4()
self.tries = 0
self.timeout_id = None
self.lock = threading.RLock()
self.instances = 0
self.timeout_callback: Callable[[Packet], None] | None = None
self.delivered_callback: Callable[[Packet], None] | None = None
def set_timeout(self, callback: Callable[[Packet], None] | None, timeout: float):
with self.lock:
if timeout is not None:
self.timeout = timeout
self.timeout_callback = callback
def send(self):
self.tries += 1
self.state = MessageState.MSGSTATE_SENT
def elapsed(timeout: float, timeout_id: uuid.uuid4):
with self.lock:
self.instances += 1
try:
time.sleep(timeout)
with self.lock:
if self.timeout_id == timeout_id:
self.timeout_id = None
self.state = MessageState.MSGSTATE_FAILED
if self.timeout_callback:
self.timeout_callback(self)
finally:
with self.lock:
self.instances -= 1
self.timeout_id = uuid.uuid4()
threading.Thread(target=elapsed, name="Packet Timeout", args=[self.timeout, self.timeout_id],
daemon=True).start()
def clear_timeout(self):
self.timeout_id = None
def set_delivered_callback(self, callback: Callable[[Packet], None]):
self.delivered_callback = callback
def delivered(self):
with self.lock:
self.state = MessageState.MSGSTATE_DELIVERED
self.timeout_id = None
if self.delivered_callback:
self.delivered_callback(self)
class ChannelOutletTest(ChannelOutletBase):
def get_packet_state(self, packet: Packet) -> MessageState:
return packet.state
def set_packet_timeout_callback(self, packet: Packet, callback: Callable[[Packet], None] | None,
timeout: float | None = None):
packet.set_timeout(callback, timeout)
def set_packet_delivered_callback(self, packet: Packet, callback: Callable[[Packet], None] | None):
packet.set_delivered_callback(callback)
def get_packet_id(self, packet: Packet) -> any:
return packet.packet_id
def __init__(self, mdu: int, rtt: float):
self.link_id = uuid.uuid4()
self.timeout_callbacks = 0
self._mdu = mdu
self._rtt = rtt
self._usable = True
self.packets = []
self.packet_callback: Callable[[ChannelOutletBase, bytes], None] | None = None
def send(self, raw: bytes) -> Packet:
packet = Packet(raw)
packet.send()
self.packets.append(packet)
return packet
def resend(self, packet: Packet) -> Packet:
packet.send()
return packet
@property
def mdu(self):
return self._mdu
@property
def rtt(self):
return self._rtt
@property
def is_usable(self):
return self._usable
def timed_out(self):
self.timeout_callbacks += 1
def __str__(self):
return str(self.link_id)
class MessageTest(MessageBase):
MSGTYPE = 0xabcd
def __init__(self):
self.id = str(uuid.uuid4())
self.data = "test"
self.not_serialized = str(uuid.uuid4())
def pack(self) -> bytes:
return umsgpack.packb((self.id, self.data))
def unpack(self, raw):
self.id, self.data = umsgpack.unpackb(raw)
class ProtocolHarness(contextlib.AbstractContextManager):
def __init__(self, rtt: float):
self.outlet = ChannelOutletTest(mdu=500, rtt=rtt)
self.channel = Channel(self.outlet)
def cleanup(self):
self.channel.shutdown()
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
__traceback: types.TracebackType) -> bool:
# self._log.debug(f"__exit__({__exc_type}, {__exc_value}, {__traceback})")
self.cleanup()
return False
class TestChannel(unittest.TestCase):
def setUp(self) -> None:
self.rtt = 0.001
self.retry_interval = self.rtt * 150
Packet.timeout = self.retry_interval
self.h = ProtocolHarness(self.rtt)
def tearDown(self) -> None:
self.h.cleanup()
def test_send_one_retry(self):
message = MessageTest()
self.assertEqual(0, len(self.h.outlet.packets))
envelope = self.h.channel.send(message)
self.assertIsNotNone(envelope)
self.assertIsNotNone(envelope.raw)
self.assertEqual(1, len(self.h.outlet.packets))
self.assertIsNotNone(envelope.packet)
self.assertTrue(envelope in self.h.channel._tx_ring)
self.assertTrue(envelope.tracked)
packet = self.h.outlet.packets[0]
self.assertEqual(envelope.packet, packet)
self.assertEqual(1, envelope.tries)
self.assertEqual(1, packet.tries)
self.assertEqual(1, packet.instances)
self.assertEqual(MessageState.MSGSTATE_SENT, packet.state)
self.assertEqual(envelope.raw, packet.raw)
time.sleep(self.retry_interval * 1.5)
self.assertEqual(1, len(self.h.outlet.packets))
self.assertEqual(2, envelope.tries)
self.assertEqual(2, packet.tries)
self.assertEqual(1, packet.instances)
time.sleep(self.retry_interval)
self.assertEqual(1, len(self.h.outlet.packets))
self.assertEqual(self.h.outlet.packets[0], packet)
self.assertEqual(3, envelope.tries)
self.assertEqual(3, packet.tries)
self.assertEqual(1, packet.instances)
self.assertEqual(MessageState.MSGSTATE_SENT, packet.state)
packet.delivered()
self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state)
time.sleep(self.retry_interval)
self.assertEqual(1, len(self.h.outlet.packets))
self.assertEqual(3, envelope.tries)
self.assertEqual(3, packet.tries)
self.assertEqual(0, packet.instances)
self.assertFalse(envelope.tracked)
def test_send_timeout(self):
message = MessageTest()
self.assertEqual(0, len(self.h.outlet.packets))
envelope = self.h.channel.send(message)
self.assertIsNotNone(envelope)
self.assertIsNotNone(envelope.raw)
self.assertEqual(1, len(self.h.outlet.packets))
self.assertIsNotNone(envelope.packet)
self.assertTrue(envelope in self.h.channel._tx_ring)
self.assertTrue(envelope.tracked)
packet = self.h.outlet.packets[0]
self.assertEqual(envelope.packet, packet)
self.assertEqual(1, envelope.tries)
self.assertEqual(1, packet.tries)
self.assertEqual(1, packet.instances)
self.assertEqual(MessageState.MSGSTATE_SENT, packet.state)
self.assertEqual(envelope.raw, packet.raw)
time.sleep(self.retry_interval * 7.5)
self.assertEqual(1, len(self.h.outlet.packets))
self.assertEqual(5, envelope.tries)
self.assertEqual(5, packet.tries)
self.assertEqual(0, packet.instances)
self.assertEqual(MessageState.MSGSTATE_FAILED, packet.state)
self.assertFalse(envelope.tracked)
def eat_own_dog_food(self, message: MessageBase, checker: typing.Callable[[MessageBase], None]):
decoded: [MessageBase] = []
def handle_message(message: MessageBase):
decoded.append(message)
self.h.channel.set_message_callback(handle_message)
self.assertEqual(len(self.h.outlet.packets), 0)
envelope = self.h.channel.send(message)
time.sleep(self.retry_interval * 0.5)
self.assertIsNotNone(envelope)
self.assertIsNotNone(envelope.raw)
self.assertEqual(1, len(self.h.outlet.packets))
self.assertIsNotNone(envelope.packet)
self.assertTrue(envelope in self.h.channel._tx_ring)
self.assertTrue(envelope.tracked)
packet = self.h.outlet.packets[0]
self.assertEqual(envelope.packet, packet)
self.assertEqual(1, envelope.tries)
self.assertEqual(1, packet.tries)
self.assertEqual(1, packet.instances)
self.assertEqual(MessageState.MSGSTATE_SENT, packet.state)
self.assertEqual(envelope.raw, packet.raw)
packet.delivered()
self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state)
time.sleep(self.retry_interval * 2)
self.assertEqual(1, len(self.h.outlet.packets))
self.assertEqual(1, envelope.tries)
self.assertEqual(1, packet.tries)
self.assertEqual(0, packet.instances)
self.assertFalse(envelope.tracked)
self.assertEqual(len(self.h.outlet.packets), 1)
self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state)
self.assertFalse(envelope.tracked)
self.assertEqual(0, len(decoded))
self.h.channel.receive(packet.raw)
self.assertEqual(1, len(decoded))
rx_message = decoded[0]
self.assertIsNotNone(rx_message)
self.assertIsInstance(rx_message, message.__class__)
checker(rx_message)
def test_send_receive_message_test(self):
message = MessageTest()
def check(rx_message: MessageBase):
self.assertIsInstance(rx_message, message.__class__)
self.assertEqual(message.id, rx_message.id)
self.assertEqual(message.data, rx_message.data)
self.assertNotEqual(message.not_serialized, rx_message.not_serialized)
self.eat_own_dog_food(message, check)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@ -6,6 +6,9 @@ import threading
import time import time
import RNS import RNS
import os import os
from tests.channel import MessageTest
from RNS.Channel import MessageBase
APP_NAME = "rns_unit_tests" APP_NAME = "rns_unit_tests"
@ -46,6 +49,11 @@ def close_rns():
global c_rns global c_rns
if c_rns != None: if c_rns != None:
c_rns.m_proc.kill() c_rns.m_proc.kill()
# stdout, stderr = c_rns.m_proc.communicate()
# if stdout:
# print(stdout.decode("utf-8"))
# if stderr:
# print(stderr.decode("utf-8"))
class TestLink(unittest.TestCase): class TestLink(unittest.TestCase):
def setUp(self): def setUp(self):
@ -346,6 +354,52 @@ class TestLink(unittest.TestCase):
time.sleep(0.5) time.sleep(0.5)
self.assertEqual(l1.status, RNS.Link.CLOSED) self.assertEqual(l1.status, RNS.Link.CLOSED)
def test_10_channel_round_trip(self):
global c_rns
init_rns(self)
print("")
print("Channel round trip test")
# TODO: Load this from public bytes only
id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0]))
self.assertEqual(id1.hash, bytes.fromhex(fixed_keys[0][1]))
dest = RNS.Destination(id1, RNS.Destination.OUT, RNS.Destination.SINGLE, APP_NAME, "link", "establish")
self.assertEqual(dest.hash, bytes.fromhex("fb48da0e82e6e01ba0c014513f74540d"))
l1 = RNS.Link(dest)
time.sleep(1)
self.assertEqual(l1.status, RNS.Link.ACTIVE)
received = []
def handle_message(message: MessageBase):
received.append(message)
test_message = MessageTest()
test_message.data = "Hello"
l1.set_message_callback(handle_message)
l1.send_message(test_message)
time.sleep(0.5)
self.assertEqual(1, len(received))
rx_message = received[0]
self.assertIsInstance(rx_message, MessageTest)
self.assertEqual("Hello back", rx_message.data)
self.assertEqual(test_message.id, rx_message.id)
self.assertNotEqual(test_message.not_serialized, rx_message.not_serialized)
self.assertEqual(1, len(l1._channel._rx_ring))
l1.teardown()
time.sleep(0.5)
self.assertEqual(l1.status, RNS.Link.CLOSED)
self.assertEqual(0, len(l1._channel._rx_ring))
def size_str(self, num, suffix='B'): def size_str(self, num, suffix='B'):
units = ['','K','M','G','T','P','E','Z'] units = ['','K','M','G','T','P','E','Z']
@ -405,6 +459,11 @@ def targets(yp=False):
link.set_resource_started_callback(resource_started) link.set_resource_started_callback(resource_started)
link.set_resource_concluded_callback(resource_concluded) link.set_resource_concluded_callback(resource_concluded)
def handle_message(message):
message.data = message.data + " back"
link.send_message(message)
link.set_message_callback(handle_message, [MessageTest])
m_rns = RNS.Reticulum("./tests/rnsconfig") m_rns = RNS.Reticulum("./tests/rnsconfig")
id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0])) id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0]))
d1 = RNS.Destination(id1, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, "link", "establish") d1 = RNS.Destination(id1, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, "link", "establish")