import Mathlib.Data.List.Sort

abbrev NatSet := { xs : List ℕ // List.Sorted (·<·) xs }

namespace NatSet

def empty : NatSet := ⟨[], List.sorted_nil⟩
def singleton (x : ℕ) : NatSet := ⟨[x], List.sorted_singleton x⟩

def remove (x : ℕ) : NatSet → NatSet
| ⟨[], _⟩ => empty
| ⟨n::ns, h⟩ =>
  let tail := ⟨ns, List.Sorted.of_cons h⟩
  if n = x
  then tail
  else ⟨n::tail, h⟩

def list_insert (x : ℕ) : List ℕ → List ℕ
| [] => [x]
| a::as =>
  if x = a
  then a::as
  else if x < a
  then x::a::as
  else a::(list_insert x as)

theorem list_insert_mem : b ∈ list_insert x xs → b = x ∨ b ∈ xs := by
  induction xs with simp [list_insert]
  | cons a as ih =>
    split
    next heq =>
      intro hmem
      simp_all only [List.mem_cons, or_true]
    next hne =>
      split
      · simp
      next hnlt =>
        simp [*]
        intro h
        cases h
        · right; left; assumption
        cases ih ‹_›
        · left; assumption
        · right; right; assumption

theorem list_insert_sorted (hs : List.Sorted (·<·) xs)
    : List.Sorted (·<·) (list_insert x xs) := by
  induction xs with simp [list_insert]
  | cons a as ih =>
    split
    · simp_all only [List.sorted_cons, implies_true, and_self]
    · split
      · exact List.Sorted.cons ‹_› hs
      · have : a < x := by
          rename_i hne hnlt
          apply Nat.lt_of_le_of_ne
          exact Nat.le_of_not_lt hnlt
          exact fun a_1 => hne (Eq.symm a_1)

        simp only [List.sorted_cons]

        constructor
        · intro b hmem
          cases list_insert_mem hmem
          <;> simp_all only [imp_self, List.sorted_cons, and_true, not_lt]
        · simp_all only [imp_self, List.sorted_cons, and_true, not_lt]

def insert (x : ℕ) (s : NatSet) : NatSet :=
  ⟨list_insert x s.val, list_insert_sorted s.prop⟩

instance : EmptyCollection NatSet where
  emptyCollection := empty

instance : Singleton ℕ NatSet where
  singleton := singleton

instance : Insert ℕ NatSet where
  insert := insert

instance : LawfulSingleton ℕ NatSet where
  insert_empty_eq := by intro; rfl

instance : Membership ℕ NatSet where
  -- this is still an optimization over typical list membership, but
  -- TODO: binary search
  mem s n := n ∈ s.val.takeWhile (· ≤ n)

instance (n : ℕ) (s : NatSet) : Decidable (n ∈ s) :=
  List.instDecidableMemOfLawfulBEq n (s.val.takeWhile (· ≤ n))

theorem mem_insert_cons (hmem : x ∈ insert x ⟨ns, h₁⟩)
    : x ∈ insert x ⟨n::ns, h₂⟩ := by
  induction ns
  <;> (
    simp only [Membership.mem, insert, list_insert]
    split
    next heq =>
      simp [heq]
      exact List.Mem.head _
    next hne =>
      split
      · simp
        exact List.Mem.head ..
      next hnlt =>
        have hle : n ≤ x := by simp at hnlt; assumption
        simp only [List.takeWhile, hle, decide_true, le_refl]
        apply List.Mem.tail n
        try exact List.Mem.head []
        try exact hmem
  )

theorem mem_insert (x : ℕ) (s : NatSet) : x ∈ s.insert x := by
  let ⟨xs, h⟩ := s
  induction xs with
  | nil =>
    simp [Membership.mem, insert, list_insert, List.Mem.head]
  | cons a as ih =>
    rw [List.sorted_cons] at h
    have hmem := ih h.right
    exact mem_insert_cons hmem

theorem insert_orderless (a b : ℕ) (s : NatSet)
    : (s.insert a).insert b = (s.insert b).insert a := by
  let ⟨xs, h⟩ := s
  induction xs with
  | nil =>
    simp [insert, list_insert]
    split
    · simp [*]
    · have hne : a ≠ b := by
        intro
        simp_all only [not_true_eq_false]
      simp only [↓reduceIte, hne]
      split
      · have hnlt : ¬a < b := Nat.not_lt_of_gt ‹_›
        exact Eq.symm (if_neg hnlt)
      · have hlt : a < b := Nat.lt_of_le_of_ne (Nat.le_of_not_lt ‹_›) hne
        exact Eq.symm (if_pos hlt)
  | cons a as ih =>

    sorry

example : a ∈ ({a, b, c} : NatSet) := mem_insert a _
-- example : a ∈ ({c, b, a} : NatSet) := mem_insert a _

-- TODO: set operation instances