module stlc.monad.impl where

open import Data.List

open import stlc.inst
open import stlc.monad public
open import stlc.catComb

--- EXCEPTION MONAD ---
exceptionReturnImpl :  {A E}  CatComb A (getMonadType {A} (exception {E = E}))
exceptionReturnImpl = i1

exceptionBindImpl :  {A B E}  CatComb ((getMonadType {A} (exception {E = E})) × (A  (getMonadType {B} (exception {E = E})))) (getMonadType {B} (exception {E = E}))
exceptionBindImpl = [ app , i2  p₂ ]   p₂ , p₁ 

--- STATE MONAD ---
stateReturnImpl :  {A S}  CatComb A (getMonadType {A} (state {S = S}))
stateReturnImpl = cur  p₂ , p₁ 

stateBindImpl :  {A B S}  CatComb ((getMonadType {A} (state {S = S})) × (A  (getMonadType {B} (state {S = S})))) (getMonadType {B} (state {S = S}))
stateBindImpl = cur (app   app  p₁ , p₂     p₂ , p₂  p₁  , p₁  p₁    app  p₁ , p₂     p₁  p₁ , p₂  , p₂  p₁  )

--- NONDETERMINISM MONAD ---
nondetReturnImpl :  {A}  CatComb A (getMonadType {A} (nondeterminism))
nondetReturnImpl = cons   id , nil  ! 

concatList :  {A}  CatComb (list A × list A) (list A)
concatList = it id (cons  p₂)   p₂ , p₁ 

flat :  {A}  CatComb (list (list A)) (list A)
flat = it nil (concatList  p₂)   ! , id 

fmap :  {A B}  CatComb ((A  B) × list A) (list B)
fmap = it (nil  !) (cons   app   p₁ , p₁  p₂  , p₂  p₂ )

nondetBindImpl :  {A B}  CatComb ((getMonadType {A} (nondeterminism)) × (A  (getMonadType {B} (nondeterminism)))) (getMonadType {B} (nondeterminism))
nondetBindImpl = flat  fmap   p₂ , p₁ 

--- CONTINUATION MONAD ---
contReturnImpl :   {A R}  CatComb A (getMonadType {A} (continuation {R = R}))
contReturnImpl = cur (app   p₂ , p₁ )

contBindImpl :  {A B R}  CatComb ((getMonadType {A} (continuation {R = R})) × (A  (getMonadType {B} (continuation {R = R})))) (getMonadType {B} (continuation {R = R}))
contBindImpl = cur (app   p₁  p₁ , cur (app   app   p₂  p₁  p₁ , p₂  , p₂  p₁ ) )

--- RETURN and BIND implementations ---
getReturnImpl :  {A TA}  (m : Monad {A} TA)  CatComb A TA
getReturnImpl (exception {E = E}) = exceptionReturnImpl
getReturnImpl (state {S = S}) = stateReturnImpl
getReturnImpl (nondeterminism) = nondetReturnImpl
getReturnImpl (continuation) = contReturnImpl

getBindImpl :  {A B TA TB}  {ma : Monad {A} TA}  {mb : Monad {B} TB}  (s : SameMonad ma mb)  CatComb (TA × (A  TB)) TB
getBindImpl exception = exceptionBindImpl
getBindImpl state = stateBindImpl
getBindImpl nondeterminism = nondetBindImpl
getBindImpl continuation = contBindImpl

--- RETURN and BIND compiled instructions ---
getReturnInst :  {A TA}  (m : Monad {A} TA)  List Inst
getReturnInst exception = INL  []
getReturnInst state = CUR (PUSH  SND  SWAP  FST  PAIR  [])  []
getReturnInst nondeterminism = PUSH  SKIP  SWAP  UNIT  NIL  PAIR  C  []
getReturnInst continuation = CUR (PUSH  SND  SWAP  FST  PAIR  APP  [])  []

getBindInst :  {A B TA TB}  {ma : Monad {A} TA}  {mb : Monad {B} TB}  (s : SameMonad ma mb)  List Inst
getBindInst exception = PUSH  SND  SWAP  FST  PAIR  CASE (APP  []) (SND  INR  [])  []
getBindInst state = CUR (PUSH  PUSH  FST  FST  SWAP  SND  PAIR  SWAP  FST  SND  PAIR  PUSH  FST  APP  SWAP  SND  PAIR  PUSH  PUSH  SND  SWAP  FST  SND  PAIR  SWAP  FST  FST  PAIR  PUSH  FST  APP  SWAP  SND  PAIR  APP  [])  []
getBindInst nondeterminism = PUSH  SND  SWAP  FST  PAIR  IT (UNIT  NIL  []) (PUSH   PUSH   FST   SWAP  SND  FST  PAIR  APP  SWAP  SND  SND  PAIR  C  [])  PUSH  UNIT  SWAP  SKIP  PAIR  IT (NIL  []) (SND   PUSH   SND  SWAP  FST  PAIR  IT (SKIP  []) (SND  C  [])  [])  []
getBindInst continuation = CUR (PUSH  FST  FST  SWAP  CUR (PUSH  PUSH  FST  FST  SND  SWAP  SND  PAIR  APP  SWAP  FST  SND  PAIR  APP  [])  PAIR  APP  [])  []