open Ast;;
open Arith;;
open Big_int;;

(* This file contains syntax generation functions that wrap each form in the
 * AST and perform a certain number of conversions -- such as string -> var
 * -- automatically.  In general, if an AST type is named XSomething, the
 * corresponding function will be named xsomething.  The functions are all
 * curried (as opposed to the datatype constructors).
 *)

(* this can change later if need be *)
let var_of_string (x:string):var = (x, 0);;

(* A couple of functions to simplify iarith construction *)

let iarith_vars (s:string) =
  iarith_var (var_of_string s);;

let iarith_consti (i:int) =
  iarith_const (big_int_of_int i);;

let subt_int_arith (i1:int_arith) (i2:int_arith):int_arith =
  add_int_arith i1 (mult_int_arith (minus_big_int unit_big_int) i2);;

(* functions that wrap bool_arith construction *)

let bvar (s:string):bool_arith = BVar (var_of_string s);;

let bconst (b:bool):bool_arith = BConst b;;

let bnot (ba:bool_arith):bool_arith = BNot ba;;

let bbinary (op:bool_binary_op) (ba1:bool_arith) (ba2:bool_arith):bool_arith =
  BBinary (op,ba1,ba2);;

let bcompare (op:int_compare_op) (i0:int_arith) (i1:int_arith):bool_arith =
  BCompare (op,i0,i1);;

(* a couple of functions that simplify construction of common bool_ariths *)

let bandall (l:bool_arith list):bool_arith =
  match l with
      [] -> bconst true
    | (h::[]) -> h
    | l -> 
	(List.fold_left
	   (fun (soFar:bool_arith) (b:bool_arith) ->
	      BBinary (BAndOp, soFar, b))
	   (List.hd l)
	   (List.tl l));;
	

(* Type wrappers
   Types are wrapped for consistency, and to eliminate
   the occasional var_of_string.
*)

let tparam (s:string) (k:kind) = (var_of_string s, k);;

let tarrow (l:typ list) (t:typ) (lim:limit):typ =
  TArrow (l,t,lim);;

let tnamed (s:string):typ =
  TNamed (var_of_string s);;

let tvar (s:string):typ =
  TVar (var_of_string s);;

let tvar_v (v:var):typ =
  TVar v;;

let tapp (t:typ) (l:typ list):typ =
  TApp (t,l);;

let tint (i:int_arith):typ =
  TInt i;;

let tint_const (i:int):typ =
  TInt (iarith_consti i);;

let tint_add (i:int_arith) (j:int_arith):typ =
  TInt (add_int_arith i j);;

let tbool (b:bool_arith):typ =
  TBool b;;

let tbool_const (b:bool):typ =
  TBool (BConst b);;

let trecord (lin:linearity) (fields:field list):typ =
  TRecord (lin,fields);;

let texists (params:tparam list) (ba:bool_arith) (t:typ):typ =
  TExists (params,ba,t);;

let tall (params:tparam list) (ba:bool_arith) (t:typ):typ =
  TAll (params,ba,t);;

let tfun (params:tparam list) (t:typ):typ =
  TFun (params,t);;

(* Expression wrappers *)

let param (s:string) (t:typ):param =
  (ref (var_of_string s), t);;

let make_e (e:exp_raw):exp =
  {
    exp_raw = e;
    exp_pos = None;
    exp_typ = None;
  }
;;

let evar (s:string):exp = 
  make_e (EVar (ref (var_of_string s)));;

let evar_v (v:var):exp =
  make_e (EVar (ref v));;

let ebool (b:bool):exp =
  make_e (EBool b);;

let eint (i:int):exp =
  make_e (EInt (big_int_of_int i));;

let eint_bi (i:big_int):exp =
  make_e (EInt i);;

let eunit:exp = make_e EUnit;;

let ecall (f:exp) (params:exp list):exp =
  make_e (ECall (f, params));;

let eoverload (l:string list):exp =  (* manually constructs expression because typechecker doesn't like overloads without positions *)
  {
    exp_raw = (EOverload (l, ref None));
    exp_pos = Some 100;
    exp_typ = None;
  }

let estruct (name:string):exp =
  make_e (EStruct (var_of_string name));;

let estruct_v (name:var):exp =
  make_e (EStruct name);;

let eassign (op:assign_op) (op1:exp) (op2:exp):exp =
  make_e (EAssign (op, op1, op2));;

let erecord (lin:linearity) (fields:(string * typ option * exp) list):exp =
  make_e (ERecord (lin, fields));;

let emember (s:exp) (mem:string):exp =
  make_e (EMember (s,mem));;

let etapp (e:exp) (l:(string * typ) list):exp =
  make_e (ETApp (e,l));;

let epack (e:exp) (t:typ) (l:(string * typ) list):exp =
  make_e (EPack (e,t,l));;

(* Statement wrappers *)

let make_s (s:stmt_raw):stmt =
  {
    stmt_raw = s;
    stmt_pos = None;
  }
;;

let sblock (l:stmt list):stmt =
  make_s (SBlock l);;

let make_decl (s:string) (t:typ option) (e:exp)=
  ((ref (var_of_string s), t), e);;

let make_mdecl (l:(string * typ option) list) (e:exp) =
  ((List.map (fun (name, t) -> (ref (var_of_string name), t)) l), e);;

let sdecl (u:unpack_spec option) (d:decl) (s:stmt):stmt =
  make_s (SDecl (u,d,s));;

let smdecl (u:unpack_spec option) (d:mdecl) (s:stmt):stmt =
  make_s (SMDecl (u,d,s));;

let sreturn (e:exp):stmt =
  make_s (SReturn e);;

let swhile (cond:exp) (body:stmt) (rest:stmt):stmt =
  make_s (SWhile (cond,body,rest));;

let sfor (tp:tparam list) (ba:bool_arith) (init:(param * exp) list) (cond:exp)  
  (defaults:exp list option) (body:stmt) (rest:stmt):stmt =
  make_s (SFor (tp,ba,init,cond,defaults,body,rest));;

let scontinue (l:exp list):stmt =
  make_s (SContinue l);;

let sifelse (cond:exp) (cons:stmt) (alt:stmt):stmt =
  make_s (SIfElse (cond,cons,alt));;

let sboolcase (ba:bool_arith) (cons:stmt) (alt:stmt):stmt =
  make_s (SBoolCase (ba,cons,alt));;

let sexp (e:exp):stmt =
  make_s (SExp e);;

(* Conversion functions *)

let tvars_of_params (tp:tparam list) =
  List.map (fun (name, _) -> TVar name) tp;;

let tvars_of_fields (fl:field list) =
  List.map (fun (name, _) -> tvar (String.capitalize name)) fl;;
