diff --git a/dissect/target/loaders/mqtt.py b/dissect/target/loaders/mqtt.py index 5ef744216..cf87f4973 100644 --- a/dissect/target/loaders/mqtt.py +++ b/dissect/target/loaders/mqtt.py @@ -1,13 +1,18 @@ from __future__ import annotations +import atexit import logging +import math +import os import ssl +import sys import time import urllib from dataclasses import dataclass from functools import lru_cache from pathlib import Path from struct import pack, unpack_from +from threading import Thread from typing import Any, Callable, Iterator, Optional, Union import paho.mqtt.client as mqtt @@ -51,6 +56,34 @@ class SeekMessage: data: bytes = b"" +class MQTTTransferRatePerSecond: + def __init__(self, window_size: int = 10): + self.window_size = window_size + self.timestamps = [] + self.bytes = [] + + def record(self, timestamp: float, byte_count: int) -> MQTTTransferRatePerSecond: + while self.timestamps and (timestamp - self.timestamps[0] > self.window_size): + self.timestamps.pop(0) + self.bytes.pop(0) + + self.timestamps.append(timestamp) + self.bytes.append(byte_count) + return self + + def value(self, current_time: float) -> float: + if not self.timestamps: + return 0 + + elapsed_time = current_time - self.timestamps[0] + if elapsed_time == 0: + return 0 + + total_bytes = self.bytes[-1] - self.bytes[0] + + return total_bytes / elapsed_time + + class MQTTStream(AlignedStream): def __init__(self, stream: MQTTConnection, disk_id: int, size: Optional[int] = None): self.stream = stream @@ -62,12 +95,108 @@ def _read(self, offset: int, length: int, optimization_strategy: int = 0) -> byt return data +class MQTTDiagnosticLine: + def __init__(self, connection: MQTTConnection, total_peers: int): + self.connection = connection + self.total_peers = total_peers + self._columns, self._rows = os.get_terminal_size(0) + atexit.register(self._detach) + self._attach() + + def _attach(self) -> None: + # save cursor position + sys.stderr.write("\0337") + # set top and bottom margins of the scrolling region to default + sys.stderr.write("\033[r") + # restore cursor position + sys.stderr.write("\0338") + # move cursor down one line in the same column; if at the bottom, the screen scrolls up + sys.stderr.write("\033D") + # move cursor up one line in the same column; if at the top, screen scrolls down + sys.stderr.write("\033M") + # save cursor position again + sys.stderr.write("\0337") + # restrict scrolling to a region from the first line to one before the last line + sys.stderr.write(f"\033[1;{self._rows - 1}r") + # restore cursor position after setting scrolling region + sys.stderr.write("\0338") + + def _detach(self) -> None: + # save cursor position + sys.stderr.write("\0337") + # move cursor to the specified position (last line, first column) + sys.stderr.write(f"\033[{self._rows};1H") + # clear from cursor to end of the line + sys.stderr.write("\033[K") + # reset scrolling region to include the entire display + sys.stderr.write("\033[r") + # restore cursor position + sys.stderr.write("\0338") + # ensure the written content is displayed (flush output) + sys.stderr.flush() + + def display(self) -> None: + # prepare: set background color to blue and text color to white at the beginning of the line + prefix = "\x1b[44m\x1b[37m\r" + # reset all attributes (colors, styles) to their defaults afterwards + suffix = "\x1b[0m" + # separator to set background color to red and text style to bold + separator = "\x1b[41m\x1b[1m" + logo = "TARGETD" + + start = time.time() + transfer_rate = MQTTTransferRatePerSecond(window_size=7) + + while True: + time.sleep(0.05) + peers = "?" + try: + peers = len(self.connection.broker.peers(self.connection.host)) + except Exception: + pass + + recv = self.connection.broker.bytes_received + now = time.time() + transfer = transfer_rate.record(now, recv).value(now) / 1000 # convert to KB/s + failures = self.connection.retries + seconds_elapsed = round(now - start) % 60 + minutes_elapsed = math.floor((now - start) / 60) % 60 + hours_elapsed = math.floor((now - start) / 60**2) + timer = f"{hours_elapsed:02d}:{minutes_elapsed:02d}:{seconds_elapsed:02d}" + display = f"{timer} {peers}/{self.total_peers} peers {transfer:>8.2f} KB p/s {failures:>4} failures" + rest = self._columns - len(display) + padding = (rest - len(logo)) * " " + + # save cursor position + sys.stderr.write("\0337") + # move cursor to specified position (last line, first column) + sys.stderr.write(f"\033[{self._rows};1H") + # disable line wrapping + sys.stderr.write("\033[?7l") + # reset all attributes + sys.stderr.write("\033[0m") + # write the display line with prefix, calculated display content, padding, separator, and logo + sys.stderr.write(prefix + display + padding + separator + logo + suffix) + # enable line wrapping again + sys.stderr.write("\033[?7h") + # restore cursor position + sys.stderr.write("\0338") + # flush output to ensure it is displayed + sys.stderr.flush() + + def start(self) -> None: + t = Thread(target=self.display) + t.daemon = True + t.start() + + class MQTTConnection: broker = None host = None prev = -1 factor = 1 prefetch_factor_inc = 10 + retries = 0 def __init__(self, broker: Broker, host: str): self.broker = broker @@ -125,6 +254,7 @@ def read(self, disk_id: int, offset: int, length: int, optimization_strategy: in # message might have not reached agent, resend... self.broker.seek(self.host, disk_id, offset, flength, optimization_strategy) attempts = 0 + self.retries += 1 return message.data @@ -138,6 +268,8 @@ class Broker: mqtt_client = None connected = False case = None + bytes_received = 0 + monitor = False diskinfo = {} index = {} @@ -217,6 +349,9 @@ def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.client.MQTTM if casename != self.case: return + if self.monitor: + self.bytes_received += len(msg.payload) + if response == "DISKS": self._on_disk(hostname, msg.payload) elif response == "READ": @@ -238,9 +373,12 @@ def info(self, host: str) -> None: self.mqtt_client.publish(f"{self.case}/{host}/INFO") def topology(self, host: str) -> None: - self.topo[host] = [] + if host not in self.topo: + self.topo[host] = [] self.mqtt_client.subscribe(f"{self.case}/{host}/ID") time.sleep(1) # need some time to avoid race condition, i.e. MQTT might react too fast + # send a simple clear command (invalid, just clears the prev. msg) just in case TOPO is stale + self.mqtt_client.publish(f"{self.case}/{host}/CLR") self.mqtt_client.publish(f"{self.case}/{host}/TOPO") def connect(self) -> None: @@ -272,6 +410,7 @@ def connect(self) -> None: @arg("--mqtt-crt", dest="crt", help="client certificate file") @arg("--mqtt-ca", dest="ca", help="certificate authority file") @arg("--mqtt-command", dest="command", help="direct command to client(s)") +@arg("--mqtt-diag", action="store_true", dest="diag", help="show MQTT diagnostic information") class MQTTLoader(Loader): """Load remote targets through a broker.""" @@ -292,6 +431,7 @@ def detect(path: Path) -> bool: def find_all(path: Path, **kwargs) -> Iterator[str]: cls = MQTTLoader num_peers = 1 + if cls.broker is None: if (uri := kwargs.get("parsed_path")) is None: raise LoaderError("No URI connection details have been passed.") @@ -299,8 +439,13 @@ def find_all(path: Path, **kwargs) -> Iterator[str]: cls.broker = Broker(**options) cls.broker.connect() num_peers = int(options.get("peers", 1)) + cls.connection = MQTTConnection(cls.broker, path) + if options.get("diag", None): + cls.broker.monitor = True + MQTTDiagnosticLine(cls.connection, num_peers).start() + else: + cls.connection = MQTTConnection(cls.broker, path) - cls.connection = MQTTConnection(cls.broker, path) cls.peers = cls.connection.topo(num_peers) yield from cls.peers diff --git a/tests/loaders/test_mqtt.py b/tests/loaders/test_mqtt.py index 7cba68788..15ba108c3 100644 --- a/tests/loaders/test_mqtt.py +++ b/tests/loaders/test_mqtt.py @@ -44,6 +44,8 @@ def publish(self, topic: str, *args) -> None: begin = int(tokens[4], 16) end = int(tokens[5], 16) response.payload = self.disks[int(tokens[3])][begin : begin + end] + else: + return self.on_message(self, None, response)