#!/usr/bin/python3
# -*- coding: utf-8 -*-

"""
Stressant is a simple yet complete stress-testing tool that forces
a computer to perform a series of test using well-known Linux software
in order to detect possible design or construction failures.
"""

# Copyright (C) 2017 Antoine Beaupré <anarcat@debian.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.


import getpass
import logging
from logging.handlers import SMTPHandler, MemoryHandler
import multiprocessing
import os
import os.path
import re

import argparse
import smtplib
import socket
import subprocess
import tempfile
import time

try:
    import colorlog
    if 'StreamHandler' not in dir(colorlog):
        colorlog = False
except ImportError:
    colorlog = False
import humanize

try:
    from setuptools_scm import get_version
    __version__ = get_version()
except (ImportError, LookupError):
    # try the local generated version
    #
    # XXX: this may load an arbitrary version in another package!
    try:
        from __version import __version__
    except ImportError:
        __version__ = '???'


class NegateAction(argparse.Action):
    '''add a toggle flag to argparse

    this is similar to 'store_true' or 'store_false', but allows
    arguments prefixed with --no to disable the default. the default
    is set depending on the first argument - if it starts with the
    negative form (define by default as '--no'), the default is False,
    otherwise True.
    '''

    negative = '--no'

    def __init__(self, option_strings, *args, **kwargs):
        '''set default depending on the first argument'''
        default = not option_strings[0].startswith(self.negative)
        super(NegateAction, self).__init__(option_strings, *args,
                                           default=default, nargs=0, **kwargs)

    def __call__(self, parser, ns, values, option):
        '''set the truth value depending on whether
        it starts with the negative form'''
        setattr(ns, self.dest, not option.startswith(self.negative))


def parseArgs():
    '''parse commandline arguments and set defaults'''

    ####################################################################
    # XXX: WARNING! update doc/usage.rst when you change usage!
    ####################################################################
    parser = argparse.ArgumentParser(epilog=__doc__)
    parser.add_argument('--version', action='version', version=__version__)
    parser.add_argument('--logfile', default=None, metavar='PATH',
                        help='write reports to the given logfile (default: %(default)s)')
    parser.add_argument('--email', help='send report by email to given address')
    parser.add_argument('--smtpserver', metavar='HOST',
                        help=('SMTP server to use, use a colon to specify '
                              'the port number if non-default (%(port)d).'
                              ' willl attempt to use STARTTLS to secure '
                              'the connexion and fail if unsupported '
                              '(default: deliver using the --mta '
                              'command)') %
                        {'port': smtplib.SMTP_PORT})
    parser.add_argument('--smtpuser', metavar='USERNAME',
                        help=('username for the SMTP server '
                              '(default: no user)'))
    parser.add_argument('--smtppass', metavar='PASSWORD',
                        help=('password for the SMTP server '
                              '(default: prompted, if --smtpuser is '
                              'specified)'))
    parser.add_argument('--information', '--no-information', action=NegateAction,
                        help='gather basic information (default: %(default)s)')
    parser.add_argument('--disk', '--no-disk', metavar='PATH', action=NegateAction,
                        help='run disk tests (default: %(default)s)')
    parser.add_argument('--no-smart', '--smart', dest='smart', action=NegateAction,
                        help='run SMART tests (default: %(default)s)')
    # XXX: disk detection could be done in a number of ways:
    #
    # * psutil.disk_partitions() - only lists mounted, but psutil also has
    #   features like checking amount of RAM, sensors and network..
    #
    # * parsing /proc/partitions
    #
    # * glob !/sys/block/%s/device/block/%s/removable
    parser.add_argument('--diskDevice', default='/dev/sda', metavar='PATH',
                        help='device to benchmark (default: %(default)s)')
    ####################################################################
    # XXX: WARNING! update doc/usage.rst when you change usage!
    ####################################################################
    parser.add_argument('--jobFile', metavar='PATH',
                        default='/usr/share/doc/fio/examples/basic-verify.fio',
                        help='path to the fio job file to use (default: %(default)s)')
    parser.add_argument('--overwrite', action='store_true',
                        help='actually destroy the given device (default: %(default)s)')
    parser.add_argument('--writeSize', default='100M', metavar='SIZE',
                        help='size to write to disk, bytes or percentage (default: %(default)s)')
    parser.add_argument('--directory', default=None, metavar='PATH',
                        help='directory to perform file tests in, created if missing (default: %(default)s)')  # noqa: E501
    parser.add_argument('--diskRuntime', default='1m',
                        help='upper limit for disk benchmark (default: %(default)s)')
    parser.add_argument('--cpu', '--no-cpu', action=NegateAction,
                        help='run CPU tests (default: %(default)s)')
    parser.add_argument('--cpuBurnTime', default='1m', metavar='TIME',
                        help='timeout for CPU burn-in (default: %(default)s)')
    parser.add_argument('--network', '--no-network', action=NegateAction,
                        help='run network tests (default: %(default)s)')
    # see also https://iperf.fr/iperf-servers.php
    # XXX: we chose he.net because they are nice, but ideally we:
    # 1. would ask first
    # 2. have a DNS round-robin for this, like NTP
    parser.add_argument('--iperfServer', default='iperf.he.net', metavar='HOST',
                        help='iperf server to use (default: %(default)s)')
    parser.add_argument('--iperfTime', default=str(60), metavar='TIME',
                        help='timeout for iperf test, in seconds (default: %(default)s)')
    ####################################################################
    # XXX: WARNING! update doc/usage.rst when you change usage!
    ####################################################################
    return parser.parse_args()


class BufferedSMTPHandler(SMTPHandler, MemoryHandler):
    """A handler class which sends records only when the buffer reaches
    capacity. The object is constructed with the arguments from
    SMTPHandler and MemoryHandler and basically behaves as a merge
    between the two classes.

    The SMTPHandler.emit() implementation was copy-pasted here because
    it is not flexible enough to be overridden. We could possibly
    override the format() function to instead look at the internal
    buffer, but that would have possibly undesirable side-effects.
    """

    # retry delay, in seconds
    DELAY = 5*60

    def __init__(self, mailhost, fromaddr, toaddrs, subject,
                 credentials=None, secure=None,
                 capacity=5000, flushLevel=logging.ERROR, retries=1):
        SMTPHandler.__init__(self, mailhost, fromaddr, toaddrs, subject,
                             credentials=None, secure=None)
        self.retries = retries
        MemoryHandler.__init__(self, capacity=capacity, flushLevel=flushLevel)

    def emit(self, record):
        '''buffer the record in the MemoryHandler'''
        MemoryHandler.emit(self, record)

    def flush(self):
        """Flush all records.

        Format the records and send it to the specified addressees.

        The only change from SMTPHandler here is the way the email
        body is created.

        """
        if len(self.buffer) <= 0:
            return
        body = ''
        for record in self.buffer:
            body += self.format(record) + "\n"
        # XXX: this is a copy of SMTPHandler.emit from Python
        # 3.7.3. We ship here because we want to improve on the error
        # handling (namely to work with temporary errors) but mostly
        # because we don't want to send an email per record, of
        # course.
        try:
            import smtplib
            from email.message import EmailMessage
            import email.utils

            port = self.mailport
            if not port:
                port = smtplib.SMTP_PORT
            smtp = smtplib.SMTP(self.mailhost, port, timeout=self.timeout)
            # change from stdlib: record already formatted
            # msg = self.format(record)
            msg = EmailMessage()
            msg['From'] = self.fromaddr
            msg['To'] = ','.join(self.toaddrs)
            msg['Subject'] = self.getSubject(record)
            msg['Date'] = email.utils.localtime()
            msg.set_content(body)
            # change from stdlib: use TLS without requiring username
            if self.secure is not None:
                smtp.ehlo()
                smtp.starttls(*self.secure)
                smtp.ehlo()
            if self.username:
                smtp.login(self.username, self.password)
            smtp.send_message(msg)
            smtp.quit()
            logging.info('sent email to %s using %s', self.toaddrs, self.mailhost)
            super(MemoryHandler, self).flush()
        except (KeyboardInterrupt, SystemExit):
            raise
        except smtplib.SMTPRecipientsRefused as e:
            for email, error in e.recipients.items():
                if error[0] == 450:  # greylisting
                    logging.info('got temporary error: "%s". waiting %d seconds for email',
                                 e, self.DELAY)
                    self.retries -= 1
                    time.sleep(self.DELAY)
                    self.lastException = e
                    if self.retries >= 0:
                        # recurse
                        self.flush()
                    else:
                        logging.error('Could not send email: %s, dropping records', e)
                        super(MemoryHandler, self).flush()
                        return
        except smtplib.SMTPException:
            self.handleError(record)


def setupLogging(logfile=None, email=None,
                 smtpserver=None, smtpuser=None, smtppass=None,
                 **args):
    '''setup standard Python logging facilities

    we create a new facility called "output" to avoid coloring command
    output and distinguishing it from out own output

    we also setup various other logging handlers as specified on the
    commandline
    '''
    defaultFormat = '%(levelname)s: %(message)s'
    logging.OUTPUT = logging.INFO + 1
    logging.addLevelName(logging.OUTPUT, 'OUTPUT')
    if colorlog:
        handler = colorlog.StreamHandler()
        handler.setFormatter(colorlog.ColoredFormatter('%(log_color)s' + defaultFormat))
        logger = colorlog.getLogger('')
    else:
        logger = logging.getLogger('')
    logger.setLevel(logging.DEBUG)
    if colorlog:
        logger.addHandler(handler)
    if logfile:
        handler = logging.FileHandler(logfile)
        handler.setFormatter(logging.Formatter(defaultFormat))
        logger.addHandler(handler)
    if email:
        if not smtpserver:
            _, smtpserver = email.split('@', 1)
        # XXX: need to do MX discovery
        fromaddr = getpass.getuser() + '@' + socket.getfqdn()
        subject = 'Stressant report'
        credentials = None
        if smtpuser:
            if smtppass:
                smtppass = getpass.getpass('enter SMTP password for server %s: ' % smtpserver)
            credentials = (smtpuser, smtppass)
        handler = BufferedSMTPHandler(smtpserver,
                                      fromaddr,
                                      email,
                                      subject,
                                      secure=(),
                                      credentials=credentials,
                                      flushLevel=logging.CRITICAL)
        handler.setFormatter(logging.Formatter(defaultFormat))
        logger.addHandler(handler)


def collectCmd(args):
    '''collect output from the given command and feed it into the logging system'''
    logging.debug('Calling %s', ' '.join(args))
    try:
        proc = subprocess.Popen(args, stdout=subprocess.PIPE,
                                stderr=subprocess.STDOUT)
    except OSError as e:
        logging.error('Command failed: %s', e)
        return
    for line in proc.stdout:
        logging.log(logging.OUTPUT, line.strip().decode('utf-8'))
    returnCode = proc.wait()
    if returnCode != 0:
        logging.error("Command failed: Command '%s' returned non-zero exit status %d",
                      ' '.join(args), returnCode)


def collectCmdWithTmp(args):
    '''this will create a tempfile and append it to the last argument of the command

    the goal is to be able to run commands interactively. fio, for
    example, shows stuff on stderr that need to be unbuffered and
    shouldn't show up in logs'''
    with tempfile.NamedTemporaryFile() as tmpfile:
        args[-1] += tmpfile.name
        logging.debug('Calling %s', ' '.join(args))
        try:
            subprocess.check_call(args)
        except subprocess.CalledProcessError as e:
            logging.error("Command failed: %s", e)
        for line in tmpfile.file:
            logging.log(logging.OUTPUT, line.rstrip().decode('utf-8'))


def gatherInfo(diskDevice='/dev/null', **args):
    '''gather basic information from system'''
    cpuCount = multiprocessing.cpu_count()
    logging.info("CPU cores: %d", cpuCount)

    memory = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
    human = humanize.naturalsize(memory, binary=True, format="%.0f")
    logging.info("Memory: %s (%d bytes)", human, memory)

    logging.info("Hardware inventory")
    collectCmd(["lshw", "-short"])

    logging.info("SMART information for %s", diskDevice)
    collectCmd(["smartctl", "-qnoserial", "-i", diskDevice])


def testDrive(overwrite=False, diskDevice='/dev/null',
              writeSize=None, diskRuntime=None, reportDir='.',
              smart=True, directory=None, jobFile=None, **args):
    '''disk tests'''
    if directory and not os.path.exists(directory):
        os.makedirs(directory)
    if directory:
        _, testFile = tempfile.mkstemp(dir=directory)
        logging.info("Basic disk bandwidth tests")
        logging.info("Writing 1MB file %s", testFile)
        collectCmd(["dd", "bs=1M", "count=512", "conv=fdatasync",
                    "if=/dev/zero", "of=" + testFile])
        logging.info("Reading 1MB file %s", testFile)
        collectCmd(["dd", "bs=1M", "count=512", "of=/dev/null", "if=" + testFile])
        os.unlink(testFile)
    else:
        logging.warning('no dd test ran, provide --directory to run')
    logging.info("hdparm test on %s", diskDevice)
    collectCmd(["hdparm", "-Tt", diskDevice])

    logging.info("Disk stress test")
    # --group_reporting, give only one report, not one per job
    cmd = ["fio", "--name=stressant", "--group_reporting", "--runtime=" + diskRuntime]

    with tempfile.NamedTemporaryFile(suffix=os.path.basename(jobFile)) as tmpfile:
        with open(jobFile, 'rb') as source:
            # comment out filename in the job, otherwise fio ignores our arguments
            newjob = re.sub(rb'^filename=.*$', rb'#\g<0>', source.read(), flags=re.MULTILINE)
            tmpfile.file.write(newjob)
        jobFile = tmpfile.name
        cmd += [jobFile]

        if overwrite:
            # XXX: this is supposed to wipe the drive, but is that enough?
            # see https://www.backblaze.com/blog/how-to-securely-recycle-or-dispose-of-your-ssd/
            cmd += ["--size=" + writeSize, "--filename=" + diskDevice]
        elif directory:
            cmd += ["--size=" + writeSize, "--directory=" + directory]
        else:
            logging.error('--overwrite or --directory not specified, no fio test ran')
            return
        cmd += ["--output="]
        # more ideas:
        # https://wiki.mikejung.biz/Benchmarking#Fio_Test_Options_and_Examples
        # https://gist.github.com/tcooper/9417014
        # https://github.com/GoogleCloudPlatform/PerfKitBenchmarker
        # how to precondition for SSD benchmarks:
        # https://www.spinics.net/lists/fio/msg02496.html
        collectCmdWithTmp(cmd)

    # logging.info("How long a test takes")
    # XXX: i often need to use -d sat on external drives
    if smart:
        # XXX: need to parse this and wait and do magic
        # XXX: this is already available in -a, above
        # collectCmd(["smartctl", "-c", diskDevice])
        logging.info("Starting long SMART test")
        collectCmd(["smartctl", "-t", "long", diskDevice])
        logging.info("use 'smartctl -l selftest %s' to see test results", diskDevice)
        # the above says:
        # Please wait 10 minutes for test to complete.
        # Test will complete after Wed Jan  4 21:28:11 2017
        # in 10 minutes:
        # smartctl -l selftest $disk
        # smartctl -a $disk says:
        # Self-test execution status:      ( 249)	Self-test routine in progress...
        #                                               90% of test remaining.


def testCpu(cpuBurnTime=None, reportDir='.', **args):
    '''stress-test the CPU'''
    logging.info("CPU stress test for %s", cpuBurnTime)
    cmd = ["stress-ng", "--timeout", cpuBurnTime,
           "--cpu", "0", "--ignite-cpu",
           "--metrics-brief", "--log-brief",
           "--tz", "--times", "--aggressive"]
    collectCmd(cmd)
    # --matrix 0 is apparently the best way to heat up the CPU
    #
    # --verify would be important, not sure it works with CPU.
    #
    # according to this it works with --vm:
    # https://wiki.ubuntu.com/Kernel/Reference/stress-ng
    # also, i7z is useful to show the status of the CPU, including temperatures

    # similar tools:
    # linpack: not in debian
    # mprime: not in debian, not free software
    # systester: not in debian


def testNetwork(iperfServer=None, iperfTime=None, **args):
    '''basic network tests'''
    logging.info('Running network benchmark')
    # we use iperf, but apparently netperf is more effective:
    # https://www.bufferbloat.net/projects/cerowrt/wiki/Netperf/
    # see also this article:
    # http://iwl.com/white-papers/iperf
    collectCmd(['iperf3', '-c', iperfServer, '-t', iperfTime])


def main():
    args = parseArgs()
    setupLogging(**vars(args))
    logging.info('Starting tests')
    if args.information:
        gatherInfo(**vars(args))
    if args.disk:
        testDrive(**vars(args))
    if args.cpu:
        testCpu(**vars(args))
    if args.network:
        testNetwork(**vars(args))
    logging.info("all done")
    # make sure emails get flushed
    logging.shutdown()


if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        logging.error("Interrupted")
