import numpy as np
from numpy.testing import assert_allclose
import matrix_transactions as u

def test_matrix_ops():
    m = np.arange(9).astype(np.float32).reshape(3,3)

    ms = u.MatrixState(m)
    with ms.transaction() as t:
        t.swap_cols(0,2)
        t.swap_rows(0,1)
        t.delete_rows(2)
        assert_allclose(ms.matrix, [[5,4,3],[2,1,0]])
        # no commit
    assert_allclose(ms.matrix, m)

    ms = u.MatrixState(m)
    with ms.transaction() as t:
        t.swap_cols(0,2)
        t.swap_rows(0,1)
        t.delete_rows(2)
        assert_allclose(ms.matrix, [[5,4,3],[2,1,0]])
        t.commit()
        # no commit
    assert_allclose(ms.matrix, [[5,4,3],[2,1,0]])

    m = np.arange(10).astype(np.float32).reshape(5,2)
    ms = u.MatrixState(m)
    with ms.transaction() as t:
        t.delete_rows([2,3,0])
        assert_allclose(ms.matrix, [[2,3],[8,9]])
        assert_allclose(ms.indices[0], [1,4])
        assert_allclose(ms.indices[1], [0,1])
        with ms.transaction() as tt:
            tt.delete_rows(0)
            assert_allclose(ms.matrix, [[8,9]])
            assert_allclose(ms.indices[0], [4])
            assert_allclose(ms.indices[1], [0,1])        
    assert_allclose(ms.matrix, m)


    m = np.arange(20).astype(np.float32).reshape(4,5)
    ms = u.MatrixState(m)
    with ms.transaction() as t:
        t.delete_cols([1,2,3])
        t.delete_rows([0,2])
        assert_allclose(ms.matrix, [[5,9],[15,19]])
        assert_allclose(ms.indices[0], [1,3])
        assert_allclose(ms.indices[1], [0,4])    
    assert_allclose(ms.matrix, m)