#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import argparse
import json
import shutil
import math
import re
import sys
import numpy as np


CHARS = ['⣿', '⣷', '⣶', '⣦', '⣤', '⣄', '⡄', '⡀', '']


def get_terminal_size():
    if hasattr(shutil, 'get_terminal_size'):
        return shutil.get_terminal_size()

    rows, cols = os.popen('stty size', 'r').read().split()
    return (int(rows), int(cols))


def make_bars(begins, ends, total, width):
    line = ''
    previous = 0

    for i, (begin, end) in enumerate(zip(begins, ends)):
        span = float(end - begin) / total
        num_full_chars = int(math.floor(width * span))
        remaining = width * span - num_full_chars
        full_bar = CHARS[0] * num_full_chars
        remaining_bar = CHARS[int((len(CHARS) - 1)* (1 - remaining))]

        if i == 0:
            line += '{}{}'.format(full_bar, remaining_bar)
            previous = end
        else:
            spaces = ' ' * int((float(begin - previous) / total) * width)
            line += '{}{}{}'.format(spaces, full_bar, remaining_bar)
            previous = end

    return line


def analyse(fp, name_fmt_func, name_header):
    events = json.load(fp)['traceEvents']

    def select(key, where=lambda e: True):
        return [e[key] for e in events if where(e)]

    tasks = set(e['tid'] for e in events)
    starts = sorted([(t, min(select('ts', lambda e: e['tid'] == t))) for t in tasks], key=lambda t: t[1])
    ends = sorted([(t, max(select('ts', lambda e: e['tid'] == t))) for t in tasks], key=lambda t: t[1])
    tasks = [t[0] for t in starts]  # sorted task names

    start = starts[0][1]
    total = ends[-1][1] - start
    term_width = get_terminal_size()[0]

    begins = {t: select('ts', lambda e: e['tid'] == t and e['ph'] == 'B') for t in tasks}
    ends = {t: select('ts', lambda e: e['tid'] == t and e['ph'] == 'E') for t in tasks}
    times = {t: [end - begin for begin, end in zip(begins[t], ends[t])] for t in tasks}

    for task in tasks:
        name = name_fmt_func(task)
        name = name[:7] + '…' if len(name) > 8 else name + ' ' * (8 - len(name))
        name = name.encode('utf-8')
        print(('{} {}'.format(name, make_bars(begins[task], ends[task], total, term_width - 11))))

    print(("\n {: <16} | {: >4} | {: >12} | {: >12} | {: >5}".format(name_header, '#', 'Total (ms)', 'Mean (ms)', '%')))
    print(('-' * (16 + 4 + 12 + 12 + 9 + 10)))

    for task, tt in ((task, times[task]) for task in tasks):
        name = name_fmt_func(task)
        total_time = np.sum(tt) / 1000.0
        mean_time = np.mean(tt) / 1000.0
        num_events = len(tt)
        percentage = float(np.sum(tt)) / total * 100
        fmt = ' {: <16} | {: >4} | {: >12.4f} | {: >12.4f} | {: >3.2f}'
        print((fmt.format(name, num_events, total_time, mean_time, percentage)))


def analyse_trace(fp):
    def relevant_name(task):
        return re.match(r'Ufo([A-Za-z]*)Task-[a-f0-9x]*', task).group(1)

    analyse(fp, relevant_name, 'Task')


def analyse_opencl(fp):
    analyse(fp, lambda x: x, 'Kernel')


def main():
    parser = argparse.ArgumentParser()

    def open_valid_file(arg):
        if not os.path.exists(arg):
            parser.error("`{}' does not exist.".format(arg))
        else:
            return open(arg, 'r'), arg

    group = parser.add_mutually_exclusive_group(required=False)
    group.add_argument('--trace', action='store_true', default=None, help="JSON input is a trace file")
    group.add_argument('--opencl', action='store_true', default=None, help="JSON input is OpenCL trace")

    parser.add_argument('input', type=open_valid_file)

    args = parser.parse_args()

    if args.trace or re.match(r'trace\..*\.json', args.input[1]):
        analyse_trace(args.input[0])
        sys.exit(0)

    if args.opencl or re.match(r'opencl\..*\.json', args.input[1]):
        analyse_opencl(args.input[0])
        sys.exit(0)

    print("Could not guess trace file type, please supply --trace or --opencl.")


if __name__ == '__main__':
    main()
