import ipaddress
import sqlite3
import random
import time


def createASNs(conn, n=10_000):
    cursor = conn.cursor()
    data = []
    start = time.time()
    for i in range(n):
        start_ip, end_ip = list(
            sorted(
                (
                    ipaddress.IPv6Address(random.randint(0, 2**128 - 1)),
                    ipaddress.IPv6Address(random.randint(0, 2**128 - 1)),
                )
            )
        )
        if int(start_ip) > int(end_ip):
            raise Exception(f"invalid ips {start_ip:.2f} {end_ip:.2f}")
        start_ip_text = str(start_ip)
        end_ip_text = str(end_ip)
        start_ip_blob = int(start_ip).to_bytes(16)
        end_ip_blob = int(end_ip).to_bytes(16)
        # store our IP address as two columns of 64-bit integers. Note that
        # sqlite supports _signed integer_ columns, so we need to subtract by
        # 2**63
        start_ip_high = (int(start_ip) >> 64) - 2**63
        start_ip_low = (int(start_ip) & 0xFFFFFFFFFFFFFFFF) - 2**63
        end_ip_high = (int(end_ip) >> 64) - 2**63
        end_ip_low = (int(end_ip) & 0xFFFFFFFFFFFFFFFF) - 2**63

        asn = random.randint(1, 65535)
        country = "US"
        name = f"AS {i}"

        data.append(
            (
                start_ip_text,
                end_ip_text,
                start_ip_blob,
                end_ip_blob,
                start_ip_high,
                start_ip_low,
                end_ip_high,
                end_ip_low,
                asn,
                country,
                name,
            )
        )

    # Perform bulk insert
    cursor.executemany(
        """
        INSERT INTO ipv6_ranges (
            start_ip, end_ip, start_ip_blob, end_ip_blob,
            start_ip_high, start_ip_low, end_ip_high, end_ip_low,
            asn, country, name
        ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
    """,
        data,
    )
    conn.commit()

    end = time.time()
    print(f"creating random ASNs: {end-start:.4f}")


def testText(conn):
    cursor = conn.cursor()

    start = time.time()
    for i in range(1000):
        randip = ipaddress.IPv6Address(random.randint(0, 2**128 - 1))
        cursor.execute(
            "SELECT asn FROM ipv6_ranges WHERE ? BETWEEN start_ip AND end_ip;",
            [str(randip)],
        ).fetchall()

    end = time.time()
    print(f"text between: {end-start:.4f} {1000/(end-start):.2f} selects / sec")


def testText2(conn):
    cursor = conn.cursor()

    start = time.time()
    for i in range(1000):
        randip = ipaddress.IPv6Address(random.randint(0, 2**128 - 1))
        ip = str(randip)
        cursor.execute(
            """
            SELECT asn FROM ipv6_ranges
             WHERE ? >= start_ip
             AND ? <= end_ip;""",
            [ip, ip],
        ).fetchall()

    end = time.time()
    print(f"text lt/gt: {end-start:.4f} {1000/(end-start):.2f} selects / sec")


def testBlob(conn):
    cursor = conn.cursor()

    start = time.time()
    for i in range(1000):
        randip = random.randint(0, 2**128 - 1)
        cursor.execute(
            """SELECT asn FROM ipv6_ranges WHERE ?
               BETWEEN start_ip_blob AND end_ip_blob;""",
            [randip.to_bytes(16)],
        ).fetchall()

    end = time.time()
    print(f"blob between: {end-start:.4f} {1000/(end-start):.2f} selects / sec")


def testBlob2(conn):
    cursor = conn.cursor()

    start = time.time()
    for i in range(1000):
        randip = random.randint(0, 2**128 - 1)
        ipb = randip.to_bytes(16)
        cursor.execute(
            """SELECT asn FROM ipv6_ranges
               WHERE ? >= start_ip_blob
                 AND ? <= end_ip_blob;""",
            [ipb, ipb],
        ).fetchall()

    end = time.time()
    print(f"blob lt/gt: {end-start:.4f} {1000/(end-start):.2f} selects / sec")


def testBlob3(conn):
    cursor = conn.cursor()

    start = time.time()
    for i in range(1000):
        randip = random.randint(0, 2**128 - 1)
        ipb = randip.to_bytes(16)
        cursor.execute(
            """SELECT asn FROM ipv6_ranges WHERE rowid IN 
                (SELECT ROWID FROM ipv6_ranges WHERE ? >= start_ip_blob
                INTERSECT
                SELECT ROWID FROM ipv6_ranges WHERE ? <= end_ip_blob)""",
            [ipb, ipb],
        ).fetchall()

    end = time.time()
    print(f"blob intersect: {end-start:.4f} {1000/(end-start):.2f} selects / sec")


def testInt(conn):
    cursor = conn.cursor()

    start = time.time()
    for i in range(1000):
        randip = random.randint(0, 2**128 - 1)
        cursor.execute(
            """SELECT asn FROM ipv6_ranges WHERE
               ? BETWEEN start_ip_high AND end_ip_high AND
               ? BETWEEN start_ip_low AND end_ip_low;""",
            [(randip >> 64) - 2**63, (randip & 0xFFFFFFFFFFFFFFFF) - 2**63],
        ).fetchall()

    end = time.time()
    print(f"int between: {end-start:.4f} {1000/(end-start):.2f} selects / sec")


def testInt2(conn):
    cursor = conn.cursor()

    start = time.time()
    for i in range(1000):
        randip = random.randint(0, 2**128 - 1)
        high, low = ((randip >> 64) - 2**63, (randip & 0xFFFFFFFFFFFFFFFF) - 2**63)
        cursor.execute(
            """SELECT asn FROM ipv6_ranges WHERE
               ? >= start_ip_high AND ? <= end_ip_high AND
               ? >= start_ip_low AND ? <= end_ip_low;""",
            [high, high, low, low],
        ).fetchall()

    end = time.time()
    print(f"int lt/gt: {end-start:.4f} {1000/(end-start):.2f} selects / sec")


def main():
    # Create a new SQLite database
    conn = sqlite3.connect("ipv6_ranges.db")
    cursor = conn.cursor()
    cursor.executescript(open("schema.sql").read())
    conn.commit()

    createASNs(conn)
    testText(conn)
    testText2(conn)
    testBlob(conn)
    testBlob2(conn)
    testBlob3(conn)
    testInt(conn)
    testInt2(conn)

    conn.close()


if __name__ == "__main__":
    main()