open Ast;;
open Big_int;;
open Num;;

let prog_trace (s):unit = (
 (*  if (!do_trace) then (Format.printf s; Format.print_flush ()) *)
);;

exception NotFormula of typ;;

(* a1*x1 + ... + an*xn + a0 *)
type (*new_*)linear_arith = big_int VarMap.t * big_int;;
(*
(* a1*x1 + ... + an*xn + a0 *)
type linear_arith = int VarMap.t * int;;
*)

type atomic =
| FEquality of linear_arith        (* a1*x1 + ... + an*xn + a0 == 0 *)
| FInequality of linear_arith      (* a1*x1 + ... + an*xn + a0 >= 0 *)
| FStride of linear_arith * int    (* a1*x1 + ... + an*xn + a0 divisible by step *)
;;

type formula =
| FAtomic of atomic
(* | FUnknown *)
(* | FVar of var *)
| FAnd of formula list
| FOr of formula list
| FNot of formula
| FImplies of formula * formula
| FForall of var * formula
| FExists of var * formula
;;

type (*new_*)clinear_arith = (string * string) list * string;;
(*
type clinear_arith = (int * string) list * int;;
*)

type catomic =
  CEquality of clinear_arith        (* a1*x1 + ... + an*xn + a0 == 0 *)
| CInequality of clinear_arith      (* a1*x1 + ... + an*xn + a0 >= 0 *)
| CStride of clinear_arith * int    (* a1*x1 + ... + an*xn + a0 divisible by step *)
;;

type cformula =
| CAtomic of catomic
| CAnd of cformula list
| COr of cformula list
| CNot of cformula
| CImplies of cformula * cformula
| CForall of var * cformula
| CExists of var * cformula
;;

let (*new_*)convert_linear_arith ((vars, i): (*new_*)linear_arith): (*new_*)clinear_arith =
  prog_trace "in new convert_linear_arith\n";
        let varlist:(string * string) list =
          (VarMap.fold (fun k d a-> (string_of_big_int d, (string_of_var k))::a) vars [])
        in
        (varlist, string_of_big_int i)
;;

(*
let convert_linear_arith ((vars, i): linear_arith): clinear_arith =
        let varlist:(int * string) list =
                (VarMap.fold (fun k d a-> (d, (string_of_var k))::a) vars [])
        in
        (varlist, i)
;;
*)

let convert_atomic (a:atomic):catomic =
  prog_trace "in convert_atomic\n";
  match a with
  | FEquality la -> CEquality (convert_linear_arith la);
  | FInequality la -> CInequality (convert_linear_arith la);
  | FStride (la, step) -> CStride ((convert_linear_arith la), step)
;;

(* Converts formula's to a C format for the Omega constraint checker.

A formula is made up of atomic, and, or, not, implies, forall, and exists.
An atomic is an int arith formula a0 + a1*x1 + ... + an*xn and a relation
to 0 (== 0 or >= 0). It could also instead have an int that divides the 
formula.

A linear_arith in Clay format is a big_int VarMap.t * big_int.
A clinear_arith in C format is a (string * string) list * string.
The first string is from the big_int in the VarMap and the second 
is from the var it mapped to. The last string is from the free 
standing big_int.

I *think* a0 + a1*x1 + ... + an*xn is [(a1->x1, ..., an->xn), a0]
is how this formula is stored in the linear_arith and clinear_arith.

Basically, the C format is made of strings and lists of strings.
*)
let rec convert_formula (f:formula):cformula =
  prog_trace "in convert_formula\n";
  match f with
  | FAtomic a -> CAtomic (convert_atomic a)
  | FAnd flist ->
      CAnd (List.map convert_formula flist)
  | FOr flist ->
      COr (List.map convert_formula flist)
  | FNot formula ->
      CNot (convert_formula formula)
  | FImplies (p, q) ->
      CImplies ((convert_formula p), (convert_formula q))
  | FForall (var, formula) ->
      CForall (var, (convert_formula formula))
  | FExists (var, formula) ->
      CExists (var, (convert_formula formula))
;;

(*
let print_linear_arith ((vars, i):linear_arith):unit =
  VarMap.iter
    (fun x ix -> Format.printf "%d*%s+" ix (string_of_var x))
    vars;
  Format.printf "%d" i
;;
*)

let (*new_*)print_linear_arith ((vars, i):(*new_*)linear_arith):unit =
  prog_trace "in new print_linear_arith\n";
  VarMap.iter
    (fun x ix -> Format.printf "%s*%s+" (string_of_big_int ix) (string_of_var x))
    vars;
  Format.printf "%s" (string_of_big_int i)
;;

let print_atomic (prec:int) (a:atomic):unit =
  prog_trace "in print_atomic\n";
  match a with
  | FEquality la -> print_linear_arith la; Format.printf "==0"
  | FInequality la -> print_linear_arith la; Format.printf ">=0"
  | FStride (la, step) -> Format.printf "%d | " step; print_linear_arith la
;;

let rec print_formula (prec:int) (f:formula):unit =
  prog_trace "in print_formula\n";
  let block = format_block prec in
  match f with
  | FAtomic a -> print_atomic prec a
  | FAnd [] -> Format.printf "true"
  | FAnd [f1] -> print_formula prec f1
  | FAnd (fh::ft) ->
      block 3 (fun () ->
        print_formula 3 fh;
        Format.printf " &&@ ";
        print_formula 3 (FAnd ft))
  | FOr [] -> Format.printf "true"
  | FOr [f1] -> print_formula prec f1
  | FOr (fh::ft) ->
      block 2 (fun () ->
        print_formula 2 fh;
        Format.printf " ||@ ";
        print_formula 2 (FOr ft))
  | FNot f ->
      block 4 (fun () ->
        Format.printf "!";
        print_formula 4 f)
  | FImplies (f1, f2) ->
      block 1 (fun () ->
        print_formula 2 f1;
        Format.printf " ==>@ ";
        print_formula 2 f2)
  | FForall (x, f1) ->
      block 0 (fun () ->
        Format.printf "all %s." (string_of_var x);
        print_formula 0 f1)
  | FExists (x, f1) ->
      block 0 (fun () ->
        Format.printf "exists %s." (string_of_var x);
        print_formula 0 f1)
;;

let rec print_strstr_list (vars:((string * string) list)):unit =
 prog_trace "in print_strstr_list\n";
  match vars with        
  | [] -> Format.printf ""
  | [(s1,s2)] -> Format.printf "%s*%s+" s1 s2
  | ((s1,s2)::ft) ->
        Format.printf "%s*%s+" s1 s2;
        Format.printf " &&@ ";
        print_strstr_list ft
;;
         
let print_clinear_arith ((vars, i):clinear_arith):unit =                
  prog_trace "in print_clinear_arith\n";                                           
  print_strstr_list vars;
  Format.printf "%s" i                                            
;;     

let print_catomic (prec:int) (a:catomic):unit =                                         
  prog_trace "in print_catomic\n";                                                     
  match a with                                                                        
  | CEquality la -> print_clinear_arith la; Format.printf "==0"                        
  | CInequality la -> print_clinear_arith la; Format.printf ">=0"                      
  | CStride (la, step) -> Format.printf "%d | " step; print_clinear_arith la           
;;    

let rec print_cformula (prec:int) (f:cformula):unit =
  prog_trace "in print_cformula\n";   
  let block = format_block prec in
  match f with
  | CAtomic a -> prog_trace "CAtomic \t"; print_catomic prec a                                                
  | CAnd [] -> prog_trace "CAnd1 \t";Format.printf "true"
  | CAnd [f1] -> prog_trace "CAnd2 \t"; print_cformula prec f1                              
  | CAnd (fh::ft) -> prog_trace "CAnd3 \t";
      block 3 (fun () ->
        print_cformula 3 fh;           
        Format.printf " &&@ ";   
        print_cformula 3 (CAnd ft))
  | COr [] -> prog_trace "COr1 \t"; Format.printf "true" 
  | COr [f1] -> prog_trace "COr2 \t"; print_cformula prec f1
  | COr (fh::ft) -> prog_trace "COr3 \t";
      block 2 (fun () -> 
        print_cformula 2 fh;
        Format.printf " ||@ ";
        print_cformula 2 (COr ft))
  | CNot f -> prog_trace "CNot \t";
      block 4 (fun () ->
        Format.printf "!"; 
        print_cformula 4 f)
  | CImplies (f1, f2) ->   prog_trace "CImplies \t";
      block 1 (fun () ->
        print_cformula 2 f1; 
        Format.printf " ==>@ ";
        print_cformula 2 f2)    
  | CForall (x, f1) -> prog_trace "CForall \t";
      block 0 (fun () ->
        Format.printf "all %s." (string_of_var x);   
        print_cformula 0 f1)
  | CExists (x, f1) -> prog_trace "CExists \t";
      block 0 (fun () ->
        Format.printf "exists %s." (string_of_var x); 
        print_cformula 0 f1)
;;    


let (*new_*)coefficient (x:var) (vars:big_int VarMap.t):big_int =
    prog_trace "in new coeficcient\n";
  if VarMap.mem x vars then VarMap.find x vars else zero_big_int
;;

(*
let coefficient (x:var) (vars:int VarMap.t):int =
  if VarMap.mem x vars then VarMap.find x vars else 0
;;
*)

(*
let linear_add ((vars1, i1):linear_arith) ((vars2, i2):linear_arith):linear_arith =
  let vars = VarMap.fold
    (fun x i vars -> VarMap.add x ((coefficient x vars) + i) vars)
    vars2
    vars1
  in (vars, i1 + i2)
;;
*)

let (*new_*)linear_add ((vars1, i1):(*new_*)linear_arith) ((vars2, i2):(*new_*)linear_arith):(*new_*)linear_arith =
  prog_trace "in new linear add\n";
  let vars = VarMap.fold
    (fun x i vars -> VarMap.add x (add_big_int ((*new_*)coefficient x vars) i) vars)
    vars2
    vars1
  in (vars, (add_big_int i1 i2))
;;

(*
let linear_subtract ((vars1, i1):linear_arith) ((vars2, i2):linear_arith):linear_arith =
  let vars = VarMap.fold
    (fun x2 j2 vars1 -> VarMap.add x2 ((coefficient x2 vars1) - j2) vars1)
    vars2
    vars1
  in (vars, i1 - i2)
;;
*)

let (*new_*)linear_subtract ((vars1, i1):(*new_*)linear_arith) ((vars2, i2):(*new_*)linear_arith):(*new_*)linear_arith =
  prog_trace "in new linear_subtract\n";
  let vars = VarMap.fold
    (fun x2 j2 vars1 -> VarMap.add x2 (sub_big_int ((*new_*)coefficient x2 vars1) j2) vars1)
    vars2
    vars1
  in (vars, (sub_big_int i1  i2))
;;

let linear_eq (l1:linear_arith) (l2:linear_arith):formula =
  prog_trace "in linear_eq\n";
  FAtomic (FEquality (linear_subtract l1 l2))
;;

let linear_ne (l1:linear_arith) (l2:linear_arith):formula =
  FNot (FAtomic (FEquality (linear_subtract l1 l2)))
;;

(*
let linear_gt (l1:linear_arith) (l2:linear_arith):formula =
  (* l1 > l2  <==>  l1 - l2 > 0 <==> l1 - l2 >= 1 <==> l1 - l2 - 1 >= 0 *)
  let (vars, i) = linear_subtract l1 l2 in
  FAtomic (FInequality (vars, i - 1))
;;
*)

let (*new_*)linear_gt (l1:(*new_*)linear_arith) (l2:(*new_*)linear_arith):formula =
  (* l1 > l2  <==>  l1 - l2 > 0 <==> l1 - l2 >= 1 <==> l1 - l2 - 1 >= 0 *)
  let (vars, i) = (*new_*)linear_subtract l1 l2 in
  FAtomic (FInequality (vars, sub_big_int i (big_int_of_int 1)))
;;

let linear_ge (l1:linear_arith) (l2:linear_arith):formula =
  FAtomic (FInequality (linear_subtract l1 l2))
;;

(*
let linear_lt (l1:linear_arith) (l2:linear_arith):formula =
  (* l1 < l2  <==>  l2 - l1 > 0 <==> l2 - l1 >= 1 <==> l2 - l1 - 1 >= 0 *)
  let (vars, i) = linear_subtract l2 l1 in
  FAtomic (FInequality (vars, i - 1))
;;
*)

let (*new_*)linear_lt (l1:(*new_*)linear_arith) (l2:(*new_*)linear_arith):formula =
  (* l1 < l2  <==>  l2 - l1 > 0 <==> l2 - l1 >= 1 <==> l2 - l1 - 1 >= 0 *)
  let (vars, i) = (*new_*)linear_subtract l2 l1 in
  FAtomic (FInequality (vars, sub_big_int i (big_int_of_int 1)))
;;

let linear_le (l1:linear_arith) (l2:linear_arith):formula =
  FAtomic (FInequality (linear_subtract l2 l1))
;;

let iarith_zero = (IArith (zero_big_int, VarMap.empty));;
let iarith_const i = (IArith (i, VarMap.empty));;
let iarith_nconst i = (IInfer (i, VarMap.empty, VarMap.empty));;
let iarith_var x = (IArith (zero_big_int, VarMap.add x unit_big_int VarMap.empty));;

let weaken_int_arith (t:int_arith):int_arith =
  match t with
  | IArith (i0, vars) ->
      IInfer
        ( Big_int i0,
          VarMap.map num_of_big_int vars,
          VarMap.empty)
  | IInfer _ -> t
;;

(* compute i*t *)
let rec mult_num_arith (i:num) (t:int_arith):int_arith =
  if i =/ (Int 0) then iarith_zero else
  match t with
  | IArith _ -> mult_num_arith i (weaken_int_arith t)
  | IInfer (i0, vars, ivars) ->
      IInfer (
        i */ i0,
        VarMap.map (( */ ) i) vars,
        VarMap.map (fun (ik, tk) -> (i */ ik, tk)) ivars)
;;

(* compute i*t *)
let mult_int_arith (i:big_int) (t:int_arith):int_arith =
  if eq_big_int i zero_big_int then iarith_zero else
  match t with
  | IArith (i0, vars) ->
      IArith (mult_big_int i i0, VarMap.map (mult_big_int i) vars)
  | IInfer (i0, vars, ivars) -> mult_num_arith (Big_int i) t
;;

(* compute i*x + t *)
let rec add_num_arith_var (i:num) (x:var) (t:int_arith):int_arith =
  if i =/ (Int 0) then t else
  match t with
  | IArith _ -> add_num_arith_var i x (weaken_int_arith t)
  | IInfer (i0, vars, ivars) ->
      let vars =
        if VarMap.mem x vars then
          let isum = i +/ (VarMap.find x vars) in
          if isum =/ (Int 0) then VarMap.remove x vars else
          VarMap.add x isum vars
        else VarMap.add x i vars
      in IInfer (i0, vars, ivars)
;;

(* compute i*x + t *)
let add_int_arith_var (i:big_int) (x:var) (t:int_arith):int_arith =
  if eq_big_int i zero_big_int then t else
  match t with
  | IArith (i0, vars) ->
      let vars =
        if VarMap.mem x vars then
          let isum = add_big_int i (VarMap.find x vars) in
          if eq_big_int isum zero_big_int then VarMap.remove x vars else
          VarMap.add x isum vars
        else VarMap.add x i vars
      in IArith (i0, vars)
  | IInfer (i0, vars, ivars) -> add_num_arith_var (Big_int i) x t
;;

(* compute i*(x,ti) + t *)
let rec add_num_arith_ivar (i:num) (x:var) (ti:typ_infer ref) (t:int_arith):int_arith =
  if i =/ (Int 0) then t else
  match t with
  | IArith _ -> add_num_arith_ivar i x ti (weaken_int_arith t)
  | IInfer (i0, vars, ivars) ->
      let ivars =
        if VarMap.mem x ivars then
          let isum = i +/ (fst (VarMap.find x ivars)) in
          if isum =/ (Int 0) then VarMap.remove x ivars else
          VarMap.add x (isum, ti) ivars
        else VarMap.add x (i, ti) ivars
      in IInfer (i0, vars, ivars)
;;

(* compute t1 + t2 *)
let rec add_int_arith (t1:int_arith) (t2:int_arith):int_arith =
  match (t1, t2) with
  | (IArith _, IInfer _) -> add_int_arith (weaken_int_arith t1) t2
  | (IInfer _, IArith _) -> add_int_arith t1 (weaken_int_arith t2)
  | (IArith (i1, vars1), IArith (i2, vars2)) ->
      VarMap.fold
        (fun x i t -> add_int_arith_var i x t)
        vars2
        (IArith (add_big_int i1 i2, vars1))
  | (IInfer (i1, vars1, ivars1), IInfer (i2, vars2, ivars2)) ->
      let t = VarMap.fold
        (fun x i t -> add_num_arith_var i x t)
        vars2
        (IInfer (i1 +/ i2, vars1, ivars1)) in
      VarMap.fold
        (fun x (i, ti) t -> add_num_arith_ivar i x ti t)
        ivars2
        t
;;

(* compute t1 = t2 (without using any context information) *)
let rec eq_int_arith (t1:int_arith) (t2:int_arith):bool =
  let t = add_int_arith t1 (mult_int_arith (big_int_of_int (-1)) t2) in
  match t with
  | IArith (i, vars) ->
    (
      if eq_big_int i zero_big_int then
        VarMap.fold (fun _ _ _ -> false) vars true
      else false
    )
  | _ -> false
;;

(* Return the set of all variables appearing in a linear_arith *)
let vars_in_linear_arith (la:linear_arith):VarSet.t =
  let (varmaps, _) = la in
    VarMap.fold
      (fun k -> fun _ -> fun u -> VarSet.add k u)
      varmaps
      VarSet.empty  
;;

(* Return the set of all variables appearing in an atomic *)
let vars_in_atomic (a:atomic):VarSet.t =
  match a with
  | FEquality la -> vars_in_linear_arith la
  | FInequality la -> vars_in_linear_arith la
  | FStride (la, _) -> vars_in_linear_arith la
;;

(* Return the set of all variables appearing in a formula *)
let rec vars_in_formula (f:formula):VarSet.t =
  match f with
  | FAtomic a -> vars_in_atomic a
  | FAnd fs -> vars_in_formulas fs
  | FOr fs -> vars_in_formulas fs
  | FNot f -> vars_in_formula f
  | FImplies (f1, f2) -> VarSet.union (vars_in_formula f1) (vars_in_formula f2)
  | FForall _ -> raise (NotImplemented "FForall")
  | FExists _ -> raise (NotImplemented "FExists")
and vars_in_formulas (fs:formula list):VarSet.t =
  let us = List.map vars_in_formula fs in
  List.fold_left
    (fun u1 -> fun u2 -> VarSet.union u1 u2)
    VarSet.empty
    us
;;

(*
 * Flatten a list of conjuncts by breaking up top-level FAnd nodes into
 * separate list elements.
 *)
let rec flatten_conjuncts (fs:formula list):formula list =
  match fs with
    [] -> []
  | (FAnd fa)::t -> flatten_conjuncts (fa @ t)
  | f::t -> f::(flatten_conjuncts t)
;;

(*
 * Given a list of formulas fs = [f1;...;fn],
 * organize f1...fn into k different sublists, where
 * the set of variables in the formulas of each sublist is disjoint from
 * the set of variables in the formulas of all other sublists.
 *
 * The first formula f1 will appear as the first element
 * of the first list (the return value has the form [[f1;...];...].
 *
 * Use depth-first search.  Two data structures:
 *   reflist: [ref (f1,vars(f1),int);...;ref (fn,vars(fn),int)]
 *   varmap: x --> (f,vars(f),int) ref list
 * (the "int" is the connected component number, where 0 indicates
 *  that the formula is not yet a member of a connected component)
 *)
let connected_formulas (fs:formula list):formula list list =
  let reflist = List.map
    (fun f -> ref (f, vars_in_formula f, 0))
    fs in
  let addMapping r x varmap =
  (
    if VarMap.mem x varmap then
      VarMap.add x (r::(VarMap.find x varmap)) varmap
    else VarMap.add x [r] varmap
  ) in
  let varmap = List.fold_left
    (fun varmap -> fun r ->
      let (_, vars, _) = !r in
      VarSet.fold (addMapping r) vars varmap)
    VarMap.empty
    reflist in
  (* If node r isn't already marked, mark it with integer i and recurse *)
  let rec mark i r =
  (
    let (f, vars, fi) = !r in
    if fi = 0 then
    (
      r := (f, vars, i);
      (* recursively mark everything that node r points to: *)
      VarSet.iter
        (fun x ->
          List.iter (mark i) (VarMap.find x varmap))
        vars
    )
  ) in
  (* mark each unmarked component with a unique integer *)
  let n = ref 0 in
  let () = List.iter
    (fun r ->
      let (_, _, i) = !r in
      if i = 0 then
      (
        incr n;
        mark !n r
      ))
    reflist in
  (* stable-sort the components by component number, then
   * split into separate sublists by component number *)
  let lst = List.map (fun r -> let (f, _, i) = !r in (f, i)) reflist in
  let lst = List.stable_sort
    (fun (_, i1) -> fun (_, i2) -> compare i1 i2)
    lst in
  let rec subdivide lst =
  (
    match lst with
    | [] -> [[]]
    | [(f, _)] -> [[f]]
    | (f1, i1)::(f2, i2)::t ->
        let tt = subdivide ((f2, i2)::t) in
        if i1 = i2 then (f1::(List.hd tt))::(List.tl tt)
        else [f1]::tt
  ) in
  subdivide lst
;;

