#!/usr/bin/env python3

import logging
import sys
import argparse
import time
import rpc


SPDK_CPU_STAT = "/proc/stat"
SPDK_UPTIME = "/proc/uptime"

SPDK_CPU_STAT_HEAD = ['cpu_stat:', 'user_stat', 'nice_stat',
                      'system_stat', 'iowait_stat', 'steal_stat', 'idle_stat']
SPDK_BDEV_KB_STAT_HEAD = ['Device', 'tps', 'KB_read/s',
                          'KB_wrtn/s', 'KB_dscd/s', 'KB_read', 'KB_wrtn', 'KB_dscd']
SPDK_BDEV_MB_STAT_HEAD = ['Device', 'tps', 'MB_read/s',
                          'MB_wrtn/s', 'MB_dscd/s', 'MB_read', 'MB_wrtn', 'MB_dscd']

SPDK_MAX_SECTORS = 0xffffffff


class BdevStat:

    def __init__(self, dictionary):
        if dictionary is None:
            return
        for k, value in dictionary.items():
            if k == 'name':
                self.bdev_name = value
            elif k == 'bytes_read':
                self.rd_sectors = value >> 9
            elif k == 'bytes_written':
                self.wr_sectors = value >> 9
            elif k == 'bytes_unmapped':
                self.dc_sectors = value >> 9
            elif k == 'num_read_ops':
                self.rd_ios = value
            elif k == 'num_write_ops':
                self.wr_ios = value
            elif k == 'num_unmap_ops':
                self.dc_ios = value
            elif k == 'read_latency_ticks':
                self.rd_ticks = value
            elif k == 'write_latency_ticks':
                self.wr_ticks = value
            elif k == 'unmap_latency_ticks':
                self.dc_ticks = value
            elif k == 'queue_depth':
                self.ios_pgr = value
            elif k == 'io_time':
                self.tot_ticks = value
            elif k == 'weighted_io_time':
                self.rq_ticks = value

        self.rd_merges = 0
        self.wr_merges = 0
        self.dc_merges = 0
        self.upt = 0.0

    def __getattr__(self, name):
        return 0


def uptime():
    with open(SPDK_UPTIME, 'r') as f:
        return float(f.readline().split()[0])


def _stat_format(data, header, leave_first=False):
    list_size = len(data)
    header_len = len(header)

    if list_size == 0:
        raise AssertionError
    list_len = len(data[0])

    for ll in data:
        if len(ll) != list_len:
            raise AssertionError
        for i, r in enumerate(ll):
            ll[i] = str(r)

    if (leave_first and list_len + 1 != header_len) or \
            (not leave_first and list_len != header_len):
        raise AssertionError

    item_sizes = [0 for i in range(header_len)]

    for i in range(0, list_len):
        if leave_first and i == 0:
            item_sizes[i] = len(header[i + 1])

        data_len = 0
        for x in data:
            data_len = max(data_len, len(x[i]))
        index = i + 1 if leave_first else i
        item_sizes[index] = max(len(header[index]), data_len)

    _format = '  '.join('%%-%ss' % item_sizes[i] for i in range(0, header_len))
    print(_format % tuple(header))
    if leave_first:
        print('\n'.join(_format % ('', *tuple(ll)) for ll in data))
    else:
        print('\n'.join(_format % tuple(ll) for ll in data))

    print()
    sys.stdout.flush()


def read_cpu_stat(last_cpu_info, cpu_info):
    jiffies = 0
    for i in range(0, 7):
        jiffies += cpu_info[i] - \
            (last_cpu_info[i] if last_cpu_info else 0)

    if last_cpu_info:
        info_stat = [
            "{:.2%}".format((cpu_info[0] - last_cpu_info[0]) / jiffies),
            "{:.2%}".format((cpu_info[1] - last_cpu_info[1]) / jiffies),
            "{:.2%}".format(((cpu_info[2] + cpu_info[5] + cpu_info[6]) -
                             (last_cpu_info[2] + last_cpu_info[5] + last_cpu_info[6])) / jiffies),
            "{:.2%}".format((cpu_info[4] - last_cpu_info[4]) / jiffies),
            "{:.2%}".format((cpu_info[7] - last_cpu_info[7]) / jiffies),
            "{:.2%}".format((cpu_info[3] - last_cpu_info[3]) / jiffies),
        ]
    else:
        info_stat = [
            "{:.2%}".format(cpu_info[0] / jiffies),
            "{:.2%}".format(cpu_info[1] / jiffies),
            "{:.2%}".format((cpu_info[2] + cpu_info[5]
                             + cpu_info[6]) / jiffies),
            "{:.2%}".format(cpu_info[4] / jiffies),
            "{:.2%}".format(cpu_info[7] / jiffies),
            "{:.2%}".format(cpu_info[3] / jiffies),
        ]

    _stat_format([info_stat], SPDK_CPU_STAT_HEAD, True)


def check_positive(value):
    v = int(value)
    if v <= 0:
        raise argparse.ArgumentTypeError("%s should be positive int value" % v)
    return v


def get_cpu_stat():
    with open(SPDK_CPU_STAT, "r") as cpu_file:
        cpu_dump_info = []
        line = cpu_file.readline()
        while line:
            line = line.strip()
            if "cpu " in line:
                cpu_dump_info = [int(data) for data in line[5:].split(' ')]
                break

            line = cpu_file.readline()
    return cpu_dump_info


def read_bdev_stat(last_stat, stat, mb, use_upt):
    if use_upt:
        upt_cur = uptime()
    else:
        upt_cur = stat['ticks']
        upt_rate = stat['tick_rate']

    info_stats = []
    unit = 2048 if mb else 2

    bdev_stats = []
    if last_stat:
        for bdev in stat['bdevs']:
            _stat = BdevStat(bdev)
            _stat.upt = upt_cur
            bdev_stats.append(_stat)
            _last_stat = None
            for last_bdev in last_stat:
                if (_stat.bdev_name == last_bdev.bdev_name):
                    _last_stat = last_bdev
                    break

            # get the interval time
            if use_upt:
                upt = _stat.upt - _last_stat.upt
            else:
                upt = (_stat.upt - _last_stat.upt) / upt_rate

            rd_sec = _stat.rd_sectors - _last_stat.rd_sectors
            if (_stat.rd_sectors < _last_stat.rd_sectors) and (_last_stat.rd_sectors <= SPDK_MAX_SECTORS):
                rd_sec &= SPDK_MAX_SECTORS

            wr_sec = _stat.wr_sectors - _last_stat.wr_sectors
            if (_stat.wr_sectors < _last_stat.wr_sectors) and (_last_stat.wr_sectors <= SPDK_MAX_SECTORS):
                wr_sec &= SPDK_MAX_SECTORS

            dc_sec = _stat.dc_sectors - _last_stat.dc_sectors
            if (_stat.dc_sectors < _last_stat.dc_sectors) and (_last_stat.dc_sectors <= SPDK_MAX_SECTORS):
                dc_sec &= SPDK_MAX_SECTORS

            tps = ((_stat.rd_ios + _stat.dc_ios + _stat.wr_ios) -
                   (_last_stat.rd_ios + _last_stat.dc_ios + _last_stat.wr_ios)) / upt

            info_stat = [
                _stat.bdev_name,
                "{:.2f}".format(tps),
                "{:.2f}".format(
                    (_stat.rd_sectors - _last_stat.rd_sectors) / upt / unit),
                "{:.2f}".format(
                    (_stat.wr_sectors - _last_stat.wr_sectors) / upt / unit),
                "{:.2f}".format(
                    (_stat.dc_sectors - _last_stat.dc_sectors) / upt / unit),
                "{:.2f}".format(rd_sec / unit),
                "{:.2f}".format(wr_sec / unit),
                "{:.2f}".format(dc_sec / unit),
            ]
            info_stats.append(info_stat)
    else:
        for bdev in stat['bdevs']:
            _stat = BdevStat(bdev)
            _stat.upt = upt_cur
            bdev_stats.append(_stat)

            if use_upt:
                upt = _stat.upt
            else:
                upt = _stat.upt / upt_rate

            tps = (_stat.rd_ios + _stat.dc_ios + _stat.wr_ios) / upt
            info_stat = [
                _stat.bdev_name,
                "{:.2f}".format(tps),
                "{:.2f}".format(_stat.rd_sectors / upt / unit),
                "{:.2f}".format(_stat.wr_sectors / upt / unit),
                "{:.2f}".format(_stat.dc_sectors / upt / unit),
                "{:.2f}".format(_stat.rd_sectors / unit),
                "{:.2f}".format(_stat.wr_sectors / unit),
                "{:.2f}".format(_stat.dc_sectors / unit),
            ]
            info_stats.append(info_stat)

    _stat_format(
        info_stats, SPDK_BDEV_MB_STAT_HEAD if mb else SPDK_BDEV_KB_STAT_HEAD)
    return bdev_stats


def get_bdev_stat(client, name):
    return rpc.bdev.bdev_get_iostat(client, name=name)


def io_stat_display(args, cpu_info, stat):
    if args.cpu_stat and not args.bdev_stat:
        _cpu_info = get_cpu_stat()
        read_cpu_stat(cpu_info, _cpu_info)
        return _cpu_info, None

    if args.bdev_stat and not args.cpu_stat:
        _stat = get_bdev_stat(args.client, args.name)
        bdev_stats = read_bdev_stat(
            stat, _stat, args.mb_display, args.use_uptime)
        return None, bdev_stats

    _cpu_info = get_cpu_stat()
    read_cpu_stat(cpu_info, _cpu_info)

    _stat = get_bdev_stat(args.client, args.name)
    bdev_stats = read_bdev_stat(stat, _stat, args.mb_display, args.use_uptime)
    return _cpu_info, bdev_stats


def io_stat_display_loop(args):
    interval = args.interval
    time_in_second = args.time_in_second
    args.client = rpc.client.JSONRPCClient(
        args.server_addr, args.port, args.timeout, log_level=getattr(logging, args.verbose.upper()))

    last_cpu_stat = None
    bdev_stats = None

    cur = 0
    while True:
        last_cpu_stat, bdev_stats = io_stat_display(
            args, last_cpu_stat, bdev_stats)

        time.sleep(interval)
        cur += interval
        if cur >= time_in_second:
            break


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='SPDK iostats command line interface')

    parser.add_argument('-c', '--cpu-status', dest='cpu_stat',
                        action='store_true', help="Only display cpu status",
                        required=False, default=False)

    parser.add_argument('-d', '--bdev-status', dest='bdev_stat',
                        action='store_true', help="Only display Blockdev io stats",
                        required=False, default=False)

    parser.add_argument('-k', '--kb-display', dest='kb_display',
                        action='store_true', help="Display drive stats in KiB",
                        required=False, default=False)

    parser.add_argument('-m', '--mb-display', dest='mb_display',
                        action='store_true', help="Display drive stats in MiB",
                        required=False, default=False)

    parser.add_argument('-u', '--use-uptime', dest='use_uptime',
                        action='store_true', help='Use uptime or spdk ticks(default) as \
                        the interval variable to calculate iostat changes.',
                        required=False, default=False)

    parser.add_argument('-i', '--interval', dest='interval',
                        type=check_positive, help='Time interval (in seconds) on which \
                        to poll I/O stats. Used in conjunction with -t',
                        required=False, default=0)

    parser.add_argument('-t', '--time', dest='time_in_second',
                        type=check_positive, help='The number of second to display stats \
                        before returning. Used in conjunction with -i',
                        required=False, default=0)

    parser.add_argument('-s', "--server", dest='server_addr',
                        help='RPC domain socket path or IP address',
                        default='/var/tmp/spdk.sock')

    parser.add_argument('-p', "--port", dest='port',
                        help='RPC port number (if server_addr is IP address)',
                        default=4420, type=int)

    parser.add_argument('-b', '--name', dest='name',
                        help="Name of the Blockdev. Example: Nvme0n1", required=False)

    parser.add_argument('-o', '--timeout', dest='timeout',
                        help='Timeout as a floating point number expressed in seconds \
                        waiting for response. Default: 60.0',
                        default=60.0, type=float)

    parser.add_argument('-v', dest='verbose', action='store_const', const="INFO",
                        help='Set verbose mode to INFO', default="ERROR")

    args = parser.parse_args()
    if ((args.interval == 0 and args.time_in_second != 0) or
            (args.interval != 0 and args.time_in_second == 0)):
        raise argparse.ArgumentTypeError(
            "interval and time_in_second should be greater than 0 at the same time")

    if args.kb_display and args.mb_display:
        parser.print_help()
        exit()

    io_stat_display_loop(args)