/* Pure primitives. These operations are implemented either directly by
   machine instructions or by C functions provided in the runtime. */

/* Copyright (c) 2008 by Albert Graef <Dr.Graef@t-online.de>.

   This file is part of the Pure programming language and system.

   Pure is free software: you can redistribute it and/or modify it under the
   terms of the GNU General Public License as published by the Free Software
   Foundation, either version 3 of the License, or (at your option) any later
   version.

   Pure is distributed in the hope that it will be useful, but WITHOUT ANY
   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
   FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
   details.

   You should have received a copy of the GNU General Public License along
   with this program.  If not, see <http://www.gnu.org/licenses/>. */

/* Throw an exception. */

extern void pure_throw(expr*) = throw; // IMPURE!

/* Force a thunk (x&). This usually happens automagically when the value of a
   thunk is needed. */

extern expr* pure_force(expr*) = force;

/* Syntactic equality. */

extern bool same(expr* x, expr* y);
x === y         = same x y;
x !== y         = not same x y;

/* Predicates to check for the built-in types. */

intp x          = case x of _::int     = 1; _ = 0 end;
bigintp x       = case x of _::bigint  = 1; _ = 0 end;
doublep x       = case x of _::double  = 1; _ = 0 end;
stringp x       = case x of _::string  = 1; _ = 0 end;
pointerp x      = case x of _::pointer = 1; _ = 0 end;
matrixp x       = case x of _::matrix = 1; _ = 0 end;

/* Predicates to check for function objects, global (unbound) variables,
   function applications, proper lists, list nodes and tuples. */

extern bool funp(expr*), bool lambdap(expr*), bool thunkp(expr*);
extern bool varp(expr*);

applp (_ _)     = 1;
applp _         = 0 otherwise;

listp []        = 1;
listp (x:xs)    = listp xs;
listp _         = 0 otherwise;

listnp []       = 1;
listnp (x:xs)   = 1;
listnp _        = 0 otherwise;

tuplep ()       = 1;
tuplep (x,xs)   = 1;
tuplep _        = 0 otherwise;

/* Compute a 32 bit hash code of a Pure expression. */

extern int hash(expr*);

/* Conversions between the different numeric and pointer types. */

private pure_intval pure_dblval pure_bigintval pure_pointerval;
extern expr* pure_intval(expr*), expr* pure_dblval(expr*),
  expr* pure_bigintval(expr*), expr* pure_pointerval(expr*);

int x::int              = x;
int x::bigint           |
int x::double           |
int x::pointer          = pure_intval x;

bigint x::bigint        = x;
bigint x::int           |
bigint x::double        |
bigint x::pointer       = pure_bigintval x;

double x::double        = x;
double x::int           |
double x::bigint        = pure_dblval x;

pointer x::pointer      = x;
pointer x::int          |
pointer x::bigint       |
pointer x::double       |
pointer x::string       = pure_pointerval x;

/* Convert signed (8/16/32/64) bit integers to the corresponding unsigned
   quantities. These functions behave as if the value was "cast" to the
   corresponding unsigned C type, and are most useful for dealing with
   unsigned integers returned by external C routines. The routines always use
   the smallest Pure int type capable of holding the result: int for ubyte and
   ushort, bigint for uint and ulong. (Note that in the case of 64 bit values
   the C interface returns a bigint, that's why ulong takes a bigint
   parameter. The other routines all take an int as input.)  */

ubyte x::int            = if x>=0 then x else x+0x100;
ushort x::int           = if x>=0 then x else x+0x10000;
uint x::int             = if x>=0 then bigint x else x+0x100000000L;
ulong x::bigint         = if x>=0 then x else x+0x10000000000000000L;

/* Rounding functions. */

extern double floor(double), double ceil(double);
extern double __round(double) = round, double __trunc(double) = trunc;

floor x::int | floor x::bigint = x;
ceil x::int | ceil x::bigint = x;
round x::int | round x::bigint = x;
trunc x::int | trunc x::bigint = x;

// Fractional part of x.
frac x::int | frac x::bigint | frac x::double = x-trunc x;

/* Absolute value and sign of a number. Note that these don't distinguish
   between IEEE 754 positive and negative zeros; abs always returns 0.0, sgn 0
   for these. The real sign bit of a floating point zero can be obtained with
   sgn (1/x). */

abs x::int | abs x::bigint | abs x::double
                        = if x>0 then x else -x;
sgn x::int | sgn x::bigint | sgn x::double
                        = if x>0 then 1 else if x<0 then -1 else 0;

/* Generic min and max functions. */

min x y                 = if x<=y then x else y;
max x y                 = if x>=y then x else y;

/* Generic succ and pred functions. */

succ x                  = x+1;
pred x                  = x-1;

/* Basic int and double arithmetic. The Pure compiler already knows how to
   handle these, we just need to supply rules with the right type tags. */

-x::int                 = -x;
~x::int                 = ~x;
not x::int              = not x;

x::int<<y::int          = x<<y;
x::int>>y::int          = x>>y;

x::int+y::int           = x+y;
x::int-y::int           = x-y;
x::int*y::int           = x*y;
x::int/y::int           = x/y;
x::int div y::int       = x div y;
x::int mod y::int       = x mod y;
x::int or y::int        = x or y;
x::int and y::int       = x and y;

x::int<y::int           = x<y;
x::int>y::int           = x>y;
x::int<=y::int          = x<=y;
x::int>=y::int          = x>=y;
x::int==y::int          = x==y;
x::int!=y::int          = x!=y;

-x::double              = -x;

x::double+y::double     = x+y;
x::double-y::double     = x-y;
x::double*y::double     = x*y;
x::double/y::double     = x/y;

x::double<y::double     = x<y;
x::double>y::double     = x>y;
x::double<=y::double    = x<=y;
x::double>=y::double    = x>=y;
x::double==y::double    = x==y;
x::double!=y::double    = x!=y;

// mixed operands

x::int+y::double        = x+y;
x::int-y::double        = x-y;
x::int*y::double        = x*y;
x::int/y::double        = x/y;

x::int<y::double        = x<y;
x::int>y::double        = x>y;
x::int<=y::double       = x<=y;
x::int>=y::double       = x>=y;
x::int==y::double       = x==y;
x::int!=y::double       = x!=y;

x::double+y::int        = x+y;
x::double-y::int        = x-y;
x::double*y::int        = x*y;
x::double/y::int        = x/y;

x::double<y::int        = x<y;
x::double>y::int        = x>y;
x::double<=y::int       = x<=y;
x::double>=y::int       = x>=y;
x::double==y::int       = x==y;
x::double!=y::int       = x!=y;

/* Logical connectives and sequences. Please note that these enjoy
   call-by-name and short-circuit evaluation only as explicit calls! But we
   still want them to work if they are applied partially, so we add these
   rules here. */

x::int&&y::int          = x&&y;
x::int||y::int          = x||y;
x$$y                    = y;

/* Bigint arithmetic. */

private bigint_neg bigint_not bigint_add bigint_sub bigint_mul bigint_div
  bigint_mod bigint_shl bigint_shr bigint_and bigint_or bigint_cmp;
extern expr* bigint_neg(void*);
extern expr* bigint_not(void*);
extern expr* bigint_add(void*, void*);
extern expr* bigint_sub(void*, void*);
extern expr* bigint_mul(void*, void*);
extern expr* bigint_div(void*, void*);
extern expr* bigint_mod(void*, void*);
extern expr* bigint_shl(void*, int);
extern expr* bigint_shr(void*, int);
extern expr* bigint_and(void*, void*);
extern expr* bigint_or(void*, void*);
extern int bigint_cmp(void*, void*);

-x::bigint              = bigint_neg x;
~x::bigint              = bigint_not x;
not x::bigint           = not int x;

x::bigint<<y::int       = bigint_shl x y if y>=0;
                        = bigint_shr x (-y);
x::bigint>>y::int       = bigint_shr x y if y>=0;
                        = bigint_shl x (-y);

x::bigint+y::bigint     = bigint_add x y;
x::bigint-y::bigint     = bigint_sub x y;
x::bigint*y::bigint     = bigint_mul x y;
x::bigint/y::bigint     = double x / double y;
x::bigint div y::bigint = bigint_div x y;
x::bigint mod y::bigint = bigint_mod x y;
x::bigint or y::bigint  = bigint_or x y;
x::bigint and y::bigint = bigint_and x y;

x::bigint<y::bigint     = bigint_cmp x y < 0;
x::bigint>y::bigint     = bigint_cmp x y > 0;
x::bigint<=y::bigint    = bigint_cmp x y <= 0;
x::bigint>=y::bigint    = bigint_cmp x y >= 0;
x::bigint==y::bigint    = bigint_cmp x y == 0;
x::bigint!=y::bigint    = bigint_cmp x y != 0;

// mixed int/bigint (promote int to bigint)

x::int+y::bigint        = bigint x+y;
x::int-y::bigint        = bigint x-y;
x::int*y::bigint        = bigint x*y;
x::int/y::bigint        = double x/y;
x::int div y::bigint    = bigint x div y;
x::int mod y::bigint    = bigint x mod y;
x::int or y::bigint     = bigint x or y;
x::int and y::bigint    = bigint x and y;

x::int<y::bigint        = bigint x<y;
x::int>y::bigint        = bigint x>y;
x::int<=y::bigint       = bigint x<=y;
x::int>=y::bigint       = bigint x>=y;
x::int==y::bigint       = bigint x==y;
x::int!=y::bigint       = bigint x!=y;

x::bigint+y::int        = x+bigint y;
x::bigint-y::int        = x-bigint y;
x::bigint*y::int        = x*bigint y;
x::bigint/y::int        = x/double y;
x::bigint div y::int    = x div bigint y;
x::bigint mod y::int    = x mod bigint y;
x::bigint or y::int     = x or bigint y;
x::bigint and y::int    = x and bigint y;

x::bigint<y::int        = x<bigint y;
x::bigint>y::int        = x>bigint y;
x::bigint<=y::int       = x<=bigint y;
x::bigint>=y::int       = x>=bigint y;
x::bigint==y::int       = x==bigint y;
x::bigint!=y::int       = x!=bigint y;

// mixed double/bigint (promote bigint to double)

x::bigint+y::double     = double x+y;
x::bigint-y::double     = double x-y;
x::bigint*y::double     = double x*y;
x::bigint/y::double     = double x/y;

x::bigint<y::double     = double x<y;
x::bigint>y::double     = double x>y;
x::bigint<=y::double    = double x<=y;
x::bigint>=y::double    = double x>=y;
x::bigint==y::double    = double x==y;
x::bigint!=y::double    = double x!=y;

x::double+y::bigint     = x+double y;
x::double-y::bigint     = x-double y;
x::double*y::bigint     = x*double y;
x::double/y::bigint     = x/double y;

x::double<y::bigint     = x<double y;
x::double>y::bigint     = x>double y;
x::double<=y::bigint    = x<=double y;
x::double>=y::bigint    = x>=double y;
x::double==y::bigint    = x==double y;
x::double!=y::bigint    = x!=double y;

/* The gcd and lcm functions from the GMP library. These return a bigint if at
   least one of the arguments is a bigint, a machine int otherwise. */

private bigint_gcd bigint_lcm;
extern expr* bigint_gcd(void*, void*);
extern expr* bigint_lcm(void*, void*);

gcd x::bigint y::bigint = bigint_gcd x y;
lcm x::bigint y::bigint = bigint_lcm x y;

gcd x::int y::bigint    = bigint_gcd (bigint x) y;
gcd x::bigint y::int    = bigint_gcd x (bigint y);
gcd x::int y::int       = int (bigint_gcd (bigint x) (bigint y));

lcm x::int y::bigint    = bigint_lcm (bigint x) y;
lcm x::bigint y::int    = bigint_lcm x (bigint y);
lcm x::int y::int       = int (bigint_lcm (bigint x) (bigint y));

/* The pow function. Computes exact powers of ints and bigints. The result is
   always a bigint. Note that y must always be nonnegative here, but see
   math.pure which deals with the case y<0 using rational numbers. */

private bigint_pow;
extern expr* bigint_pow(void*, int);

pow x::int y::int       = bigint_pow (bigint x) y if y>=0;
pow x::bigint y::bigint = bigint_pow x y if y>=0;

// mixed int/bigint
pow x::int y::bigint    = bigint_pow (bigint x) y if y>=0;
pow x::bigint y::int    = bigint_pow x y if y>=0;

/* The ^ operator. Computes inexact powers for any combination of int, bigint
   and double operands. The result is always a double. */

private c_pow;
extern double pow(double, double) = c_pow;

x::double^y::double     = c_pow x y;
x::int^y::int           |
x::bigint^y::bigint     |
x::int^y::bigint        |
x::bigint^y::int        = c_pow (double x) (double y);
x::double^y::int        |
x::double^y::bigint     = c_pow x (double y);
x::int^y::double        |
x::bigint^y::double     = c_pow (double x) y;

/* Pointer arithmetic. We do this using bigints, so that the code is portable
   to 64 bit systems. */

const NULL = pointer 0; // the null pointer

null x::pointer         = bigint x==0;

x::pointer-y::pointer   = bigint x-bigint y;
x::pointer+y::int       = pointer (bigint x+y);
x::pointer+y::bigint    = pointer (bigint x+y);

x::pointer<y::pointer   = bigint x <  bigint y;
x::pointer>y::pointer   = bigint x >  bigint y;
x::pointer<=y::pointer  = bigint x <= bigint y;
x::pointer>=y::pointer  = bigint x >= bigint y;
x::pointer==y::pointer  = bigint x == bigint y;
x::pointer!=y::pointer  = bigint x != bigint y;

/* IEEE floating point infinities and NaNs. Place these after the definitions
   of the built-in operators so that the double arithmetic works. */

const inf = 1.0e307 * 1.0e307; const nan = inf-inf;

/* Predicates to check for inf and nan values. */

infp x          = case x of x::double = x==inf || x==-inf; _ = 0 end;
nanp x          = case x of x::double = not (x==x); _ = 0 end;

/* Direct memory accesses. Use with care ... or else! */

private pointer_get_byte pointer_get_short pointer_get_int
  pointer_get_float pointer_get_double
  pointer_get_string pointer_get_pointer;
extern int pointer_get_byte(void *ptr);
extern int pointer_get_short(void *ptr);
extern int pointer_get_int(void *ptr);
extern double pointer_get_float(void *ptr);
extern double pointer_get_double(void *ptr);
extern char *pointer_get_string(void *ptr);
extern void *pointer_get_pointer(void *ptr);

get_byte x::pointer = pointer_get_byte x;
get_short x::pointer = pointer_get_short x;
get_int x::pointer = pointer_get_int x;
get_float x::pointer = pointer_get_float x;
get_double x::pointer = pointer_get_double x;
get_string x::pointer = pointer_get_string x;
get_pointer x::pointer = pointer_get_pointer x;

private pointer_put_byte pointer_put_short pointer_put_int
  pointer_put_float pointer_put_double
  pointer_put_string pointer_put_pointer;
extern void pointer_put_byte(void *ptr, int x); // IMPURE!
extern void pointer_put_short(void *ptr, int x); // IMPURE!
extern void pointer_put_int(void *ptr, int x); // IMPURE!
extern void pointer_put_float(void *ptr, double x); // IMPURE!
extern void pointer_put_double(void *ptr, double x); // IMPURE!
extern void pointer_put_string(void *ptr, char *x); // IMPURE!
extern void pointer_put_pointer(void *ptr, void *x); // IMPURE!

put_byte x::pointer y::int = pointer_put_byte x y;
put_short x::pointer y::int = pointer_put_short x y;
put_int x::pointer y::int = pointer_put_int x y;
put_float x::pointer y::double = pointer_put_float x y;
put_double x::pointer y::double = pointer_put_double x y;
put_string x::pointer y::string = pointer_put_string x y;
put_pointer x::pointer y::string = pointer_put_pointer x y;
put_pointer x::pointer y::pointer = pointer_put_pointer x y;

/* Sentries. These are expression "guards" which are applied to the target
   expression when it is garbage-collected. The sentry function places a
   sentry at an expression (and returns the modified expression), clear_sentry
   removes, get_sentry returns it. NOTE: In the current implementation
   sentries can only be placed at applications and pointer objects. The sentry
   itself can be any type of object (but usually it's a function). */

extern expr* pure_sentry(expr*,expr*) = sentry; // IMPURE!
extern expr* pure_clear_sentry(expr*) = clear_sentry; // IMPURE!
extern expr* pure_get_sentry(expr*) = get_sentry;

/* Expression references. If you need these, then you're doomed. ;-) However,
   they can be useful as a last resort when you need to keep track of some
   local state or interface to the messy imperative world. Pure's references
   are implemented as Pure expression pointers so that you can readily pass
   them as pointers to a C function which expects a pure_expr** parameter.
   This may even be useful at times.

   'ref x' creates a reference pointing to x initially, 'put r x' sets a new
   value (and returns it), 'get r' retrieves the current value, and 'unref r'
   purges the referenced object and turns the reference into a dangling
   pointer. (The latter is used as a sentry on reference objects and shouldn't
   normally be called directly.) The refp predicate can be used to check for
   reference values. Note that manually removing the unref sentry turns the
   reference into just a normal pointer object and renders it unusable as a
   reference. Doing this will also leak memory, so don't! */

private pure_new pure_free pure_expr_pointer;
private pointer_get_expr pointer_put_expr;
extern expr* pure_new(expr*), expr* pure_expr_pointer();
extern void pure_free(expr*);
extern expr* pointer_get_expr(void*), void pointer_put_expr(void*, expr*);

ref x = pointer_put_expr r (pure_new x) $$
        sentry unref r when r::pointer = pure_expr_pointer end;

unref r::pointer = pure_free (pointer_get_expr r) $$
                   clear_sentry r if refp r;

put r::pointer x = pure_free (pointer_get_expr r) $$
                   pointer_put_expr r (pure_new x) $$ x if refp r;

get r::pointer = pointer_get_expr r if refp r;

refp r = case r of _::pointer = get_sentry r===unref; _ = 0 end;