#!/usr/bin/env python
#
# Copyright 2011 Jared Boone
#
# This file is part of Project Ubertooth.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; see the file COPYING.  If not, write to
# the Free Software Foundation, Inc., 51 Franklin Street,
# Boston, MA 02110-1301, USA.

import signal
import sys
import threading
import numpy

from argparse import ArgumentParser

from PySide2 import QtCore, QtGui, QtWidgets
from PySide2.QtCore import Qt, QPointF, QLineF

from specan import Ubertooth

DEFAULT_LOWER_FREQ = 2400
DEFAULT_UPPER_FREQ = 2483

# going much further causes the Ubertooth to stop responding :(

MIN_FREQ = 2300
MAX_FREQ = 2600

class SpecanThread(threading.Thread):
    def __init__(self, device, low_frequency, high_frequency, new_frame_callback, ubertooth_device=-1):
        threading.Thread.__init__(self)
        self.daemon = True

        self._device = device
        self._ubertooth_device = ubertooth_device
        self._low_frequency = low_frequency
        self._high_frequency = high_frequency
        self._new_frame_callback = new_frame_callback
        self._stopping = False
        self._stopped = False

    def run(self):
        frame_source = self._device.specan(self._low_frequency, self._high_frequency, ubertooth_device=self._ubertooth_device)
        for frequency_axis, rssi_values in frame_source:
            self._new_frame_callback(numpy.copy(frequency_axis), numpy.copy(rssi_values))
            if self._stopping:
                break

    def stop(self):
        self._stopping = True
        self.join(3.0)
        self._stopped = True


class RenderArea(QtWidgets.QWidget):
    def __init__(self, device, parent=None, ubertooth_device=-1, lower_freq=DEFAULT_LOWER_FREQ, upper_freq=DEFAULT_UPPER_FREQ):
        QtWidgets.QWidget.__init__(self, parent)

        self._graph = None
        self._reticle = None

        self._device = device
        self._frame = None
        self._persisted_frames = None
        self._persisted_frames_depth = 350
        self._path_max = None

        self._low_frequency = lower_freq * 1e6
        self._high_frequency = upper_freq * 1e6
        self._frequency_step = 1e6
        self._high_dbm = 0.0
        self._low_dbm = -100.0

        self._hide_markers = False
        self._mouse_x = None
        self._mouse_y = None
        self._mouse_x2 = None
        self._mouse_y2 = None

        self._clear_scheduled = False

        self._thread = SpecanThread(self._device,
                                    self._low_frequency,
                                    self._high_frequency,
                                    self._new_frame,
                                    ubertooth_device=ubertooth_device)
        self._thread.start()

    def schedule_clear(self):
        self._clear_scheduled = True

    def stop_thread(self):
        self._thread.stop()

    def _new_graph(self):
        self._graph = QtGui.QPixmap(self.width(), self.height())
        self._graph.fill(Qt.black)

    def _new_reticle(self):
        self._reticle = QtGui.QPixmap(self.width(), self.height())
        self._reticle.fill(Qt.transparent)

    def _new_persisted_frames(self, frequency_bins):
        self._persisted_frames = numpy.empty((self._persisted_frames_depth, frequency_bins))
        self._persisted_frames.fill(-128 + -54)
        self._persisted_frames_next_index = 0

    def minimumSizeHint(self):
        x_points = round((self._high_frequency - self._low_frequency) / self._frequency_step)
        y_points = round(self._high_dbm - self._low_dbm)
        return QtCore.QSize(x_points * 4, y_points * 1)

    def _new_frame(self, frequency_axis, rssi_values):

        self._frame = (frequency_axis, rssi_values)
        if self._persisted_frames is None:
            self._new_persisted_frames(len(frequency_axis))
        self._persisted_frames[self._persisted_frames_next_index] = rssi_values
        self._persisted_frames_next_index = (self._persisted_frames_next_index + 1) % self._persisted_frames.shape[0]
        self.update()

    def _draw_graph(self):
        if self._clear_scheduled:
            frequency_axis, _ = self._frame
            self._clear_scheduled = False
            self._new_graph()
            self._new_persisted_frames(len(frequency_axis))

        if self._graph is None:
            self._new_graph()
        elif self._graph.size() != self.size():
            self._new_graph()

        painter = QtGui.QPainter(self._graph)
        try:
            painter.setRenderHint(QtGui.QPainter.Antialiasing)
            painter.fillRect(0, 0, self._graph.width(), self._graph.height(), QtGui.QColor(0, 0, 0, 10))

            if self._frame:
                frequency_axis, rssi_values = self._frame

                path_now = QtGui.QPainterPath()
                path_max = QtGui.QPainterPath()

                bins = range(len(frequency_axis))
                x_axis = self._hz_to_x(frequency_axis)
                y_now = self._dbm_to_y(rssi_values)
                y_max = self._dbm_to_y(numpy.amax(self._persisted_frames, axis=0))

                path_now.moveTo(float(x_axis[0]), float(y_now[0]))
                for i in bins:
                    path_now.lineTo(float(x_axis[i]), float(y_now[i]))

                path_max.moveTo(float(x_axis[0]), float(y_max[0]))
                db_tmp = self._low_dbm
                max_max = None
                for i in bins:
                    path_max.lineTo(float(x_axis[i]), float(y_max[i]))
                    if self._y_to_dbm(y_max[i]) > db_tmp:
                        db_tmp = self._y_to_dbm(y_max[i])
                        max_max = i

                pen = QtGui.QPen()
                pen.setBrush(Qt.white)
                painter.setPen(pen)
                painter.drawPath(path_now)
                self._path_max = path_max
                if max_max is not None and not self._hide_markers:
                    pen.setBrush(Qt.red)
                    pen.setStyle(Qt.DotLine)
                    painter.setPen(pen)
                    painter.drawText(QPointF(x_axis[max_max] + 4, 30), '%.06f' % (self._x_to_hz(x_axis[max_max]) / 1e6))
                    painter.drawText(QPointF(30, y_max[max_max] - 4), '%d' % (self._y_to_dbm(y_max[max_max])))
                    painter.drawLine(QPointF(x_axis[max_max], 0), QPointF(x_axis[max_max], self.height()))
                    painter.drawLine(QPointF(0, y_max[max_max]), QPointF(self.width(), y_max[max_max]))
                    if self._mouse_x:
                        painter.drawText(QPointF(self._hz_to_x(self._mouse_x) + 4, 58), '(%.06f)' % ((self._x_to_hz(x_axis[max_max]) / 1e6) - (self._mouse_x / 1e6)))
                        pen.setBrush(Qt.yellow)
                        painter.setPen(pen)
                        painter.drawText(QPointF(self._hz_to_x(self._mouse_x) + 4, 44), '%.06f' % (self._mouse_x / 1e6))
                        painter.drawText(QPointF(54, self._dbm_to_y(self._mouse_y) - 4), '%d' % (self._mouse_y))
                        painter.drawLine(QPointF(self._hz_to_x(self._mouse_x), 0), QPointF(self._hz_to_x(self._mouse_x), self.height()))
                        painter.drawLine(QPointF(0, self._dbm_to_y(self._mouse_y)), QPointF(self.width(), self._dbm_to_y(self._mouse_y)))
                        if self._mouse_x2:
                            painter.drawText(QPointF(self._hz_to_x(self._mouse_x2) + 4, 118), '(%.06f)' % ((self._mouse_x / 1e6) - (self._mouse_x2 / 1e6)))
                    if self._mouse_x2:
                        pen.setBrush(Qt.red)
                        painter.setPen(pen)
                        painter.drawText(QPointF(self._hz_to_x(self._mouse_x2) + 4, 102), '(%.06f)' % ((self._x_to_hz(x_axis[max_max]) / 1e6) - (self._mouse_x2 / 1e6)))
                        pen.setBrush(Qt.magenta)
                        painter.setPen(pen)
                        painter.drawText(QPointF(self._hz_to_x(self._mouse_x2) + 4, 88), '%.06f' % (self._mouse_x2 / 1e6))
                        painter.drawText(QPointF(78, self._dbm_to_y(self._mouse_y2) - 4), '%d' % (self._mouse_y2))
                        painter.drawLine(QPointF(self._hz_to_x(self._mouse_x2), 0), QPointF(self._hz_to_x(self._mouse_x2), self.height()))
                        painter.drawLine(QPointF(0, self._dbm_to_y(self._mouse_y2)), QPointF(self.width(), self._dbm_to_y(self._mouse_y2)))
                        if self._mouse_x:
                            painter.drawText(QPointF(self._hz_to_x(self._mouse_x) + 4, 74), '(%.06f)' % ((self._mouse_x2 / 1e6) - (self._mouse_x / 1e6)))
        finally:
            painter.end()

    def _draw_reticle(self):
        if self._reticle is None or (self._reticle.size() != self.size()):
            self._new_reticle()

            dbm_lines = [QLineF(self._hz_to_x(self._low_frequency), self._dbm_to_y(dbm),
                                self._hz_to_x(self._high_frequency), self._dbm_to_y(dbm))
                         for dbm in numpy.arange(self._low_dbm, self._high_dbm, 20.0)]
            dbm_labels = [(dbm, QPointF(self._hz_to_x(self._low_frequency) + 2, self._dbm_to_y(dbm) - 2))
                          for dbm in numpy.arange(self._low_dbm, self._high_dbm, 20.0)]

            frequency_lines = [QLineF(self._hz_to_x(frequency), self._dbm_to_y(self._high_dbm),
                                      self._hz_to_x(frequency), self._dbm_to_y(self._low_dbm))
                               for frequency in numpy.arange(self._low_frequency, self._high_frequency, self._frequency_step * 10.0)]
            frequency_labels = [(frequency, QPointF(self._hz_to_x(frequency) + 2, self._dbm_to_y(self._high_dbm) + 10))
                                for frequency in numpy.arange(self._low_frequency, self._high_frequency, self._frequency_step * 10.0)]

            painter = QtGui.QPainter(self._reticle)
            try:
                painter.setRenderHint(QtGui.QPainter.Antialiasing)

                painter.setPen(Qt.blue)

                # TODO: Removed to support old (<1.0) PySide API in Ubuntu 10.10
                # painter.drawLines(dbm_lines)
                for dbm_line in dbm_lines: painter.drawLine(dbm_line)
                # TODO: Removed to support old (<1.0) PySide API in Ubuntu 10.10
                # painter.drawLines(frequency_lines)
                for frequency_line in frequency_lines: painter.drawLine(frequency_line)

                painter.setPen(Qt.white)
                for dbm, point in dbm_labels:
                    painter.drawText(point, '%+.0f' % dbm)
                for frequency, point in frequency_labels:
                    painter.drawText(point, '%.0f' % (frequency / 1e6))

            finally:
                painter.end()

    def paintEvent(self, event):
        self._draw_graph()
        self._draw_reticle()

        painter = QtGui.QPainter(self)
        try:
            painter.setRenderHint(QtGui.QPainter.Antialiasing)
            painter.setPen(QtGui.QPen())
            painter.setBrush(QtGui.QBrush())

            if self._graph:
                painter.drawPixmap(0, 0, self._graph)

            if self._path_max:
                painter.setPen(Qt.green)
                painter.drawPath(self._path_max)

            painter.setOpacity(0.5)
            if self._reticle:
                painter.drawPixmap(0, 0, self._reticle)
        finally:
            painter.end()

    def _hz_to_x(self, frequency_hz):
        delta = frequency_hz - self._low_frequency
        range = self._high_frequency - self._low_frequency
        normalized = delta / range
        return normalized * self.width()

    def _x_to_hz(self, x):
        range = self._high_frequency - self._low_frequency
        tmp = x / self.width()
        delta = tmp * range
        return delta + self._low_frequency

    def _dbm_to_y(self, dbm):
        delta = self._high_dbm - dbm
        range = self._high_dbm - self._low_dbm
        normalized = delta / range
        return normalized * self.height()

    def _y_to_dbm(self, y):
        range = self._high_dbm - self._low_dbm
        tmp = y / self.height()
        delta = tmp * range
        return self._high_dbm - delta


class Window(QtWidgets.QWidget):
    def __init__(self, parent=None, ubertooth_device=-1, lower_freq=DEFAULT_LOWER_FREQ, upper_freq=DEFAULT_UPPER_FREQ):
        QtWidgets.QWidget.__init__(self, parent)

        self._device = self._open_device()

        self.render_area = RenderArea(self._device, ubertooth_device=ubertooth_device, lower_freq=lower_freq, upper_freq=upper_freq)

        main_layout = QtWidgets.QGridLayout()
        main_layout.setContentsMargins(0, 0, 0, 0)
        main_layout.addWidget(self.render_area, 0, 0)
        self.setLayout(main_layout)

        self.setWindowTitle("Ubertooth Spectrum Analyzer")

    def sizeHint(self):
        return QtCore.QSize(480, 160)

    def _open_device(self):
        return Ubertooth.Ubertooth()

    def closeEvent(self, event):
        self.render_area.stop_thread()
        self._device.close()
        event.accept()

    # handle mouse button clicks
    def mousePressEvent(self, event):
        if event.button() == Qt.LeftButton:
            self.render_area._mouse_x = self.render_area._x_to_hz(float(event.x()))
            self.render_area._mouse_y = self.render_area._y_to_dbm(float(event.y()))
            self.render_area._hide_markers = False
        if event.button() == Qt.RightButton:
            self.render_area._mouse_x2 = self.render_area._x_to_hz(float(event.x()))
            self.render_area._mouse_y2 = self.render_area._y_to_dbm(float(event.y()))
            self.render_area._hide_markers = False
        if event.button() == Qt.MidButton:
            self.render_area._mouse_x = None
            self.render_area._mouse_y = None
            self.render_area._mouse_x2 = None
            self.render_area._mouse_y2 = None
            self.render_area._hide_markers = not self.render_area._hide_markers
        event.accept()
        return

    # handle key presses
    def keyPressEvent(self, event):
        try:
            key = chr(event.key()).upper()
            event.accept()
        except:
            print('Unknown key pressed: 0x%x' % event.key())
            event.ignore()
            return
        if key == 'H':
            print('Key                  Action\n')
            print(' <LEFT MOUSE>        Mark LEFT frequency / signal strength at pointer')
            print(' <RIGHT MOUSE>       Mark RIGHT frequency / signal strength at pointer')
            print(' <MIDDLE MOUSE>      Toggle visibility of frequency / signal strength markers')
            print(' C                   Clear graph')
            print(' H                   Print this HELP text')
            print(' M                   Simulate MIDDLE MOUSE click (for those with trackpads)')
            print(' Q                   Quit')
            return
        if key == 'M':
            self.render_area._mouse_x = None
            self.render_area._mouse_y = None
            self.render_area._mouse_x2 = None
            self.render_area._mouse_y2 = None
            self.render_area._hide_markers = not self.render_area._hide_markers
            return
        if key == 'C':
            self.render_area.schedule_clear()
            return
        if key == 'Q':
            print('Quit!')
            self.close()
            return
        print('Unsupported key pressed:', key)


def sigint_handler(*args):
    """Handler for the SIGINT signal."""
    QtWidgets.QApplication.quit()


def convert_wifi(channel):
    if channel < 1 or channel > 14:
        print("ERROR: channel " + str(channel) + " is not a valid wifi channel")
        raise ValueError()

    if channel == 14:
        return 2482
    else:
        return channel * 5 + 2407

def check_freq(freq):
    if freq < MIN_FREQ:
        print("ERROR: frequency of " + str(freq) + " MHz is below minimum frequency of " + str(MIN_FREQ))
        raise ValueError()
    if freq > MAX_FREQ:
        print("ERROR: frequency of " + str(freq) + " MHz is above maximum frequency of " + str(MAX_FREQ))
        raise ValueError()


def check_freq_pair(freq1, freq2):
    check_freq(freq1)
    check_freq(freq2)

    if freq1 > freq2:
        print("ERROR: lower frequency of " + str(freq1) + " MHz is above upper frequency of " + str(freq2) + " MHz")
        raise ValueError()

if __name__ == '__main__':
    signal.signal(signal.SIGINT, sigint_handler)

    parser = ArgumentParser()
    parser.add_argument("-U", type=int, dest="device",
                      help="set ubertooth device to use")
    parser.add_argument("-l", type=int, dest="lower_freq", help="lower bound for scan, in MHz (no less than " + str(MIN_FREQ) + ")")
    parser.add_argument("-u", type=int, dest="upper_freq", help="upper bound for scan, in MHz (no more than " + str(MAX_FREQ) + ")")
    parser.add_argument("--wifi", type=str, nargs='?', dest="wifi", metavar="channel(s)", help="display the spectrum for the wifi channels provided, either as a single number for one channel, or a range (e.g. 1-11) for two channels", const="1", default=False)
    parser.add_argument("--padding", type=int, dest="padding", help="padding on both ends when using --wifi, measured in MHz (default 10)", default=10)

    (options, extras) = parser.parse_known_args()

    ubertooth_device = options.device

    if ubertooth_device is None:
        ubertooth_device = -1

    lower_freq = options.lower_freq
    upper_freq = options.upper_freq

    if options.wifi:

        lower_channel = upper_channel = None

        parts = options.wifi.split("-")

        try:
            if len(parts) == 1:
                lower_channel = upper_channel = int(parts[0])
            elif len(parts) == 2:
                lower_channel = int(parts[0])
                upper_channel = int(parts[1])
            else:
                raise ValueError()
        except ValueError:
            print("ERROR: invalid channel range: " + options.wifi)
            sys.exit(1)

        try:
            lower_freq = convert_wifi(lower_channel) - options.padding
            upper_freq = convert_wifi(upper_channel) + options.padding
        except ValueError:
            sys.exit(1)
    else:
        if not lower_freq:
            lower_freq = DEFAULT_LOWER_FREQ
        if not upper_freq:
            upper_freq = DEFAULT_UPPER_FREQ

    try:
        check_freq_pair(lower_freq, upper_freq)
    except ValueError:
        sys.exit(1)

    app = QtWidgets.QApplication(sys.argv)
    window = Window(ubertooth_device=ubertooth_device, lower_freq=lower_freq, upper_freq=upper_freq)
    window.show()
    sys.exit(app.exec_())
