#!/usr/bin/python3

import argparse
import os
import csv
import fnmatch
import re
import datetime
from collections import defaultdict
import jinja2
import matplotlib.pyplot as plt

class Benchmarks:
    def __init__(self, args, dirname):
        self.data = []

        if args.filter:
            re_match = re.compile(fnmatch.translate(args.filter))
        else:
            re_match = None

        for f in os.listdir(dirname):
            if not f.endswith(".csv"): continue
            if re_match and not re_match.match(f): continue
            dt, db, shasum = os.path.splitext(f)[0].split("_")
            dt = datetime.datetime.strptime(dt, "%Y%m%d%H%M%S")
            with open(os.path.join(dirname, f), "rt") as fd:
                for row in csv.reader(fd):
                    if not row[0].startswith(db): continue
                    bench_name = row[0].split(".", 1)[1]
                    bench_values = row[1:]
                    if len(bench_values) == 2:
                        secs, count = bench_values
                        val = float(count) / float(secs)
                    self.data.append((dt, db, bench_name, val))

        # Jinja2 template engine
        from jinja2 import Environment, FileSystemLoader
        self.jinja2 = Environment(
            loader=FileSystemLoader([
                "bench/",
            ]),
            autoescape=True,
        )

    def plot(self):
        destdir = "bench/out"
        os.makedirs(destdir, exist_ok=True)

        self.plot_test_db_over_time(destdir)

        # TODO:
        #  - for each test, plot all DBs compared over time (logscale?)
        #  - for each db, plot all tests over time

    def plot_test_db_over_time(self, destdir):
        from matplotlib.dates import DayLocator, HourLocator, DateFormatter, drange

        # Aggregate by (db, test)
        graphs = defaultdict(list)
        for dt, db, name, vals in self.data:
            graphs[(db, name)].append((dt, vals))

        # plot over time
        for (db, name), data in graphs.items():
            fname = "{}_{}.png".format(db, name)
            plt.clf()
            plt.cla()
            fig, ax = plt.subplots()
            fig.autofmt_xdate()
            ax.xaxis_date()
            ax.xaxis.set_major_formatter(DateFormatter('%Y-%m-%d'))
            ax.fmt_xdata = DateFormatter('%Y-%m-%d')
            plt.title("{} {}".format(db, name))
            plt.ylabel("calls per second")
            plt.xlabel("sampled date")
            plt.plot_date([d[0] for d in data], [d[1] for d in data])
            plt.ylim(ymin=0)
            plt.savefig(os.path.join(destdir, fname), bbox_inches='tight')
            plt.close("all")

def main():
    parser = argparse.ArgumentParser(description="Plot DB-All.e benchmark results.")
    parser.add_argument("-f", "--filter", default=None, help="Graph only the data files matching the given glob")
    args = parser.parse_args()

    results = Benchmarks(args, "bench")
    results.plot()


if __name__ == "__main__":
    main()

