#!/usr/bin/env python3
# SPDX-License-Identifier: (Apache-2.0 OR MIT)

import collections
import io
import math
import os

import pandas as pd
import seaborn
from matplotlib import pyplot
from tabulate import tabulate

import orjson

LIBRARIES = ("orjson", "ujson", "rapidjson", "simplejson", "json")


def aggregate():
    benchmarks_dir = os.path.join(".benchmarks", os.listdir(".benchmarks")[0])
    res = collections.defaultdict(dict)
    for filename in os.listdir(benchmarks_dir):
        with open(os.path.join(benchmarks_dir, filename), "r") as fileh:
            data = orjson.loads(fileh.read())

        for each in data["benchmarks"]:
            res[each["group"]][each["extra_info"]["lib"]] = {
                "data": [val * 1000 for val in each["stats"]["data"]],
                "median": each["stats"]["median"] * 1000,
                "ops": each["stats"]["ops"],
                "correct": each["extra_info"]["correct"],
            }
    return res


def tab(obj):
    buf = io.StringIO()
    headers = (
        "Library",
        "Median latency (milliseconds)",
        "Operations per second",
        "Relative (latency)",
    )

    seaborn.set(rc={"figure.facecolor": (0, 0, 0, 0)})
    seaborn.set_style("darkgrid")

    barplot_data = []
    for group, val in sorted(obj.items(), reverse=True):
        buf.write("\n" + "#### " + group + "\n\n")
        table = []
        for lib in LIBRARIES:
            correct = val[lib]["correct"]
            table.append(
                [
                    lib,
                    val[lib]["median"] if correct else None,
                    int(val[lib]["ops"]) if correct else None,
                    0,
                ]
            )
            barplot_data.append(
                {
                    "operation": "deserialization"
                    if "deserialization" in group
                    else "serialization",
                    "group": group.strip("serialization")
                    .strip("deserialization")
                    .strip(),
                    "library": lib,
                    "latency": val[lib]["median"],
                    "operations": int(val[lib]["ops"]) if correct else None,
                }
            )

        orjson_baseline = table[0][1]
        for each in table:
            each[3] = (
                "%.1f" % (each[1] / orjson_baseline)
                if isinstance(each[1], float)
                else None
            )
            if group.startswith("github"):
                each[1] = "%.2f" % each[1] if isinstance(each[1], float) else None
            else:
                each[1] = "%.1f" % each[1] if isinstance(each[1], float) else None

        buf.write(tabulate(table, headers, tablefmt="github") + "\n")

    for operation in ("deserialization", "serialization"):
        per_op_data = list(
            (each for each in barplot_data if each["operation"] == operation)
        )
        if not per_op_data:
            continue

        max_y = 0

        json_baseline = {}
        for each in per_op_data:
            if each["group"] == "witter.json":
                each["group"] = "twitter.json"
            if each["library"] == "json":
                json_baseline[each["group"]] = each["operations"]

        for each in per_op_data:
            relative = each["operations"] / json_baseline[each["group"]]
            each["relative"] = relative
            max_y = max(max_y, relative)

        p = pd.DataFrame.from_dict(per_op_data)
        p.groupby("group")

        graph = seaborn.barplot(
            p,
            x="group",
            y="relative",
            orient="x",
            hue="library",
            errorbar="sd",
            legend="brief",
        )
        graph.set_xlabel("Document")
        graph.set_ylabel("Operations/second relative to stdlib json")

        pyplot.title(operation)

        # ensure Y range
        max_y = int(math.ceil(max_y))
        if max_y > 10 and max_y % 2 > 0:
            max_y = max_y + 1
        pyplot.gca().set_yticks(
            list(
                {1, max_y}.union(
                    set(int(y) for y in pyplot.gca().get_yticks() if int(y) <= max_y)
                )
            )
        )

        # print Y as percent
        pyplot.gca().set_yticklabels([f"{x}x" for x in pyplot.gca().get_yticks()])

        # reference for stdlib
        pyplot.axhline(y=1, color="#999", linestyle="dashed")

        pyplot.savefig(fname=f"doc/{operation}", dpi=300)
        pyplot.close()

    print(buf.getvalue())


tab(aggregate())
