/* Pure's dict and hdict data types based on AVL trees. */

/* Copyright (c) 2008 by Albert Graef <Dr.Graef@t-online.de>.
   Copyright (c) 2008 by Jiri Spitz <jiri.spitz@bluetone.cz>.

   This file is part of the Pure programming language and system.

   Pure is free software: you can redistribute it and/or modify it under the
   terms of the GNU General Public License as published by the Free Software
   Foundation, either version 3 of the License, or (at your option) any later
   version.

   Pure is distributed in the hope that it will be useful, but WITHOUT ANY
   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
   FOR a PARTICULAR PURPOSE.  See the GNU General Public License for more
   details.

   You should have received a copy of the GNU General Public License along
   with this program.  If not, see <http://www.gnu.org/licenses/>. */


/* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
   The used algorithm of AVL trees has its origin in the SWI-Prolog
   implementation of association lists. The original file was created by
   R. A. O'Keefe and updated for the SWI-Prolog by Jan Wielemaker. For the
   original file see http://www.swi-prolog.org.

   The port from SWI-Prolog and the deletion stuff (rmfirst, rmlast, delete)
   missing in the original file was provided by Jiri Spitz
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */


/* Tree constructors. */
private nullary nil;
private bin;

/*****
Tree for dict and hdict is either:
-  nil  (empty tree) or
-  bin Key value Balance Left Right  (Left, Right: trees)

Balance: ( 1), ( 0), or (-1) denoting |L|-|R| = 1, 0, or -1, respectively
*****/


/* Public operations: ******************************************************

emptydict, emptyhdict:          return the empty dict or bag
dict xs, hdict xs;              create a dict or hdict from list xs
dictp d, hdictp d;              check whether x is a dict or hdict
mkdict y xs, mkhdict y xs:      create dict or hdict from a list of keys and
                                a constant value

#d                              size of dict or hdict d
d!x:                            get value from d by key x

null d                          tests whether d is the empty dict or hdict
member d x                      tests whether d contains member with key x
members d, list d               list members of d (in ascending order fo dict)
keys d:                         lists keys of d (in ascending order fo dict)
values d:                       list values of d

first d, last d                 return first and last member of dict
rmfirst d, rmlast d             remove first and last member from dict
insert d xy                     insert x into d (replace existing element)
update d x y                    fully curried version of insert
delete d x                      remove x from d

   *************************************************************************/

private adjustd avl_geq;

// Dict and hdict type checks
dictp (Dict _)                  = 1;
dictp _                         = 0;

hdictp (Hdict _)                = 1;
hdictp _                        = 0;

// create an empty dict or hdict
emptydict                       = Dict  nil;
emptyhdict                      = Hdict nil;

// create dict or hdict from a list
dict  xys                       = foldl insert emptydict  xys if listp xys;
hdict xys                       = foldl insert emptyhdict xys if listp xys;

// insert a new member into the dict or hdict
insert (t@Dict  d) (x         => y) |
insert (t@Hdict d) (x         => y)
        = if t === Dict
                then t ((insertd d x y)!0)
                else t ((inserth d (hash x) x y)!0)
with
  insertd nil key         val
                                = [(bin key val ( 0) nil nil), 1];

  insertd (bin k _         b l r) key         val
                                = [(bin key val b l r), 0] if key == k;

  insertd (bin k         v b l r) key         val
        = adjust leftHasChanged (bin k v b newl r) (-1)
        when
                [newl, leftHasChanged] = insertd l key val
        end
        if key < k;

  insertd (bin k         v b l r) key         val
        = adjust rightHasChanged (bin k v b l newr) ( 1)
        when
                [newr, rightHasChanged] = insertd r key val
        end
        if key > k;

  inserth nil k::int x y        = [(bin k [x => y] ( 0) nil nil), 1];

  inserth (bin k::int v b l r) key::int x y
                = [(bin k (inserth2 v x y) b l r), 0] if k == key;

  inserth (bin k::int v b l r) key::int x y
        = adjust leftHasChanged (bin k v b newl r) (-1)
        when
                [newl, leftHasChanged] = inserth l key x y
        end
        if key < k;

  inserth (bin k::int v b l r) key::int x y
        = adjust rightHasChanged (bin k v b l newr) ( 1)
        when
                [newr, rightHasChanged] = inserth r key x y
        end
        if key > k;

  inserth2 [] x  y              = [x => y];
  inserth2 ((x1 => y):xys) x2 y1
                                = ((x1 => y1):xys) if x1 === x2;
  inserth2 ((x  => y):xys) x1 y1
                                = ((x  => y ):(inserth2 xys x1 y1));

  adjust 0 oldTree _            = [oldTree, 0];

  adjust 1 (bin key         val b0 l r) LoR
        = [rebal toBeRebalanced (bin key val b0 l r) b1, whatHasChanged]
        when
                [b1, whatHasChanged, toBeRebalanced] = table b0 LoR
        end;

  rebal 0 (bin k         v _ l r) b
                                = bin k v b l r;

  rebal 1  oldTree _            = (avl_geq oldTree)!0;

/*
// Balance rules for insertions
//      balance where           balance   whole tree    to be
//      before  inserted        after     increased     rebalanced
table   ( 0)    (-1)            = [( 1),        1,              0];
table   ( 0)    ( 1)            = [(-1),        1,              0];
table   ( 1)    (-1)            = [( 0),        0,              1];
table   ( 1)    ( 1)            = [( 0),        0,              0];
table   (-1)    (-1)            = [( 0),        0,              0];
table   (-1)    ( 1)            = [( 0),        0,              1]
*/

// table w/o pattern matching
  table bb::int wi::int = [ba, wti, tbr]
        when
                ba      = if bb == 0 then -wi else 0;
                wti     = bb == 0;
                tbr     = (bb + wi) == 0;
        end
end;

// delete a member by key from the dict or hdict
delete (t@Dict  d) x         |
delete (t@Hdict d) x
        = if t === Dict
                then t ((deleted d x)!0)
                else t ((deleteh d (hash x) x)!0)
with
  deleted nil _                 = [nil, 0];

  deleted (bin k         _ _ nil r  ) key
                                = [r, 1] if key == k;

  deleted (bin k         _ _ l   nil) key
                                = [l, 1] if key == k;

  deleted (bin k _         b (bin kl         vl bl rl ll) r  ) key
        = adjustd leftHasChanged (bin lastk lastv b newl r) (-1)
        when
                [lastk, lastv]  = last (bin kl vl bl rl ll);
                [newl, leftHasChanged]
                                = rmlast (bin kl vl bl rl ll)
        end
        if key == k;

  deleted (bin k v b l r) key
        = adjustd leftHasChanged (bin k v b newl r) (-1)
        when
                [newl, leftHasChanged] = deleted l key
        end
        if key < k;

  deleted (bin k         v b l r) key
        = adjustd rightHasChanged (bin k v b l newr) ( 1)
        when
                [newr, rightHasChanged] = deleted r key
        end
        if key > k;

  deleteh nil _ _               = [nil, 0];

  deleteh (bin k::int xys b nil r  ) key::int x
        = (if (newxys == [])
        then [r, 1]
        else [bin k newxys b nil r, 0])
                when
                        newxys = deleteh2 xys x
                end
        if k == key;

  deleteh (bin k::int xys b l   nil) key::int x
        = (if (newxys == [])
        then [l, 1]
        else [bin k newxys b l nil, 0])
                when
                        newxys = deleteh2 xys x
                end
        if k == key;

  deleteh (bin k::int xys b (bin kl vl bl rl ll) r) key::int x
        = adjustd leftHasChanged (bin lastk lastv b newl r) (-1)
        when
                [lastk, lastv]  = last (bin kl vl bl rl ll);
                [newl, leftHasChanged]  = rmlast (bin kl vl bl rl ll)
        end
        if (k == key) && ((deleteh2 xys x) == []);

  deleteh (bin k::int xys b l r) key::int x
        = [bin key (deleteh2 xys x) b l r, 0]
        if k == key;

  deleteh (bin k::int v b l r) key::int x
        = adjustd leftHasChanged (bin k v b newl r) (-1)
        when
                [newl, leftHasChanged] = deleteh l key x
        end
        if key < k;

  deleteh (bin k::int v b l r) key::int x
        = adjustd rightHasChanged (bin k v b l newr) ( 1)
        when
                [newr, rightHasChanged] = deleteh r key x
        end
        if key > k;

  deleteh2 [] _                 = [];
  deleteh2 ((x1 => _) : xys) x2 = xys if x1 === x2;
  deleteh2 ((x  => y) : xys) x1 = (x => y) : (deleteh2 xys x1);

  rmlast nil                    = [nil, 0];
  rmlast (bin _ _ _      l nil) = [l,   1];
  rmlast (bin k v b::int l r  )
        = adjustd rightHasChanged (bin k v b l newr) ( 1)
        when [newr, rightHasChanged] = rmlast r end;

  last (bin x y _ _ nil)        = [x, y];
  last (bin _ _ _ _ d2 )        = last d2
end;


// create dict or hdict from a list of keys and a constant value
mkdict  y xs                    = dict  (zipwith (=>) xs (repeatn (#xs) y)) if listp xs;
mkhdict y xs                    = hdict (zipwith (=>) xs (repeatn (#xs) y)) if listp xs;

// check for the empty dict or hdict
null (Dict nil)                 = 1;
null (Dict   _)                 = 0;

null (Hdict nil)                = 1;
null (Hdict   _)                = 0;

// get a number of members in dict or hdict
#(Dict  d)                      = #d
with
  #nil                          = 0;
  #(bin _ _ _ d1 d2)            = #d1 + #d2 + 1
end;

#(Hdict d)                      = size d
with
  size nil                      = 0;
  size (bin _ xys _ d1 d2)      = size d1 + size d2 + #xys
end;

// check whether a key in dict or hdict
member (Dict d) k               = member d k
with
  member nil _                  = 0;

  member (bin x _ _ d1 d2) y
                                = member d1 y if x > y;
                                = member d2 y if x < y;
                                = 1 if x == y
end;

member (Hdict d) k              = member d (hash k) k
with
  member nil _ _                = 0;
  member (bin k::int xys _ d1 d2) k1::int x1
                                = member d1 k1 x1 if k > k1;
                                = member d2 k1 x1 if k < k1;
                                = memberk xys x1;

  memberk [] _                  = 0;
  memberk ((x1 => y):_  ) x2    = 1 if x1 === x2;
  memberk (        _:xys) x2    = memberk xys x2
end;;

// get all members of dict or hdict
members (Dict  d)               = members d
with
  members nil                   = [];

  members (bin x         y _ d1 d2)
                                = members d1 + ((x => y) : (members d2))
end;

members (Hdict d)               = members d
with
  members nil                   = [];
  members (bin _ xys _ d1 d2)   = members d1 + xys + members d2
end;

list d@(Dict  _) |
list d@(Hdict _)                = members d;

// get the first member of a dict
first (Dict d)                  = first d
with
  first (bin x y _ nil _)       = (x => y);
  first (bin _ _ _ d1  _)       = first d1
end;

// get the last member of a dict
last (Dict d)                   = last d
with
  last (bin x y _ _ nil)        = (x => y);
  last (bin _ _ _ _ d2 )        = last d2
end;

// remove the first member from a dict
rmfirst (Dict d)                = Dict ((rmfirst d)!0)
with
  rmfirst nil                   = [nil, 0];
  rmfirst (bin _ _ _ nil r)     = [r, 1];
  rmfirst (bin k v b l   r)
        = adjustd leftHasChanged (bin k v b newl r) (-1)
        when
                [newl, leftHasChanged] = rmfirst l
        end
end;

// remove the last member from a dict
rmlast (Dict d)                 = Dict ((rmlast d)!0)
with
  rmlast nil                    = [nil 0];
  rmlast (bin _ _ _ l nil)      = [l, 1];
  rmlast (bin k v b l   r)
        = adjustd rightHasChanged (bin k v b l newr) ( 1)
        when
                [newr, rightHasChanged] = rmlast r
        end
end;

// get a list of all keys from dict or hdict
keys (Dict  d)  = keys  d
with
  keys nil                      = [];

  keys (bin x         _ _ d1 d2)
                                = keys d1 + (x : (keys d2))
end;

keys (Hdict d)  = keys d
with
  keys nil                      = [];
  keys (bin _ xys _ d1 d2)      = keys d1 + map (\(key => _) -> key) xys + keys d2
end;

// get a list of all values from dict or hdict
vals (Dict  d)  = vals d
with
  vals nil                      = [];
  vals (bin _ y _ d1 d2)        = vals d1 + (y : (vals d2))
end;

vals (Hdict d)  = vals d
with
  vals nil                      = [];
  vals (bin _ xys _ d1 d2)      = vals d1 +
                                  map (\(_ => val) -> val) xys +
                                  vals d2
end;

// get a value by key from dict or hdict
(Dict d)!k                      = d!k
with
  nil!_                         = throw out_of_bounds;

  (bin x         y _ d1 d2)!x1
                                = d1!x1 if x1 < x;
                                = d2!x1 if x1 > x;
                                = y
end;

(Hdict d)!k                     = lookup d (hash k) k
with
  lookup nil _ _                = throw out_of_bounds;

  lookup (bin k::int xys _ d1 d2) k1::int x1
                                = lookup d1 k1 x1 if k > k1;
                                = lookup d2 k1 x1 if k < k1;
                                = lookupk xys x1;

  lookupk []           _        = throw out_of_bounds;
  lookupk ((xa => y):_  ) xb    = y if xa === xb;
  lookupk (       _ :xys) x     = lookupk xys x
end;

// curried version of insert for dict and hdict
update d@(Dict  _) x         y |
update d@(Hdict _) x         y
                                = insert d (x => y);

// equality checks for dict and hdict
d1@(Dict  _) == d2@(Dict  _)    = (members d1) == (members d2);

d1@(Hdict _) == d2@(Hdict _)
        = if (all (member d1) (keys d2))
                then
                        if (all (member d2) (keys d1))
                        then (vals d1) == (map ((!)d2) (keys d1))
                        else 0
                else 0;


// inequality checks for dict and hdict
d1@(Dict  _) != d2@(Dict  _)    = (members d1) != (members d2);
d1@(Hdict _) != d2@(Hdict _)    = not (d1 == d2);

/* Private functions. */

adjustd ToF::int tree LoR::int
                                = adjust ToF tree LoR
with
  adjust 0 oldTree _            = [oldTree, 0];

  adjust 1 (bin key         val b0 l r) LoR
        = rebal toBeRebalanced (bin key val b0 l r) b1 whatHasChanged
        when
                [b1, whatHasChanged, toBeRebalanced] = table b0 LoR
        end;

  rebal 0 (bin k         v _ l r) b whatHasChanged
                                = [bin k v b l r, whatHasChanged];

  rebal 1  oldTree _ _          = avl_geq oldTree;

/*
// Balance rules for deletions
//      balance where           balance   whole tree    to be
//      before  deleted         after     decreased     rebalanced
table   ( 0)    ( 1)            = [( 1),  0,            0];
table   ( 0)    (-1)            = [(-1),  0,            0];
table   ( 1)    ( 1)            = [( 0),  1,            1];
//                                        ^^^^
// It depends on the tree pattern in avl_geq whether it really decreases

table   ( 1)    (-1)            = [( 0),  1,            0];
table   (-1)    ( 1)            = [( 0),  1,            0];
table   (-1)    (-1)            = [( 0),  1,            1];
//                                        ^^^^
// It depends on the tree pattern in avl_geq whether it really decreases
*/

// table w/o pattern matching
  table bb wd   = [ba, wtd, tbr]
        when
                ba      = if bb == 0 then wd else 0;
                wtd     = abs bb;
                tbr     = bb == wd;
        end
end;

// Single and double tree rotations - these are common for insert and delete
/*
  The patterns (-1)-(-1), (-1)-( 1), ( 1)-( 1) and ( 1)-(-1) on the LHS always
  change the tree height and these are the only patterns which can happen
  after an insertion. That's the reason why we can use tablei only to decide
  the needed changes.
  The patterns (-1)-( 0) and ( 1)-( 0) do not change the tree height. After a
  deletion any pattern can occur and so we return 1 or 0 as a flag of
  a height change.
*/
avl_geq d = avl_geq d
with
  avl_geq (bin a         va (-1) alpha (bin b         vb (-1) beta gamma))
        = [bin b vb ( 0) (bin a va ( 0) alpha beta) gamma, 1];

  avl_geq (bin a         va (-1) alpha (bin b         vb ( 0) beta gamma))
        = [bin b vb ( 1) (bin a va (-1) alpha beta) gamma, 0];
                // the tree doesn't decrease with this pattern

  avl_geq (bin a         va (-1) alpha
                        (bin b vb ( 1) (bin x vx b1 beta gamma) delta))
        = [bin x vx ( 0) (bin a va b2 alpha beta) (bin b vb b3 gamma delta), 1]
                when
                        [b2, b3] = table b1
                end;

  avl_geq (bin b vb ( 1) (bin a va ( 1) alpha beta) gamma)
        = [bin a va ( 0) alpha (bin b vb ( 0) beta  gamma), 1];

  avl_geq (bin b vb ( 1) (bin a va ( 0) alpha beta) gamma)
        = [bin a va (-1) alpha (bin b vb ( 1) beta  gamma), 0];
                        // the tree doesn't decrease with this pattern

  avl_geq (bin b vb ( 1)
                        (bin a va (-1) alpha (bin x vx b1 beta gamma)) delta)
        = [bin x vx ( 0) (bin a va b2 alpha beta) (bin b vb b3 gamma delta), 1]
                when
                        [b2, b3] = table b1
                end;

/*
  table ( 1)                    = [( 0), (-1)];
  table (-1)                    = [( 1), ( 0)];
  table ( 0)                    = [( 0), ( 0)]
*/

// table w/o pattern matching
  table bal = [b1, b2]
        when
                b1 =   bal == (-1);
                b2 = -(bal ==   1);
        end
end;