Allow channel message handlers to short circuit
- a message handler can return logical True to prevent subsequent message handlers from running
This commit is contained in:
		
							parent
							
								
									a61b15cf6a
								
							
						
					
					
						commit
						e005826151
					
				| @ -152,7 +152,7 @@ def client_connected(link): | ||||
|     # Register message types and add callback to channel | ||||
|     channel = link.get_channel() | ||||
|     channel.register_message_type(StringMessage) | ||||
|     channel.add_message_callback(server_message_received) | ||||
|     channel.add_message_handler(server_message_received) | ||||
| 
 | ||||
| def client_disconnected(link): | ||||
|     RNS.log("Client disconnected") | ||||
| @ -290,7 +290,7 @@ def link_established(link): | ||||
|     # Register messages and add handler to channel | ||||
|     channel = link.get_channel() | ||||
|     channel.register_message_type(StringMessage) | ||||
|     channel.add_message_callback(client_message_received) | ||||
|     channel.add_message_handler(client_message_received) | ||||
| 
 | ||||
|     # Inform the user that the server is | ||||
|     # connected | ||||
|  | ||||
| @ -150,28 +150,34 @@ class Channel(contextlib.AbstractContextManager): | ||||
|         return False | ||||
| 
 | ||||
|     def register_message_type(self, message_class: Type[MessageBase]): | ||||
|         if not issubclass(message_class, MessageBase): | ||||
|             raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} is not a subclass of {MessageBase}.") | ||||
|         if message_class.MSGTYPE is None: | ||||
|             raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} has invalid MSGTYPE class attribute.") | ||||
|         try: | ||||
|             message_class() | ||||
|         except Exception as ex: | ||||
|             raise ChannelException(CEType.ME_INVALID_MSG_TYPE, | ||||
|                                      f"{message_class} raised an exception when constructed with no arguments: {ex}") | ||||
|         with self._lock: | ||||
|             if not issubclass(message_class, MessageBase): | ||||
|                 raise ChannelException(CEType.ME_INVALID_MSG_TYPE, | ||||
|                                        f"{message_class} is not a subclass of {MessageBase}.") | ||||
|             if message_class.MSGTYPE is None: | ||||
|                 raise ChannelException(CEType.ME_INVALID_MSG_TYPE, | ||||
|                                        f"{message_class} has invalid MSGTYPE class attribute.") | ||||
|             try: | ||||
|                 message_class() | ||||
|             except Exception as ex: | ||||
|                 raise ChannelException(CEType.ME_INVALID_MSG_TYPE, | ||||
|                                        f"{message_class} raised an exception when constructed with no arguments: {ex}") | ||||
| 
 | ||||
|         self._message_factories[message_class.MSGTYPE] = message_class | ||||
|             self._message_factories[message_class.MSGTYPE] = message_class | ||||
| 
 | ||||
|     def add_message_callback(self, callback: MessageCallbackType): | ||||
|         if callback not in self._message_callbacks: | ||||
|             self._message_callbacks.append(callback) | ||||
|     def add_message_handler(self, callback: MessageCallbackType): | ||||
|         with self._lock: | ||||
|             if callback not in self._message_callbacks: | ||||
|                 self._message_callbacks.append(callback) | ||||
| 
 | ||||
|     def remove_message_callback(self, callback: MessageCallbackType): | ||||
|         self._message_callbacks.remove(callback) | ||||
|     def remove_message_handler(self, callback: MessageCallbackType): | ||||
|         with self._lock: | ||||
|             self._message_callbacks.remove(callback) | ||||
| 
 | ||||
|     def shutdown(self): | ||||
|         self._message_callbacks.clear() | ||||
|         self.clear_rings() | ||||
|         with self._lock: | ||||
|             self._message_callbacks.clear() | ||||
|             self.clear_rings() | ||||
| 
 | ||||
|     def clear_rings(self): | ||||
|         with self._lock: | ||||
| @ -205,19 +211,29 @@ class Channel(contextlib.AbstractContextManager): | ||||
|                 env.tracked = False | ||||
|                 self._rx_ring.remove(env) | ||||
| 
 | ||||
|     def _run_callbacks(self, message: MessageBase): | ||||
|         with self._lock: | ||||
|             cbs = self._message_callbacks.copy() | ||||
| 
 | ||||
|         for cb in cbs: | ||||
|             try: | ||||
|                 if cb(message): | ||||
|                     return | ||||
|             except Exception as ex: | ||||
|                 RNS.log(f"Channel: Error running message callback: {ex}", RNS.LOG_ERROR) | ||||
| 
 | ||||
|     def receive(self, raw: bytes): | ||||
|         try: | ||||
|             envelope = Envelope(outlet=self._outlet, raw=raw) | ||||
|             message = envelope.unpack(self._message_factories) | ||||
|             with self._lock: | ||||
|                 message = envelope.unpack(self._message_factories) | ||||
|                 is_new = self.emplace_envelope(envelope, self._rx_ring) | ||||
|                 self.prune_rx_ring() | ||||
|             if not is_new: | ||||
|                 RNS.log("Channel: Duplicate message received", RNS.LOG_DEBUG) | ||||
|                 return | ||||
|             RNS.log(f"Message received: {message}", RNS.LOG_DEBUG) | ||||
|             for cb in self._message_callbacks: | ||||
|                 threading.Thread(target=cb, name="Message Callback", args=[message], daemon=True).start() | ||||
|             threading.Thread(target=self._run_callbacks, name="Message Callback", args=[message], daemon=True).start() | ||||
|         except Exception as ex: | ||||
|             RNS.log(f"Channel: Error receiving data: {ex}") | ||||
| 
 | ||||
|  | ||||
| @ -245,13 +245,49 @@ class TestChannel(unittest.TestCase): | ||||
|         self.assertEqual(MessageState.MSGSTATE_FAILED, packet.state) | ||||
|         self.assertFalse(envelope.tracked) | ||||
| 
 | ||||
|     def test_multiple_handler(self): | ||||
|         handler1_called = 0 | ||||
|         handler1_return = True | ||||
|         handler2_called = 0 | ||||
| 
 | ||||
|         def handler1(msg: MessageBase): | ||||
|             nonlocal handler1_called, handler1_return | ||||
|             self.assertIsInstance(msg, MessageTest) | ||||
|             handler1_called += 1 | ||||
|             return handler1_return | ||||
| 
 | ||||
|         def handler2(msg: MessageBase): | ||||
|             nonlocal handler2_called | ||||
|             self.assertIsInstance(msg, MessageTest) | ||||
|             handler2_called += 1 | ||||
| 
 | ||||
|         message = MessageTest() | ||||
|         self.h.channel.register_message_type(MessageTest) | ||||
|         self.h.channel.add_message_handler(handler1) | ||||
|         self.h.channel.add_message_handler(handler2) | ||||
|         envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=0) | ||||
|         raw = envelope.pack() | ||||
|         self.h.channel.receive(raw) | ||||
| 
 | ||||
|         self.assertEqual(1, handler1_called) | ||||
|         self.assertEqual(0, handler2_called) | ||||
| 
 | ||||
|         handler1_return = False | ||||
|         envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=1) | ||||
|         raw = envelope.pack() | ||||
|         self.h.channel.receive(raw) | ||||
| 
 | ||||
|         self.assertEqual(2, handler1_called) | ||||
|         self.assertEqual(1, handler2_called) | ||||
| 
 | ||||
| 
 | ||||
|     def eat_own_dog_food(self, message: MessageBase, checker: typing.Callable[[MessageBase], None]): | ||||
|         decoded: [MessageBase] = [] | ||||
| 
 | ||||
|         def handle_message(message: MessageBase): | ||||
|             decoded.append(message) | ||||
| 
 | ||||
|         self.h.channel.add_message_callback(handle_message) | ||||
|         self.h.channel.add_message_handler(handle_message) | ||||
|         self.assertEqual(len(self.h.outlet.packets), 0) | ||||
| 
 | ||||
|         envelope = self.h.channel.send(message) | ||||
|  | ||||
| @ -382,7 +382,7 @@ class TestLink(unittest.TestCase): | ||||
| 
 | ||||
|         channel = l1.get_channel() | ||||
|         channel.register_message_type(MessageTest) | ||||
|         channel.add_message_callback(handle_message) | ||||
|         channel.add_message_handler(handle_message) | ||||
|         channel.send(test_message) | ||||
| 
 | ||||
|         time.sleep(0.5) | ||||
| @ -466,7 +466,7 @@ def targets(yp=False): | ||||
|             message.data = message.data + " back" | ||||
|             channel.send(message) | ||||
|         channel.register_message_type(MessageTest) | ||||
|         channel.add_message_callback(handle_message) | ||||
|         channel.add_message_handler(handle_message) | ||||
| 
 | ||||
|     m_rns = RNS.Reticulum("./tests/rnsconfig") | ||||
|     id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0])) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user