module stlc.examples.monad where

open import Data.List

open import config
open import stlc.term
open import stlc.properities
open import stlc.catComb.compile
open import stlc.run

open import stlc.monad

open Utils

--- EXCEPTION MONAD ---

infix 4 _throws_

_throws_ : Type  Type  Type
A throws E = getMonadType {A} (exception {A} {E})

some :  {Γ A E}  Γ  A  Γ  A throws E
some = return {m = exception}

try_on_ :  {Γ A B E}  Γ , A  B throws E  Γ  A throws E  Γ  B throws E
try f on e = `let_`in_ {ma = exception} {mb = exception} {s = exception} e f

infix 5 throw_

throw_ :  {Γ A B}  Γ  nat  Γ , A  B throws nat
throw n = weaken {Γ' = } (inr n)

--- Example 1 - Exception Monad ---

swap :  {Γ A B}  Γ , (A × B)  (B × A) throws nat
swap = some (snd # 0 , fst # 0)

ex1 :   nat × nat throws nat
ex1 = try swap on some (`nat 1 , `nat 2)

inst1 : List Inst
inst1 = compile ex1

result1 : Result
result1 = run 100  inst1  ⟨⟩  [] 

--- Example 2 - throw error ---

ex2 :   nat × nat throws nat
ex2 = try (throw `nat 9) on some (`nat 1 , `nat 3)

inst2 : List Inst
inst2 = compile ex2

result2 : Result
result2 = run 100  inst2  ⟨⟩  [] 


--- STATE MONAD ---

infix 4 _stores_

_stores_ : Type  Type  Type
A stores St = getMonadType {A} (state {A} {St})

infixr 5 _>>=_

_>>=_ :  {Γ A B S}  Γ  A stores S  Γ , A  B stores S  Γ  B stores S
m >>= f = `let_`in_ {ma = state} {mb = state} {s = state} m f

get :  {Γ S}  Γ  S stores S
get = ƛ ((# 0) , (# 0))

put :  {Γ S}  Γ  S  Γ  unit stores S
put x = ƛ (weaken {Γ' = } x , ⟨⟩)

store :  {Γ A S}  Γ  A  Γ  A stores S
store a = ƛ ((# 0) , weaken {Γ' = } a)

runState :  {Γ A S}  Γ  A stores S  Γ  S  Γ  S × A
runState m s = (m · s)

evalState :  {Γ A S}  Γ  A stores S  Γ  S  Γ  A
evalState M s = snd (runState M s)

execState :  {Γ A S}  Γ  A stores S  Γ  S  Γ  S
execState M s = fst (runState M s)

open import stlc.examples.list using (Nat; s_; z; _+Nat_)

increment :  {Γ A}  Γ , A  A stores Nat
increment = get >>= put ((# 0) +Nat (s z)) >>= store (# 2)

--- Example 3 - increment store ---

ex3 :   Nat × unit
ex3 = runState (store ⟨⟩ >>= increment >>= increment >>= increment) z

inst3 : List Inst
inst3 = compile ex3

result3 : Result
result3 = run 1000  inst3  ⟨⟩  []