From 1250e2bd0a094955d3fbc740912a3d6ac9489df8 Mon Sep 17 00:00:00 2001 From: Pierre Fersing Date: Sun, 14 Jan 2024 19:44:36 +0100 Subject: [PATCH] Add tests for reasoncode comparison --- src/paho/mqtt/reasoncodes.py | 11 ++++++++++- tests/test_reasoncodes.py | 26 ++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 tests/test_reasoncodes.py diff --git a/src/paho/mqtt/reasoncodes.py b/src/paho/mqtt/reasoncodes.py index 93093636..1f95dd27 100644 --- a/src/paho/mqtt/reasoncodes.py +++ b/src/paho/mqtt/reasoncodes.py @@ -16,10 +16,12 @@ ******************************************************************* """ +import functools from .packettypes import PacketTypes +@functools.total_ordering class ReasonCodes: """MQTT version 5.0 reason codes class. @@ -173,11 +175,18 @@ def __eq__(self, other): if isinstance(other, int): return self.value == other if isinstance(other, str): - return self.value == str(self) + return other == str(self) if isinstance(other, ReasonCodes): return self.value == other.value return False + def __lt__(self, other): + if isinstance(other, int): + return self.value < other + if isinstance(other, ReasonCodes): + return self.value < other.value + return NotImplemented + def __str__(self): return self.getName() diff --git a/tests/test_reasoncodes.py b/tests/test_reasoncodes.py new file mode 100644 index 00000000..1e620787 --- /dev/null +++ b/tests/test_reasoncodes.py @@ -0,0 +1,26 @@ +from paho.mqtt.packettypes import PacketTypes +from paho.mqtt.reasoncodes import ReasonCodes + + +class TestReasonCode: + def test_equality(self): + rc_success = ReasonCodes(PacketTypes.CONNACK, "Success") + assert rc_success == 0 + assert rc_success == "Success" + assert rc_success != "Protocol error" + assert rc_success == ReasonCodes(PacketTypes.CONNACK, "Success") + + rc_protocol_error = ReasonCodes(PacketTypes.CONNACK, "Protocol error") + assert rc_protocol_error == 130 + assert rc_protocol_error == "Protocol error" + assert rc_protocol_error != "Success" + assert rc_protocol_error == ReasonCodes(PacketTypes.CONNACK, "Protocol error") + + def test_comparison(self): + rc_success = ReasonCodes(PacketTypes.CONNACK, "Success") + rc_protocol_error = ReasonCodes(PacketTypes.CONNACK, "Protocol error") + + assert not rc_success > 0 + assert rc_protocol_error > 0 + assert not rc_success != 0 + assert rc_protocol_error != 0