import hashlib
from struct import *

"""
This implementation was reverse engineered using Wireshark (and source code), strace and two excelent articles:
- https://x-c3ll.github.io/posts/CVE-2018-7081-RCE-ArubaOS/
- https://packetstormsecurity.com/files/136997/Aruba-Authentication-Bypass-Insecure-Transport-Tons-Of-Issues.html
"""

def papi_encrypt(data):
    decrypted = ''
    for c in data:
        decrypted += chr(c ^ 0x93)
    return bytes(decrypted, 'latin-1')

def papi_header(dst_host, src_host, dst_port, src_port, sequence_number, message_code, body, calculate_checksum):
                                                            # Description                                               Offset
    header = b'\x49\x72'                                    # Magic Header for PAPI message                             0x00
    header += b'\x00\x03'                                   # Protocol Version ??? I have observed values 1 and 3       0x02
    header += bytes(map(int, dst_host.split('.')))          # Destination host                                          0x04
    header += bytes(map(int, src_host.split('.')))          # Source host                                               0x08
    header += b'\x00\x00'                                   # NAT Port number                                           0x0C
    header += b'\x00\x00'                                   # "garbage"                                                 0x0E
    header += pack('>H', dst_port)                          # Destination port                                          0x10
    header += pack('>H', src_port)                          # Source port                                               0x12
    header += b'\x20\x04'                                   # Packet type ???                                           0x14
    header += b'\x00\x00'                                   # Packet size ??? Seems unused in version 3                 0x16
    header += pack('>H', sequence_number)                   # sequence_number                                           0x18
    header += pack('>H', message_code)                      # PAPI message code - application specific                  0x1A
    checksum = b'\x00'*16                                   # Empty checksum
    padding = b'\x00'*32
    if calculate_checksum:
        m = hashlib.md5()
        m.update(header + checksum + padding + body)
        key = b'asdf;lkj763'
        m.update(key)
        checksum = m.digest()                               # Calculated checksum
    header += checksum                                      # Checksum                                                  0x1C
    header += padding                                       # Header padding                                            0x2C
                                                            # End                                                       0x4C
    return header

def sxdr_write_ip(str_ip):
    msg = b'\x05'
    msg += bytes(map(int, str_ip.split('.')))[::-1] #???
    return msg

def sxdr_write_u8(value):
    msg = b'\x02'
    msg += pack('B', value)
    return msg

def sxdr_write_u16(value):
    msg = b'\x03'
    msg += pack('>H', value)
    return msg

def sxdr_write_u32(value):
    msg = b'\x04'
    msg += pack('>I', value)
    return msg

def sxdr_write_bool(value):
    msg = b'\x07'
    if value:
        msg += sxdr_write_u8(1)
    else:
        msg += sxdr_write_u8(0)
    return msg

def sxdr_write_str(value):
    msg = b'\x00'
    msg += pack('>H', len(value))
    msg += bytes(value, 'latin-1')
    return msg

def sxdr_write_mac(value):
    msg = b'\x01'
    msg += bytes.fromhex(value.replace(':', ''))
    return msg

def sxdr_write_ipv6(str_ip6):
    msg = b'\x0A'
    msg += b'\x00'*16
    return msg

def sxdr_write_ip_af(str_ip):
    msg = b'\x0B'
    msg += b'\x00'*20
    return msg