open Ast;;
open Big_int;;
open Num;;
open Arith;;

let prog_trace (s):unit = (
(*  if (!do_trace) then (Format.printf s; Format.print_flush ()) *)
);;

exception KindMismatch of string * kind * kind;;
exception TypeMismatch of string * typ * typ;;
exception TypeError of string;;
exception OverloadFailure of (string * exn) list;;
exception ConstraintError of (bool_arith list * bool_arith);;

type global_type =
| GT_Abbrev of typ * kind
| GT_Named of tparam list option * named_type * kind2
;;

type var_to_global_type = global_type VarMap.t;;
type var_to_typ = typ VarMap.t;;
type var_to_kind = kind VarMap.t;;
type var_to_latyp = var VarMap.t * (linearity * bool * typ) VarMap.t;;
type var_set = VarSet.t;;

type tv_context =
{
  tvc_vars:var_to_kind;
  tvc_known:bool_arith list;
};;

let env_domain (env:'a VarMap.t):var_set =
  VarMap.fold (fun x _ s -> VarSet.add x s) env VarSet.empty
;;

let env_add_replace = VarMap.add;;

let env_add (x:var) (y:'a) (env:'a VarMap.t):'a VarMap.t =
  if VarMap.mem x env then raise (TypeError ((string_of_var x) ^ " already declared")) else
  VarMap.add x y env
;;

let env_add_list_replace (x_y_list:(var * 'a) list) (env:'a VarMap.t):'a VarMap.t =
  List.fold_left (fun env (x, y) -> env_add_replace x y env) env x_y_list
;;

let env_add_list (x_y_list:(var * 'a) list) (env:'a VarMap.t):'a VarMap.t =
  List.fold_left (fun env (x, y) -> env_add x y env) env x_y_list
;;

let get_var (env:'a VarMap.t) (x:var):'a =
  try (VarMap.find x env)
  with Not_found -> raise (TypeError ("variable " ^ (string_of_var x) ^ " not in scope"))
;;

let gtv_env_add = env_add;;
let gt_env_add = env_add;;
let tv_env_add = env_add;;
let tv_env_add_list = env_add_list;;

let add_var_prefix (prefix:string) (x:var):var =
  let (s, i) = x in (prefix ^ s, i)
;;

let struct_fun (x:var):var = add_var_prefix "__struct_" x;;

let vars_equal (x1:var) (x2:var):bool = (compare x1 x2) = 0;;

let get_gt_named (gtv_env:var_to_global_type) (x:var):
    (tparam list option * named_type * kind2) =
  try
    match (VarMap.find x gtv_env) with
    | GT_Named (tparams_opt, nt, k) -> (tparams_opt, nt, k)
    | _ -> raise (InternalError ("get_gt_named: " ^ (string_of_var x)))
  with Not_found -> raise (InternalError ("get_gt_named: " ^ (string_of_var x)))
;;

let make_submap (tparams:tparam list) (targs:typ list):var_to_typ =
  List.fold_left2
    (fun submap (x, _) t -> env_add x t submap)
    VarMap.empty
    tparams
    targs
;;

let rec unify_kinds (k1:kind) (k2:kind):unit =
  match (k1, k2) with
  | (K2 k1, K2 k2) -> unify_kinds2 k1 k2
  | (KInt, KInt) -> ()
  | (KBool, KBool) -> ()
  | _ -> raise (KindMismatch ("", k1, k2))
and unify_kinds2 (k1:kind2) (k2:kind2):unit =
  match (k1, k2) with
  | (KType (size1, lin1), KType (size2, lin2)) ->
      if size1 = size2 && lin1 = lin2 then ()
      else raise (KindMismatch ("", K2 k1, K2 k2))
  | (KArrow (k1params, k1ret), KArrow (k2params, k2ret))
        when (List.length k1params) = (List.length k2params) ->
      List.iter2 unify_kinds k1params k2params;
      unify_kinds2 k1ret k2ret
  | _ -> raise (KindMismatch ("", K2 k1, K2 k2))
;;

let unify_kinds_msg (k1:kind) (k2:kind) (msg:string):unit =
  try unify_kinds k1 k2 with 
    KindMismatch _ -> raise (KindMismatch (msg, k1, k2))
;;

let unify_kinds (k1:kind) (k2:kind):unit = unify_kinds_msg k1 k2 "";;

(*
 * Check that a type is well formed, and return its kind.
 *
 * If allow_inference is false, then disallow any (TInfer (ref (TUnknown _)))
 * and disallow non-integer numbers in IInfer _.
 *)
let rec check_typ
    (gtv_env:var_to_global_type)
    (tv_env:var_to_kind)
    (allow_inference:bool)
    (t:typ):kind =
  let rec check (t:typ):kind =
    try
    (
      match t with
      | TArrow (args, ret, limit) ->
        (
          List.iter
            (fun t ->
              match check t with
              | K2 (KType _) -> ()
              | _ -> raise (TypeError "argument type must have kind 'type'"))
            args;
          (match check ret with
          | K2 (KType _) -> ()
          | _ -> raise (TypeError "return type must have kind 'type'"));
          (match limit with
          | Unlimited -> (K2 (KType (wordsize, Nonlinear)))
          | LimitAny -> (K2 (KType (wordsize, Nonlinear)))
          | Limited t ->
            (
              check_int_arith t;
              (K2 (KType (0, Nonlinear)))
            ))

        )
      | TNamed x ->
        (
          let (tparams_opt, nt, k) = get_gt_named gtv_env x in
          match tparams_opt with
          | None -> K2 k
          | Some tparams ->
            (
              let kparams = List.map snd tparams in
              K2 (KArrow (kparams, k))
            )
        )
      | TVar x ->
        (
          if VarMap.mem x tv_env then
            VarMap.find x tv_env
          else if VarMap.mem x gtv_env then
          (
            match (VarMap.find x gtv_env) with
            | GT_Abbrev (t, k) -> k
            | GT_Named (None, nt, k) -> K2 k
            | GT_Named (Some tparams, nt, k) ->
                let kparams = List.map snd tparams in
                K2 (KArrow (kparams, k))
          )
          else raise (TypeError ("variable " ^ (string_of_var x) ^ " not in scope"))
        )
      | TApp (tfun, targs) ->
        (
          let kfun = check tfun in
          let kargs = List.map check targs in
          match kfun with
          | K2 (KArrow(kparams, kret)) ->
            (
              if (List.length kparams) != (List.length kargs) then raise (TypeError "incorrect number of type arguments") else
              List.iter2 unify_kinds kparams kargs;
              K2 kret
            )
          | _ -> raise (TypeError "this type cannot be applied to arguments")
        )
      | TInt t -> check_int_arith t; KInt
      | TBool t -> check_bool_arith t; KBool
      | TRecord (lin, fields) ->
        (
          let (size, names) = List.fold_left
            (fun (size, names) (field_name, tf) ->
              if VarSet.mem (field_name, 0) names then raise (TypeError ("field " ^ field_name ^ " declared twice")) else
              match check tf with
              | K2 (KType (sizef, linf)) ->
                (
                  match (lin, linf) with
                  | (Nonlinear, Linear) -> raise (TypeError "nonlinear record cannot contain a linear field")
                  | _ -> (size + sizef, VarSet.add (field_name, 0) names)
                )
              | _ -> raise (TypeError "field of a record must have type of kind 'type'"))
            (0, VarSet.empty)
            fields
          in K2 (KType (size, lin))
        )
      | TExists (tparams, wher, t) ->
        (
          let tv_env = env_add_list_replace tparams tv_env in
          ignore (check_typ gtv_env tv_env allow_inference (TBool wher));
          let k = check_typ gtv_env tv_env allow_inference t in
          match k with
          | K2 (KType _) -> k
          | _ -> raise (TypeError "existential type must have kind 'type'")
        )
      | TAll (tparams, wher, t) ->
        (
          let tv_env = env_add_list_replace tparams tv_env in
          ignore (check_typ gtv_env tv_env allow_inference (TBool wher));
          let k = check_typ gtv_env tv_env allow_inference t in
          match k with
          | K2 (KType _) -> k
          | _ -> raise (TypeError "polymorphic type must have kind 'type'")
        )
      | TFun (tparams, t) ->
        (
          let tv_env = env_add_list_replace tparams tv_env in
          let k = check_typ gtv_env tv_env allow_inference t in
          match k with
          | K2 k2 -> K2 (KArrow (List.map snd tparams, k2))
          | KInt -> raise (TypeError "type functions cannot return integers")
          | KBool -> raise (TypeError "type functions cannot return booleans")
        )
      | TInfer ti ->
        (
          match !ti with
          | TKnown t -> check t
          | TUnknown (x, k, _) -> if allow_inference then k else raise (TypeError ("Cannot infer type ?" ^ (string_of_var x)))
        )
    ) with err -> raise (TypeExn (t, err))
  and check_int_arith (t:int_arith):unit =
    try
    (
(* the check_var function *)
      let check_var x _ = unify_kinds (get_var tv_env x) KInt in
(* the check_ivar function *)
      let check_ivar _ (_, ti) =
      (
        match !ti with
        | TKnown t -> unify_kinds (check t) KInt
        | TUnknown (x, k, _) ->
            if allow_inference then unify_kinds k KInt
            else raise (TypeError ("Cannot infer type ?" ^ (string_of_var x)))
      ) in
(* the body of the check_int_arith function *)
      match t with
      | IArith (i0, vars) -> VarMap.iter check_var vars
      | IInfer (i0, vars, ivars) ->
        (
          VarMap.iter check_var vars;
          VarMap.iter check_ivar ivars
        )
    ) with err -> raise (TypeExn (TInt t, err))
  and check_bool_arith (t:bool_arith):unit =
    try
    (
      match t with
      | BVar x -> unify_kinds (get_var tv_env x) KBool
      | BConst _ -> ()
      | BNot t -> check_bool_arith t
      | BBinary (_, t1, t2) -> check_bool_arith t1; check_bool_arith t2
      | BCompare (_, t1, t2) -> check_int_arith t1; check_int_arith t2
      | BInfer ti ->
        (
          match !ti with
          | TKnown t -> unify_kinds (check t) KBool
          | TUnknown (x, k, _) ->
              if allow_inference then unify_kinds k KBool
              else raise (TypeError ("Cannot infer type ?" ^ (string_of_var x)))
        )
    ) with err -> raise (TypeExn (TBool t, err))
  in check t
;;

(****************************************************************************
 * make_type_utils constructs an object (implemented as a record of closures,
 * for simplicity) with methods for type substitution, normalization,
 * and conversion.
 *)

type typ_unrolled =
| TU_Typ of typ
| TU_Struct of (linearity * field list)
;;

type type_utils =
{
  tu_gtv_env: var_to_global_type;
  subst_typ: var_to_typ -> typ -> typ;
  subst_int_arith: var_to_typ -> int_arith -> int_arith;
  subst_bool_arith: var_to_typ -> bool_arith -> bool_arith;
  norm_typ: typ -> typ;
  norm_int_arith: int_arith -> int_arith;
  norm_bool_arith: bool_arith -> bool_arith;
  iarith_typ: typ -> int_arith;
  barith_typ: typ -> bool_arith;
  unroll_typ: typ -> typ_unrolled;
  get_struct_or_record: typ -> (linearity * field list);
  make_linear_arith: int_arith -> linear_arith;
  make_formula: bool_arith -> formula;
}
;;

let make_type_utils (gtv_env:var_to_global_type):type_utils =

  (* apply substitution submap=[x1->t1,...,xn->tn] to t *)
  let rec subst_typ (submap:var_to_typ) (t:typ):typ =
  (
    let rec sub (t:typ):typ =
    (
      match t with
      | TArrow (args, ret, Unlimited) -> TArrow (List.map sub args, sub ret, Unlimited)
      | TArrow (args, ret, LimitAny) -> TArrow (List.map sub args, sub ret, LimitAny)
      | TArrow (args, ret, Limited tlimit) -> TArrow (
          List.map sub args,
          sub ret,
          Limited (subst_int_arith submap tlimit))
      | TNamed _ -> t
      | TVar x ->
        (
          if (VarMap.mem x submap) then VarMap.find x submap
          else t
        )
      | TApp (tfun, targs) -> TApp (sub tfun, List.map sub targs)
      | TInt t -> TInt (subst_int_arith submap t)
      | TBool t -> TBool (subst_bool_arith submap t)
      | TRecord (lin, fields) ->
        (
          let fields' = List.map (fun (x, t) -> (x, sub t)) fields in
          TRecord(lin, fields')
        )
      | TExists (tparams, wher, t) ->
        (
          let (tparams, submap) = rename_params tparams submap in
          TExists (tparams, subst_bool_arith submap wher, subst_typ submap t)
        )
      | TAll (tparams, wher, t) ->
        (
          let (tparams, submap) = rename_params tparams submap in
          TAll (tparams, subst_bool_arith submap wher, subst_typ submap t)
        )
      | TFun (tparams, t) ->
        (
          let (tparams, submap) = rename_params tparams submap in
          TFun (tparams, subst_typ submap t)
        )
      | TInfer ti ->
        (
          match !ti with
          | TKnown t -> (sub t)
          | TUnknown (_, _, tv_env) ->
            (
              if not (VarSet.is_empty (VarSet.inter (env_domain submap) (env_domain tv_env))) then raise (InternalError "TUnknown in subst_typ") else
              t
            )
        )
    )
    and rename_params old_tparams submap =
    (
      let new_tparams = List.map (fun (x, k) -> (new_var x, k)) old_tparams in
      let submap = List.fold_left2
        (fun submap (old_x, _) (new_x, _) -> VarMap.add old_x (TVar new_x) submap)
          submap
          old_tparams
          new_tparams in
      (new_tparams, submap)
    ) in
    sub t
  )

  and subst_int_arith (submap:var_to_typ) (t:int_arith):int_arith =
  (
    let add_vars init vars fmult fadd =
    (
      VarMap.fold
        (fun x i t ->
          if VarMap.mem x submap then add_int_arith
            (fmult i (iarith_typ (VarMap.find x submap)))
            t
          else fadd i x t)
        vars
        init
    ) in
    match t with
    | IArith (i0, vars) ->
        add_vars (iarith_const i0) vars mult_int_arith add_int_arith_var
    | IInfer (i0, vars, ivars) -> 
      (
        VarMap.fold
          (fun x (i, ti) t ->
            add_int_arith
              (mult_num_arith i (iarith_typ (subst_typ submap (TInfer ti))))
              t)
          ivars
          (add_vars (iarith_nconst i0) vars mult_num_arith add_num_arith_var)
      )
  )

  and subst_bool_arith (submap:var_to_typ) (t:bool_arith):bool_arith =
  (
    match t with
    | BVar x ->
      (
        if (VarMap.mem x submap) then barith_typ (VarMap.find x submap)
        else t
      )
    | BConst _ -> t
    | BNot t -> BNot (subst_bool_arith submap t)
    | BBinary (op, t1, t2) ->
        BBinary (op, subst_bool_arith submap t1, subst_bool_arith submap t2)
    | BCompare (op, t1, t2) ->
        BCompare (op, subst_int_arith submap t1, subst_int_arith submap t2)
    | BInfer ti -> barith_typ (subst_typ submap (TInfer ti))
  )

  (* Simplify expressions of the form
   *   TVar x
   *   TApp (TVar x, ...)
   *   TApp (TFun ..., ...)
   *   TInfer (ref (TKnown ...))
   * where "x" is an abbreviation or named type.
   *
   * norm_typ is as lazy as possible -- it does not recurse except
   * on the following types:
   *   TApp (t, ...) -- recurse on t before normalizing the TApp
   * If normalization changes t1 to t2, then it also normalizes t2.
   * Thus, norm_typ returns a typ that is normalized only at the
   * very root of the tree ("weak head normal form").
   * It is safe to pattern match on this
   * root, but it is not safe to do any deeper pattern matching beyond
   * the root of the tree.
   *)
  and norm_typ (t:typ):typ =
  (
    let t =
    (
      match t with
      | TVar x ->
        (
          if VarMap.mem x gtv_env then
          (
            match VarMap.find x gtv_env with
            | GT_Named _ -> TNamed x
            | GT_Abbrev (t, _) -> norm_typ t
          )
          else t
        )
      | TApp (tfun, targs) ->
        (
          let tfun = norm_typ tfun in
          match tfun with
          | TFun (tparams, tbody) ->
            (
              let submap = List.fold_left2
                (fun submap (xparam, _) targ -> env_add xparam targ submap)
                VarMap.empty
                tparams
                targs in
              norm_typ (subst_typ submap tbody)
            )
          | _ -> TApp (tfun, targs)
        )
      | TInfer ti ->
        (
          match !ti with
          | TKnown t -> norm_typ t
          | TUnknown _ -> t
        )
      | _ -> t
    ) in
    (* To make it easier to pattern match t=TExist... without
     * causing name conflicts, rename bound variables.
     * XXX: this doesn't look very efficient
     *)
    subst_typ VarMap.empty t
  )

  (* If no unification variables remain, convert an IInfer into an IArith *)
  and norm_int_arith (t:int_arith):int_arith =
  (
    match t with
    | IArith (i0, vars) -> t
    | IInfer (i0, vars, ivars) ->
      (
        let t = VarMap.fold
          (fun x (i,ti) t ->
            match !ti with
            | TKnown tknown ->
                add_int_arith (mult_num_arith i (norm_int_arith (iarith_typ tknown))) t
            | TUnknown _ -> add_num_arith_ivar i x ti t)
          ivars
          (IInfer (i0, vars, VarMap.empty)) in
        match t with
        | IArith (i0, vars) -> t
        | IInfer (i0, vars, ivars) ->
          (
            let has_ivars = VarMap.fold (fun _ _ _ -> true) ivars false in
            if has_ivars then t else
            (* if no unification variables remain, then convert to an IArith *)
            if not (is_integer_num i0) then raise (TypeError "attemped to infer an integer type, but found a fractional type instead") else
            let vars = VarMap.map
              (fun i ->
                if not (is_integer_num i) then raise (TypeError "attemped to infer an integer type, but found a fractional type instead") else
                big_int_of_num i)
              vars in
            IArith (big_int_of_num i0, vars)
          )
      )
  )

  (* Convert a (BInfer (ref (TKnown t))) into t *)
  and norm_bool_arith (t:bool_arith):bool_arith =
  (
    match t with
    | BInfer ti ->
      (
        match !ti with
        | TKnown t -> barith_typ t
        | TUnknown _ -> t
      )
    | _ -> t
  )

  (* Convert a typ (of kind int) to an int_arith *)
  and iarith_typ (t:typ):int_arith =
  (
    match norm_typ t with
    | TVar x -> iarith_var x
    | TInt t -> t
    | TInfer ti ->
      (
        match !ti with
        | TKnown t -> iarith_typ t
        | TUnknown (x, _, _) -> 
            IInfer (Int 0, VarMap.empty, VarMap.add x (Int 1, ti) VarMap.empty)
      )
    | _ -> raise (InternalError "iarith_typ")
  )

  (* Convert a typ (of kind bool) to a bool_arith *)
  and barith_typ (t:typ):bool_arith =
  (
    match norm_typ t with
    | TVar x -> BVar x
    | TBool t -> t
    | TInfer ti -> BInfer ti
    | _ -> raise (InternalError "barith_typ")
  )

  and unroll_typ (t:typ):typ_unrolled =
  (
    let t = norm_typ t in
    match t with
    | TApp (TNamed x, targs) ->
      (
        let (tparams_opt, nt, k) = get_gt_named gtv_env x in
        match (tparams_opt, nt) with
        | (Some tparams, TStruct (lin, fields)) ->
            let submap = make_submap tparams targs in
            TU_Struct (lin,
              List.map (fun (name, t) -> (name, subst_typ submap t)) fields)
        | _ -> TU_Typ t
      )
    | TNamed x ->
      (
        let (tparams_opt, nt, k) = get_gt_named gtv_env x in
        match (tparams_opt, nt) with
        | (None, TStruct (lin, fields)) -> TU_Struct (lin, fields)
        | _ -> TU_Typ t
      )
    | _ -> TU_Typ t
  )

  and get_struct_or_record (t:typ):
      (linearity * field list) =
  (
    match unroll_typ t with
    | TU_Typ (TRecord (lin, fields)) -> (lin, fields)
    | TU_Struct (lin, fields) -> (lin, fields)
    | _ -> raise (TypeExn (t, (TypeError "record or struct expected")))
  )

(*
  and make_linear_arith (t:int_arith):linear_arith =
  (
    match norm_int_arith t with
    | IArith (i0, vars) -> (VarMap.map int_of_big_int vars, int_of_big_int i0)
    | IInfer _ -> raise (NotFormula (TInt t))
  )
*)

  and (*new_*)make_linear_arith (t:(*new_*)int_arith):(*new_*)linear_arith =
  (
    match norm_int_arith t with
    | IArith (i0, vars) -> (vars, i0)
    | IInfer _ -> raise (NotFormula (TInt t))
  )

(*
  and make_formula (t:bool_arith):formula =
  (
    match norm_bool_arith t with
    | BVar x -> 
        (* Use an integer equation "x==0" to represent a boolean variable x *)
        linear_eq (VarMap.add x 1 VarMap.empty, 0) (VarMap.empty, 0)
    | BConst true -> FAnd []
    | BConst false -> FOr []
    | BNot t -> FNot (make_formula t)
    | BBinary (BAndOp, t1, t2) ->
        FAnd [make_formula t1; make_formula t2]
    | BBinary (BOrOp, t1, t2) ->
        FOr [make_formula t1; make_formula t2]
    | BCompare (op, t1, t2) ->
      (
        let (t1, t2) = (make_linear_arith t1, make_linear_arith t2) in
        match op with
        | BEqOp -> linear_eq t1 t2
        | BNeOp -> linear_ne t1 t2
        | BLtOp -> linear_lt t1 t2
        | BGtOp -> linear_gt t1 t2
        | BLeOp -> linear_le t1 t2
        | BGeOp -> linear_ge t1 t2
      )
    | BInfer _ -> raise (NotFormula (TBool t))
  )
*)

  and (*new_*)make_formula (t:bool_arith):(*new_*)formula =
  (
    match norm_bool_arith t with
    | BVar x -> 
        (* Use an integer equation "x==0" to represent a boolean variable x *)
        linear_eq (VarMap.add x unit_big_int VarMap.empty, zero_big_int) (VarMap.empty, zero_big_int)
    | BConst true -> FAnd []
    | BConst false -> FOr []
    | BNot t -> FNot (make_formula t)
    | BBinary (BAndOp, t1, t2) ->
        FAnd [make_formula t1; make_formula t2]
    | BBinary (BOrOp, t1, t2) ->
        FOr [make_formula t1; make_formula t2]
    | BCompare (op, t1, t2) ->
      (
        let (t1, t2) = (make_linear_arith t1, make_linear_arith t2) in
        match op with
        | BEqOp -> linear_eq t1 t2
        | BNeOp -> linear_ne t1 t2
        | BLtOp -> linear_lt t1 t2
        | BGtOp -> linear_gt t1 t2
        | BLeOp -> linear_le t1 t2
        | BGeOp -> linear_ge t1 t2
      )
    | BInfer _ -> raise (NotFormula (TBool t))
  )

in
{
  tu_gtv_env = gtv_env;
  subst_typ = subst_typ;
  subst_int_arith = subst_int_arith;
  subst_bool_arith = subst_bool_arith;
  norm_typ = norm_typ;
  norm_int_arith = norm_int_arith;
  norm_bool_arith = norm_bool_arith;
  iarith_typ = iarith_typ;
  barith_typ = barith_typ;
  unroll_typ = unroll_typ;
  get_struct_or_record = get_struct_or_record;
  make_linear_arith = make_linear_arith;
  make_formula = make_formula;
}
;;

(****************************************************************************
 * The unification algorithm keeps a list of constraints, and repeatedly
 *   - removes a constraint from the list
 *   - attempts to satisfy the constraint, which either
 *       - fails, meaning the whole algorithm fails
 *       - succeeds, possibly adding new constraints to the list
 * The algorithm terminates successfully when the list becomes empty.
 *
 * Our unification algorithm is not complete, because some constraints
 * are difficult to handle.  In particular, the following constraints are
 * difficult:
 *   - (alpha[t1]) = t2, where alpha is a unification variable, and t2
 *     is not a unification variable (we can't start unifying this until
 *     we know what function alpha is)
 *   - b1 = t2, where b1 is a bool_arith containing unification variables,
 *     and t2 is not a unification variable (this would require the
 *     full Presberger arithmetic, with all the unification variables
 *     existentially quantified)
 * The algorithm does not process difficult constraints.  If the list
 * contains only difficult constraints, the algorithm fails.  Otherwise,
 * the algorithm selects an easy constraint from the list and processes
 * that first; hopefully, the processing of the easy constraints will
 * resolve enough unification variables to make the difficult constraints
 * easy, so that unification will eventually succeed.
 *)

(* To print better error messages, we track the original constraint
 * that generated the current constraints we're working on.
 *)
type constraint_context = typ * typ;;

(* A constraint contains two types that should be unified. *)
type easy_constraint = typ * typ * constraint_context;;

(* Wait until all the unification variables in the list have been
 * resolved, then try to satisfy the constraint.
 *)
type hard_constraint = (typ_infer ref list) * easy_constraint;;

type constraint_set =
{
  mutable cs_easy:easy_constraint list;
  mutable cs_hard:hard_constraint list;
  cs_tu:type_utils;
  cs_tvc:tv_context;
};;

let new_constraint_set
  (easy:easy_constraint list)
  (tu:type_utils)
  (tvc:tv_context):constraint_set =
{
  cs_easy = easy;
  cs_hard = [];
  cs_tu = tu;
  cs_tvc = tvc;
}

(* Before calling add_constraint, call know_typ and norm_typ on t1 and t2 *)
let add_constraint
    (constraints:constraint_set)
    (t1:typ)
    (t2:typ):unit =
  constraints.cs_easy <- (t1, t2, (t1, t2))::constraints.cs_easy
;;

(*
 * If must_finish is false, then unify_constraints will return when
 * there are no remaining easy contraints, even if hard contraints
 * remain.  If must_finish is true, then unify_constraints will
 * raise an exception if there are hard constraints but no easy
 * constraints.
 *)
let unify_constraints
    (cs:constraint_set)
    (must_finish:bool):unit =
  prog_trace "in unify_constraints.\n";
  let {
    subst_typ = subst_typ;
    subst_int_arith = subst_int_arith;
    subst_bool_arith = subst_bool_arith;
    norm_typ = norm_typ;
    norm_int_arith = norm_int_arith;
    norm_bool_arith = norm_bool_arith;
    barith_typ = barith_typ;
    make_formula = make_formula;
  } = cs.cs_tu in

  (* See if t is true. *)
  let rec check_arith_constraint (t:bool_arith):unit =
    prog_trace "\tin check_arith_consraint\n";
  (
    match t with
    | BBinary(BAndOp, t1, t2) ->
      (
        check_arith_constraint t1;
        check_arith_constraint t2
      )
    | _ ->
      (
        let f = make_formula t in

        if(!do_trace) then
        (
          Format.printf "\nknown: ";
          List.iter (fun t -> Format.printf "("; print_bool_arith VarMap.empty 0 t; Format.printf ") ") cs.cs_tvc.tvc_known;
          Format.printf "\nmust prove: ";
          print_bool_arith VarMap.empty 0 t;
          Format.print_newline ()
        );

        (* For efficiency, rewrite formula p1/\..../\pn==>goal
           as (p11/\.../\p1i==>goal) \/ !(p21/\.../\p2j) \/ ... \/ (!pm1/\.../\pmk)
           where no two groups in the disjunction share any common variables
         *)
        let fConnected = Arith.connected_formulas
          (f::(Arith.flatten_conjuncts
                (List.map make_formula cs.cs_tvc.tvc_known))) in
        let fGoal = (List.hd (List.hd fConnected)) in
        let fKnown = (List.tl (List.hd fConnected)) in
        let fRest = List.tl fConnected in
        (* If any formula in fAll is true, we succeed. *)
        let fAll =
            (FImplies (FAnd fKnown, fGoal))
          ::(List.map (fun fs -> FNot (FAnd fs)) fRest) in
        prog_trace "TWO\n ";  
(*
if (List.length fAll) > 7 then
(
 Format.printf "\n%d %d\n" (List.length cs.cs_tvc.tvc_known) (List.length fAll);
 Arith.print_formula 0 f; Format.print_newline ();
 List.iter (Format.printf " // "; fun f -> Arith.print_formula 0 f; Format.print_newline ()) (List.map make_formula cs.cs_tvc.tvc_known); Format.print_newline ();
 List.iter (Format.printf " ** "; fun f -> Arith.print_formula 0 f; Format.print_newline ()) fAll; Format.print_newline ()
);
*)

        let rec check_all (fs:formula list) =
	prog_trace "\tin check_all\n";
        (
          match fs with
            []   -> raise (ConstraintError (cs.cs_tvc.tvc_known, t))
          | h::t -> (prog_trace "going thru formula list\n";
                     let converted = convert_formula h in
                     prog_trace "done converting formula\n";
                     (* print_cformula 3 converted;
                     Format.print_flush (); *)
                     if not (Presburger.presburger (converted)) 
                       then (check_all t))
        ) in
        check_all fAll

(*
        let cf:cformula =
          convert_formula
            (FImplies (
              FAnd (List.map make_formula cs.cs_tvc.tvc_known), f))
        in

        if Presburger.presburger cf then
        (
	 prog_trace "==> Satisfiable\n";
        )
        else raise (ConstraintError (cs.cs_tvc.tvc_known, t))
*)
    )
  ) in

  (* Check that t1 == t2 *)
  let check_int_equiv (t1:int_arith) (t2:int_arith):unit =
  prog_trace "\t in check_int_equiv\n";
  (
    check_arith_constraint (BCompare (BEqOp, t1, t2))
  ) in

  (* Check that t1 <=> t2 *)
  let check_bool_equiv (t1:bool_arith) (t2:bool_arith):unit =
      prog_trace "\t in check_bool_equiv\n"; 
  (
    match (norm_bool_arith t1, norm_bool_arith t2) with
    | (BConst true, BConst true) -> ()
    | (BConst true, t2) -> check_arith_constraint t2
    | (t1, BConst true) -> check_arith_constraint t1
    | (t1, t2) ->
        check_arith_constraint
          (BBinary (BOrOp,
            (BBinary (BAndOp, t1, t2)),
            (BBinary (BAndOp, BNot t1, BNot t2))))
  ) in

  let subst_unknown (ti:typ_infer ref) (t:typ):unit =
  prog_trace "\t in subst_unknown\n"; 
  (
    (* Set ti to t. *)
    match (!ti, t) with
    | (_, TInfer ti2) when ti2 == ti ->
        (* If t = ti, then do nothing (setting ti to ti is a no-op) *)
        ()
    | (TUnknown (x, k, tv_env), _) ->
      (
        (*
         * Check that t does not contain ti (occurs-check).  We use
         * a cheap trick: temporarily replace ti with a fresh variable
         * and check it; if it checks, then permanently replace ti with t.
         *)
        let gtv_env = cs.cs_tu.tu_gtv_env in
        ti := TKnown (TVar (new_var x));
        (try ignore (check_typ gtv_env tv_env true t) with TypeError _ ->
          raise (TypeMismatch ("", TInfer ti, t)));
        (* occurs-check succeeded; substitute t for ti *)
        ti := TKnown t;
        unify_kinds k (check_typ gtv_env tv_env true t)
      )
    | _ -> raise (InternalError "subst_unknown")
  ) in

  let unify_int_arith (t1:int_arith) (t2:int_arith):unit =
  prog_trace "\t in unify_int_arith\n"; 
  (
    (* Try to solve the equation t1 - t2 == 0.
     *)
    let t = norm_int_arith (add_int_arith
      t1
      (mult_int_arith (minus_big_int unit_big_int) t2)) in
    match t with
    | IArith (i0, vars) ->
      (
        (* no unification variables, so pass it to the arithmetic solver *)
        try
          check_int_equiv t iarith_zero
        with err -> raise (NestedExn (TypeMismatch ("", TInt t1, TInt t2), err))
      )
    | IInfer (i0, vars, ivars) ->
      (
        (* pick a unification variable and eliminate it (Gaussian
         * elimination)
         *)
        let ivar_opt = VarMap.fold
          (fun x ivar ivar_opt -> Some (x, ivar))
          ivars
          None in
        match ivar_opt with
        | None -> raise (InternalError "unify_int_arith")
        | Some (x, (i, ti)) ->
          (
            (* we have t == i*x + trest == 0
             * we have x == (-1/i) * trest
             *)
            subst_unknown
              ti
              (TInt (norm_int_arith (mult_num_arith
                ((num_of_int (-1)) // i)
                (IInfer (i0, vars, VarMap.remove x ivars)))))
          )
      )
  ) in

  let unify_bool_arith
      (add_hard:typ_infer ref list->typ->typ->unit)
      (t1:bool_arith)
      (t2:bool_arith):unit =
  prog_trace "\t in unify_bool_arith\n"; 
  (
    (* We don't attempt any serious unification here.  If t1 or t2
     * contain any unification variables, we postpone the unification.
     * Otherwise, we pass the formulas to the arithmetic constraint
     * checker.
     *)

    let get_int_unknowns (t:int_arith) =
  prog_trace "\t in get_int_unknowns\n"; 
    (
      match norm_int_arith t with
      | IArith _ -> []
      | IInfer (_, _, ivars) ->
          VarMap.fold
            (fun _ (_, ti) lis -> ti::lis)
            ivars
            []
    ) in

    let rec get_bool_unknowns (t:bool_arith) =
  prog_trace "\t in get_bool_unknowns\n"; 
    (
      match norm_bool_arith t with
      | BVar _ -> []
      | BConst _ -> []
      | BNot t -> get_bool_unknowns t
      | BBinary (op, t1, t2) -> (get_bool_unknowns t1) @ (get_bool_unknowns t2)
      | BCompare (op, t1, t2) -> (get_int_unknowns t1) @ (get_int_unknowns t2)
      | BInfer ti ->
        (
          match !ti with
          | TKnown t -> get_bool_unknowns (barith_typ t)
          | TUnknown _ -> [ti]
        )
    ) in

    let unknowns = (get_bool_unknowns t1) @ (get_bool_unknowns t2) in
    match unknowns with
    | [] ->
      (
        try
          check_bool_equiv t1 t2
        with err -> raise (NestedExn (TypeMismatch ("", TBool t1, TBool t2), err))
      )
    | _ -> add_hard unknowns (TBool t1) (TBool t2)
  ) in

  let unify_quantified add_easy add_hard t1 tparams1 wher1 tb1 t2 tparams2 wher2 tb2 =
    (* First, check that the parameter names and kinds match (the names
     * are considered part of the type)
     *)
    if (List.length tparams1) != (List.length tparams2) then raise (TypeMismatch ("", t1, t2)) else
    List.iter2
      (fun ((name1, _), k1) ((name2, _), k2) ->
        if name1 <> name2 then raise (TypeMismatch ("", t1, t2)) else
        unify_kinds k1 k2)
      tparams1
      tparams2;
    (* Now rename the parameters and unify *)
    (* XXX: we don't add wher1/wher2 to tvc; should we? *)
    let tparams = List.map (fun (x, k) -> (new_var x, k)) tparams1 in
    let targs = List.map (fun (x, k) -> TVar x) tparams in
    let submap1 = make_submap tparams1 targs in
    let submap2 = make_submap tparams2 targs in
    unify_bool_arith
      add_hard
      (subst_bool_arith submap1 wher1)
      (subst_bool_arith submap2 wher2);
    add_easy (subst_typ submap1 tb1) (subst_typ submap2 tb2)
  in

  let unify_typ
      (add_easy:typ->typ->unit)
      (add_hard:typ_infer ref list->typ->typ->unit)
      (t1:typ)
      (t2:typ):unit =
  prog_trace "\t in unify_typ\n"; 
  (
    match (norm_typ t1, norm_typ t2) with
    | (TInfer ti1, t2) ->
      (
        match !ti1 with
        | TKnown t1 -> add_easy t1 t2
        | TUnknown _ -> subst_unknown ti1 t2
      )
    | (t1, TInfer ti2) ->
      (
        match !ti2 with
        | TKnown t2 -> add_easy t1 t2
        | TUnknown _ -> subst_unknown ti2 t1
      )
    | (TBool (BInfer ti), t2) -> add_easy (TInfer ti) t2
    | (t1, TBool (BInfer ti)) -> add_easy t1 (TInfer ti)
    | (TArrow (args1, ret1, Unlimited), TArrow(args2, ret2, Unlimited)) ->
      (
        if (List.length args1) != (List.length args2) then raise (TypeMismatch ("", t1, t2)) else
        List.iter2 add_easy args1 args2;
        add_easy ret1 ret2
      )
    | (TArrow (args1, ret1, LimitAny), TArrow(args2, ret2, LimitAny)) ->
      (
        if (List.length args1) != (List.length args2) then raise (TypeMismatch ("", t1, t2)) else
        List.iter2 add_easy args1 args2;
        add_easy ret1 ret2
      )
    | (TArrow (args1, ret1, Limited tlimit1), TArrow(args2, ret2, Limited tlimit2)) ->
      (
        if (List.length args1) != (List.length args2) then raise (TypeMismatch ("", t1, t2)) else
        List.iter2 add_easy args1 args2;
        add_easy ret1 ret2;
        unify_int_arith tlimit1 tlimit2
      )
    | (TNamed x1, TNamed x2) ->
      (
        if (vars_equal x1 x2) then ()
        else raise (TypeMismatch ("", t1, t2))
      )
    | (TVar x1, TVar x2) ->
      (
        if (vars_equal x1 x2) then ()
        else if VarMap.mem x1 cs.cs_tvc.tvc_vars then
        (
          match (get_var cs.cs_tvc.tvc_vars x1) with
          | KInt -> check_int_equiv (iarith_var x1) (iarith_var x2)
          | KBool -> check_bool_equiv (BVar x1) (BVar x2)
          | _ -> raise (TypeMismatch ("", t1, t2))
        )
        else raise (TypeMismatch ("", t1, t2))
      )
    | (TApp (TInfer ti1, targs1), t2) ->
      (
        match !ti1 with
        | TKnown tfun1 -> add_easy (TApp (tfun1, targs1)) t2
        | TUnknown _ -> add_hard [ti1] t1 t2
      )
    | (t1, TApp (TInfer ti2, targs2)) ->
      (
        match !ti2 with
        | TKnown tfun2 -> add_easy t1 (TApp (tfun2, targs2))
        | TUnknown _ -> add_hard [ti2] t1 t2
      )
    | (TApp (t1a, targs1), TApp (t2a, targs2)) ->
      (
        if (List.length targs1) != (List.length targs2) then raise (TypeMismatch ("", t1, t2)) else
        add_easy t1a t2a;
        List.iter2
          (fun targ1 targ2 -> add_easy targ1 targ2)
          targs1
          targs2
      )
    | (TInt t1, TInt t2) -> unify_int_arith t1 t2
    | (TVar x1, TInt t2) -> unify_int_arith (iarith_var x1) t2
    | (TInt t1, TVar x2) -> unify_int_arith t1 (iarith_var x2)
    | (TBool t1, TBool t2) -> unify_bool_arith add_hard t1 t2
    | (TVar x1, TBool t2) -> unify_bool_arith add_hard (BVar x1) t2
    | (TBool t1, TVar x2) -> unify_bool_arith add_hard t1 (BVar x2)
    | (TRecord (lin1, fields1), TRecord (lin2, fields2)) ->
      (
        if (List.length fields1) != (List.length fields2) then raise (TypeMismatch ("", t1, t2)) else
        if lin1 <> lin2 then raise (TypeMismatch ("", t1, t2)) else
        List.iter2
          (fun (x1, tf1) (x2, tf2) ->
            if x1 <> x2 then raise (TypeMismatch ("", t1, t2)) else
            add_easy tf1 tf2)
          fields1
          fields2
      )
    | (TExists (tparams1, wher1, tb1), TExists (tparams2, wher2, tb2)) ->
        unify_quantified add_easy add_hard t1 tparams1 wher1 tb1 t2 tparams2 wher2 tb2
    | (TAll (tparams1, wher1, tb1), TAll (tparams2, wher2, tb2)) ->
        unify_quantified add_easy add_hard t1 tparams1 wher1 tb1 t2 tparams2 wher2 tb2
    | (TFun (tparams1, tb1), TFun (tparams2, tb2)) ->
      (
        (* Rename the parameters and unify *)
        let tparams = List.map (fun (x, k) -> (new_var x, k)) tparams1 in
        let targs = List.map (fun (x, k) -> TVar x) tparams in
        let submap1 = make_submap tparams1 targs in
        let submap2 = make_submap tparams2 targs in
        add_easy (subst_typ submap1 tb1) (subst_typ submap2 tb2)
      )
    | _ -> raise (TypeMismatch ("", t1, t2))
  ) in

  let rec unify_loop ():unit =
    prog_trace "\tin unify_loop\n";
  (
    match (cs.cs_easy, cs.cs_hard) with
    | ([], []) -> () (* success *)
    | ([], hard) ->
      (
        (* Look through the hard list, and find all the constraints that
         * are now easy.
         *)
        let (easy, hard) = List.fold_left
          (fun (easy, hard) (tinfers, (t1, t2, ctxt)) ->
            (* drop all resolved unification variables *)
            let tinfers = List.filter
              (fun ti -> match !ti with TKnown _ -> false | TUnknown _ -> true)
              tinfers in
            (* if there are no remaining unification variables, put the
             * constraint in the easy list
             *)
            match tinfers with
            | [] -> ((norm_typ t1,
                      norm_typ t2,
                      ctxt)::easy, hard)
            | _ -> (easy, (tinfers, (t1, t2, ctxt))::hard))
          ([], [])
          hard in
        (* If there are still no easy constraints, we fail.  Otherwise, keep
         * going.
         *)
        match easy with
        | [] ->
          (
            if must_finish then
              (* Pick a constraint and apologize for it. *)
              let (_, (_, _, (t1, t2))) = List.hd hard in
              raise (TypeMismatch ("Could not infer types here.  Try using explicit type arguments.", t1, t2))
            else
              ()
          )
        | _ -> cs.cs_easy<-easy; cs.cs_hard<-hard; unify_loop ()
      )
    | ((t1, t2, (t1root, t2root))::easy, _) ->
      (
        cs.cs_easy <- easy;
        let add_easy t1 t2 =
          cs.cs_easy <- (t1, t2, (t1root, t2root))::cs.cs_easy in
        let add_hard tinfers t1 t2 =
          cs.cs_hard <- (tinfers, (t1, t2, (t1root, t2root)))::cs.cs_hard in
        if(!do_trace) then (print_typ VarMap.empty 0 t1; Format.printf " <=> "; print_typ VarMap.empty 0 t2; Format.print_newline ());
        ( try
            unify_typ add_easy add_hard t1 t2
          with
          | e -> raise (NestedExn (TypeMismatch ("", t1root, t2root), e))
        );
        unify_loop ()
      )
  ) in

  unify_loop ()
;;

