
(*
Conventions:/home/accounts/facultystaff/l/lwittie/bin
    k, k1, k2, ... for kinds
    e, e1, e2, ... for expressions
    v, v1, v2, ... for values
    t, t1, t2, ... for types
    x, x1, x2, ... for variable names
    i, i1, i2, ... for integer constants
*)

exception NotImplemented of string;;
exception InternalError of string;;

(* offset in a file; the first character is at offset 0 *)
type pos = int;;

(* name and ?? *)
type var = string * int;;

let new_var:var->var =
  let count = ref 0 in
  let f (name, i) =
    incr count;
    if !count < 0 then raise (InternalError "variable count overflow") else
    (name, !count) in
  f
;;

(* A mapping and set indexed by var *)
module VarCompare = struct
  type t = var
  let compare = compare
end;;
module VarMap = Map.Make(VarCompare);;
module VarSet = Set.Make(VarCompare);;

(* Takes any form of VarMap and returns a SetMap of all the vars *)
let varmap_domain (vmap: 'a VarMap.t): VarSet.t =
  VarMap.fold
    (fun x _ vars -> VarSet.add x vars)
    vmap
    VarSet.empty
;;

type linearity = Nonlinear | Linear;;
type traced = Traced | Untraced;;
type nbits = int;;
let wordsize = 32;;

(* values have type of kind KType *)
(* Kinds are split into "kind" and "kind2" to prevent kinds
 * of the form (...)->int and (...)->bool, which might be difficult
 * to deal with in the arithmetic engine.
 *)
type kind =
| K2 of kind2
| KInt
| KBool
and kind2 =
| KArrow of (kind list) * kind2
| KType of nbits * (*traced * *)linearity
;;

(* type parameters have a name and a kind  [int M, int N; N>4] *)
type tparam = var * kind;;

type num = Num.num;;
type big_int = Big_int.big_int;;
type int_binary_op = IAddOp | ISubOp;;
type int_compare_op = BEqOp | BNeOp | BLtOp | BGtOp | BLeOp | BGeOp;;
type bool_binary_op = BAndOp | BOrOp;;
type sign = Signed | Unsigned;;

(* Type constructors. 
 Examples:

 TRecord  struct { Mem[] mem1; Mem[] mem2; }

*)
type typ =
| TArrow of typ list * typ * limit  (* kind KType Nonlinear, where each typ has kind KType _ *)
| TNamed of var
| TVar of var                 (* kind determined by environment *)
| TApp of typ * typ list      (* if t0:(k1,...,kn)->k0, t1:k1,...,tn:kn, then t0(t1,...,tn):k0 *)
| TInt of int_arith           (* t has kind KInt ==> (t = TInt(...) or t = TVar(v)) *)
| TBool of bool_arith         (* t has kind KBool ==> (t = TBool(...) or t = TVar(v)) *)
(*
| TSInt of sign * nbits       (* kind type(nbits)<-int *)
| TSBool                      (* kind type(wordbits)<-bool *)
*)
| TRecord of linearity * field list  (* each typ has kind KType *)
| TExists of tparam list * bool_arith * typ
| TAll of tparam list * bool_arith * typ
| TFun of tparam list * typ    (* if t:k2 when A:k1, then (Fun A:k1 => t):k1->k2 *)
| TInfer of typ_infer ref
| TAntiquote of string
(*
| TPtr of typ                 (* t has traced/untraced nonlinear type => (TPtr t) has traced 1-word nonlinear type *)
| TOPtr of typ                (* t has traced/untraced nonlinear type => (TOPtr t) has traced 1-word nonlinear type *)
*)

(* An integer arithmetic type has the form i0 + i1*x1 + ... + in*xn.
 * where all ik are nonzero and all xi are distinct.
 *
 * Two integer arithmetic types
 *   i0 + i1*x1 + ... + in*xn and
 *   i0' + i1'*x1' + ... + in'*xn'
 * are equal iff 
 *   - n = n'
 *   - i0 = i0'
 *   - for each 1<=j<=n, there is an ik'*xk' such that ij=ik' and xj=xk'
 * Integer equality is easy enough that we can use it to do type inference,
 * and we don't need to call a special constraint solver.
 *
 * For type inference, we use rational arithmetic.
 * Suppose you have two constraints:
 *   2 * alpha = 3 * beta
 *   beta = 10
 * If you start with the first constraint, you temporarily have a rational
 * constraint, such as:
 *   alpha = (3/2) * beta
 * Which later is successfully resolved to
 *   alpha = 15
 *   beta = 10
 * An IInfer is analogous to a TInfer: it is a placeholder for
 * yet-to-be-inferred information, but it is not a complete, legal type.
 *)
and int_arith =
| IArith of big_int * (big_int VarMap.t)
| IInfer of num * (num VarMap.t) * ((num * typ_infer ref) VarMap.t)

and bool_arith =
| BVar of var
| BConst of bool
| BNot of bool_arith
| BBinary of bool_binary_op * bool_arith * bool_arith
| BCompare of int_compare_op * int_arith * int_arith
| BInfer of typ_infer ref

(* a placeholder for a yet-to-be-inferred type
 *
 * (TUnknown (x, k, tv_env)) may be overwritten with
 *   (TKnown t), if t has kind k in environment tv_env.
 * The "x" is only used for printing error messages.
 *
 * (TInfer (ref (TKnown t))) may be replaced with t.
 *)
and typ_infer =
| TKnown of typ
| TUnknown of var * kind * kind VarMap.t

and field = string * typ
and param = var ref * typ

and limit = Unlimited | LimitAny | Limited of int_arith
;;

type named_type =
| TStruct of (linearity * field list)
(*| TUnion of (linearity * field list) *)
| TNative of (nbits * linearity)
;;

type type_spec =
| AbbrevSpec of typ
| StructSpec of nbits * linearity * tparam list option * field list
| EnumSpec of (var * big_int option) list
(*| HeapSpec of tparam list option * field list *)
(*| UnionSpec of linearity * tparam list option * field list*)
| NativeSpec of nbits * linearity * tparam list option
;;

type type_decl = var * type_spec;;

type assign_op =
| AssignOp | MultAssignOp | DivAssignOp | ModAssignOp
| AddAssignOp | SubAssignOp | LShiftAssignOp | RShiftAssignOp
| BitwiseAndAssignOp | XorAssignOp | BitwiseOrAssignOp
;;

type unary_op = PositiveOp | NegativeOp | BitwiseNotOp | NotOp;;

type binary_op =
| MultOp | DivOp | ModOp | AddOp | SubOp
| LShiftOp | RShiftOp | LtOp | GtOp | LeOp | GeOp
| EqOp | NeOp | BitwiseAndOp | XorOp | BitwiseOrOp | BitwiseXorOp
| AndOp | OrOp
;;

type typ_ann = Typ_ann of typ * kind * (string * typ_ann) list option;;

type exp = {exp_raw:exp_raw; exp_pos:pos option; mutable exp_typ:typ_ann option}
and exp_raw =
| EVar of var ref
| EBool of bool
| EInt of big_int
| EUnit
| ECall of exp * exp list
| EOverload of string list * string option ref
| EAssign of assign_op * exp * exp
(*
| EUnary of unary_op * exp
| EBinary of binary_op * exp * exp
*)
| EStruct of var
| ERecord of linearity * (string * typ option * exp) list
| EMember of exp * string
| ETApp of exp * (string * typ) list
| EPack of exp * typ * (string * typ) list
| EAntiquote of string
;;

type decl = (var ref * typ option) * exp;;
type mdecl = (var ref * typ option) list * exp;;
type unpack_spec = (string * var) list;;

type stmt = {stmt_raw:stmt_raw; stmt_pos:pos option}
and stmt_raw =
| SBlock of stmt list
| SDecl of unpack_spec option * decl * stmt
| SMDecl of unpack_spec option * mdecl * stmt
| SReturn of exp
| SWhile of exp * stmt * stmt
| SFor of (tparam (*XXX * typ option*)) list * bool_arith * (param * exp) list * exp * exp list option * stmt * stmt
| SContinue of exp list
| SIfElse of exp * stmt * stmt
| SBoolCase of bool_arith * stmt * stmt
| SExp of exp
| SAntiquote of string
(* | STryCatch of stmt * string * stmt *)
;;

type fun_body =
| FunBody of stmt
| FunNative
| FunLocalNative of string   (* Behaves like a native function, but defined in the same file *)
| FunStruct

type linkage_spec = LinkageC | LinkageCpp;;

type fun_decl =
{
  fun_decl_is_inline: bool;
  fun_decl_linkage: linkage_spec;
  fun_decl_name: var;
  fun_decl_tparams: (tparam list * bool_arith) option;
  fun_decl_params: (param * (typ_ann option ref)) list;
  fun_decl_ret: typ * (typ_ann option ref);
  fun_decl_limit: limit;
  fun_decl_stmt: fun_body;
}
;;

type cd_symbol =
  CDSymbol of string
| CDLiteral of string
;;

type emittable =
  PureC of string
| GlobalCode of typ * var * exp
| ExternGlobal of typ * var

type compiler_directive =
  CDImport of cd_symbol list
| CDExports of cd_symbol list
| CDExportsAll
| CDEmit of emittable
| CDGlobal of typ * var * exp
| CDEndOfFile            (* horrible hack *)
;;



type program =
{
  program_type_decls: type_decl list;
  mutable program_fun_decls: fun_decl list;
  program_directives: compiler_directive list
}
;;

let exp_at (exp_raw:exp_raw) (pos:pos) = {exp_raw = exp_raw; exp_pos = Some pos; exp_typ = None};;
let exp_nowhere (exp_raw:exp_raw) = {exp_raw = exp_raw; exp_pos = None; exp_typ = None};;
let exp_at_exp (exp_raw:exp_raw) ({exp_pos = exp_pos}:exp) = {exp_raw = exp_raw; exp_pos = exp_pos; exp_typ = None};;

let stmt_at (stmt_raw:stmt_raw) (pos:pos) = {stmt_raw = stmt_raw; stmt_pos = Some pos};;
let stmt_nowhere (stmt_raw:stmt_raw) = {stmt_raw = stmt_raw; stmt_pos = None};;
let stmt_at_stmt (stmt_raw:stmt_raw) ({stmt_pos = stmt_pos}:stmt) = {stmt_raw = stmt_raw; stmt_pos = stmt_pos};;

let new_tinfer (x:var) (k:kind) (tv_env:kind VarMap.t):typ = TInfer (ref (TUnknown (new_var x, k, tv_env)));;
(*let new_binfer (x:var):bool_arith = BInfer (ref (TUnknown (new_var x, KBool)));;*)

(* command-line flags used by the typechecker, interpreter, and compiler *)
let do_typecheck = ref true;;
let do_trace = ref false;;

exception SyntaxError of string;;
exception PosExn of pos * exn;;
exception TypeExn of typ * exn;;
exception MessageExn of string * exn;;
exception NestedExn of exn * exn;;





(* pretty-printing types and expressions (used for the trace feature) *)

let string_of_var (x:var):string =
  match x with
  | (x, 0) -> x
  | (x, i) -> x ^ "$" ^ (string_of_int i)
;;

let string_of_int_compare_op (op:int_compare_op) =
  match op with
  | BEqOp -> "=="
  | BNeOp -> "!="
  | BLtOp -> "<"
  | BGtOp -> ">"
  | BLeOp -> "<="
  | BGeOp -> ">="
;;

let format_block prec k f =
  Format.printf "@[<2>%s" (if k < prec then "(" else "");
  f();
  Format.printf "%s@]" (if k < prec then ")" else "")
;;

let print_var (vmap:var VarMap.t) (x:var):unit =
  let x = if VarMap.mem x vmap then VarMap.find x vmap else x in
  match x with
  | (x, 0) -> Format.printf "%s" x
  | (x, i) -> Format.printf "%s$%d" x i
;;

(*
 * Pick a new name for a variable, that doesn't conflict with
 * any existing name in vmap, but is still as simple as possible.
 *
 * For each variable x = (name, i), rename_var will pick some (name, j)
 * with as low a j as possible, and add the following mappings to vmap:
 *   (name, i) -> (name, j)
 *   (name, -1) -> (name, j)
 * The -1 entry is a hack to track which j values have been assigned to name.
 *
 * XXX: when print_typ is called more than once, the variable names
 * don't necessarily match up -- particularly if they appear free
 * in one print_typ call and bound in another.
 *)
let rec rename_var (vmap:var VarMap.t) ((name, i):var):(var VarMap.t * var) =
  let j =
    if VarMap.mem (name, -1) vmap then 1 + (snd (VarMap.find (name, -1) vmap))
    else 0 in
  ( VarMap.add (name, i) (name, j)
      (VarMap.add (name, -1) (name, j)
        vmap),
    (name, j))
;;

let rename_tparams
    (vmap:var VarMap.t)
    (tparams:tparam list):(var VarMap.t * tparam list) =
  List.fold_left
    (fun (vmap, tparams) (x, k) ->
      let (vmap, x) = rename_var vmap x in
      (vmap, tparams @ [(x, k)]))
    (vmap, [])
    tparams
;;

let rec print_list (print_element:'a->unit) (lis:'a list):unit =
  match lis with
  | [] -> ()
  | [h] -> print_element h
  | h::t -> print_element h; Format.printf ",@ "; print_list print_element t
;;

let rec print_kind (prec:int) (k:kind):unit =
  match k with
  | K2 k2 -> print_kind2 prec k2
  | KInt -> Format.printf "int"
  | KBool -> Format.printf "bool"
and print_kind2 (prec:int) (k:kind2):unit =
  let block = format_block prec in
  match k with
  | KType (size, lin) -> Format.printf "%stype%s"
      (match lin with Linear -> "@" | Nonlinear -> "")
      (if size = wordsize then "" else (string_of_int size))
  | KArrow (args, ret) ->
      block 1 (fun () ->
        print_kind2 1 ret;
        Format.printf "<-@,(";
        print_list (print_kind prec) args;
        Format.printf ")")
;;

let rec print_comma_list (f:'a->unit) (lis:'a list) =
  match lis with
  | [] -> ()
  | [h] -> f h
  | h::t -> (f h; Format.printf ",@ "; print_comma_list f t)
;;

let print_tparams
    (vmap:var VarMap.t)
    (tparams:tparam list):unit =
  print_comma_list
    (fun (x, k) ->
      print_kind 0 k;
      Format.printf " ";
      print_var vmap x)
    tparams
;;

let rec print_int_prefix (is_first:bool) (i:num):unit =
  if (Num.ge_num) i (Num.Int 0) then
  (
    Format.printf "%s%s"
      (if is_first then "" else "+")
      (if (Num.eq_num) i (Num.Int 1) then "" else (Num.string_of_num i) ^ "*")
  )
  else
  (
    Format.printf "-";
    print_int_prefix true (Num.minus_num i)
  )
;;

let rec print_int_arith (vmap:var VarMap.t) (prec:int) (t:int_arith):unit =
  match t with
  | IArith (i0, vars) ->
    (
      let b = (VarMap.fold
        (fun x i b -> print_int_prefix b (Num.num_of_big_int i); print_var vmap x; false)
        vars
        true) in
      Format.printf "%s%s"
        (if b || (Big_int.le_big_int i0 Big_int.zero_big_int) then "" else "+")
        (if (not b) && (Big_int.eq_big_int i0 Big_int.zero_big_int) then
          "" else
          Big_int.string_of_big_int i0)
    )
  | IInfer (i0, vars, ivars) ->
    (
      let b = (VarMap.fold
        (fun x i b -> print_int_prefix b i; print_var vmap x; false)
        vars
        true) in
      let b = (VarMap.fold
        (fun x (i, ti) b ->
          print_int_prefix b i;
          Format.printf "(";
          print_typ vmap 0 (TInfer ti);
          Format.printf ")";
          false)
        ivars
        b) in
      Format.printf "%s%s"
        (if b || (Num.le_num i0 (Num.Int 0)) then "" else "+")
        (if (not b) && (Num.eq_num i0 (Num.Int 0)) then "" else Num.string_of_num i0)
    )

and print_bool_arith (vmap:var VarMap.t) (prec:int) (t:bool_arith):unit =
  let block = format_block prec in
  match t with
  | BVar x -> print_var vmap x
  | BConst true -> Format.printf "true"
  | BConst false -> Format.printf "false"
  | BNot t -> block 6 (fun () -> Format.printf "!"; print_bool_arith vmap 6 t)
  | BBinary (BOrOp, t1, t2) ->
      block 3 (fun () ->
        print_bool_arith vmap 3 t1;
        Format.printf "||";
        print_bool_arith vmap 4 t2)
  | BBinary (BAndOp, t1, t2) ->
      block 4 (fun () ->
        print_bool_arith vmap 4 t1;
        Format.printf "&&";
        print_bool_arith vmap 5 t2)
  | BCompare (op, t1, t2) ->
      block 5 (fun () ->
        print_int_arith vmap 6 t1;
        Format.printf "%s" (string_of_int_compare_op op);
        print_int_arith vmap 6 t2)
  | BInfer ti -> print_typ vmap prec (TInfer ti)

and print_typ (vmap:var VarMap.t) (prec:int) (t:typ):unit =
  let block = format_block prec in
  match t with
  | TArrow (args, ret, Unlimited) ->
      block 1 (fun () ->
        print_typ vmap 2 ret;
        Format.printf "<-@,(";
        print_list (print_typ vmap prec) args;
        Format.printf ")")
  | TArrow (args, ret, LimitAny) ->
      block 1 (fun () ->
        print_typ vmap 2 ret;
        Format.printf "<-@,limitany@,(";
        print_list (print_typ vmap prec) args;
        Format.printf ")")
  | TArrow (args, ret, Limited tlimit) ->
      block 1 (fun () ->
        print_typ vmap 2 ret;
        Format.printf "<-@,[(";
        print_int_arith vmap 0 tlimit;
        Format.printf "]@,(";
        print_list (print_typ vmap prec) args;
        Format.printf ")")
  | TNamed x -> print_var vmap x
  | TVar x -> print_var vmap x
  | TApp (tfun, targs) ->
      block 2 (fun () ->
        print_typ vmap 2 tfun;
        Format.printf "[";
        print_comma_list (print_typ vmap 0) targs;
        Format.printf "]")
  | TInt t -> print_int_arith vmap prec t
  | TBool t -> print_bool_arith vmap prec t
  | TRecord (lin, fields) ->
    (
      Format.printf "%s[" (match lin with Linear -> "@" | Nonlinear -> ".");
      let (_, fields) = List.fold_left
        (fun (i, fields) (field_name, t) ->
          if field_name = (string_of_int i) then (i + 1, fields @ [("", t)])
          else (i + 1, fields @ [(" " ^ field_name, t)]))
        (1, [])
        fields in
      print_comma_list
        (fun (field_name, t) ->
          print_typ vmap 0 t;
          Format.printf "%s" field_name)
        fields;
      Format.printf "]"
    )
  | TExists (tparams, wher, tb) ->
      block 1 (fun () ->
        let (vmap, tparams) = rename_tparams vmap tparams in
        Format.printf "exists[";
        print_tparams vmap tparams;
        (match wher with BConst true -> () | _ ->
          Format.printf ";";
          print_bool_arith vmap 0 wher);
        Format.printf "]@ ";
        print_typ vmap 1 tb)
  | TAll (tparams, wher, tb) ->
      block 1 (fun () ->
        let (vmap, tparams) = rename_tparams vmap tparams in
        Format.printf "all[";
        print_tparams vmap tparams;
        (match wher with BConst true -> () | _ ->
          Format.printf ";";
          print_bool_arith vmap 0 wher);
        Format.printf "]@ ";
        print_typ vmap 1 tb)
  | TFun (tparams, tb) ->
      block 1 (fun () ->
        let (vmap, tparams) = rename_tparams vmap tparams in
        Format.printf "fun[";
        print_tparams vmap tparams;
        Format.printf "]@ ";
        print_typ vmap 1 tb)
  | TInfer ti ->
    (
      match !ti with
      | TKnown t -> print_typ vmap prec t
      | TUnknown (x, _, _) -> Format.printf "?%s" (string_of_var x)
    )
;;

