/* Basic routines for Pure matrices. 2008-09-25 AG */

/* This is intended to become a grab bag for useful definitions of matrix
   operations (mostly linear algebra) and emulations of operations provided as
   builtins in numeric computation software like Octave. These examples also
   illustrate the use of Pure's matrix operations and, in particular, matrix
   comprehensions. Note, however, that we're striving for clarity rather than
   efficiency in these definitions, so some routines may be quite slow
   compared to hand-tuned C code. */

/* Some convenience functions for generating zero and identity matrices with
   given dimensions. These create double matrices by default, use the prelude
   functions imatrix and cmatrix to convert as needed. */

zeros n::int            = zeros (n,n);
zeros (n::int,m::int)   = dmatrix (n,m);

eye n::int              = eye (n,n);
eye (n::int,m::int)     = {double (i==j) | i=0..n-1; j = 0..m-1};

/* Basic matrix arithmetic. */

/* Mixed matrix-scalar arithmetic. */

a + x::matrix           = map (\x->a+x) x if not matrixp a;
x::matrix + a           = map (\x->x+a) x if not matrixp a;

a - x::matrix           = map (\x->a-x) x if not matrixp a;
x::matrix - a           = map (\x->x-a) x if not matrixp a;

a * x::matrix           = map (\x->a*x) x if not matrixp a;
x::matrix * a           = map (\x->x*a) x if not matrixp a;

a / x::matrix           = map (\x->a/x) x if not matrixp a;
x::matrix / a           = map (\x->x/a) x if not matrixp a;

/* Element-wise arithmetic. */

x::matrix + y::matrix   = zipwith (+) x y if dim x==dim y;
x::matrix - y::matrix   = zipwith (-) x y if dim x==dim y;

/* Dot product. */

dot x::matrix y::matrix = foldl (+) 0 [x!i*y!i | i=0..#x-1]
                            if vectorp x && vectorp y && #x==#y;

/* Matrix multiplication. (The redim is needed to properly handle all empty
   matrix cases. See the Octave manual for an explanation of this issue.) */

x::matrix * y::matrix   = redim (dim x!0,dim y!1)
                          {dot u v | u = rows x; v = cols y}
                            if dim x!1==dim y!0;

/* Convenience functions to print matrices in "short" or "long" format a la
   Octave. These also emulate Octave's way to show empty matrices along with
   their dimensions. FIXME: This isn't quite the same as Octave's display
   algorithm yet, as it uses fixed point format, and only double and int
   matrices are supported right now. Contributions are welcome. ;-) */

using system;

// Uncomment one of these to enable.
__show__ x::matrix = short_format x;
//__show__ x::matrix = long_format x;

short_format x::matrix
= if null x then sprintf "{}(%dx%d)" (dim x)
  else strcat [printd j (x!(i,j))|i=0..n-1; j=0..m-1] + "\n"
with printd 0 = sprintf "\n%10.5f"; printd _ = sprintf "%10.5f" end
when n,m = dim x end if dmatrixp x;
= if null x then sprintf "{}(%dx%d)" (dim x)
  else strcat [printd j (x!(i,j))|i=0..n-1; j=0..m-1] + "\n"
with printd 0 = sprintf "\n%10d"; printd _ = sprintf "%10d" end
when n,m = dim x end if imatrixp x;
= sprintf "{}(%dx%d)" (dim x) if null x;

long_format x::matrix
= if null x then sprintf "{}(%dx%d)" (dim x)
  else strcat [printd j (x!(i,j))|i=0..n-1; j=0..m-1] + "\n"
with printd 0 = sprintf "\n%20.15f"; printd _ = sprintf "%20.15f" end
when n,m = dim x end if dmatrixp x;
= if null x then sprintf "{}(%dx%d)" (dim x)
  else strcat [printd j (x!(i,j))|i=0..n-1; j=0..m-1] + "\n"
with printd 0 = sprintf "\n%20d"; printd _ = sprintf "%20d" end
when n,m = dim x end if imatrixp x;
= sprintf "{}(%dx%d)" (dim x) if null x;