Monday, March 22, 2010

Learning Clojure by writing a (very) minimal Lisp interpreter

I wanted to learn a bit of Clojure, but I only had a couple of Emacs Lisp notions.. So I thought it would be fun to try writing a mini Lisp interpreter, in Clojure (using Emacs, to add a level of self-reference!) and document the process. But first please consider (1) that I tried to not consult any book or website about parsing or interpreting Lisp for this exercise (other than my memories of some vague CS notions), and (2) that I was not, and have not become a Lisp expert in any way, so it is pretty obvious that this will contain blatant errors, lack of stylistic taste, grossly non-optimized algorithms and wrong (or abuse of) Lisp terminology.. Please read it only for the sake of the learning process of someone new to Lisp/CLojure, but very enthusiastic about it.
I'll be using clojure-contrib version 1.2 (from the current master branch), since its "string" API provides some very useful functions (that seem to not be available in the 1.1 version):

(require '[clojure.contrib.string :as str])

If we assume that we will be processing a mini Lisp source file as one big string, one first thing we can do is stripping the comments from it, using a simple regular expression:

(defn strip-comments [s]
  (str/join "" (str/split #";.*\n?" s)))

We can then tokenize the input string, by giving the partition function (from the "string" API) a regular expression that chunks the input file content using parentheses and white space separators (while not breaking string literals, which may contain such characters). Note that one important feature of the partition function (that we will make use of) is that it retains the separators in the result list, which is why we need to trim and filter it, in order to remove the unwanted empty strings:

(defn tokenize [s]
  (filter #(not (= "" %)) 
          (map #(str/trim %) 
               (str/partition #"\(| |\)|\"[^\"]*\"" 
                              (strip-comments s)))))

Next we'd like the numeric tokens to be cast to the proper numeric type:

; Try to cast string first as integer, then as float
(defn parse-num [s]
  (try (Integer/parseInt s)
       (catch NumberFormatException nfe 
         (try (Float/parseFloat s) 
              (catch NumberFormatException nfe s)))))

(defn process-numeric-tokens [tokens]
  (map #(parse-num %) tokens))

Once we have extracted and processed our list of tokens, we'd like to make sure that it's properly balanced, i.e. that it contains the same number of "("s than ")"s:

(defn balanced? [tokens]
  (loop [i 0, nop 0, ncp 0]
    (if (>= i (count tokens))
      (= nop ncp)
      (cond 
        (= (nth tokens i) "(") (recur (inc i) (inc nop) ncp)          
        (= (nth tokens i) ")") (recur (inc i) nop (inc ncp))                  
        :default (recur (inc i) nop ncp)))))

Being done with the preprocessing, we are now ready to parse our list of tokens, that is, build a syntax tree representation of the input program, in which every atom is a leaf, and every list is a subtree. Thanks to the particular Lisp syntax, this is relatively easy.. although I must admit that I didn't find the solution as easily as this sounds:

(defn build-syntax-tree [tokens]
  (loop [i 0, node ()]
    (if (>= i (count tokens))
      node
      (let [token (nth tokens i)]
        (cond
           (= "(" token) ; create and visit a new node level, 
                         ; and add it to current node
             (let [sub-node (build-syntax-tree (drop (inc i) tokens))]
               (recur (+ i (count-tokens sub-node)) 
                      (concat node (list sub-node))))
           (= ")" token) ; return node
             node
           :default ; add token to current node
             (recur (inc i) (concat node (list token))))))))

As you have noticed, this recursive function relies on two additional self-explanatory functions:

(defn atom? [elem]
  (not (seq? elem)))

(defn count-tokens [node]
  (loop [i 0, counter 2] ; count opening and closing parentheses
    (if (>= i (count node))
      counter
      (let [elem (nth node i)]
        (if (atom? elem)
          (recur (inc i) (inc counter))
          (recur (inc i) (+ counter (count-tokens elem))))))))

We now need to extract symbols from this syntax tree: functions, defined with the defun special form (remember this is a Lisp interpreter, not a Clojure one), and variables, defined with setq (setf would be much harder to implement I think, and is thus outside the scope of this simple interpreter). Each symbol will be stored in a Clojure map, and will point to a Clojure vector of two elements in the case of variables (a "var" identifier, and the variable node, as found in the syntax tree) and three elements in the case of functions (a "function" identifier, the list of function parameters, and a list of the function nodes, as found in the syntax tree). All this is accomplished with this function:

; symbol-table: { var-symbol -> ["var", var-node],
;                 fn-symbol  -> ["function", params, fn-nodes] }, 
;                                 where fn-nodes is a list of function 
;                                 expressions, or "statements"
(defn extract-symbol-table [node]
  (loop [i 0, symbol-table {}]
    (if (>= i (count node))
      symbol-table 
      (let [elem (nth node i)]
        (if (atom? elem)
          (cond 
            (= elem "setq")
              (do
                (assert (zero? i))
                (assert (= (count node) 3))
                (recur (inc i) 
                       (assoc symbol-table (nth node 1) 
                              ["var" (nth node 2)])))
            (= elem "defun")
              (do
                (assert (zero? i))
                (recur (inc i) 
                       (assoc symbol-table (nth node 1) 
                              ["function" (nth node 2) (drop 3 node)])))
            :default
              (recur (inc i) symbol-table))
          (recur (inc i) (merge symbol-table (extract-symbol-table elem))))))))

We can now define a function that will execute a node, recursively of course:

(defn execute-node [node symbol-table]
  (if (atom? node)
    ; it's an atom
    (cond
      (contains? #{"true" "t"} node) true
      (contains? #{"false" "f" "nil"} node) false
      (contains? symbol-table node) ; try to replace atom by corresponding var-node, if it exists
        (execute-node (second (get symbol-table node)) symbol-table)
      :default node)
    ; it's a node 
    (cond 
       (= (first node) "if") ; "if" node
         (do
           (assert (= (count node) 4))
           (execute-function "if" (rest node) symbol-table))
       (or (= (first node) "setq") (= (first node) "defun")) ; dont execute those
         nil
       :default
         (loop [i 0, function nil, params nil] ; not if
           (if (>= i (count node))
             (execute-function function params symbol-table)
             (let [elem (nth node i)]
               (if (zero? i)
                 (recur (inc i) elem ()) ; set 1rst elem as function
                 (let [exec-node (execute-node elem symbol-table)]
                     (recur (inc i) function 
                            (concat params (list exec-node))))))))))) ; accum params


When it reaches a function node, that is, a node that should do something, this is where we inject semantics (admittedly, quite a simple one) into our interpreter, that was concerned until now only with the "what" to compute, and not with the "how" to do it. On a side note, I would say that this is the place where I reached the "relevance limit" of this learning project: in particular, lists are being implemented.. with lists! I guess that if I would have gone the hardcore way, I would have implemented my own lists from lower-level components, but to keep things simple, I figured this would make an acceptable place to stop (obviously a much sillier place to stop would have been to simply "eval" the whole thing in the first place!).. Anyway, here is the deeper function:

; function: symbol
; params: list
; symbol-table: map
(defn execute-function [function params symbol-table]
  (cond 
    (contains? symbol-table function) ; user-defined function
      (let [fn-cell (get symbol-table function)]
        (let [fn-params (nth fn-cell 1)
              fn-nodes (nth fn-cell 2)]
          (let [fn-instance (get-function-instance fn-nodes (zipmap fn-params params))]
            (doseq [fn-inst-node (take (dec (count fn-instance)) fn-instance)]
              (execute-node fn-inst-node symbol-table))
            (execute-node (last fn-instance) symbol-table))))
    (= function "+")
      (apply + params)
    (= function "-")
      (apply - params)
    (= function "print")
      (if (= (:dtype (first params)) java.lang.String)
          (println (str/replace-str "\"" "" (first params))) ; remove enclosing double quotes
          (println (first params)))
    (= function "list")
      params
    (= function "count")
      (count (first params))
    (= function "car")
      (first (first params))
    (= function "cdr")
      (rest (first params))
    (= function "=")
      (= (first params) (second params))
    (= function "if")
      (if (execute-node (first params) symbol-table)
        (execute-node (second params) symbol-table)
        (execute-node (nth params 2) symbol-table))
    (= function "progn")
      (do ; multi-statements emulation
        (doseq [node (take (dec (count params)) params)]
          (execute-node node symbol-table))        
        (execute-node (last params) symbol-table))
    :default
      (if function
        (cons function params)
        ())
    ))

For user-defined functions, i.e. symbols that can be found in our previously extracted symbol table, we require an extra step. We need to "instantiate" the function, that is, replace all the formal parameter symbols in its definition by their actual value when the function is called. This is first accomplished by creating a map that associates every symbol to their value (using the "zipmap" function), and then pass it to this function, that will do the rest of the job:

; Returns a function-node in which all instances of params found in 
; param-map have been replaced by their actual values
(defn get-function-instance [fn-nodes param-map]
  (loop [i 0, fn-instance ()]
    (if (>= i (count fn-nodes))
      fn-instance
      (let [elem (nth fn-nodes i)]
        (if (atom? elem)
          (if (contains? param-map elem) ; match found: replace param by param-node
            (recur (inc i) (concat fn-instance (list (get param-map elem))))
            (recur (inc i) (concat fn-instance (list elem))))
          (recur (inc i) (concat fn-instance (list (get-function-instance elem param-map)))))))))

The last step is easy: we need a top-level function to execute the whole input program source, node by node (after after having performed all the preprocessing steps that we described previously):

(defn execute-program [input-str]
  (let [tokens (process-numeric-tokens (tokenize input-str))]
    (if (not (balanced? tokens))
      (println "Unbalanced parentheses problem")      
      (let [syntax-tree (build-syntax-tree tokens)]
        (let [symbol-table (extract-symbol-table syntax-tree)]
          (doseq [node syntax-tree]
            (execute-node node symbol-table)))))))

(execute-program (slurp (first *command-line-args*)))

Finally, this is the kind of mini Lisp programs that it can run:

; Remember this is mini Lisp, not Clojure!
(setq *g* 40)

(defun add (x y)
  (+ x y))

(defun minus (x y)
  (- x y))

(defun test (p)
  (add p (minus p *g*)))

(print (test 3)) ; Should give -34

One thing I really wanted to be able to do, and was a bit anxious about because I didn't dare to "execute it in my head" before trying it.. was recursion, which I was happily surprised to see handled gracefully:

(defun recursive-count (p)
  (if (= (list) p)
    0
    (+ 1 (recursive-count (cdr p)))))

(print (recursive-count (list 1 2 3 4))) ; Should give 4

As you'll see if you try, this interpreter is ridiculously easy to break or confound, with very minimal effort.. but then the goal was simply to learn. The code is available on GitHub.