Theory GTT_Compose_Impl

theory GTT_Compose_Impl
imports GTT_Compose GTT_Impl Horn_Autoref Tree_Automata_Autoref_Setup
theory GTT_Compose_Impl
  imports GTT_Compose GTT_Impl Horn_Autoref TA.Tree_Automata_Autoref_Setup
begin

definition Δε_rules :: "('q, 'f) ta ⇒ ('q, 'f) ta ⇒ ('q × 'q) horn set" where
  ε_rules A B =
    {zip ps qs →h (p, q) |f ps p qs q. f ps → p ∈ ta_rules A ∧ f qs → q ∈ ta_rules B ∧ length ps = length qs} ∪
    {[(p, q)] →h (p', q) |p p' q. (p, p') ∈ ta_eps A} ∪
    {[(p, q)] →h (p, q') |p q q'. (q, q') ∈ ta_eps B}"

locale Δε_horn =
  fixes A :: "('q, 'f) ta" and B :: "('q, 'f) ta"
begin

sublocale horn ε_rules A B" .

lemma Δε_infer0:
  "infer0 = {(p, q) |f p q. f [] → p ∈ ta_rules A ∧ f [] → q ∈ ta_rules B}"
  unfolding horn.infer0_def Δε_rules_def
  using zip_Nil[of "[]"] zip_inj[of "[]" "[]"] by auto force+

lemma Δε_infer1:
  "infer1 pq X = {(p, q) |f ps p qs q. f ps → p ∈ ta_rules A ∧ f qs → q ∈ ta_rules B ∧ length ps = length qs ∧
    (fst pq, snd pq) ∈ set (zip ps qs) ∧ set (zip ps qs) ⊆ insert pq X} ∪
    {(p', snd pq) |p p'. (p, p') ∈ ta_eps A ∧ p = fst pq} ∪
    {(fst pq, q') |q q'. (q, q') ∈ ta_eps B ∧ q = snd pq}"
  unfolding Δε_rules_def horn_infer1_union
  apply (intro arg_cong2[of _ _ _ _ "(∪)"])
    by (auto simp: horn.infer1_def simp flip: ex_simps(1)) force+

lemma Δε_sound:
  ε A B = saturate"
proof (intro set_eqI iffI, goal_cases lr rl)
  case (lr x) obtain p q where x: "x = (p, q)" by (cases x)
  show ?case using lr unfolding x
  proof (induct)
    case (Δε_cong f ps p qs q) show ?case
      apply (intro infer[of "zip ps qs" "(p, q)"])
      subgoal using Δε_cong(1-3) by (auto simp: Δε_rules_def)
      subgoal using Δε_cong(3,5) by (auto simp: zip_nth_conv)
      done
  next
    case (Δε_eps1 p p' q) then show ?case
      by (intro infer[of "[(p, q)]" "(p', q)"]) (auto simp: Δε_rules_def)
  next
    case (Δε_eps2 q q' p) then show ?case
      by (intro infer[of "[(p, q)]" "(p, q')"]) (auto simp: Δε_rules_def)
  qed
next
  case (rl x) obtain p q where x: "x = (p, q)" by (cases x)
  show ?case using rl unfolding x
  proof (induct)
    case (infer as a) then show ?case
      using Δε_cong[of _ "map fst as" "fst a" A "map snd as" "snd a" B]
        Δε_eps1[of _ "fst a" A "snd a" B] Δε_eps2[of _ "snd a" B "fst a" A]
      by (auto simp: Δε_rules_def)
  qed
qed

end

definition Δε_infer0_list where
  ε_infer0_list A B = map (map_prod r_rhs r_rhs) (
     filter (λ(ra, rb). case (ra, rb) of (f ps → p, g qs → q) ⇒ f = g ∧ ps = [] ∧ qs = []) (List.product (tal_rules A) (tal_rules B)))"

definition Δε_infer1_list :: "('q, 'f) ta_list ⇒ ('q, 'f) ta_list ⇒ 'q × 'q ⇒ ('q × 'q) list ⇒ ('q × 'q) list" where
  ε_infer1_list A B pq bs =
    map (map_prod r_rhs r_rhs) (filter (λ(ra, rb). case (ra, rb) of (f ps → p, g qs → q) ⇒ f = g ∧ length ps = length qs ∧ (fst pq, snd pq) ∈ set (zip ps qs) ∧ set (zip ps qs) ⊆ set ((fst pq, snd pq) # bs)) (List.product (tal_rules A) (tal_rules B))) @
    map (λ(p, p'). (p', snd pq)) (filter (λ(p, p') ⇒ p = fst pq) (tal_eps A)) @
    map (λ(q, q'). (fst pq, q')) (filter (λ(q, q') ⇒ q = snd pq) (tal_eps B))"

locale Δε_list =
  fixes A :: "('q, 'f) ta_list" and B :: "('q, 'f) ta_list"
begin

sublocale Δε_horn "ta_of A" "ta_of B" .

sublocale l: horn_list ε_rules (ta_of A) (ta_of B)" ε_infer0_list A B" ε_infer1_list A B"
  apply (unfold_locales)
  unfolding Δε_horn.Δε_infer0 Δε_horn.Δε_infer1 Δε_infer0_list_def Δε_infer1_list_def set_append Un_assoc[symmetric]
   apply (force simp: r_rhs_def map_prod_def split: ta_rule.splits)
  apply (intro arg_cong2[of _ _ _ _ "(∪)"])
  by (auto 0 0 simp: r_rhs_def map_prod_def) (force split: ta_rule.splits)+

lemmas infer = l.infer0 l.infer1
lemmas saturate_impl_sound = l.saturate_impl_sound
lemmas saturate_impl_complete = l.saturate_impl_complete

end

definition Δε_impl :: "('q, 'f) ta_list ⇒ ('q, 'f) ta_list ⇒ ('q × 'q) list option" where
  ε_impl A B = horn_list_impl.saturate_impl (Δε_infer0_list A B) (Δε_infer1_list A B)"

lemma Δε_impl_sound:
  ε_impl A B = Some xs ⟹ set xs = Δε (ta_of A) (ta_of B)"
  using Δε_list.saturate_impl_sound unfolding Δε_impl_def Δε_horn.Δε_sound .

lemma Δε_impl_complete:
  fixes A B :: "('q, 'f) ta_list"
  shows ε_impl A B ≠ None"
proof -
  have *: ε (ta_of A) (ta_of B) ⊆ ta_states (ta_of A) × ta_states (ta_of B)"
    by (auto simp: Δε_def' ta_of_def subsetD[OF ta_res_states])
  have "finite (Δε (ta_of A) (ta_of B))"
    by (cases A; cases B; rule finite_subset[OF *])
      (auto simp: ta_of_def ta_states_def r_states_def intro!: finite_cartesian_product)
  then show ?thesis unfolding Δε_impl_def
    by (intro Δε_list.saturate_impl_complete) (auto simp: Δε_horn.Δε_sound)
qed

definition GTT_comp_impl :: "('q, 'f) gtt_list ⇒ ('q, 'f) gtt_list ⇒ (_, 'f) gtt_list option" where
  "GTT_comp_impl G1 G2 = (case (Δε_impl (snd G1) (fst G2), Δε_impl (fst G2) (snd G1)) of
     (Some xs, Some ys) ⇒ Some
     (ta_list [] (tal_rules (fst G1) @ tal_rules (fst G2)) (tal_eps (fst G1) @ tal_eps (fst G2) @ xs),
      ta_list [] (tal_rules (snd G1) @ tal_rules (snd G2)) (tal_eps (snd G1) @ tal_eps (snd G2) @ ys))
     | _ ⇒ None
     )"

lemma GTT_comp_impl_sound:
  "GTT_comp_impl G1 G2 = Some G ⟹ gtt_of G = GTT_comp (gtt_of G1) (gtt_of G2)"
  by (auto simp: GTT_comp_def GTT_comp_impl_def split: option.splits dest!: Δε_impl_sound
   intro!: ta_of_eq_ta_makeI)

lemma GTT_comp_impl_complete:
  "GTT_comp_impl G1 G2 ≠ None"
  using Δε_impl_complete[of "snd G1" "fst G2"] Δε_impl_complete[of "fst G2" "snd G1"]
  by (auto simp: GTT_comp_impl_def)

text ‹Same thing with automatic refinement.›

definition Δε_infer0_rbt where
  ε_infer0_rbt A B = map (map_prod r_rhs r_rhs) (
     filter (λ(ra, rb). case (ra, rb) of (f ps → p, g qs → q) ⇒ f = g ∧ ps = [] ∧ qs = [])
        (List.product (RBT_Impl.keys (ta_rules_impl A)) (RBT_Impl.keys (ta_rules_impl B))))"

definition Δε_infer1_rbt where
  ε_infer1_rbt A B pq X =
    map (map_prod r_rhs r_rhs) (filter (λ(ra, rb). case (ra, rb) of (f ps → p, g qs → q) ⇒ f = g ∧ length ps = length qs ∧
     (fst pq, snd pq) ∈ set (zip ps qs) ∧ set (zip ps qs) ⊆ set (pq # (RBT_Impl.keys X))) (List.product (RBT_Impl.keys (ta_rules_impl A)) (RBT_Impl.keys (ta_rules_impl B)))) @
    map (λ(p, p'). (p', snd pq)) (filter (λ(p, p') ⇒ p = fst pq) (ta_eps_impl A)) @
    map (λ(q, q'). (fst pq, q')) (filter (λ(q, q') ⇒ q = snd pq) (ta_eps_impl B))"

lemma keys_list:
  fixes X :: "('a :: compare_order) set"
  assumes "(X', X) ∈ ⟨R⟩dflt_rs_rel"
  assumes "PREFER_id R"
  shows "(RBT_Impl.keys X', X) ∈ ⟨R⟩list_set_rel"
  using assms
    linorder.rbt_lookup_keys[OF linorder_class.linorder_axioms ord.is_rbt_rbt_sorted, of "X'"]
  by (auto simp: map2set_rel_def rbt_map_rel_def rbt_map_rel'_def list_set_rel_def br_def
    lt_of_comp_post_simp linorder_class.distinct_keys compare_order_class.ord_defs)

lemmas keys_it = list_set_rel_imp_it_set_rel[OF keys_list]

lemma ta_rules_list:
  assumes "(A', A) ∈ dflt_ta_rel"
  shows "(RBT_Impl.keys (ta_rules_impl A'), ta_rules A) ∈ ⟨Id⟩list_set_rel"
  using fun_relD[OF ta_impl_autoref(2), OF assms]
   keys_list[of "ta_rules_impl A'" "ta_rules A"]
  by (auto simp: lt_of_comp_post_simp compare_order_class.ord_defs)

lemmas ta_rules_it = list_set_rel_imp_it_set_rel[OF ta_rules_list]

lemma Δε_infer0_rbt:
  assumes A': "(A', A) ∈ dflt_ta_rel"
  assumes B': "(B', B) ∈ dflt_ta_rel"
  shows "(Δε_infer0_rbt A' B', horn.infer0 (Δε_rules A B)) ∈ ⟨Id⟩it_set_rel"
  unfolding Δε_horn.Δε_infer0 Δε_infer0_rbt_def
  using assms[THEN ta_rules_it]
  by (auto simp: it_set_rel_def br_def r_rhs_def split: ta_rule.split) force+

lemma Δε_infer1_rbt:
  assumes A': "(A', A) ∈ dflt_ta_rel"
  assumes B': "(B', B) ∈ dflt_ta_rel"
  assumes pq': "(pq', pq) ∈ Id"
  assumes X': "(X', X) ∈ ⟨Id⟩dflt_rs_rel"
  shows "(Δε_infer1_rbt A' B' pq' X', horn.infer1 (Δε_rules A B) pq X) ∈ ⟨Id⟩it_set_rel"
  unfolding Δε_horn.Δε_infer1 Δε_infer1_rbt_def it_set_rel_def list_rel_id_simp Id_O_R
  apply (intro brI TrueI)
  using assms(1,2)[THEN fun_relD[OF ta_impl_autoref(3)], THEN list_set_rel_imp_it_set_rel] assms(3)
    assms(1,2)[THEN ta_rules_it] assms(4)[THEN keys_it] unfolding set_append Un_assoc[symmetric]
  apply (intro arg_cong2[of _ _ _ _ "(∪)"])
  by (auto 0 0 intro!: rev_image_eqI simp: it_set_rel_def br_def r_rhs_def simp flip: ex_simps(1) split: ta_rule.split prod.split) force+

schematic_goal Δε_impl_rbt:
  assumes "(A', A) ∈ dflt_ta_rel"
  assumes "(B', B) ∈ dflt_ta_rel"
  shows "(?f, horn.saturate_impl (Δε_rules A B)) ∈ ⟨⟨Id⟩dflt_rs_rel⟩nres_rel"
  apply (rule horn_linorder.saturate_impl_rbt)
   apply (rule Δε_infer0_rbt[OF assms]) 
  apply (intro fun_relI)
  apply (rule Δε_infer1_rbt[OF assms])
  apply assumption+
  done

concrete_definition Δε_rbt uses Δε_impl_rbt
prepare_code_thms Δε_rbt_def

abbreviation dflt_gtt_rel where
  "dflt_gtt_rel ≡ dflt_ta_rel ×r dflt_ta_rel"

schematic_goal GTT_comp_rbt_aux:
  assumes [autoref_rules]: "(G1', G1) ∈ dflt_gtt_rel"
  assumes [autoref_rules]: "(G2', G2) ∈ dflt_gtt_rel"
  defines "Δ12 ≡ Δε (snd G1) (fst G2)" and "Δ21 ≡ Δε (fst G2) (snd G1)"
  assumes Δ: "(Δ12', Δ12) ∈ ⟨Id ×r Id⟩dflt_rs_rel" "(Δ21', Δ21) ∈ ⟨Id ×r Id⟩dflt_rs_rel"
  notes [autoref_rules] = Δ[THEN keys_list]
  shows "(?f, GTT_comp G1 G2) ∈ dflt_gtt_rel"
  unfolding GTT_comp_def Δ12_def[symmetric] Δ21_def[symmetric]
  by autoref

concrete_definition GTT_comp_rbt_aux uses GTT_comp_rbt_aux

(* export_code GTT_comp_rbt_aux in Haskell *)

lemma bind_RETURN:
  assumes "(a, RETURN b) ∈ ⟨R⟩nres_rel"
  assumes "⋀a'. (a', b) ∈ R ⟹ (f' a', f) ∈ ⟨Q⟩nres_rel"
  shows "(a ⤜ f', f) ∈ ⟨Q⟩nres_rel"
  using assms 
  apply (auto simp: nres_rel_def bind_def SUP_le_iff split: nres.splits)
  by (metis RES_Sup_RETURN RETURN_ref_RETURND SUP_le_iff)

lemma nres_of_bind:
  "nres_of (f ⤜ g) = nres_of f ⤜ (λx. nres_of (g x))"
  by (cases f) auto

schematic_goal GTT_comp_rbt:
  assumes G1: "(G1', G1) ∈ dflt_gtt_rel"
  assumes G2: "(G2', G2) ∈ dflt_gtt_rel"
  notes G1' = param_prod(4,5)[THEN fun_relD, OF G1]
  notes G2' = param_prod(4,5)[THEN fun_relD, OF G2]
  notes * = horn.saturate_impl_sound[unfolded SPEC_eq_is_RETURN, of ε_rules _ _", unfolded Δε_horn.Δε_sound[symmetric]]
  shows "(?f, RETURN (GTT_comp G1 G2)) ∈ ⟨dflt_gtt_rel⟩nres_rel"
  apply (rule HOL.back_subst[of "λx. (x, _) ∈ _"])
  apply (rule bind_RETURN[OF conc_trans_additional(1)[OF Δε_rbt.refine[OF G1'(2) G2'(1), THEN nres_relD] *, THEN nres_relI]])
  apply (rule bind_RETURN[OF conc_trans_additional(1)[OF Δε_rbt.refine[OF G2'(1) G1'(2), THEN nres_relD] *, THEN nres_relI]])
  apply (rule param_RETURN[THEN fun_relD])
   apply (rule GTT_comp_rbt_aux[OF G1 G2, unfolded prod_rel_id_simp]; assumption+)
  unfolding nres_of_simps(3)[symmetric] nres_of_bind[symmetric]
  by (rule refl)

concrete_definition GTT_comp_rbt uses GTT_comp_rbt

export_code GTT_comp_rbt in Haskell

end