X-Git-Url: https://scm.cri.ensmp.fr/git/Faustine.git/blobdiff_plain/043c676f59520b93dfacfa0d8d7e1fdd448cd7dc..1878a8448a5a73cbf289306beb5e88ab48561129:/dsp_files/tests/matrix_mul_test.dsp diff --git a/dsp_files/tests/matrix_mul_test.dsp b/dsp_files/tests/matrix_mul_test.dsp new file mode 100644 index 0000000..637954f --- /dev/null +++ b/dsp_files/tests/matrix_mul_test.dsp @@ -0,0 +1,35 @@ +vectorize = +; +serialize = _ , 1 : +; +concat = * ; +nth = / ; + +concat1 = case { + (1) => _, 1 : vectorize ; + (m) => concat1(m - 1), ( _, 1 : vectorize) : concat; +}; + +concat2 = case { + (1, m) => concat1(m), 1 : vectorize; + (n, m) => concat2(n - 1, m), (concat1(m), 1 : vectorize) : concat; +}; + +make_input_matrix(n, m) = _, m : vectorize : _, n : vectorize ; + +make_output_matrix(n, m) = concat2(n, m); + +accumulate_vector(k) = _ <: sum(i, k, ( _, i : nth)); + +get_column(k, j) = _ <: par(p, k, ( _, p : nth : _, j : nth)) : concat1(k); + +get_line(i) = _, i : nth; + +make_line(i, k, m) = par(j, m, ( get_line(i), get_column(k, j) : * : accumulate_vector(k))); + +multiply(n, k, m) = par(i, n, make_line(i, k, m)); + +matrix_mul(n, k, m) = make_input_matrix(n, k), make_input_matrix(k, m) <: multiply(n, k, m) : make_output_matrix(n, m); + +matrix_output(n, m) = _ <: par(i, n, ( _, i : nth)); + +process = matrix_mul(10, 10, 10) : matrix_output(10, 10); + \ No newline at end of file