bytecode/Bytecode/Basic.lean
2025-04-09 15:37:27 -04:00

81 lines
2.2 KiB
Text

import Batteries.Data.Rat.Basic
import Batteries.Control.Lawful.MonadLift
-- THIS GETS ME LawfulMonad (OptionT m)! YES
import Batteries.Control.OptionT
inductive Ast
| lit (n : Rat)
| add (l r : Ast)
| sub (l r : Ast)
| mul (l r : Ast)
| div (l r : Ast)
instance : Coe Rat Ast where coe := .lit
instance : OfScientific Ast where
ofScientific a b c := .lit (.ofScientific a b c)
instance : OfNat Ast n where ofNat := .lit n
instance : Add Ast where add := .add
instance : Sub Ast where sub := .sub
instance : Mul Ast where mul := .mul
instance : Div Ast where div := .div
def Ast.interpret : Ast → Rat
| lit n => n
| add l r => l.interpret + r.interpret
| sub l r => l.interpret - r.interpret
| mul l r => l.interpret * r.interpret
| div l r => l.interpret / r.interpret
#eval Ast.interpret <| 3/2 + 0.5
abbrev M := OptionT (StateM (Array Rat))
def M.push (n : Rat) : M Unit := modify (·.push n)
def M.pop : M Rat := OptionT.mk do
let top ← getModify Array.pop
return top.back?
abbrev M.op (l r : M Unit) (f : Rat → Rat → Rat) : M Unit := do
l
let l ← pop
r
let r ← pop
push (f l r)
def Ast.compile : Ast → M Unit
| lit n => M.push n
| add l r => M.op l.compile r.compile (·+·)
| sub l r => M.op l.compile r.compile (·-·)
| mul l r => M.op l.compile r.compile (·*·)
| div l r => M.op l.compile r.compile (·/·)
def M.run (m : M Unit) : Option Rat :=
OptionT.run (do m; pop) #[] |>.fst
#eval Ast.compile (3/2 + 0.5) |>.run
@[simp]
theorem M.push_pop : (do push x; pop) = pure x := by
funext xs
-- what on earth
simp only [bind, OptionT.bind, OptionT.mk, StateT.bind, push, modify, modifyGet,
MonadStateOf.modifyGet, monadLift, MonadLift.monadLift, OptionT.lift, StateT.modifyGet, pure,
StateT.pure, pop, getModify, Array.pop_push, Array.back?_push, OptionT.pure]
-- instance : LawfulMonad M where
-- map_const := rfl
-- id_map {a} x := by
-- simp [Functor.map, OptionT.bind, OptionT.mk, StateT.instLawfulMonad]
theorem Ast.compile_sound (ast : Ast)
: ast.compile = M.push ast.interpret := by
induction ast
· rfl
repeat (
simp only [M.op, compile, interpret, *]
repeat rw [←bind_assoc, M.push_pop, pure_bind]
)