open Ast;;
open Arith;;
open Big_int;;
open List;;

open Codeprinter;;

exception ClayPrinter of string;;

type separator = SpaceSep | NewlineSep | NoSep

type print_list_opts = {sep_str:string; sep_opt:separator; final_sep:bool};;

let comma_list = {sep_str = ","; sep_opt = SpaceSep; final_sep = false};;
let struct_list = {sep_str = ";"; sep_opt = NewlineSep; final_sep = true};;

let print_list (opt:print_list_opts) (l:'a list) (print_fcn:'a -> unit):unit =
  match l with
      [] -> ()
    | (h::[]) -> print_fcn h; (if opt.final_sep then print_string opt.sep_str else ())
    | _ ->
	let l' = rev l in
	let most_l = rev (tl l') in
	let last_l = hd l' in
	  iter 
	    (fun (a:'a) -> print_fcn a; print_string opt.sep_str;
	       (match opt.sep_opt with
		    SpaceSep -> print_space ()
		  | NewlineSep -> print_newline ()
		  | NoSep -> ()))
	    most_l;
	  print_fcn last_l;
	  (if opt.final_sep then
	     print_string opt.sep_str
	   else
	     ())
;;

let print_comma_list = fun l fcn -> print_list comma_list l fcn;;


let rec print_kind (k:kind):unit =
  match k with
      KInt -> print_string "int"
    | KBool -> print_string "bool"
    | K2 k2 -> print_kind2 k2
and print_kind2 (k:kind2):unit =
  match k with
      KType (size,lin) -> 
	(match lin with
	     Linear -> print_string "@"
	   | _      -> ());
	print_string "type";
	if (size <> wordsize) then
	  print_int size
	else
	  ()
    | KArrow (args, ret) ->
	print_kind2 ret;
	print_string "<-(";
	print_comma_list args print_kind;
	print_string ")"
;;

(* Seeing as how this module is disgustingly imperative already, I'm
   going to finish the damage and use a mutable variable for the variable
   name map instead of passing and returning it all the time. *)
   
let vmap:var VarMap.t ref = ref VarMap.empty;;

let print_var (v:var):unit =
  try
    let v' = VarMap.find v !vmap in print_string (string_of_var v')
  with
      Not_found -> print_string (string_of_var v)
;;

let rec rename_var ((name,i):var):var =
  let j =
    (try
       1 + (snd (VarMap.find (name, -1) !vmap))
     with Not_found ->
       0)
  in
    vmap := VarMap.add (name, i) (name, j) (VarMap.add (name, -1) (name, j) !vmap);
    (name, j)
;;

let rename_tparams (tp:tparam list):tparam list =
  map (fun (v,k) -> (rename_var v, k)) tp
;;

let print_int_prefix (is_first:bool) (i:num):unit =
  if (Num.ge_num) i (Num.Int 0) then (
    (if is_first then
       ()
     else
       print_string "+");
    if (Num.eq_num) i (Num.Int 1) then
      ()
    else (
      print_string (Num.string_of_num i);
      print_string "*"))
  else
    ((if is_first then
	print_string "0"
      else 
	());
     print_string "-";
     print_int_prefix true (Num.minus_num i))
;;

let rec print_int_arith (i:int_arith):unit =
  let print_int (i:big_int):unit =
    if (Big_int.lt_big_int i zero_big_int) then (
      print_string "(0";
      print_string (Big_int.string_of_big_int i);
      print_string ")")
    else (
      print_string (Big_int.string_of_big_int i))
  in
    match i with
	IArith (i0, vars) -> (
	  let printed_nothing = 
	    VarMap.fold
	      (fun x i b -> print_int_prefix b (Num.num_of_big_int i); print_var x; false)
	      vars
	      true
	  in
	    (if not (printed_nothing || eq_big_int i0 zero_big_int) then
	       print_string "+"
	     else
	       ());
	    (if printed_nothing || (not (eq_big_int i0 zero_big_int)) then
	       print_int i0
	     else
	       ()))
      | IInfer (i0, vars, ivars) -> (
	  let printed_anything =
	    VarMap.fold
	      (fun x i b -> print_int_prefix b i; print_var x; false)
	      vars
	      true in
	  let printed_anything =
	    VarMap.fold
	      (fun x (i, ti) b ->
	       print_int_prefix b i;
		 print_string "(";
		 print_typ (TInfer ti);
		 print_string ")";
		 false)
	      ivars
	      printed_anything in
	    (if not (printed_anything || (Num.le_num i0 (Num.Int 0))) then
	       print_string "+"
	     else
	       ());
	    if printed_anything || (not (Num.eq_num i0 (Num.Int 0))) then
	      print_string (Num.string_of_num i0)
	    else
	      ())

and print_bool_arith (ba:bool_arith):unit =
  match ba with
      BVar v -> print_var v
    | BConst true -> print_string "true"
    | BConst false -> print_string "false"
    | BNot t -> print_string "!"; print_bool_arith t
    | BBinary (BOrOp, t1, t2) ->
	print_bool_arith t1;
	print_string "||";
	print_bool_arith t2
    | BBinary (BAndOp, t1, t2) ->
	print_bool_arith t1;
	print_string "&&";
	print_bool_arith t2
    | BCompare (op, t1, t2) ->
	print_int_arith t1;
	print_string (string_of_int_compare_op op);
	print_int_arith t2
    | BInfer ti -> print_typ (TInfer ti)


and print_typ (t:typ):unit =
  match t with
      TArrow (args, ret, limit) ->
	print_typ ret;
	print_string "<-";
	(match limit with 
	     Unlimited -> ()
	   | LimitAny -> print_string "limitany"
	   | Limited l -> print_string "limited["; print_int_arith l; print_string "]");
	print_string "(";
	print_list comma_list args print_typ;
	print_string ")"
    | TNamed v -> print_var v
    | TVar v -> print_var v
    | TApp (func, args) ->
	print_typ func;
	print_string "[";
	print_comma_list args print_typ;
	print_string "]"
    | TInt t -> print_int_arith t
    | TBool ba -> print_bool_arith ba
    | TRecord (lin, fields) ->
	(match lin with
	     Linear -> print_string "@["
	   | Nonlinear -> print_string ".[");
	let (_, fields) =
	  fold_left
	    (fun (i,fields) (field_name,t) ->
	       if field_name = (string_of_int i) then
		 (i + 1, fields @ [("", t)])
	       else
		 (i + 1, fields @ [(" " ^ field_name, t)]))
	    (1,[])
	    fields in
	  print_comma_list fields (fun (field_name, field_typ) -> print_typ field_typ; print_string field_name);
	  print_string "]"
    | TAll (tparams, cond, tb) ->
	print_string "all[";
	print_comma_list tparams (fun (v,k) -> print_kind k; print_space (); print_var v);
	(match cond with
	     BConst true -> ()
	   | ba          -> print_string ";"; print_bool_arith ba);
	print_string "]";
	print_space ();
	print_typ tb
    | TExists (tparams, cond, tb) ->
	print_string "exists[";
	print_comma_list tparams (fun (v,k) -> print_kind k; print_space (); print_var v);
	(match cond with 
	     BConst true -> ()
	   | ba          -> print_string ";"; print_bool_arith ba);
	print_string "]";
	print_space ();
	print_typ tb
    | TFun (tparams, tb) ->
	print_string "fun[";
	print_comma_list tparams (fun (v,k) -> print_kind k; print_space (); print_var v);
	print_string "]"; print_space ();
	print_typ tb
    | TInfer ti ->
	match !ti with
	    TKnown t -> print_typ t
	  | TUnknown (v,_,_) -> print_string ("?" ^ (string_of_var v))
;;

let print_type_params (params:(string * typ) list):unit =
  print_comma_list params (fun (name,t) -> print_string name; print_string "="; print_typ t)
;;

let rec print_exp_raw (e:exp_raw):unit =
  match e with
      EVar v -> print_string (string_of_var !v)
    | EBool b -> print_string (if b then "true" else "false")
    | EInt bi -> 
	if (Big_int.lt_big_int bi Big_int.zero_big_int) then (
	  print_string "(0 -";
	  print_string (Big_int.string_of_big_int bi);
	  print_string ")") 
	else (
	  print_string (Big_int.string_of_big_int bi))
    | EUnit -> print_string "()"
    | ECall (func,params) -> (
	match func.exp_raw with
	    EOverload ((fname :: _), _) -> (
	      if not ((List.length params) = 2) then
		raise (ClayPrinter "Cannot understand EOverload called without two parameters.")
	      else (
		let op_string =
		  match fname with
		      "i32_add" -> "+"
		    | "i32_subtract" -> "-"
		    | "s32_mult" -> "*"
		    | "is32_lt" -> "<"
		    | "is32_gt" -> ">"
		    | "is32_le" -> "<="
		    | "is32_ge" -> ">="
		    | "is32_eq" -> "=="
		    | "is32_ne" -> "!="
		    | "bool_and" -> "&&"
		    | "bool_or" -> "||"
		    | "iu32_lshift" -> ">>"
		    | "iu8_rshift" -> "<<"
		    | "iu32_and" -> "&"
		    | "iu32_or" -> "|"
		    | "iu32_xor" -> "^"
		    | _ -> raise (ClayPrinter ("Failed to understand EOverload function name list; first element was " ^ fname))
		in
		  print_exp (List.hd params);
		  print_space ();
		  print_string op_string;
		  print_space ();
		  print_exp (List.nth params 1)
	      )
	    )
	  | _ ->  
	      print_exp func;
	      print_string "(";
	      print_comma_list params print_exp;
	      print_string ")"
      )
    | EOverload ((fname :: _), _) -> raise (ClayPrinter "Attempted to print EOverload outside of ECall.")
    | EAssign (op,arg1,arg2) ->
	let op_string =
	  match op with
	      AssignOp -> "="
	    | MultAssignOp -> "*="
	    | DivAssignOp -> "/=" 
	    | ModAssignOp -> "%=" 
	    | AddAssignOp -> "+="
	    | SubAssignOp -> "-="
	    | LShiftAssignOp -> ">>="
	    | RShiftAssignOp -> "<<="
	    | BitwiseAndAssignOp -> "&="
	    | BitwiseOrAssignOp -> "|="
	    | XorAssignOp -> "^" in
	  print_exp arg1;
	  print_space ();
	  print_string op_string;
	  print_space ();
	  print_exp arg2
    | EStruct v -> print_string "struct"; print_space (); print_string (string_of_var v)
    | ERecord (lin,fields) ->
	(match lin with
	     Linear -> print_string "@("
	   | Nonlinear -> print_string ".(");
	let print_field (name,typ_opt,e) =
	  (match typ_opt with
	       Some t -> print_typ t; print_space ()
	     | None   -> ());
	  print_string name;
	  print_string "=";
	  print_exp e 
	in
	  print_comma_list fields print_field;
	  print_string ")"
    | EMember (e,name) -> print_exp e; print_string ("." ^ name)
    | ETApp (e,params) ->
	print_exp e;
	print_string "[";
	print_type_params params;
	print_string "]";
    | EPack (e,t,params) ->
	print_string "pack[";
	print_typ t;
	if (params <> []) then (
	  print_string "][";
	  print_type_params params
	) else ();
	print_string "](";
	print_exp e;
	print_string ")"

and print_exp (e:exp):unit =
  print_exp_raw (e.exp_raw);;

let rec print_stmt_raw (s:stmt_raw):unit =
  let print_unpack_spec spec =
    (match spec with
	 None -> ()
       | Some s ->
	   print_string "[";
	   print_comma_list s (fun (n1,n2) -> print_string (string_of_var n2); print_string "="; print_string n1);
	   print_string "]") in

    match s with
	SBlock stmts ->
	  iter print_stmt stmts;
      | SDecl (unpack,d,rest) ->
	  print_string "let";
	  print_unpack_spec unpack;
	  print_space ();
	  let ((v,t),e) = d in
	    (match t with
		 None -> ()
	       | Some t' ->
		   print_typ t';
		   print_space ());
	    print_string (string_of_var !v);
	    print_space ();
	    print_string "=";
	    print_space ();
	    print_exp e;
	    print_string ";";
	    print_newline ();
	    print_stmt rest
      | SMDecl (unpack,(decls,e),rest) ->
	  print_string "let";
	  print_unpack_spec unpack;
	  print_space ();
	  print_string "(";
	  print_comma_list 
	    decls
	    (fun (v,t_opt) ->
	       (match t_opt with
		    None -> print_string (string_of_var !v)
		  | Some t -> print_typ t; print_space (); print_string (string_of_var !v)));
	  print_string ")";
	  print_space ();
	  print_string "=";
	  print_space ();
	  print_exp e;
	  print_string ";";
	  print_newline ();
	  print_stmt rest
      | SReturn e -> print_string "return"; print_space (); print_exp e; print_string ";"; print_newline ()
      | SWhile (cond,body,rest) ->
	  print_string "while"; print_space ();
	  print_string "(";
	  print_exp cond;
	  print_string ") {";
	  start_block ();
	  print_stmt body;
	  end_block ();
	  print_string "}"; print_newline ();
	  print_stmt rest;
      | SFor (params, ba, inits, cond, cont, body, rest) ->
	  print_string "stupid for";   (* This should really change *)
	  print_newline ();
	  print_stmt rest
      | SContinue exps ->
	  print_string "continue(";
	  print_comma_list exps print_exp;
	  print_string ");";
	  print_newline ();
      | SIfElse (cond,conseq,alt) ->
	  print_string "if"; print_space (); print_string "(";
	  print_exp cond;
	  print_string ") {";
	  start_block (); print_stmt conseq; end_block ();
	  print_string "} else {";
	  start_block (); print_stmt alt; end_block ();
	  print_string "}"; print_newline ()
      | SBoolCase (ba,conseq,alt) ->
	  print_string "if"; print_space (); print_string "[";
	  print_bool_arith ba;
	  print_string "] {";
	  start_block (); print_stmt conseq; end_block ();
	  print_string "} else {"; 
	  start_block (); print_stmt alt; end_block ();
	  print_string "}"; print_newline ()
      | SExp e -> print_exp e; print_string ";"; print_newline ()

and print_stmt (s:stmt):unit =
  print_stmt_raw (s.stmt_raw)
;;

let print_function (f:fun_decl):unit =
  if f.fun_decl_is_inline then (
    print_string "inline"; print_space ()
  ) else ();
  (match f.fun_decl_stmt with
       FunNative -> print_string "native"; print_space ()
     | _ -> ());
  (match f.fun_decl_linkage with
       LinkageC -> print_string "\"C\""; print_space ()
     | _ -> ());
  let (t,_) = f.fun_decl_ret in
    print_typ t;
    print_space ();
    print_string (string_of_var f.fun_decl_name);
    (match f.fun_decl_tparams with
	 None -> ()
       | Some (params,ba) ->
	   print_string "[";
	   print_comma_list params (fun (name,k) -> print_kind k; print_space (); print_string (string_of_var name));
	   print_string ";";
	   print_space ();
	   print_bool_arith ba;
	   print_string "]");
    print_string "(";
    let params = map fst f.fun_decl_params in
      print_comma_list params (fun (v,t) -> print_typ t; print_space (); print_string (string_of_var !v));
      print_string ")";
      (match f.fun_decl_limit with
	   Unlimited -> ()
	 | LimitAny -> print_space (); print_string "limitany"
	 | Limited i -> print_space (); print_string "limited["; print_int_arith i; print_string "]");
      (match f.fun_decl_stmt with
	   FunNative -> print_string ";"
	 | FunBody stmt -> print_string "{"; start_block (); print_stmt stmt; end_block (); print_string "}"
	 | _ -> print_string "aw, shucks."; print_newline ()     (* this should change eventually too *));
      print_newline ();
      print_newline ()
;;

let print_type_decl ((name,spec):type_decl):unit =
  (match spec with
       AbbrevSpec t ->
	 print_string "typedef"; print_space ();
	 print_var name; 
	 
	 print_space (); print_string "="; print_space ();
	 print_typ t; print_newline ()
     | StructSpec (nbits,lin,tp_opt,fields) ->
	 (match lin with
	      Linear -> print_string "@"
	    | Nonlinear -> ());
	 print_string "type";
	 (if nbits = wordsize then
	    ()
	  else 
	    print_int nbits);
	 print_space (); print_var name; 
	 (match tp_opt with
	      Some tp ->
		print_string "[";
		print_comma_list tp (fun (v,k) -> print_kind k; print_space (); print_var v);
	       print_string "]"
	    | None -> ());
	 print_space ();
	 print_string "="; print_space (); print_string "struct";
	 print_space (); print_string "{"; start_block ();
	 print_list struct_list fields (fun (v,t) -> print_typ t; print_space (); print_string v);
	 end_block (); print_string "}"; print_newline ()
     | NativeSpec (nbits,lin,tp_opt) ->
	 (match lin with
	      Linear -> print_string "@"
	    | Nonlinear -> ());
	 print_string "type";
	 (if nbits = wordsize then
	    ()
	  else
	    print_int nbits);
	 print_space (); print_var name;
	 (match tp_opt with
	      Some tp ->
		print_string "[";
		print_comma_list tp (fun (v,k) -> print_kind k; print_space (); print_var v);
		print_string "]"
	    | None -> ());
	 print_space (); print_string "="; print_space (); print_string "native";
	 print_newline ()
     | _ -> print_string "#preprocessed#");
  print_newline ();
  print_newline ()
;;
   
   

let print_program (p:program) (file_name:string):unit =
  let b = Buffer.create 1000 in
  let pr = make_new_printer b 105 "" in
  let write_buffer (output_file:Unix.file_descr) (buf:Buffer.t):bool =
    (Buffer.length buf) = (Unix.write output_file (Buffer.contents buf) 0 (Buffer.length buf)) in
  let output_file = Unix.openfile file_name [Unix.O_WRONLY; Unix.O_CREAT; Unix.O_TRUNC] (6*64 + 4*8 + 4) in
    try (
      set_printer pr;
      iter print_type_decl p.program_type_decls;
      iter print_function p.program_fun_decls;
      (if (write_buffer output_file b) then
	 ()
       else
	 raise (ClayPrinter "Unable to write output Clayprinter output"));
      Unix.close output_file
    ) with e -> (
      Unix.close output_file;
      raise e
    )
;;	   
