"""
Plots race TypeRacer history.
Make sure you have permission to read the file

usage:
    $ python3 race_plot.py race_data.csv [day, week, month or year]
"""
import sys
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib import cm
from datetime import datetime, timezone, timedelta
plt.style.use('ggplot')
ticks_format = {
    'day': '%b, %-d',
    'week': '%b, %-d',
    'month': '%b',
    'year': '%Y'
}


def accuracy_to_color(acc_df):
    return ((acc_df-0.9).clip(0, 0.1) * 1000).astype(int)


def group_date(gr):
    def str_to_datetime(s):
        date_format = "%Y-%m-%d %H:%M:%S"
        return datetime.strptime(s, date_format)

    def utc_to_local(utc_dt):
        return utc_dt.replace(tzinfo=timezone.utc).astimezone(tz=None)

    group_method = {
        'day': datetime.date,
        'week': lambda dt: dt.date() + timedelta(days=6-(dt.weekday() % 7)),
        'month': lambda dt: datetime(dt.year, dt.month, 1),
        'year': lambda dt: datetime(dt.year, 1, 1)
    }

    return lambda s: group_method[gr](utc_to_local(str_to_datetime(s)))


def last_race_of_day(df, date):
    return df.loc[df['date'] == date]['Race #'].max()


def step_plot(ax, df, col, color, label):
    def average_of_day(df, col):
        ret = {}
        for date in set(df['date']):
            data = df.loc[df['date'] == date]
            ret[date] = data[col].mean()
        return ret

    avg = average_of_day(df, col)
    avg = sorted(avg.items())
    v = [(0, avg[0][1])]
    if len(avg) > 1:
        for (d1, v1), (d2, v2) in zip(avg, avg[1:]):
            v.append(((last_race_of_day(df, d1)+0.5, v1)))
            v.append(((last_race_of_day(df, d1)+0.5, v2)))
            v.append(((last_race_of_day(df, d2)+0.5, v2)))
    else:
        v.append(((last_race_of_day(df, avg[0][0])+0.5, avg[0][1])))
    return ax.plot(*zip(*v), linewidth=2, c=color, label=label)


if __name__ == "__main__":
    data = pd.read_csv(sys.argv[1], sep=',')
    gr = sys.argv[2] if (
        len(sys.argv) > 2 and sys.argv[2] in ticks_format) else 'day'
    data['date'] = data['Date/Time (UTC)'].map(group_date(gr))

    fig, ax = plt.subplots()
    ticks = []
    for date in set(data['date']):
        vert_line = last_race_of_day(data, date) + 0.5
        ax.axvline(x=vert_line, linestyle='dashed', color='black', linewidth=1)
        ticks.append((vert_line, date.strftime(ticks_format[gr])))
    plt.xticks(*zip(*ticks), rotation='60')
    for tick in ax.xaxis.get_majorticklabels():
        tick.set_horizontalalignment("right")
    cax = ax.scatter(data['Race #'], data['WPM'],
                     c=accuracy_to_color(data['Accuracy']),
                     cmap=cm.Oranges, label=None)
    ax2 = ax.twinx()
    ax2.grid(False)
    line_acc = step_plot(ax2, data, 'Accuracy', 'blue', 'Accuracy')
    line_wpm = step_plot(ax, data, 'WPM', 'red', 'WPM')
    ax.set_ylabel('WPM')
    ax2.set_ylabel('Accuracy')

    lns = line_wpm + line_acc
    labs = [l.get_label() for l in lns]
    ax.legend(lns, labs, loc=2)

    plt.title('Total %d games' % len(data))
    plt.show()