/* Basic matrix functions. */

/* Copyright (c) 2008 by Albert Graef <Dr.Graef@t-online.de>.

   This file is part of the Pure programming language and system.

   Pure is free software: you can redistribute it and/or modify it under the
   terms of the GNU General Public License as published by the Free Software
   Foundation, either version 3 of the License, or (at your option) any later
   version.

   Pure is distributed in the hope that it will be useful, but WITHOUT ANY
   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
   FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
   details.

   You should have received a copy of the GNU General Public License along
   with this program.  If not, see <http://www.gnu.org/licenses/>. */

/* Additional matrix predicates. The numeric matrix types will only be
   available if the interpreter was built with support for the GNU Scientific
   Library (GSL). This is the case iff the global gsl_version variable is
   predefined by the interpreter. */

extern int matrix_type(expr* x);

dmatrixp x      = case x of _::matrix = matrix_type x==1; _ = 0 end;
cmatrixp x      = case x of _::matrix = matrix_type x==2; _ = 0 end;
imatrixp x      = case x of _::matrix = matrix_type x==3; _ = 0 end;

/* The nmatrixp predicate checks for any kind of numeric (double, complex or
   int) matrix, smatrix for symbolic matrices. */

nmatrixp x      = case x of _::matrix = matrix_type x>=1; _ = 0 end;
smatrixp x      = case x of _::matrix = matrix_type x==0; _ = 0 end;

/* Pure represents row and column vectors as matrices with 1 row or column,
   respectively. The following predicates allow you to check for these special
   kinds of matrices. */

vectorp x       = matrixp x && (n==1 || m==1 when n::int,m::int = dim x end);
rowvectorp x    = matrixp x && dim x!0==1;
colvectorp x    = matrixp x && dim x!1==1;

/* Matrix comparisons. */

x::matrix == y::matrix  = x === y
                            if nmatrixp x && matrix_type x == matrix_type y;
// mixed numeric cases
                        = cmatrix x === y if nmatrixp x && cmatrixp y;
                        = x === cmatrix y if cmatrixp x && nmatrixp y;
                        = dmatrix x === y if imatrixp x && dmatrixp y;
                        = x === dmatrix y if dmatrixp x && imatrixp y;
// comparisons with symbolic matrices
                        = 0 if dim x != dim y;
                        = compare 0 with
                            compare i::int = 1 if i>=n;
                                           = 0 if x!i != y!i;
                                           = compare (i+1);
                          end when n::int = #x end;

x::matrix != y::matrix  = not x == y;

/* Size of a matrix (#x) and its dimensions (dim x=n,m where n is the number
   of rows, m the number of columns). Note that Pure supports empty matrices,
   thus the total size may well be zero, meaning that either the number of
   rows or the number of columns is zero, or both. The null predicate checks
   for empty matrices. */

private matrix_size matrix_dim;
extern int matrix_size(expr *x), expr* matrix_dim(expr *x);

#x::matrix              = matrix_size x;
dim x::matrix           = matrix_dim x;
null x::matrix          = #x==0;

/* The stride of a matrix denotes the real row size of the underlying C array,
   see the description of the 'pack' function below for further details.
   There's little use for this value in Pure, but it may be needed when
   interfacing to C. */

private matrix_stride;
extern int matrix_stride(expr *x);

stride x::matrix        = matrix_stride x;

/* Indexing. Note that in difference to Octave and MATLAB, all indices are
   zero-based, and you *must* use Pure's indexing operator '!' to retrieve an
   element. You can either get an element by specifying its row-major index in
   the range 0..#x-1, or with a two-dimensional subscript of the form of a
   pair (i,j), where i denotes the row and j the column index. Both operations
   take only constant time, and an 'out_of_bounds' exception is thrown if an
   index falls outside the valid range. */

private matrix_elem_at matrix_elem_at2;
extern expr* matrix_elem_at(expr* x, int i);
extern expr* matrix_elem_at2(expr* x, int i, int j);

x::matrix!i::int        = matrix_elem_at x i if i>=0 && i<#x;
                        = throw out_of_bounds otherwise;

x::matrix!(i::int,j::int)
                        = matrix_elem_at2 x i j
                            if (i>=0 && i<n && j>=0 && j<m
                                when n::int,m::int = dim x end);
                        = throw out_of_bounds otherwise;

/* Matrix slices (x!!ns). As with simple indexing, elements can be addressed
   using either singleton (row-major) indices or index pair (row,column). The
   former is specified as a list of int values, the latter as a pair of lists
   of int values. As with list slicing, index ranges may be non-contiguous
   and/or non-monotonous. However, the case of contiguous and monotonous
   ranges is optimized by making good use of the 'submat' operation below. We
   also add some convenience rules to handle matrix ranges as well "mixed"
   ranges (ns,ms) where either ns or ms is a singleton. */

x!!ns::matrix           = x!!list ns;
x!!(ns::matrix,ms)      = x!!(list ns,ms);
x!!(ns,ms::matrix)      = x!!(ns,list ms);
x!!(ns::int,ms)         = x!!([ns],ms);
x!!(ns,ms::int)         = x!!(ns,[ms]);

x::matrix!!(ns,ms)      = case ns,ms of
                            ns@(n:_),ms@(m:_) = submat x (n,m) (#ns,#ms)
                              if cont ns && cont ms;
                            _ = colcatmap (mth (rowcatmap (nth x) ns)) ms;
                          end with
                            cont [n] = 1;
                            cont (n::int:ns@(m::int:_)) = cont ns if m==n+1;
                            cont _ = 0 otherwise;
                            nth x n = catch (cst {}) (row x n);
                            mth x m = catch (cst {}) (col x m);
                          end;
x::matrix!!ns           = if all intp ns && packed x
                          then rowvector x!!([0],ns)
                          else colcatmap (nth x) ns with
                            nth x n = catch (cst {}) {x!n};
                          end;

/* Extract rows and columns from a matrix. */

private matrix_slice;
extern expr* matrix_slice(expr* x, int i1, int j1, int i2, int j2);

row x::matrix i::int    = if i>=0 && i<n then matrix_slice x i 0 i (m-1)
                          else throw out_of_bounds
                          when n::int,m::int = dim x end;

col x::matrix j::int    = if j>=0 && j<m then matrix_slice x 0 j (n-1) j
                          else throw out_of_bounds
                          when n::int,m::int = dim x end;

rows x::matrix          = map (row x) (0..n-1) when n::int,_ = dim x end;

cols x::matrix          = map (col x) (0..m-1) when _,m::int = dim x end;

/* Convert a matrix to a list and vice versa. list x converts a matrix x to a
   flat list of its elements, while list2 converts it to a list of lists.
   Conversely, matrix xs converts a list of lists or matrices specifying the
   rows of the matrix to the corresponding matrix; otherwise, the result is a
   row vector. NOTE: The matrix function may throw a 'bad_matrix_value x' in
   case of dimension mismatch, where x denotes the offending submatrix. */

list x::matrix          = [x!i|i=0..#x-1];
list2 x::matrix         = [[x!(i,j)|j=0..m-1]|i=0..n-1]
                            when n::int,m::int = dim x end;

matrix []               = {};
matrix xs@(x:_)         = rowcatmap colcat xs if all listp xs;
                        = rowcat xs if any matrixp xs;
                        = colcat xs otherwise;

/* Extract (sub-,super-) diagonals from a matrix. Sub- and super-diagonals for
   k=0 return the main diagonal. Indices for sub- and super-diagonals can also
   be negative, in which case the corresponding super- or sub-diagonal is
   returned instead. In each case the result is a row vector. */

private matrix_diag matrix_subdiag matrix_supdiag;
extern expr* matrix_diag(expr* x) = diag;
extern expr* matrix_subdiag(expr* x, int k) = subdiag;
extern expr* matrix_supdiag(expr* x, int k) = supdiag;

/* Extract a submatrix of a given size at a given offset. The result shares
   the underlying storage with the input matrix (i.e., matrix elements are
   *not* copied) and so this is a comparatively cheap operation. */

submat x::matrix (i::int,j::int) (n::int,m::int)
                        = matrix_slice x i j (i+n-1) (j+m-1);

/* Construct matrices from lists of rows and columns. These take either
   scalars or submatrices as inputs; corresponding dimensions must match.
   rowcat combines submatrices vertically, like {x;y}; colcat combines them
   horizontally, like {x,y}. NOTE: Like the built-in matrix constructs, these
   operations may throw a 'bad_matrix_value x' in case of dimension mismatch,
   where x denotes the offending submatrix. */

extern expr* matrix_rows(expr *x) = rowcat;
extern expr* matrix_columns(expr *x) = colcat;

/* Combinations of rowcat/colcat and map. These are used, in particular, for
   implementing matrix comprehensions. */

rowcatmap f []          = {};
rowcatmap f xs@(_:_)    = rowcat (map f xs);

colcatmap f []          = {};
colcatmap f xs@(_:_)    = colcat (map f xs);

/* Optimization rules for "void" matrix comprehensions (cf. the catmap
   optimization rules at the beginning of prelude.pure). */

def void (rowcatmap f x) = do f x;
def void (colcatmap f x) = do f x;

/* Convenience functions to create zero matrices with the given dimensions
   (either a pair denoting the number of rows and columns, or just the row
   size in order to create a row vector). */

dmatrix (n::int,m::int) = dmatrix_dup (pointer 0,n,m);
cmatrix (n::int,m::int) = cmatrix_dup (pointer 0,n,m);
imatrix (n::int,m::int) = imatrix_dup (pointer 0,n,m);

dmatrix n::int          = dmatrix (1,n);
cmatrix n::int          = cmatrix (1,n);
imatrix n::int          = imatrix (1,n);

/* Matrix conversions. These convert between different types of numeric
   matrices. You can also extract the real and imaginary parts of a (complex)
   matrix. */

private matrix_double matrix_complex matrix_int;
extern expr* matrix_double(expr *x), expr* matrix_complex(expr *x),
  expr* matrix_int(expr *x);

dmatrix x::matrix       = matrix_double x if imatrixp x || dmatrixp x;
imatrix x::matrix       = matrix_int x if imatrixp x || dmatrixp x;
cmatrix x::matrix       = matrix_complex x if nmatrixp x;

private matrix_re matrix_im;
extern expr* matrix_re(expr *x), expr* matrix_im(expr *x);

re x::matrix            = matrix_re x if nmatrixp x;
im x::matrix            = matrix_im x if nmatrixp x;

/* Pack a matrix. This creates a copy of the matrix which has the data in
   contiguous storage. It also frees up extra memory if the matrix was created
   as a slice from a bigger matrix (see 'submat' above). The 'packed'
   predicate can be used to verify whether a matrix is already packed. Note
   that even if a matrix is already packed, 'pack' will make a copy of it
   anyway, so this routine also provides a quick way to copy a matrix, e.g.,
   if you want to pass it as an input/output parameter to a GSL routine. */

pack x::matrix          = colcat [x,{}];
packed x::matrix        = stride x==dim x!1;

/* Change the dimensions of a matrix without changing its size. The total
   number of elements must match that of the input matrix. Reuses the
   underlying storage of the input matrix if possible (i.e., if the matrix is
   packed). */

private matrix_redim;
extern expr* matrix_redim(expr* x, int n, int m);

redim (n::int,m::int) x::matrix
                        = matrix_redim x n m if n>=0 && m>=0 && n*m==#x;

/* You can also redim a matrix to a given row size. In this case the row size
   must divide the total size of the matrix, */

redim m::int x::matrix  = redim (#x div m,m) x if m>0 && #x mod m==0;
                        = x if m==0 && #x==0;

/* Convert a matrix to a row or column vector. */

rowvector x::matrix     = redim (1,#x) x;
colvector x::matrix     = redim (#x,1) x;

/* Transpose a matrix. */

private matrix_transpose;
extern expr* matrix_transpose(expr *x);

x::matrix'              = matrix_transpose x;

/* Reverse a matrix. 'rowrev' reverses the rows, 'colrev' the columns,
   'reverse' both dimensions. */

rowrev x::matrix        = rowcat (reverse (rows x));
colrev x::matrix        = colcat (reverse (cols x));
reverse x::matrix       = rowrev (colrev x);

/* catmap et al on matrices. This allows list and matrix comprehensions to
   draw values from matrices as well as from lists, treating the matrix as a
   flat list of its elements. */

catmap f x::matrix      = catmap f (list x);
rowcatmap f x::matrix   = rowcat (map f (list x));
colcatmap f x::matrix   = colcat (map f (list x));

/* Implementations of the other customary list operations, so that these can
   be used on matrices, too. These operations treat the matrix essentially as
   if it was a flat list of its elements. Aggregate results are then converted
   back to matrices with the appropriate dimensions. (This depends on the
   particular operation; functions like map and zip keep the dimensions of the
   input matrix intact, while other functions like filter, take or scanl
   always return a flat row vector. Also note that the zip-style operations
   require that the row sizes of all arguments match.) */

cycle x::matrix         = cycle (list x);
cyclen n::int x::matrix = cyclen n (list x) if not null x;

all p x::matrix         = all p (list x);
any p x::matrix         = any p (list x);
do f x::matrix          = do f (list x);
drop k::int x::matrix   = x!!(k..#x-1);
dropwhile p x::matrix   = colcat (dropwhile p (list x));
filter p x::matrix      = colcat (filter p (list x));
foldl f a x::matrix     = foldl f a (list x);
foldl1 f x::matrix      = foldl1 f (list x);
foldr f a x::matrix     = foldr f a (list x);
foldr1 f x::matrix      = foldr1 f (list x);
head x::matrix          = x!0 if not null x;
init x::matrix          = x!!(0..#x-2) if not null x;
last x::matrix          = x!(#x-1) if not null x;
map f x::matrix         = redim (dim x) $ colcat (map f (list x));
scanl f a x::matrix     = colcat (scanl f a (list x));
scanl1 f x::matrix      = colcat (scanl1 f (list x));
scanr f a x::matrix     = colcat (scanr f a (list x));
scanr1 f x::matrix      = colcat (scanr1 f (list x));
take k::int x::matrix   = x!!(0..k-1);
takewhile p x::matrix   = colcat (takewhile p (list x));
tail x::matrix          = x!!(1..#x-1) if not null x;

private zipdim zip3dim;
zipdim x::matrix y::matrix
                        = min (dim x!0) (dim y!0),dim x!1;
zip3dim x::matrix y::matrix z::matrix
                        = min (dim x!0) (min (dim y!0) (dim z!0)),dim x!1;

zip x::matrix y::matrix = redim (zipdim x y) $
                          colcat (zip (list x) (list y))
                            if dim x!1==dim y!1;
zip3 x::matrix y::matrix z::matrix
                        = redim (zip3dim x y z) $
                          colcat (zip3 (list x) (list y) (list z))
                            if dim x!1==dim y!1 && dim x!1==dim z!1;
zipwith f x::matrix y::matrix
                        = redim (zipdim x y) $
                          colcat (zipwith f (list x) (list y))
                            if dim x!1==dim y!1;
zipwith3 f x::matrix y::matrix z::matrix
                        = redim (zip3dim x y z) $
                          colcat (zipwith3 f (list x) (list y) (list z))
                            if dim x!1==dim y!1 && dim x!1==dim z!1;
dowith f x::matrix y::matrix
                        = dowith f (list x) (list y)
                            if dim x!1==dim y!1;
dowith3 f x::matrix y::matrix z::matrix
                        = dowith3 f (list x) (list y) (list z)
                            if dim x!1==dim y!1 && dim x!1==dim z!1;

unzip x::matrix         = redim (dim x) (colcat u),
                          redim (dim x) (colcat v)
                            when u,v = unzip (list x) end;
unzip3 x::matrix        = redim (dim x) (colcat u),
                          redim (dim x) (colcat v),
                          redim (dim x) (colcat w)
                            when u,v,w = unzip3 (list x) end;

/* Low-level operations for converting between matrices and raw pointers.
   These are typically used to shovel around massive amounts of numeric data
   between Pure and external C routines, when performance and throughput is an
   important consideration (e.g., graphics, video and audio applications). The
   usual caveats apply. */

/* Convert a matrix to a pointer, which points directly to the underlying C
   array. You have to be careful when passing such a pointer to C functions if
   the underlying data is non-contiguous; when in doubt, first use the 'pack'
   function from above to place the data in contiguous storage. */

private pure_pointerval;
extern expr* pure_pointerval(expr*);

pointer x::matrix       = pure_pointerval x;

/* Create a numeric matrix from a pointer, without copying the data. The
   caller must ensure that the pointer points to properly initialized memory
   big enough to accommodate the requested dimensions, which persists for the
   entire lifetime of the matrix object. */

private matrix_from_double_array;
private matrix_from_complex_array;
private matrix_from_int_array;
extern expr* matrix_from_double_array(void* p, int n, int m);
extern expr* matrix_from_complex_array(void* p, int n, int m);
extern expr* matrix_from_int_array(void* p, int n, int m);

dmatrix (p::pointer,n::int,m::int)
                        = matrix_from_double_array p n m;
cmatrix (p::pointer,n::int,m::int)
                        = matrix_from_complex_array p n m;
imatrix (p::pointer,n::int,m::int)
                        = matrix_from_int_array p n m;

dmatrix (p::pointer,n::int)
                        = dmatrix (p,1,n);
cmatrix (p::pointer,n::int)
                        = cmatrix (p,1,n);
imatrix (p::pointer,n::int)
                        = imatrix (p,1,n);

/* Create a numeric matrix from a pointer, copying the data to fresh memory.
   The source pointer p may also be NULL, in which case the new matrix is
   filled with zeros instead. */

private matrix_from_double_array_dup;
private matrix_from_complex_array_dup;
private matrix_from_int_array_dup;
extern expr* matrix_from_double_array_dup(void* p, int n, int m);
extern expr* matrix_from_complex_array_dup(void* p, int n, int m);
extern expr* matrix_from_int_array_dup(void* p, int n, int m);

dmatrix_dup (p::pointer,n::int,m::int)
                        = matrix_from_double_array_dup p n m;
cmatrix_dup (p::pointer,n::int,m::int)
                        = matrix_from_complex_array_dup p n m;
imatrix_dup (p::pointer,n::int,m::int)
                        = matrix_from_int_array_dup p n m;

dmatrix_dup (p::pointer,n::int)
                        = dmatrix_dup (p,1,n);
cmatrix_dup (p::pointer,n::int)
                        = cmatrix_dup (p,1,n);
imatrix_dup (p::pointer,n::int)
                        = imatrix_dup (p,1,n);

/* Some additional functions for alternate base types. These work like the
   routines above, but initialize the data from float, complex float, short
   and byte arrays, respectively. */

private matrix_from_float_array_dup;
private matrix_from_complex_float_array_dup;
private matrix_from_short_array_dup;
private matrix_from_byte_array_dup;
extern expr* matrix_from_float_array_dup(void* p, int n, int m);
extern expr* matrix_from_complex_float_array_dup(void* p, int n, int m);
extern expr* matrix_from_short_array_dup(void* p, int n, int m);
extern expr* matrix_from_byte_array_dup(void* p, int n, int m);

float_dmatrix_dup (p::pointer,n::int,m::int)
                        = matrix_from_float_array_dup p n m;
float_cmatrix_dup (p::pointer,n::int,m::int)
                        = matrix_from_complex_float_array_dup p n m;
short_imatrix_dup (p::pointer,n::int,m::int)
                        = matrix_from_short_array_dup p n m;
byte_imatrix_dup (p::pointer,n::int,m::int)
                        = matrix_from_byte_array_dup p n m;

float_dmatrix_dup (p::pointer,n::int)
                        = float_dmatrix_dup (p,1,n);
float_cmatrix_dup (p::pointer,n::int)
                        = float_cmatrix_dup (p,1,n);
short_imatrix_dup (p::pointer,n::int)
                        = short_imatrix_dup (p,1,n);
byte_imatrix_dup (p::pointer,n::int)
                        = byte_imatrix_dup (p,1,n);