import hashlib
from struct import *

"""
This implementation was reverse engineered using Wireshark (and source code), strace and two excelent articles:
- https://x-c3ll.github.io/posts/CVE-2018-7081-RCE-ArubaOS/
- https://packetstormsecurity.com/files/136997/Aruba-Authentication-Bypass-Insecure-Transport-Tons-Of-Issues.html
"""

def papi_encrypt(data):
    decrypted = ''
    for c in data:
        decrypted += chr(c ^ 0x93)
    return bytes(decrypted, 'latin-1')

def papi_header(dst_host, src_host, dst_port, src_port, sequence_number, message_code, body, calculate_checksum, natport=0):
                                                            # Description                                               Offset
    header = b'\x49\x72'                                    # Magic Header for PAPI message                             0x00
    header += b'\x00\x03'                                   # Protocol Version ??? I have observed values 1 and 3       0x02
    header += bytes(map(int, dst_host.split('.')))          # Destination host                                          0x04
    header += bytes(map(int, src_host.split('.')))          # Source host                                               0x08
    header += pack('>H', natport)                           # NAT Port number                                           0x0C
    header += b'\x00\x00'                                   # "garbage"                                                 0x0E
    header += pack('>H', dst_port)                          # Destination port                                          0x10
    header += pack('>H', src_port)                          # Source port                                               0x12
    header += b'\x20\x04'                                   # Packet type ???                                           0x14
    header += b'\x00\x00'                                   # Packet size ??? Seems unused in version 3                 0x16
    header += pack('>H', sequence_number)                   # sequence_number                                           0x18
    header += pack('>H', message_code)                      # PAPI message code - application specific                  0x1A
    checksum = b'\x00'*16                                   # Empty checksum
    padding = b'\x00'*32
    if calculate_checksum:
        m = hashlib.md5()
        m.update(header + checksum + padding + body)
        key = b'asdf;lkj763'
        key = b'eG3eAUwZ5UEen1xu'                           # 'enhanced security'
        m.update(key)
        checksum = m.digest()                               # Calculated checksum
    header += checksum                                      # Checksum                                                  0x1C
    header += padding                                       # Header padding                                            0x2C
                                                            # End                                                       0x4C
    return header

def amapi_execute_command_object_body(parameters=[], object_type=0, op_type=0, app_id='', object_name='', major_ver=8, minor_ver=5):
    body = sxdr_write_str('executeCommandObject')
    body += sxdr_write_u8(0x00)
    body += sxdr_write_u8(0x00)
    body += sxdr_write_u16(object_type)                 # commandObj->objectType
    body += sxdr_write_u16(op_type)                     # commandObj->opType
    body += sxdr_write_str(app_id)                      # commandObj->appId
    body += sxdr_write_str(object_name)                 # commandObj->objectName
    body += sxdr_write_u8(major_ver)                    # commandObj->majorVer
    body += sxdr_write_u8(minor_ver)                    # commandObj->minorVer
    body += sxdr_write_str('')                          # commandObj->instance
    body += sxdr_write_bool(False)                      # commandObj->bIsLocalCommand
    body += sxdr_write_u16(0x00)                        # commandObj->commandStrLen
    body += sxdr_write_str('')                          # commandObj->commandStr
    body += sxdr_write_u16(0x00)                        # commandObj->commandLogStrLen
    body += sxdr_write_str('')                          # commandObj->commandLogStr
    body += sxdr_write_u16(0x00)                        # commandObj->rawDatalength
    body += sxdr_write_u16(0x00)                        # commandObj->secretLen
    body += sxdr_write_str('')                          # commandObj->secret ???
    body += sxdr_write_u8(0x00)                         # commandObj->encryptFlag
    body += sxdr_write_u32(len(parameters))             # numEntries
    for keyvalue in parameters:
        for key, value in keyvalue.items():
            body += sxdr_write_str(key)                 # key
            body += sxdr_write_str(value)               # value
            body += sxdr_write_bool(False)              # keyValue->wasMode
            body += sxdr_write_bool(False)              # keyValue->is_inst_key
            body += sxdr_write_bool(False)              # keyValue->is_reorder_key
    body += sxdr_write_u32(0x00)                        # numEntries
    body += sxdr_write_u8(0x00)                         # configEncrypt
    body += sxdr_write_u8(0x00)                         # commandObj->cliRequest
    body += sxdr_write_u8(0x00)                         # commandObj->webRequest
    body += sxdr_write_u8(0x00)
    body += sxdr_write_str('admin')                     # commandObj->userName
    body += sxdr_write_str('root')                      # commandObj->userRole
    body += sxdr_write_str('/sc/mynode')                # node
    body += sxdr_write_bool(False)                      # commandObj->preserve_no
    body += sxdr_write_str('')                          # commandObj->scappcmd
    body += sxdr_write_str('')                          # commandObj->currModeName
    body += sxdr_write_u16(0x00)                        # commandObj->flags
    body += sxdr_write_str('')                          # commandObj->permlen
    return (14001, body)

def sxdr_write_ip(str_ip):
    msg = b'\x05'
    msg += bytes(map(int, str_ip.split('.')))[::-1] #???
    return msg

def sxdr_write_u8(value):
    msg = b'\x02'
    msg += pack('B', value)
    return msg

def sxdr_write_u16(value):
    msg = b'\x03'
    msg += pack('>H', value)
    return msg

def sxdr_write_u32(value):
    msg = b'\x04'
    msg += pack('>I', value)
    return msg

def sxdr_write_bool(value):
    msg = b'\x07'
    if value:
        msg += sxdr_write_u8(1)
    else:
        msg += sxdr_write_u8(0)
    return msg

def sxdr_write_str(value):
    msg = b'\x00'
    msg += pack('>H', len(value))
    msg += bytes(value, 'latin-1')
    return msg

def sxdr_write_mac(value):
    msg = b'\x01'
    msg += bytes.fromhex(value.replace(':', ''))
    return msg

def sxdr_write_ipv6(str_ip6):
    msg = b'\x0A'
    msg += b'\x00'*16
    return msg

def sxdr_write_ip_af(str_ip):
    msg = b'\x0B'
    msg += b'\x00'*20
    return msg

def sxdr_write_struct(value):
    msg = b'\x0C'
    msg += value
    return msg