diff --git a/tests/service_mocks/test_telnet_service_mock.py b/tests/service_mocks/test_telnet_service_mock.py index f84709d..72bd98e 100644 --- a/tests/service_mocks/test_telnet_service_mock.py +++ b/tests/service_mocks/test_telnet_service_mock.py @@ -109,3 +109,22 @@ def test_telnet_service_mock_add_credentials(): assert match_object, foo.decode() tn.close() + + +def test_telnet_service_mock_add_banner(): + with TelnetServiceMock("127.0.0.1", 8023, + scenario=TelnetScenario.GENERIC) as target: + banner = b"Scoobeedoobeedoo where are you?" + target.add_banner(banner) + + assert target.host == "127.0.0.1" + assert target.port == 8023 + + tn = Telnet(target.host, target.port, timeout=1.0) + + _, match_object, _ = tn.expect([banner], 1.0) + assert match_object + + _, match_object, foo = tn.expect([b"Login: ", b"login: "], 1.0) + assert match_object, foo.decode() + tn.close() diff --git a/threat9_test_bed/service_mocks/telnet_service_mock.py b/threat9_test_bed/service_mocks/telnet_service_mock.py index d004bcd..122a5c7 100644 --- a/threat9_test_bed/service_mocks/telnet_service_mock.py +++ b/threat9_test_bed/service_mocks/telnet_service_mock.py @@ -36,4 +36,9 @@ def get_command_mock(self, command: str): return command_mock def add_credentials(self, login: str, password: str): + """ Add custom credentials pair. """ self.protocol.add_credentials(login, password) + + def add_banner(self, banner: bytes): + """ Add welcoming banner after connection. """ + self.protocol.add_banner(banner) diff --git a/threat9_test_bed/telnet_service/protocol.py b/threat9_test_bed/telnet_service/protocol.py index bdd4afa..d34f47d 100644 --- a/threat9_test_bed/telnet_service/protocol.py +++ b/threat9_test_bed/telnet_service/protocol.py @@ -50,6 +50,7 @@ def __init__(self, scenario: TelnetScenario): self.password = None self.authorized = False + self.banner = b"" self._command_mocks = {} self._creds = [ ("admin", "admin"), @@ -73,6 +74,8 @@ def connection_made(self, transport: asyncio.Transport): self.remote_address = transport.get_extra_info("peername") logger.debug(f"Connection from {self.remote_address}") self.transport = transport + if self.banner: + self.transport.write(self.banner + b"\r\n") self.transport.write(b"Login: ") @authorized @@ -91,3 +94,6 @@ def add_command_handler(self, command: str, handler: typing.Callable): def add_credentials(self, login: str, password: str): self._creds.append((login, password)) + + def add_banner(self, banner: bytes): + self.banner = banner