open Ast;;
open Big_int;;
open Num;;
open Types;;

(*
Type checking algorithm:

Stage 1
  - Named types may be mutually recursive.  First, an environment
    of named types is built.
    For each named type declared as "type A[B1,...,Bn] = N",
    add A to the global type environment.
  - Type abbreviations cannot be mutually recursive.  They are processed
    in the order that they are declared.
    For each type abbreviation declared as "type A[B1,...,Bn] = t",
    check that t is well-formed,
    then add A to the global type environment.

Note that a "TVar" type may be a locally declared type variable, a named type,
or a type abbreviation.  We could eagerly replace non-local TVars with
named types and abbreviation definitions in an early stage, but we do it
lazily instead, for the sake of printing better error messages.

Stage 2
  - Check the definitions of all named types.

Stage 3
  - Next, all global declarations, including functions, are collected
    in a global type environment, which maps variables to poly_types.
    A poly_type is a list of type parameters followed by a type.  Each
    type parameter has a type variable name and a kind.
    We don't look at any statements or expressions
    yet, nor do we do any kind/type checking.

Stage 4
  - Next, all types in the global type environment are checked.  Each
    type must have kind KType.  This stage could be merged with the
    previous stage, but it is left separate for clarity.

Stage 5
  - Next, all statements and expressions are checked.

Some key subroutines:

"Normalization": simplify expressions of the form
  TVar x
  TApp (TVar x, ...)
  TApp (TFun ..., ...)
  TInfer (ref (TKnown ...))
where "x" is an abbreviation or named type.  Normalization is delayed
for as long as possible for the sake of generating good error messages.
It must be performed before pattern matching on a typ.

"Substitution": replace local type variables with specific types.

"Unification": instantiate TInfer types in a way that makes two types
identical.

XXX: some places allow linear variables to leak
  - ends of functions
  - ends of for loops

*)

(* A context used to check the proper flow of control through statements *)
type stmt_control =
{
  stmt_ret_type:typ;
  stmt_cont_type:typ option;
  stmt_cont_default:(exp list * tv_context * var_to_latyp) option;
  stmt_must_jump:bool;
}
;;

let prog_trace (s):unit = (
(*  if (!do_trace) then (Format.printf s; Format.print_flush ())*)
);;

(*
 * Coerce type t1 to type t2, by attempting implicit coercions.
 *
 * Try to unify t1 with t2, but if t2 is an existential and t1 is not,
 * then try an implicit pack coercion before unifying.
 *
 * More coercions may be added later.  But the implicit coercions
 * should always be conservative in this sense: they should only
 * modify expressions that are guaranteed not to type-check.  In other
 * words, if there is any possibility that t1 unifies with t2, then
 * no implicit coercions will be attempted.
 *)
let rec coerce_type
    (tu:type_utils)
    (cs:constraint_set)
    (t1:typ)
    (t2:typ):unit =
  match (tu.norm_typ t1, tu.norm_typ t2) with
  | (TExists _, TExists _) -> add_constraint cs t1 t2
  | (TInfer _, TExists _) -> add_constraint cs t1 t2
  | (_, TExists (tparams, wher, tb2)) ->
    (
      (* To coerce t1 to (exists[..A..] where w. tb2), substitute a TInfer
       * for A in tb2, then unify t1 with tb2.  Substitute the same in
       * for w, and unify this with true.
       *)
      let targs = List.map
        (fun (x, k) -> new_tinfer x k cs.cs_tvc.tvc_vars)
        tparams in
      let submap = make_submap tparams targs in
      add_constraint
        cs
        (TBool (BConst true))
        (TBool (tu.subst_bool_arith submap wher));
      coerce_type tu cs t1 (tu.subst_typ submap tb2)
    )
  | (TRecord (lin1, fields1), TRecord (lin2, fields2)) ->
    (
      if (List.length fields1) != (List.length fields2) then raise (TypeMismatch ("", t1, t2)) else
      if lin1 <> lin2 then raise (TypeMismatch ("", t1, t2)) else
      List.iter2
        (fun (x1, tf1) (x2, tf2) ->
          if x1 <> x2 then raise (TypeMismatch ("", t1, t2)) else
          coerce_type tu cs tf1 tf2)
        fields1
        fields2
    )
  | _ -> add_constraint cs t1 t2
;;

(*
 * If t is a TAll type, eliminate the TAll by substituting unknowns
 * in for the bound variables.
 *)
let elim_all
    (tu:type_utils)
    (cs:constraint_set)
    (t:typ):typ =
  prog_trace "\tin elim_all.\n";
  match tu.norm_typ t with
  | TAll (tparams, wher, tb) ->
    (
      let targs =
        List.map (fun (x, k) -> new_tinfer x k cs.cs_tvc.tvc_vars)
        tparams in
      let submap = make_submap tparams targs in
      add_constraint
        cs
        (TBool (BConst true))
        (TBool (tu.subst_bool_arith submap wher));
      (tu.subst_typ submap tb)
    )
  | _ -> t
;;

let rec make_typ_ann
    (tu:type_utils)
    (tv_env:var_to_kind)
    (t:typ):typ_ann =
(* prog_trace "\tin make_typ_ann.\n"; *)
  let k = check_typ tu.tu_gtv_env tv_env true t in
  let get_fields fields = List.map
    (fun (x, t) -> 
      (x, Typ_ann (t, check_typ tu.tu_gtv_env tv_env true t, None)))
    fields in
  match (tu.unroll_typ t) with
  | TU_Typ (TExists (tparams, wher, tb)) ->
    (
      let tv_env = env_add_list tparams tv_env in
      make_typ_ann tu tv_env tb
    )
  | TU_Typ (TRecord (lin, fields)) -> Typ_ann (t, k, Some (get_fields fields))
  | TU_Struct (lin, fields) -> Typ_ann (t, k, Some (get_fields fields))
  | TU_Typ _ -> Typ_ann (t, k, None)
;;

(* Extract the type parameter named x_name *)
let get_tparam (x_name:string) (tparams:tparam list):tparam * tparam list =
  prog_trace "\tin get_tparam.\n";
  let (txparams, tparams) = List.partition
    (fun ((y_name, y_num), k) -> x_name = y_name)
    tparams in
  match txparams with
  | [] -> raise (TypeError ("variable " ^ x_name ^ " not found"))
  | [tparam] -> (tparam, tparams)
  | _ -> raise (InternalError "get_tparam")
;;

(*
 * Collect all named type declarations.  Don't check
 * their validity yet; just remember them and establish
 * a kind for the parameters of each named type.
 *)
let collect_named_types (p:program):var_to_global_type =
  prog_trace "in Collect Named Types.\n";
  List.fold_left
    (fun gtv_env (x, ts) ->
      match ts with
      | AbbrevSpec _ -> gtv_env
      | StructSpec (size, lin, tparams, fields) ->
          gtv_env_add x (GT_Named (
              tparams,
              TStruct (lin, fields),
              KType (size, lin)))
            gtv_env
      | NativeSpec (size, lin, tparams) ->
          gtv_env_add x (GT_Named (
              tparams,
              TNative (size, lin),
              KType (size, lin)))
            gtv_env
    )
    VarMap.empty
    p.program_type_decls
;;

(*
 *  Type abbreviations cannot be mutually recursive.  They are checked
 *  in the order that they are declared.
 *  For each type abbreviation declared as "typedef A = t",
 *  check that t is well-formed, then add A to the global type environment.
 *)
let collect_type_abbrevs (p:program) (gtv_env:var_to_global_type):var_to_global_type =
  prog_trace "in Collect Type Abbrevs.\n"; 
  List.fold_left
    (fun gtv_env (x, ts) ->
      match ts with
      | AbbrevSpec t ->
        (
          try
          (
            let k = check_typ gtv_env VarMap.empty false t in
            gtv_env_add x (GT_Abbrev (t, k)) gtv_env
          )
          with err -> raise (MessageExn ("checking type declaration " ^ (string_of_var x), err))
        )
      | _ -> gtv_env
    )
    gtv_env
    p.program_type_decls
;;

let check_named_types (gtv_env:var_to_global_type):unit =
  prog_trace "in Check Named Types.\n";
  VarMap.iter
    (fun x gt ->
      match gt with
      | GT_Abbrev _ -> ()
      | GT_Named (tparams_opt, nt, k) ->
          try
          (
            let tv_env = env_add_list
              (match tparams_opt with Some s -> s | None -> [])
              VarMap.empty in
            let check_fields (fields:field list):int =
              List.fold_left
                (fun size (name, t) ->
                  try
                  (
                    let kfield = check_typ gtv_env tv_env false t in
                    match (k, kfield) with
                    | (KType (_, Linear), K2 (KType (n, _))) -> size + n
                    | (KType (_, Nonlinear), K2 (KType (n, Nonlinear))) -> size + n
                    | (_, K2 (KType _)) -> raise (TypeError "linear field in nonlinear object")
                    | _ -> raise (TypeError "field must have kind type")
                  )
                  with err -> raise (MessageExn ("checking field " ^ name, err)))
                0
                fields in
            match nt with
            | TNative _ -> ()
            | TStruct (lin, fields) ->
                unify_kinds (K2 k) (K2 (KType (check_fields fields, lin)))
          )
          with err -> raise (MessageExn ("checking type declaration " ^ (string_of_var x), err))
    )
    gtv_env
;;

let check_program (p:program):unit =
  prog_trace "in Check Program.\n";
  let gtv_env = collect_type_abbrevs p (collect_named_types p) in
  check_named_types gtv_env;
  let tu = make_type_utils gtv_env in
  let {
    subst_typ = subst_typ;
    subst_int_arith = subst_int_arith;
    subst_bool_arith = subst_bool_arith;
    norm_typ = norm_typ;
    norm_int_arith = norm_int_arith;
    norm_bool_arith = norm_bool_arith;
    barith_typ = barith_typ;
    get_struct_or_record = get_struct_or_record;
  } = tu in

  let t__unit = TVar ("__Unit", 0) in
  let t__bool = TVar ("__Bool", 0) in
  let t__int = TVar ("__Int", 0) in

  let add_constructors (fun_decls: fun_decl list):fun_decl list =
  prog_trace "in add_constructors.\n";
    List.fold_left
      (fun fun_decls (x, spec) ->
        match spec with
        | StructSpec (size, lin, tparams_opt, fields) ->
          (
            let tparams = (match tparams_opt with Some s -> s | None -> []) in
            let tv_env = env_add_list tparams VarMap.empty in
            let tstruct =
              match tparams_opt with
              | None -> TVar x
              | Some s -> TApp (TVar x, List.map (fun (x, _) -> TVar x) s) in
            let fdecl =
            { fun_decl_is_inline = true;
              fun_decl_linkage = LinkageCpp;
              fun_decl_name = struct_fun x;
              fun_decl_tparams =
                (match tparams_opt with
                | None -> None
                | Some s -> Some (s, BConst true));
              fun_decl_params =
                List.map (fun (name, t) ->
                  ((ref (name, 0), t),
                    ref (Some (make_typ_ann tu tv_env t)))) fields;
              fun_decl_ret =
                (tstruct, ref (Some (make_typ_ann tu tv_env tstruct)));
              fun_decl_limit = LimitAny;
              fun_decl_stmt = FunStruct } in
            fdecl::fun_decls
          )
        | _ -> fun_decls)
      fun_decls
      p.program_type_decls
  in

  let build_global_tenv (fun_decls: fun_decl list):var_to_typ =
    prog_trace "in build_global_tenv.\n";
    List.fold_left
      (fun gt_env {
          fun_decl_name = x;
          fun_decl_tparams = tparams_opt;
          fun_decl_params = params;
          fun_decl_ret = ret;
          fun_decl_limit = limit } ->
        try
        (
          let t = TArrow (List.map (fun ((_, t), _) -> t) params, fst ret, limit) in

          let tv_env = List.fold_left
            (fun tv_env (x, k) -> tv_env_add x k tv_env)
            VarMap.empty
            (match tparams_opt with None -> [] | Some (s, _) -> s) in
          ( match tparams_opt with
            | None -> ()
            | Some (_, s) -> ignore (check_typ gtv_env tv_env false (TBool s)));
          ignore (check_typ gtv_env tv_env false t);

          (* set the type annotations *)
          (snd ret) := Some (make_typ_ann tu tv_env (fst ret));
          List.iter
            (fun ((_, t), ann) ->
              ann := Some (make_typ_ann tu tv_env t))
            params;

          (* create a polymorphic named type "Typeof_<x>" *)
          let ptype =
          (
            match tparams_opt with
            | None -> t
            | Some (tparams, wher) -> TAll (tparams, wher, t)
          ) in
          gt_env_add x ptype gt_env
        )
        with err -> raise (MessageExn ("Checking function declaration " ^ (string_of_var x), err)))
      VarMap.empty
      fun_decls
  in

  p.program_fun_decls <- add_constructors p.program_fun_decls;
  let gt_env = build_global_tenv (p.program_fun_decls) in

  let do_check_function (fdecl:fun_decl):unit =
    prog_trace "\tin do_check_function.\n";
    let { fun_decl_name = fname;
          fun_decl_tparams = ftparams_opt;
          fun_decl_params = fparams;
          fun_decl_ret = fret;
          fun_decl_stmt = fstmt;
          fun_decl_limit = flimit } = fdecl in

    let tenv_add
        (x:var ref)
        (lin:linearity)
        (assigned:bool)
        (t:typ)
        ((vmap,v_tenv):var_to_latyp):var_to_latyp =
      prog_trace "\t\tin tenv_add.\n"; 
      let (xname, xnum) = !x in
      (if VarMap.mem (xname, xnum) v_tenv then
        (* XXX: let's revisit this sometime; look at Xanadu.
         * raise (TypeError ((string_of_var x) ^ " already in environment"))
         *)
        x := (new_var (xname, 0))); 
        (VarMap.add (xname, 0) !x vmap, VarMap.add !x (lin, assigned, t) v_tenv)
    in

    (*
     * Return an updated version of tenv1, reflecting any changes to
     * the assignment statuses of variables from tenv2.
     *)
    let apply_tenv_updates
        ((vmap, v_tenv1):var_to_latyp)
        ((_, v_tenv2):var_to_latyp)
        (f:var -> bool -> bool -> bool):
        var_to_latyp =
      prog_trace "\t\tin apply_tenv_updates.\n";
      let v_tenv = VarMap.fold
        (fun x (lin1, assigned1, t1) tenv ->
          if not (VarMap.mem x v_tenv2) then raise (InternalError ("merge " ^ (string_of_var x))) else
          let (_, assigned2, t2) = VarMap.find x v_tenv2 in
          VarMap.add x (lin1, f x assigned1 assigned2, t1) tenv)
        v_tenv1
        VarMap.empty in
      (vmap, v_tenv)
    in

(* function: check_exp *)
    let rec check_exp
        (tvc:tv_context)
        (tenv:var_to_latyp)
        (definitely_assign:bool)
        (e:exp):var_to_latyp * typ  =
      prog_trace "\t\tin check_exp.";
      try
      (
        let () =
          match e.exp_raw with
          | EVar _ -> ()
          | EMember _ -> ()
          | _ -> if definitely_assign then raise (TypeError "cannot assign to this expression")
        in

        let exp_coerce_type t1 t2 =
	prog_trace "\t\tin exp_coerce_type.\n";
        (
          let cs = new_constraint_set [] tu tvc in
          coerce_type tu cs t1 t2;
          unify_constraints cs true
          (* XXX: check that all unification variables were resolved? *)
        ) in

        let (tenv, t) =
        (
          match e.exp_raw with
          | EVar x ->
            (
              if VarMap.mem !x (fst tenv) then
              (
                let (vmap, v_tenv) = tenv in
                x := VarMap.find !x vmap;
                let (lin, assigned, t) = VarMap.find !x v_tenv in
                if definitely_assign then
                  ((vmap, VarMap.add !x (lin, true, t) v_tenv), t)
                else if assigned then
                (
                  match lin with
                  | Nonlinear -> (tenv, t)
                  | Linear -> ((vmap, VarMap.add !x (lin, false, t) v_tenv), t)
                )
                else raise (TypeError ((string_of_var !x) ^ " is unassigned here"))
              )
              else if VarMap.mem !x gt_env then (tenv, VarMap.find !x gt_env)
              else raise (TypeError ((string_of_var !x) ^ " not in scope"))
            )
          | EBool b -> (tenv, TApp (TNamed ("Bool", 0), [TBool (BConst b)]))
          | EInt i -> (tenv, TApp (TNamed ("Int32", 0),
              [TInt (IArith (i, VarMap.empty))]))
          | EUnit -> (tenv, TNamed ("__Unit", 0))
          | ECall ({exp_raw = EOverload (names, name_ref); exp_pos = Some epos} as e_overload, args) ->
            (
              (* Overloaded function call:
               * Try the function names in order; use the first function call
               * that typechecks.
               *)
              let rec f names exn_list =
              (
                match names with
                | [] -> raise (OverloadFailure exn_list)
                | name::names ->
                  (
                    try
                    (
                      name_ref := Some name;
                      if not (VarMap.mem (name, 0) gt_env) then raise (TypeError (name ^ " not in scope")) else
                      let tf = VarMap.find (name, 0) gt_env in
                      let t = check_call tvc tenv tf args in
                      e_overload.exp_typ <- Some (make_typ_ann tu tvc.tvc_vars tf);
                      t
                    )
                    with err -> f names (exn_list @ [(name, err)])
                  )
              ) in
              f names []
            )
          | ECall (ef, args) ->
            (
              let (tenv, tf) = check_exp tvc tenv false ef in
              check_call tvc tenv tf args
            )
          | EOverload _ -> raise (TypeError "overloaded function must be used in a function call")
          | EAssign (op, e1, e2) ->
            (
              let (tenv, t1) = check_exp tvc tenv true e1 in
              let (tenv, t2) = check_exp tvc tenv false e2 in
              exp_coerce_type t2 t1;
              (
                match op with
                | AssignOp -> ()
                | _ -> raise (InternalError "EAssign not implemented") (* XXX *)
              );
              (tenv, t2)
            )
          | EStruct x ->
            (
              let xfun = struct_fun x in
              if not (VarMap.mem xfun gt_env) then raise (TypeError ("no such struct: " ^ (string_of_var x))) else
              (tenv, VarMap.find xfun gt_env)
            )
          | EMember (e, field_name) ->
            (
              let (tenv, ts) = check_exp tvc tenv false e in
              let (lin, fields) = get_struct_or_record ts in
              if (not (List.mem_assoc field_name fields)) then raise (TypeExn (ts, (TypeError ("record has no field " ^ field_name)))) else
              let tf = List.assoc field_name fields in
              (tenv, tf)
            )
          | ERecord (lin, fields) ->
            (
              let (tenv, tfields) = List.fold_left
                (fun (tenv, tfields) (field_name, t_opt, e) ->
                  let (tenv, t2) = check_exp tvc tenv false e in
                  let t =
                  (
                    match t_opt with
                    | None -> t2
                    | Some t ->
                      (
                        ignore (check_typ gtv_env tvc.tvc_vars false t);
                        exp_coerce_type t2 t;
                        t
                      )
                  ) in
                  (tenv, tfields @ [(field_name, t)])
                )
                (tenv, [])
                fields in
              (tenv, TRecord (lin, tfields))
            )
          | ETApp (e, targs) ->
            (
              let (tenv, ta) = check_exp tvc tenv false e in
              match norm_typ ta with
              | TAll (tparams, wher, tb) ->
                (
                  (*
                   * If e:(all[C,D,E,F;wher] t), then
                   * e[D=td,F=tf]:(fun[D,F] all[A,C;wher] t)[td,tf]
                   *)
                  let (tlist, tparams) =
                    List.fold_left
                      (fun (tlist, tparams) (x_name, targ) ->
                        let ((yparam, kparam), tparams) = get_tparam x_name tparams in
                        unify_kinds kparam (check_typ gtv_env tvc.tvc_vars false targ);
                        (((yparam, kparam), targ)::tlist, tparams))
                      ([], tparams)
                      targs in
                  let t = norm_typ (TApp
                    ( TFun (List.map fst tlist, TAll (tparams, wher, tb)),
                      List.map snd tlist)) in
                  match tparams with
                  | [] ->
                    (
                      (* when all variables are gone, eliminate the TAll *)
                      let cs = new_constraint_set [] tu tvc in
                      let t = elim_all tu cs t in
                      unify_constraints cs true;
                      (tenv, t)
                    )
                  | _ -> (tenv, t)
                )
              | _ -> raise (TypeError "polymorphic type expected")
            )
          | EPack (e, tdest, targs) ->
            (
              let (tenv, ta) = check_exp tvc tenv false e in
              unify_kinds
                (check_typ gtv_env tvc.tvc_vars false tdest)
                (check_typ gtv_env tvc.tvc_vars false ta);
              match norm_typ tdest with
              | TExists (tparams, wher, t) ->
                (
                  (* To coerce t1 to (exists[..A..] where w. tb2), substitute a TInfer
                   * for A in tb2, then unify t1 with tb2.  Substitute the same in
                   * for w, and unify this with true.
                   *)
                  let (submap, tparams) = List.fold_left
                    (fun (submap, tparams) (x_name, targ) ->
                      let ((yparam, kparam), tparams) = get_tparam x_name tparams in
                      unify_kinds kparam (check_typ gtv_env tvc.tvc_vars false targ);
                      (VarMap.add yparam targ submap, tparams))
                    (VarMap.empty, tparams)
                    targs in
                  let submap = List.fold_left
                    (fun submap (x, k) ->
                      let t = new_tinfer x k tvc.tvc_vars in
                      VarMap.add x t submap)
                    submap
                    tparams in
                  let (t1a, t2a) = (ta, subst_typ submap t) in
                  let (t1b, t2b) =
                    ( TBool (BConst true),
                      TBool (subst_bool_arith submap wher)) in
                  unify_constraints
                    (new_constraint_set
                      [(t1a, t2a, (t1a, t2a)); (t1b, t2b, (t1b, t2b))]
                      tu
                      tvc)
                    true;
                  (* XXX: check that all unification variables were resolved? *)
                  (tenv, tdest)
                )
              | _ -> raise (TypeExn (tdest, TypeError "pack expects an existential type"))
            )
        ) in

        (* Make sure the type has no unresolved inference variables.
         * XXX: what is the right place to check this?  How should
         * inference failures be reported?
         * Note: ERecord relies on the check_typ here.
         *)
        ignore (check_typ gtv_env tvc.tvc_vars false t);

        e.exp_typ <- Some (make_typ_ann tu tvc.tvc_vars t);

        (tenv, t)
      )
      with err ->
      (
        match e.exp_pos with
        | None -> raise err
        | Some pos -> raise (PosExn (pos, err))
      )
    and check_call
        (tvc:tv_context)
        (tenv:var_to_latyp)
        (tf:typ)
        (args:exp list):var_to_latyp * typ  =
    (
      prog_trace "\t\tin check_call.\n";
      let (tenv, targs) = List.fold_left
        (fun (tenv, targs) arg ->
          let (tenv, targ) = check_exp tvc tenv false arg in
          (tenv, targs @ [targ]))
        (tenv, [])
        args in
      let cs = new_constraint_set [] tu tvc in
      let tf = elim_all tu cs tf in
      match norm_typ tf with
      | TArrow (params, ret, limit) ->
        (
          (* Is our limit big enough to call the function? *)
          ( match (flimit, limit) with
            | (LimitAny, _) -> raise (InternalError "check_call: LimitAny")
            | (_, LimitAny) -> ()
            | (Unlimited, _) -> ()
            | (Limited _, Unlimited) -> raise (TypeError "cannot call unlimited function from a limited function")
            | (Limited tlimit1, Limited tlimit2) ->
                add_constraint cs (TBool (BCompare (BGtOp, tlimit1, tlimit2))) (TBool (BConst true))
          );

          if (List.length params) != (List.length targs) then raise (TypeError "incorrect number of arguments to function") else
          (* Unify argument and parameter types *)
          List.iter2 (coerce_type tu cs) targs params;
          unify_constraints cs true;
          (* check that all inference variables were resolved: *)
          ignore (check_typ gtv_env tvc.tvc_vars false tf);

          ( match (flimit, limit) with
            | (Unlimited, Limited _) ->
              (
                match check_typ gtv_env tvc.tvc_vars false ret with
                | K2(KType (0, _)) -> ()
                | _ -> raise (TypeError "limited function cannot return run-time data to unlimited function")
              )
            | _ -> ()
          );

          (tenv, ret)
        )
      | _ -> raise (TypeError "cannot call a non-function")
    )
    in

    let get_tbool
        (tvc:tv_context)
        (tenv:var_to_latyp)
        (e:exp):var_to_latyp * bool_arith option =
      prog_trace "\t\tin get_tbool.\n";
    (
      let (tenv, t) = check_exp tvc tenv false e in
      match norm_typ t with
      | TApp (TNamed ("Bool", 0), [tb]) -> (tenv, Some (barith_typ tb))
      | TExists ([(("B", i1), KBool)], BConst true, t2) ->
        (
          (* XXX: there needs to be a better way to do this *)
          match norm_typ t2 with
          | TApp (TNamed ("Bool", 0), [TVar ("B", i2)]) when i1 = i2 ->
              (tenv, None)
          | _ -> raise (TypeExn (t, TypeError "boolean expected"))
        )
      | _ -> raise (TypeExn (t, TypeError "boolean expected"))
    ) in

    let unpack_stmt
        (unpack_opt:unpack_spec option)
        (tvc:tv_context)
        (t:typ):tv_context * typ =
      prog_trace "\t\tin unpack_stmt.\n";
    (
      match (unpack_opt, norm_typ t) with
      | (None, _) -> (tvc, t)
      | (Some unpack_map, TExists (tparams, wher, t)) ->
        (
          let (submap, tparams, tvars) = List.fold_left
            (fun (submap, tparams, tvars) (x1_name, x2) ->
              let ((yparam, kparam), tparams) = get_tparam x1_name tparams in
              ( env_add yparam (TVar x2) submap,
                tparams,
                env_add x2 kparam tvars))
            (VarMap.empty, tparams, tvc.tvc_vars)
            unpack_map in
          let (submap, tvars) = List.fold_left
            (fun (submap, tvars) (x, k) ->
              let x2 = new_var x in
              (env_add x (TVar x2) submap, env_add x2 k tvars))
            (submap, tvars)
            tparams in
          ( { tvc_vars = tvars;
              tvc_known = (subst_bool_arith submap wher)::tvc.tvc_known; },
            subst_typ submap t)
        )
      | _ -> raise (TypeExn (t, TypeError "only an existential type can be unpacked"))
    ) in

    (* If the control reaches the end of a block, check that
     * the function jumps if necessary.
     *)
    let check_control_end
        (control:stmt_control)
        (tvc:tv_context)
        (tv_env:var_to_latyp):var_to_latyp =
      prog_trace "\t\tin check_control_end.\n";
    (
      match control with
      | {stmt_must_jump = false} -> tv_env
      | { stmt_must_jump = true;
          stmt_cont_type = Some tf;
          stmt_cont_default = Some (args, tvc_loop, tv_env_loop)} ->
        (
          (* First, check that the continue statement would typecheck in
           * the original loop body environment, updated to reflect the
           * current assignment state of variables.
           *)
          let tv_env_loop = apply_tenv_updates
            tv_env_loop
            tv_env
            (fun _ a1 a2 -> a2) in
          let _ = check_call tvc_loop tv_env_loop tf args in
          (* Use our own environment to compute the new tv_env *)
          let (tv_env, _) = check_call tvc tv_env tf args in
          tv_env
        )
      | { stmt_must_jump = true;
          stmt_cont_type = Some tf;
          stmt_cont_default = None} -> raise (TypeError "statement must continue or return")
      | {stmt_must_jump = true} -> raise (TypeError "function must return")
    ) in

    let rec check_stmt
        (control:stmt_control)
        (tvc:tv_context)
        (tenv:var_to_latyp)
        (s:stmt):(var_to_latyp) =
      prog_trace "\t\tin check_stmt.";
      let stmt_coerce_type tvc t1 t2 =
        prog_trace "\t\tin stmt_coerce_type."; 
      (
        let cs = new_constraint_set [] tu tvc in
        coerce_type tu cs t1 t2;
        unify_constraints cs true
        (* XXX: check that all unification variables were resolved? *)
      ) in

      try
      (
        match s.stmt_raw with
        | SBlock slist ->
          (
            let rec f tenv slist =
              match slist with
              | [] -> check_control_end control tvc tenv
              | [s] -> check_stmt control tvc tenv s
              | s::slist ->
                  let tenv = check_stmt
                    {control with stmt_must_jump = false} tvc tenv s in
                  f tenv slist
            in f tenv slist
          )
        | SDecl (unpack_opt, ((x, t_opt), e), s2) ->
          (
            let (tenv, t2) = check_exp tvc tenv false e in
            let (tvc, t2) = unpack_stmt unpack_opt tvc t2 in
            let t =
            (
              match t_opt with
              | None -> t2
              | Some t ->
                (
                  ignore (check_typ gtv_env tvc.tvc_vars true t);
                  stmt_coerce_type tvc t2 t;
                  t
                )
            ) in
            match (check_typ gtv_env tvc.tvc_vars false t) with
            | K2 (KType (_, lin)) ->
              (
                let tenv2 = tenv_add x lin true t tenv in
                let tenv2 = check_stmt control tvc tenv2 s2 in
                apply_tenv_updates
                  tenv
                  tenv2
                  (fun x assigned1 assigned2 -> assigned2)
              )
            | _ -> raise (InternalError "expression's type does not have kind type")
          )
        | SMDecl (unpack_opt, (params, e), s2) ->
          (
            let (tenv, ts) = check_exp tvc tenv false e in
            let (tvc, ts) = unpack_stmt unpack_opt tvc ts in
            let (_, fields) = get_struct_or_record ts in
            if (List.length params) <> (List.length fields) then raise (TypeError ("number of declared variables must match number of fields in struct")) else
            let tenv2 = List.fold_left2
              (fun tenv2 (xp, tp_opt) (xf, tf) ->
                let (lin, tp) =
                (
                  match tp_opt with
                  | None ->
                    (
                      let lin = match (check_typ gtv_env tvc.tvc_vars true tf) with
                      | K2 (KType (_, lin)) -> lin
                      | _ -> raise (TypeError "Type does not have kind type") in
                      (lin, tf)
                    )
                  | Some tp ->
                    (
                      let lin = match (check_typ gtv_env tvc.tvc_vars true tp) with
                      | K2 (KType (_, lin)) -> lin
                      | _ -> raise (TypeError "Type does not have kind type") in
                      stmt_coerce_type tvc tf tp;
                      (lin, tp)
                    )
                ) in
                ignore (check_typ gtv_env tvc.tvc_vars false tp);
                tenv_add xp lin true tp tenv2)
              tenv
              params
              fields in
            let tenv2 = check_stmt control tvc tenv2 s2 in
            apply_tenv_updates
              tenv
              tenv2
              (fun x assigned1 assigned2 -> assigned2)
          )
        | SReturn e ->
          (
            let (tenv, t) = (check_exp tvc tenv false e) in
            stmt_coerce_type tvc t control.stmt_ret_type;
            tenv
          )
        | SWhile (e, s3, s4) ->
          (
            (match flimit with Unlimited -> () | _ -> raise (TypeError "cannot use loops in a limited function"));

            let (tenv2, tbool_opt) = get_tbool tvc tenv e in
            let tvc3 = {
              tvc_vars = tvc.tvc_vars;
              tvc_known = match tbool_opt with
                | None -> tvc.tvc_known
                | Some tbool -> tbool::tvc.tvc_known} in
            let tvc4 = {
              tvc_vars = tvc.tvc_vars;
              tvc_known = match tbool_opt with
                | None -> tvc.tvc_known
                | Some tbool -> (BNot tbool)::tvc.tvc_known} in
            let tenv3 = check_stmt 
              { control with stmt_must_jump = false;
                stmt_cont_type = Some (TArrow ([], TVar ("__Unit", 0), Unlimited))}
              tvc3
              tenv2
              s3 in
            let tenv4 = check_stmt control tvc4 tenv2 s4 in

            (* If the loop repeats, it will jump back to old_tenv, which
             * cannot make any assumptions that new_tenv doesn't maintain *)
            let _ = apply_tenv_updates
              tenv
              tenv3
              (fun x assigned1 assigned2 ->
                if assigned1 = assigned2 then assigned1
                else raise (TypeError ((string_of_var x) ^ " assignment state must be same before and after loop body")))
            in
            tenv4
          )
        | SFor (tparams, wher, params, e_test, args, s_body, s_exit) ->
          (
            (match flimit with Unlimited -> () | _ -> raise (TypeError "cannot use loops in a limited function"));

            (* Build the new type variable context *)
            let tv_env_loop = tv_env_add_list tparams tvc.tvc_vars in
            ignore (check_typ gtv_env tv_env_loop false (TBool wher));
            let tvc_loop = {tvc_vars = tv_env_loop; tvc_known = wher::tvc.tvc_known} in

            (* Check the parameter types *)
            let param_types = List.map
              (fun ((x, t), _) ->
                try
                  ( match (check_typ gtv_env tv_env_loop false t) with
                    | K2 (KType (_, lin)) -> (x, t, lin)
                    | _ -> raise (TypeError "parameter must have kind type"))
                with err -> raise (MessageExn ("checking variable " ^ (string_of_var !x), err)))
              params in
            let tf = TAll (tparams, wher, TArrow
              ( List.map (fun (_, t, _) -> t) param_types,
                TVar ("__Unit", 0),
                Unlimited)) in

            (* Check the loop variable initial values *)
            let (tenv, _) = check_call tvc tenv tf (List.map snd params) in

            (* Check and add the parameters *)
            let tenv_pre_loop = List.fold_left
              (fun tenv_loop (x, t, lin) ->
                tenv_add x lin (lin = Nonlinear) t tenv_loop)
              tenv
              param_types in
            let tenv_loop = List.fold_left
              (fun tenv_loop (x, t, lin) -> 
              tenv_add x lin true t tenv_loop)
              tenv
              param_types in

            (* Check the body *)
            let (tenv_bool, tbool_opt) = get_tbool tvc_loop tenv_loop e_test in
            let tvc_body = {
              tvc_vars = tvc_loop.tvc_vars;
              tvc_known = match tbool_opt with
                | None -> tvc_loop.tvc_known
                | Some tbool -> tbool::tvc_loop.tvc_known} in
            let tvc_exit = {
              tvc_vars = tvc_loop.tvc_vars;
              tvc_known = match tbool_opt with
                | None -> tvc_loop.tvc_known
                | Some tbool -> (BNot tbool)::tvc_loop.tvc_known} in
            let tenv_post_body = check_stmt
              { control with stmt_must_jump = true;
                stmt_cont_type = Some tf;
                stmt_cont_default = (match args with
                  | None -> None
                  | Some args -> Some (args, tvc_body, tenv_bool))}
              tvc_body
              tenv_bool
              s_body in
            let tenv_post_exit = check_stmt control tvc_exit tenv_bool s_exit in

            (* If the loop repeats, it will jump back to old_tenv, which
             * cannot make any assumptions that new_tenv doesn't maintain *)
            let _ = apply_tenv_updates
              tenv_pre_loop
              tenv_post_body
              (fun x assigned1 assigned2 ->
                if assigned1 = assigned2 then assigned1
                else raise (TypeError ((string_of_var x) ^ " assignment state must be same before and after loop body")))
            in
            tenv_post_exit
          )
        | SContinue args ->
          (
            match control.stmt_cont_type with
            | None -> raise (TypeError "continue statement must be inside a loop")
            | Some tf ->
              (
                let (tenv, _) = check_call tvc tenv tf args in
                tenv
              )
          )
        | SIfElse (e, s2, s3) ->
          (
            let (tenv, tbool_opt) = get_tbool tvc tenv e in
            let tvc2 = {
              tvc_vars = tvc.tvc_vars;
              tvc_known = match tbool_opt with
                | None -> tvc.tvc_known
                | Some tbool -> tbool::tvc.tvc_known} in
            let tvc3 = {
              tvc_vars = tvc.tvc_vars;
              tvc_known = match tbool_opt with
                | None -> tvc.tvc_known
                | Some tbool -> (BNot tbool)::tvc.tvc_known} in
            let tenv2 = check_stmt control tvc2 tenv s2 in
            let tenv3 = check_stmt control tvc3 tenv s3 in
            apply_tenv_updates
              tenv2
              tenv3
              (fun x assigned1 assigned2 ->
                if assigned1 = assigned2 then assigned1
                else raise (TypeError ("variable " ^ (string_of_var x) ^ " unassigned in one branch of if/then/else")))
          )
        | SBoolCase (tbool, s2, s3) ->
          (
            (match flimit with (Limited _) -> () | _ -> raise (TypeError "if[B] can only be used in a limited function"));

            let tvc2 = {
              tvc_vars = tvc.tvc_vars;
              tvc_known = tbool::tvc.tvc_known} in
            let tvc3 = {
              tvc_vars = tvc.tvc_vars;
              tvc_known = (BNot tbool)::tvc.tvc_known} in
            let tenv2 = check_stmt control tvc2 tenv s2 in
            let tenv3 = check_stmt control tvc3 tenv s3 in
            apply_tenv_updates
              tenv2
              tenv3
              (fun x assigned1 assigned2 ->
                if assigned1 = assigned2 then assigned1
                else raise (TypeError ("variable " ^ (string_of_var x) ^ " unassigned in one branch of if/then/else")))
          )
        | SExp e -> 
          (
            let (tenv, _) = check_exp tvc tenv false e in
            check_control_end control tvc tenv
          )
      )
      with err ->
      (
        match s.stmt_pos with
        | None -> raise err
        | Some pos -> raise (PosExn (pos, err))
      )
    in

    let (ftparams, fwher) =
    (
      match ftparams_opt with
      | None -> ([], BConst true)
      | Some s -> s
    ) in

    let tv_env = tv_env_add_list ftparams VarMap.empty in
    ignore (check_typ gtv_env tv_env false (TBool fwher));
    let tvc = {tvc_vars = tv_env; tvc_known = [fwher]} in

    (* Check and add the parameters *)
    let tenv = List.fold_left
      (fun tenv ((x, t), _) ->
        try
          let lin =
          ( match (check_typ gtv_env tv_env false t) with
            | K2 (KType (_, lin)) -> lin
            | _ -> raise (TypeError "parameter must have kind type")) in
          prog_trace "\t\tcalling tenv_add.";
          tenv_add x lin true t tenv
        with err -> raise (MessageExn ("checking parameter " ^ (string_of_var !x), err)))
      (VarMap.empty, VarMap.empty)
      fparams in

    ignore (check_typ gtv_env tv_env false (fst fret));

    ( match (fstmt, flimit) with
      | (_, Unlimited) -> ()
      | (FunNative, LimitAny) -> ()
      | (FunLocalNative _, LimitAny) -> ()
      | (FunStruct, LimitAny) -> ()
      | (_, LimitAny) -> raise (TypeError "only native functions can be declared limitany")
      | (_, Limited tlimit) ->
      (
        let cs = new_constraint_set [] tu tvc in
        add_constraint cs (TBool (BCompare (BGeOp, tlimit, Arith.iarith_zero))) (TBool (BConst true));
        unify_constraints cs true
      )
    );

    match fstmt with
    | FunNative -> ()
    | FunLocalNative _ -> ()
    | FunStruct -> ()
    | FunBody s ->
      (
        let must_return =
          match norm_typ (fst fret) with
          | TNamed ("__Unit", 0) -> false
          | _ -> true in
        ignore (check_stmt
          {
            stmt_ret_type = (fst fret);
            stmt_cont_type = None;
            stmt_cont_default = None;
            stmt_must_jump = must_return;
          }
          tvc
          tenv
          s)
      )
  in

  let check_function (fdecl:fun_decl):unit =
    prog_trace "in check_function.\n";
    try do_check_function fdecl
    with err -> raise (MessageExn ("In function " ^ (string_of_var fdecl.fun_decl_name), err))
  in

  List.iter check_function p.program_fun_decls;
  prog_trace "done with check program.\n";
;;

