from datetime import datetime
from time import sleep
import argparse
import psycopg2
import random


params = {
    "host": "localhost",
    "port": 8011,
    "database": "dev",
    "user": "dev",
    "password": "dev",
}


def make_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--number", "-n", type=int, default=50)
    return parser



def update_counts(previous):
    out = {}
    for field in ("infected", "dead", "treated"):
        previous_count = previous[field]
        change = random.randint(-10, 10)
        applied = previous_count + change
        out[field] = max(applied, 0)

    return {**previous, **out}


def update_outbreak_status(cursor, state, timestamp):
    current = update_counts(state)
    cursor.execute("""
        insert into outbreak
        (disease, region, infected, dead, treated, modified_at)
        values
        (%(disease)s, %(region)s, %(infected)s, %(dead)s, %(treated)s, %(modified_at)s)
        on conflict (disease, region) do update
        set
            infected = %(infected)s,
            dead = %(dead)s,
            treated = %(treated)s,
            modified_at = %(modified_at)s
    """, {
        "modified_at": timestamp,
        **current
    })
    return current


def update_outbreak_statuses(cursor, states):
    modified_at = datetime.utcnow()

    batch_size = random.choice([1, 2, 3])
    indices = set(random.sample([0, 1, 2], k=batch_size))
    updated = []
    for idx, state in enumerate(states):
        if idx not in indices:
            updated.append(state)
            continue
        updated.append(
            update_outbreak_status(cursor, state, modified_at)
        )

    cursor.execute("""
        insert into table_audit
        (tablename, status, modified_at)
        values
        ('outbreak', 'Success', %(modified_at)s)
    """, {"modified_at": modified_at})
    return updated


def main():
    parser = make_parser()
    args = parser.parse_args()

    initial_states = [
        {
            "disease": "Cholera",
            "region": "Congo Basin",
            "infected": 2,
            "dead": 0,
            "treated": 0,
        }, {
            "disease": "SARS",
            "region": "Western China",
            "infected": 20,
            "dead": 4,
            "treated": 2,
        }, {
            "disease": "Avian Flu",
            "region": "Southern China",
            "infected": 4,
            "dead": 0,
            "treated": 1,
        }
    ]

    random.seed(13)
    with psycopg2.connect(**params) as conn:
        cursor = conn.cursor()
        cursor.execute("truncate table outbreak")
        cursor.execute("truncate table table_audit")
        conn.commit()

        states = initial_states
        for _ in range(args.number):
            states = update_outbreak_statuses(cursor, states)
            conn.commit()
            delay_ms = random.randrange(500, 1_200)
            sleep(delay_ms / 1_000)

    return


if __name__ == "__main__":
    main()