import json
import math
from datetime import datetime, timedelta

import requests


SLOTS_PER_EPOCH = 32
SECONDS_PER_SLOT = 12


def main(validators_indices, eth2_api_url="http://localhost:5052/eth/v1/"):
    def api_get(endpoint):
        return requests.get(f"{eth2_api_url}{endpoint}").json()

    def api_post(endpoint, data):
        return requests.post(f"{eth2_api_url}{endpoint}", data).json()

    head_slot = int(api_get("beacon/headers/head")["data"]["header"]["message"]["slot"])
    epoch = head_slot // SLOTS_PER_EPOCH

    cur_epoch_data = api_post(f"validator/duties/attester/{epoch}", json.dumps(validators_indices))[
        "data"
    ]
    next_epoch_data = api_post(
        f"validator/duties/attester/{epoch + 1}", json.dumps(validators_indices)
    )["data"]

    genesis_timestamp = 1606824023

    attestation_duties = {}
    for d in (*cur_epoch_data, *next_epoch_data):
        attestation_duties.setdefault(int(d["slot"]), []).append(d["validator_index"])
    attestation_duties = {k: v for k, v in sorted(attestation_duties.items()) if k > head_slot}

    all_proposer_duties = api_get(f"validator/duties/proposer/{epoch}")["data"]

    validators_indices_set = set(validators_indices)
    duties = attestation_duties.copy()
    for s in all_proposer_duties:
        slot = int(s["slot"])
        if slot <= head_slot:
            continue

        prop_index = int(s["validator_index"])
        if prop_index in validators_indices_set:
            duties.setdefault(slot, []).append(f"{prop_index} (proposal)")

    duties = dict(sorted(duties.items()))

    # Also insert (still unknown) attestation duties at epoch after next,
    # assuming worst case of having to attest at its first slot
    first_slot_epoch_p2 = (epoch + 2) * SLOTS_PER_EPOCH
    attestation_duties[first_slot_epoch_p2] = []

    print(f"Calculating attestation/proposal slots and gaps for validators:")
    print(f"  {validators_indices}")

    print("\nUpcoming voting/proposal slots and gaps")
    print("(Gap in seconds)")
    print("(slot/epoch - time range - validators)")
    print("*" * 80)

    prev_end_time = datetime.now()
    # Floor to seconds
    prev_end_time = datetime(*datetime.utctimetuple(prev_end_time)[:6])

    # Current epoch gaps
    cur_epoch_gap_store = {"longest_gap": timedelta(seconds=0), "gap_time_range": (None, None)}
    overall_gap_store = cur_epoch_gap_store.copy()

    next_epoch_start_slot = (epoch + 1) * SLOTS_PER_EPOCH
    next_epoch_start_time = datetime.fromtimestamp(genesis_timestamp + next_epoch_start_slot * 12.0)

    in_next_epoch = False

    def _update_gap(end, start, gap_store):
        gap = end - start
        assert gap.total_seconds() >= 0
        if gap > gap_store["longest_gap"]:
            gap_store["longest_gap"] = gap
            gap_store["gap_time_range"] = (end, start)

    for slot, validators in duties.items():
        slot_start = datetime.fromtimestamp(genesis_timestamp + slot * SECONDS_PER_SLOT)
        slot_end = slot_start + timedelta(seconds=SECONDS_PER_SLOT)

        suf = ""
        if not in_next_epoch and slot >= next_epoch_start_slot:
            print("- " * 40)
            print(
                f"Time until epoch change: {math.floor((next_epoch_start_time - prev_end_time).total_seconds())} seconds"
            )
            print(
                f"Epoch boundary (proposal duties are not yet known for next epoch): {next_epoch_start_time}"
            )
            print(
                f"Time until next duty: {math.floor((slot_start - next_epoch_start_time).total_seconds())} seconds"
            )
            print("- " * 40)
            suf = "(after prev. slot duty or current time)"

            _update_gap(next_epoch_start_time, prev_end_time, cur_epoch_gap_store)
            in_next_epoch = True

        print(f"Gap - {math.floor((slot_start - prev_end_time).total_seconds())} seconds {suf}")

        if validators:
            print(
                f"  {slot}/{slot // SLOTS_PER_EPOCH}"
                f" - {slot_start.strftime('%H:%M:%S')} until {slot_end.strftime('%H:%M:%S')}"
                f" - [{', '.join(validators)}]"
            )
        else:
            assert slot % SLOTS_PER_EPOCH == 0

        _update_gap(slot_start, prev_end_time, overall_gap_store)
        if in_next_epoch is False:
            _update_gap(slot_start, prev_end_time, cur_epoch_gap_store)

        prev_end_time = slot_end

    print("\nLongest attestation and proposer duty gap (only current epoch, first):")
    longest_gap, gap_time_range = cur_epoch_gap_store.values()
    print("*" * 80)
    print(
        f"{longest_gap.total_seconds()} seconds"
        f" ({int(longest_gap.total_seconds()) // SECONDS_PER_SLOT} slots),"
        f" from {gap_time_range[1].strftime('%H:%M:%S')}"
        f" until {gap_time_range[0].strftime('%H:%M:%S')}"
    )

    print("\nLongest attestation gap (first):")
    longest_gap, gap_time_range = overall_gap_store.values()
    print("*" * 80)
    print(
        f"{longest_gap.total_seconds()} seconds"
        f" ({int(longest_gap.total_seconds()) // SECONDS_PER_SLOT} slots),"
        f" from {gap_time_range[1].strftime('%H:%M:%S')}"
        f" until {gap_time_range[0].strftime('%H:%M:%S')}"
    )


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        description="Show validator duties of current and next epoch to find largest gap."
    )
    parser.add_argument("indices", metavar="index", type=int, nargs="+", help="validator indices")

    args = parser.parse_args()

    main(args.indices)