Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add monitoring option to MQTT Loader #709

Merged
merged 8 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 147 additions & 2 deletions dissect/target/loaders/mqtt.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from __future__ import annotations

import atexit
import logging
import math
import os
import ssl
import sys
import time
import urllib
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from struct import pack, unpack_from
from threading import Thread
from typing import Any, Callable, Iterator, Optional, Union

import paho.mqtt.client as mqtt
Expand Down Expand Up @@ -51,6 +56,34 @@ class SeekMessage:
data: bytes = b""


class MQTTTransferRatePerSecond:
def __init__(self, window_size: int = 10):
self.window_size = window_size
self.timestamps = []
self.bytes = []

def record(self, timestamp: float, byte_count: int) -> MQTTTransferRatePerSecond:
while self.timestamps and (timestamp - self.timestamps[0] > self.window_size):
self.timestamps.pop(0)
self.bytes.pop(0)

self.timestamps.append(timestamp)
self.bytes.append(byte_count)
return self

def value(self, current_time: float) -> float:
if not self.timestamps:
return 0

elapsed_time = current_time - self.timestamps[0]
if elapsed_time == 0:
return 0

total_bytes = self.bytes[-1] - self.bytes[0]

return total_bytes / elapsed_time


class MQTTStream(AlignedStream):
def __init__(self, stream: MQTTConnection, disk_id: int, size: Optional[int] = None):
self.stream = stream
Expand All @@ -62,12 +95,108 @@ def _read(self, offset: int, length: int, optimization_strategy: int = 0) -> byt
return data


class MQTTDiagnosticLine:
def __init__(self, connection: MQTTConnection, total_peers: int):
self.connection = connection
self.total_peers = total_peers
self._columns, self._rows = os.get_terminal_size(0)
atexit.register(self._detach)
self._attach()

def _attach(self) -> None:
# save cursor position
sys.stderr.write("\0337")
# set top and bottom margins of the scrolling region to default
sys.stderr.write("\033[r")
# restore cursor position
sys.stderr.write("\0338")
# move cursor down one line in the same column; if at the bottom, the screen scrolls up
sys.stderr.write("\033D")
# move cursor up one line in the same column; if at the top, screen scrolls down
sys.stderr.write("\033M")
# save cursor position again
sys.stderr.write("\0337")
# restrict scrolling to a region from the first line to one before the last line
sys.stderr.write(f"\033[1;{self._rows - 1}r")
# restore cursor position after setting scrolling region
sys.stderr.write("\0338")

def _detach(self) -> None:
# save cursor position
sys.stderr.write("\0337")
# move cursor to the specified position (last line, first column)
sys.stderr.write(f"\033[{self._rows};1H")
# clear from cursor to end of the line
sys.stderr.write("\033[K")
# reset scrolling region to include the entire display
sys.stderr.write("\033[r")
# restore cursor position
sys.stderr.write("\0338")
# ensure the written content is displayed (flush output)
sys.stderr.flush()

def display(self) -> None:
# prepare: set background color to blue and text color to white at the beginning of the line
prefix = "\x1b[44m\x1b[37m\r"
# reset all attributes (colors, styles) to their defaults afterwards
suffix = "\x1b[0m"
# separator to set background color to red and text style to bold
separator = "\x1b[41m\x1b[1m"
logo = "TARGETD"

start = time.time()
transfer_rate = MQTTTransferRatePerSecond(window_size=7)

while True:
time.sleep(0.05)
peers = "?"
try:
peers = len(self.connection.broker.peers(self.connection.host))
except Exception:
pass

recv = self.connection.broker.bytes_received
now = time.time()
transfer = transfer_rate.record(now, recv).value(now) / 1000 # convert to KB/s
failures = self.connection.retries
seconds_elapsed = round(now - start) % 60
minutes_elapsed = math.floor((now - start) / 60) % 60
hours_elapsed = math.floor((now - start) / 60**2)
timer = f"{hours_elapsed:02d}:{minutes_elapsed:02d}:{seconds_elapsed:02d}"
display = f"{timer} {peers}/{self.total_peers} peers {transfer:>8.2f} KB p/s {failures:>4} failures"
rest = self._columns - len(display)
padding = (rest - len(logo)) * " "

# save cursor position
sys.stderr.write("\0337")
# move cursor to specified position (last line, first column)
sys.stderr.write(f"\033[{self._rows};1H")
# disable line wrapping
sys.stderr.write("\033[?7l")
# reset all attributes
sys.stderr.write("\033[0m")
# write the display line with prefix, calculated display content, padding, separator, and logo
sys.stderr.write(prefix + display + padding + separator + logo + suffix)
# enable line wrapping again
sys.stderr.write("\033[?7h")
# restore cursor position
sys.stderr.write("\0338")
# flush output to ensure it is displayed
sys.stderr.flush()

def start(self) -> None:
t = Thread(target=self.display)
t.daemon = True
t.start()


class MQTTConnection:
broker = None
host = None
prev = -1
factor = 1
prefetch_factor_inc = 10
retries = 0

def __init__(self, broker: Broker, host: str):
self.broker = broker
Expand Down Expand Up @@ -125,6 +254,7 @@ def read(self, disk_id: int, offset: int, length: int, optimization_strategy: in
# message might have not reached agent, resend...
self.broker.seek(self.host, disk_id, offset, flength, optimization_strategy)
attempts = 0
self.retries += 1

return message.data

Expand All @@ -138,6 +268,8 @@ class Broker:
mqtt_client = None
connected = False
case = None
bytes_received = 0
monitor = False

diskinfo = {}
index = {}
Expand Down Expand Up @@ -217,6 +349,9 @@ def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.client.MQTTM
if casename != self.case:
return

if self.monitor:
self.bytes_received += len(msg.payload)

if response == "DISKS":
self._on_disk(hostname, msg.payload)
elif response == "READ":
Expand All @@ -238,9 +373,12 @@ def info(self, host: str) -> None:
self.mqtt_client.publish(f"{self.case}/{host}/INFO")

def topology(self, host: str) -> None:
self.topo[host] = []
if host not in self.topo:
self.topo[host] = []
self.mqtt_client.subscribe(f"{self.case}/{host}/ID")
time.sleep(1) # need some time to avoid race condition, i.e. MQTT might react too fast
# send a simple clear command (invalid, just clears the prev. msg) just in case TOPO is stale
self.mqtt_client.publish(f"{self.case}/{host}/CLR")
self.mqtt_client.publish(f"{self.case}/{host}/TOPO")

def connect(self) -> None:
Expand Down Expand Up @@ -272,6 +410,7 @@ def connect(self) -> None:
@arg("--mqtt-crt", dest="crt", help="client certificate file")
@arg("--mqtt-ca", dest="ca", help="certificate authority file")
@arg("--mqtt-command", dest="command", help="direct command to client(s)")
@arg("--mqtt-diag", action="store_true", dest="diag", help="show MQTT diagnostic information")
class MQTTLoader(Loader):
"""Load remote targets through a broker."""

Expand All @@ -292,15 +431,21 @@ def detect(path: Path) -> bool:
def find_all(path: Path, **kwargs) -> Iterator[str]:
cls = MQTTLoader
num_peers = 1

if cls.broker is None:
if (uri := kwargs.get("parsed_path")) is None:
raise LoaderError("No URI connection details have been passed.")
options = dict(urllib.parse.parse_qsl(uri.query, keep_blank_values=True))
cls.broker = Broker(**options)
cls.broker.connect()
num_peers = int(options.get("peers", 1))
cls.connection = MQTTConnection(cls.broker, path)
if options.get("diag", None):
cls.broker.monitor = True
MQTTDiagnosticLine(cls.connection, num_peers).start()
else:
cls.connection = MQTTConnection(cls.broker, path)

cls.connection = MQTTConnection(cls.broker, path)
cls.peers = cls.connection.topo(num_peers)
yield from cls.peers

Expand Down
2 changes: 2 additions & 0 deletions tests/loaders/test_mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def publish(self, topic: str, *args) -> None:
begin = int(tokens[4], 16)
end = int(tokens[5], 16)
response.payload = self.disks[int(tokens[3])][begin : begin + end]
else:
return
self.on_message(self, None, response)


Expand Down