import Mathlib.Data.List.Sort

abbrev NatSet := { xs : List ℕ // List.Sorted (·<·) xs }

namespace NatSet

def empty : NatSet := ⟨[], List.sorted_nil⟩
def singleton (x : ℕ) : NatSet := ⟨[x], List.sorted_singleton x⟩

def remove (x : ℕ) : NatSet → NatSet
| ⟨[], _⟩ => empty
| ⟨n::ns, h⟩ =>
  let tail := ⟨ns, List.Sorted.of_cons h⟩
  if n = x
  then tail
  else ⟨n::tail, h⟩

def list_insert (x : ℕ) : List ℕ → List ℕ
| [] => [x]
| a::as =>
  if x = a
  then a::as
  else if x < a
  then x::a::as
  else a::(list_insert x as)

theorem list_insert_mem : b ∈ list_insert x xs → b = x ∨ b ∈ xs := by
  induction xs with simp [list_insert]
  | cons a as ih =>
    split
    next heq =>
      intro hmem
      simp_all only [List.mem_cons, or_true]
    next hne =>
      split
      · simp
      next hnlt =>
        simp [*]
        intro h
        cases h
        · right; left; assumption
        cases ih ‹_›
        · left; assumption
        · right; right; assumption

theorem list_insert_sorted (hs : List.Sorted (·<·) xs)
    : List.Sorted (·<·) (list_insert x xs) := by
  induction xs with simp [list_insert]
  | cons a as ih =>
    split
    · simp_all only [List.sorted_cons, implies_true, and_self]
    · split
      · exact List.Sorted.cons ‹_› hs
      · have : a < x := by
          rename_i hne hnlt
          apply Nat.lt_of_le_of_ne
          exact Nat.le_of_not_lt hnlt
          exact fun a_1 => hne (Eq.symm a_1)

        simp only [List.sorted_cons]

        constructor
        · intro b hmem
          cases list_insert_mem hmem
          <;> simp_all only [imp_self, List.sorted_cons, and_true, not_lt]
        · simp_all only [imp_self, List.sorted_cons, and_true, not_lt]

def insert (x : ℕ) (s : NatSet) : NatSet :=
  ⟨list_insert x s.val, list_insert_sorted s.prop⟩

instance : EmptyCollection NatSet where
  emptyCollection := empty

instance : Singleton ℕ NatSet where
  singleton := singleton

instance : Insert ℕ NatSet where
  insert := insert

instance : LawfulSingleton ℕ NatSet where
  insert_empty_eq := by intro; rfl

instance : Membership ℕ NatSet where
  -- this is still an optimization over typical list membership, but
  -- TODO: binary search
  mem s n := n ∈ s.val.takeWhile (· ≤ n)

instance (n : ℕ) (s : NatSet) : Decidable (n ∈ s) :=
  List.instDecidableMemOfLawfulBEq n (s.val.takeWhile (· ≤ n))

-- TODO: (tricky!)
theorem mem_insert (x : ℕ) (s : NatSet) : x ∈ s.insert x := by
  sorry

#eval 3 ∈ ({1, 2, 3} : NatSet)
#eval 3 ∈ ({1, 2} : NatSet)
#eval 3 ∈ ({} : NatSet)
#eval ({1, 2} : NatSet) == {3, 4}
#eval ({2, 2, 3, 4} : NatSet) == {4, 3, 2}
#eval ({2, 2, 3, 4} : NatSet)

end NatSet