import Data.Maybe
import Control.Monad.State

data Val = N Int | S String | B Bool 
  deriving (Eq, Show)

data Fun = Def String | MixDef String [Val] 
  deriving (Eq, Show)
  
data E = Val {val::Val} 
       | Var String
       | If E E E
       | Call Fun [E]
       | Prim String [E]
     deriving (Eq, Show)

isVal :: E -> Bool  
isVal (Val v) = True
isVal _ = False

type Prog = [(Fun, ([String], E))]
type Env = [(String, Val)]

-- evaluator
eval :: Prog -> Env -> E -> Val
eval p r (Val v) = v
eval p r (Var x) = fromJust (lookup x r)
eval p r (If c t e) = 
  if (eval p r c) == B True then 
      eval p r t
  else
      eval p r e
eval p r (Call f args) =
  let (formals, body) = fromJust $ lookup f p
      r' = zip formals (map (eval p r) args) ++ r
  in
      eval p r' body 
eval p r (Prim f args) =
  prim f (map (eval p r) args)

-- evalute a primitive function
prim :: String -> [Val] -> Val
prim "iszero" [N v] = 
  B (v == 0)
prim "dec" [N v] = 
  N (v - 1)
prim "*" [N v1, N v2] = 
  N (v1 * v2)

-- sample program
prog :: Prog
prog = [(Def "pow", (["n", "m"], 
            If (Prim "iszero" [Var "n"])
               (Val (N 1))
               (Prim "*" [Var "m", 
                         Call (Def "pow") [
                             Prim "dec" [Var "n"], 
                             Var "m"]])))]

-- online partial evaluator
mix :: Env -> E -> State Prog E
mix r (Val v) = return $ Val v

mix r (Var x) =
  case lookup x r of
    Just v -> return $ Val v
    Nothing -> return $ Var x

mix r (If c t e) = do
  c' <- mix r c
  case c' of
    Val (B True) -> mix r t
    Val (B False) -> mix r e
    _ -> do 
      t' <- mix r t
      e' <- mix r e
      return $ If c' t' e'

mix r (Call (Def f) args) = do
  args' <- mapM (mix r) args
  p <- get
  let (valExps, exps) = span isVal $ args' in
   let vals = map val valExps in
    case lookup (MixDef f vals) p of
      Just _ -> return $ Call (MixDef f vals) exps
      Nothing -> 
        let (formals, body) = fromJust $ lookup (Def f) p in do
          body' <- mix (zip formals vals ++ r) body
          add (MixDef f vals, (drop (length vals) formals, body'))
          return $ Call (MixDef f vals) exps

mix r (Prim f args) = do {
  args' <- mapM (mix r) args;
  if all isVal args' then
    return $ Val $ prim f (map val args')
  else
    return $ Prim f args'
  }

-- adds an element to a list-based state monad
add :: a -> State [a] ()
add x = do
  p <- get
  put (x : p)
  
  

main = do 
  print $ eval prog [] (Call (Def "pow") [Val $ N 2, Val $ N 5])
  print $ eval prog [] (Call (Def "pow") [Val $ N 5, Val $ N 10])
  let (e, p') = runState (mix [] (Call (Def "pow") [Val $ N 2, Var "x"])) prog in do
    print e
    sequence (map print p')
    return ()
  