"""
Loopback tests for pastream.
"""
import os, sys
import numpy as np
import soundfile as sf
import pytest
import numpy.testing as npt
import time
import tempfile

import pastream

BLOCKSIZE = 2048
TEST_LENGTHS = [5]


vhex = np.vectorize('{:#10x}'.format)
tohex = lambda x: vhex(x.view('u4'))


def qreader_compare(stream, rxq):
    dtype = stream.dtype[0]
    
    inpf2 = sf.SoundFile(stream.inp_fh.name.name, mode='rb')    
    found_delay = False
    while True:
        try:
            item = rxq.get(timeout=1)
        except queue.Empty:
            raise queue.Empty("Timed out waiting for data.")
        if item is None: break

        outframes = np.frombuffer(item, dtype=dtype).reshape(-1, stream.channels[1])

        if not found_delay:
            if outframes.any(): 
                found_delay = True
                nonzeros = np.where(outframes[:, 0])[0]
                outframes = outframes[nonzeros[0]:]

        if found_delay:
            inframes = inpf2.read(len(outframes), dtype=dtype, always_2d=True)
            # if not allow_truncation:
            #     assert len(inframes) == len(outframes), "Some samples were dropped"

            mlen = min(len(inframes), len(outframes))
            inp = inframes[:mlen].view('u4')
            out = outframes[:mlen].view('u4')

            try:
                npt.assert_array_equal(inp, out, "Loopback data mismatch")
            except AssertionError:
                stream.set_exception()
                break

def sf_find_delay(xf, mask=0xFFFF0000, chan=None):
    pos = xf.tell()

    off = -1
    inpblocks = xf.blocks(BLOCKSIZE, dtype='int32')
    for i,inpblk in enumerate(inpblocks):
        nonzeros = np.where(inpblk[:, chan]&mask)
        if nonzeros[0].any(): 
            off = i*BLOCKSIZE + nonzeros[0][0]
            break

    xf.seek(pos)

    return off

def sf_assert_equal(tx_fh, rx_fh, mask=0xFFFF0000, chi=None, chf=None, allow_truncation=False, allow_delay=False):
    d = sf_find_delay(rx_fh, mask=mask, chan=chi)
    
    assert d != -1, "Test Preamble pattern not found"

    if not allow_delay:
        assert d == 0, "Signal delayed by %d frames" % d
    rx_fh.seek(d)
    
    inpblocks = tx_fh.blocks(BLOCKSIZE, dtype='int32')
    for inpblk in inpblocks:
        outblk = rx_fh.read(BLOCKSIZE, dtype='int32')

        if not allow_truncation:
            assert len(inpblk) == len(outblk), "Some samples were dropped"

        mlen = min(len(inpblk), len(outblk))
        inp = (inpblk[:mlen,chi:chf].view('u4'))&mask
        out = outblk[:mlen,chi:chf].view('u4')&mask

        npt.assert_array_equal(inp, out, "Loopback data mismatch")


class PortAudioLoopbackTester(object):
    def _gen_random(self, rdm_fh, nrepeats, nbytes):
        shift = 8*(4-nbytes)
        minval = -(0x80000000>>shift)
        maxval = 0x7FFFFFFF>>shift

        preamble = np.zeros((rdm_fh.samplerate//10, rdm_fh.channels), dtype=np.int32)
        preamble[:] = maxval << shift
        rdm_fh.write(preamble)

        for i in range(nrepeats):
            pattern = np.random.randint(minval, maxval+1, (rdm_fh.samplerate, rdm_fh.channels), dtype=np.int32) << shift
            rdm_fh.write(pattern.astype(np.int32))

    @pytest.fixture(scope='session', params=TEST_LENGTHS)
    def randomwav84832(self, request, tmpdir_factory):
        tmpdir = tmpdir_factory.getbasetemp()
        rdmf = tempfile.NamedTemporaryFile('w+b', dir=str(tmpdir))

        rdm_fh = sf.SoundFile(rdmf, 'w+', 48000, 8, 'PCM_32', format='wav')
        self._gen_random(rdm_fh, request.param, 4)
        rdm_fh.seek(0)

        yield rdm_fh

        rdm_fh.close()

    @pytest.fixture(scope='session', params=TEST_LENGTHS)
    def randomwav84824(self, request, tmpdir_factory):
        tmpdir = tmpdir_factory.getbasetemp()
        rdmf = tempfile.NamedTemporaryFile('w+b', dir=str(tmpdir))

        rdm_fh = sf.SoundFile(rdmf, 'w+', 48000, 8, 'PCM_24', format='wav')
        self._gen_random(rdm_fh, request.param, 3)
        rdm_fh.seek(0)

        yield rdm_fh

        rdm_fh.close()

    @pytest.fixture(scope='session', params=TEST_LENGTHS)
    def randomwav84816(self, request, tmpdir_factory):
        tmpdir = tmpdir_factory.getbasetemp()
        rdmf = tempfile.NamedTemporaryFile('w+b', dir=str(tmpdir))

        rdm_fh = sf.SoundFile(rdmf, 'w+', 48000, 8, 'PCM_16', format='wav')
        self._gen_random(rdm_fh, request.param, 2)
        rdm_fh.seek(0)

        yield rdm_fh

        rdm_fh.close()

    @pytest.fixture(scope='session', params=TEST_LENGTHS)
    def randomwav84432(self, request, tmpdir_factory):
        tmpdir = tmpdir_factory.getbasetemp()
        rdmf = tempfile.NamedTemporaryFile('w+b', dir=str(tmpdir))

        rdm_fh = sf.SoundFile(rdmf, 'w+', 44100, 8, 'PCM_32', format='wav')
        self._gen_random(rdm_fh, request.param, 4)
        rdm_fh.seek(0)

        yield rdm_fh

        rdm_fh.close()

    @pytest.fixture(scope='session', params=TEST_LENGTHS)
    def randomwav84416(self, request, tmpdir_factory):
        tmpdir = tmpdir_factory.getbasetemp()
        rdmf = tempfile.NamedTemporaryFile('w+b', dir=str(tmpdir))

        rdm_fh = sf.SoundFile(rdmf, 'w+', 44100, 8, 'PCM_16', format='wav')
        self._gen_random(rdm_fh, request.param, 2)
        rdm_fh.seek(0)

        yield rdm_fh

        rdm_fh.close()

    def assert_stream_equal(self, inp_fh, qreader=qreader_compare, **kwargs):
        devargs = dict(self.devargs)
        devargs.update(kwargs)

        with pastream.SoundFileStream(inp_fh, qreader=qreader_compare, **devargs) as stream:
            while stream.active: time.sleep(0.1)

class TestALSALoopback(PortAudioLoopbackTester):
    devargs = dict(fileblocksize=512, device='aduplex')
    def test_wav24(self, tmpdir, randomwav84824):
        tx_fh = randomwav84824
        tx_fh.seek(0)
        self.assert_stream_equal(tx_fh, dtype='int32')