Created
July 21, 2020 16:11
-
-
Save mclements/8641f8c3376d5ea3f0da3464541bb11f to your computer and use it in GitHub Desktop.
SML / MLton: naive implementation for BLAS-based matrix multiplication
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local | |
val call = _import "cblas_dgemm" public: int * int * int * int * int * int * real * real Vector.vector * int * real Vector.vector * int * real * real Array.array * int -> unit; | |
datatype cblasTranspose = NoTrans | Trans | ConjTrans | ConjNoTrans | |
fun cblasOrder Array2.RowMajor = 101 | |
| cblasOrder Array2.ColMajor = 102 | |
fun cblasTranspose NoTrans = 111 | |
| cblasTranspose Trans = 112 | |
| cblasTranspose ConjTrans = 113 | |
| cblasTranspose ConjNoTrans = 114 | |
fun getVector a = | |
let open Array2 | |
val (m,n) = dimensions a | |
val a'= Array.array(m*n, 0.0) | |
val _ = appi RowMajor (fn (i,j,aij) => Array.update(a',i+m*j,aij)) {base=a, row=0, col=0, nrows=NONE, ncols=NONE} | |
in | |
Array.vector a' | |
end | |
fun makeArray(a,m,n) = | |
Array2.tabulate Array2.RowMajor (m, n, fn (i,j) => Array.sub(a,i+m*j)) | |
in | |
fun matmul2(a, b) = | |
let | |
open Array2 | |
val ((m,k), (k',n)) = (dimensions a, dimensions b) | |
val () = if k <> k' then raise General.Size else () | |
val arrayc = Array.array(m*n,0.0) | |
val _ = call(cblasOrder ColMajor, cblasTranspose NoTrans, cblasTranspose NoTrans, m, n, k, 1.0, getVector a, m, getVector b, k, 0.0, arrayc, m) | |
in | |
makeArray (arrayc, m, n) | |
end | |
end; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment