=begin
MiniKanren Copyright (C) 2006 Scott Dial

This library is free software; you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License as published by the Free
Software Foundation; either version 2.1 of the License, or (at your option) any
later version.

This library is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public License along
with this library; if not, write to the Free Software Foundation, Inc.,
59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
=end

module MiniKanren
  class Var; end
  class Subst < Hash; end

  def MiniKanren.unify(u, v, s)
    u = MiniKanren::walk(u, s)
    v = MiniKanren::walk(v, s)

    if u.equal?(v)
      return s
    elsif u.instance_of?(Var)
      if v.instance_of?(Var)
        return ext_s(u, v, s)
      else
        return ext_s_check(u, v, s)
      end
    elsif v.instance_of?(Var)
      return ext_s_check(v, u, s)
    elsif u.instance_of?(Array) && v.instance_of?(Array)
      if u.length != v.length
        return nil
      elsif u.empty? && v.empty?
        return s
      else
        u.zip(v).each { |u, v|
          s = unify(u, v, s)
          break if s.nil? }
        return s
      end
    elsif u == v
      return s
    else
      return nil
    end
  end

  def MiniKanren.walk(v, s)
    if v.instance_of?(Var) && s.has_key?(v)
      walk(s[v], s)
    else
      v
    end
  end

  def MiniKanren.ext_s_check(x, v, s)
    if MiniKanren::occurs_check(x, v, s)
      nil
    else
      MiniKanren::ext_s(x, v, s)
    end
  end

  def MiniKanren.occurs_check(x, v, s)
    v = walk(v, s)

    if v.instance_of?(Var)
      return x.equal?(v)
    elsif v.instance_of?(Array)
      r = v.find { |vv| occurs_check(x, vv, s) == true }
      return !r.nil?
    else
      return false
    end
  end

  def MiniKanren.ext_s(x, v, s)
    s[x] = v
    s
  end

  def MiniKanren.reify_s(v, s)
    v = walk(v, s)

    if v.instance_of?(Var)
      ext_s(v, reify_name(s.length), s)
    elsif v.instance_of?(Array) && v.length
      v.each { |v| s = reify_s(v, s) }
      s
    else
      s
    end
  end

  def MiniKanren.reify(v, s)
    v = walk_all(v, s)

    walk_all(v, reify_s(v, Subst.new))
  end

  def MiniKanren.walk_all(w, s)
    v = walk(w, s)

    if v.instance_of?(Array)
      v.map { |v| walk_all(v, s) }
    else
      v
    end
  end

  def MiniKanren.reify_name(n)
    "_." + n.to_s
  end

  def MiniKanren.mplus(ss, f)
    if ss.nil?
      f.call
    elsif ss.instance_of?(Proc)
      lambda { mplus(f.call, ss) }
    elsif ss.instance_of?(Array)
      [ss[0], lambda { mplus(ss[1].call, f) }]
    else
      [ss, f]
    end
  end

  def MiniKanren.take(n, f)
    res = []
    while !n || n > 0
      ss = f.call
      if ss.nil?
        return res
      elsif ss.instance_of?(Proc)
        f = ss
      elsif ss.instance_of?(Array)
        n -= 1 if n
        res << ss[0]
        f = ss[1]
      else
        res << ss
        return res
      end
    end
    res
  end

  def MiniKanren.bind(ss, goal)
    if ss.nil?
      nil
    elsif ss.instance_of?(Proc)
      lambda { bind(ss.call, goal) }
    elsif ss.instance_of?(Array)
      mplus(goal.call(ss[0]), lambda { bind(ss[1].call, goal) })
    else
      goal.call(ss)
    end
  end

  def MiniKanren.mplus_all(goals, s)
    if goals.length == 1
      goals[0].call(s)
    else
      mplus(goals[0].call(s.clone), lambda { mplus_all(goals[1..-1], s) })
    end
  end

  def eq(u, v)
    lambda { |s|
      s = MiniKanren::unify(u, v, s)
      s.nil? ? nil : s }
  end

  def all(*goals)
    return succeed if goals.length == 0
    lambda { |s|
      goals.each { |goal| s = MiniKanren::bind(s, goal) }
      s }
  end

  def any(*goals)
    return succeed if goals.length == 0
    lambda { |s| lambda { MiniKanren::mplus_all(goals, s) } }
  end

  def defer(func, *args)
    if func.arity >= 0
      fixed_arity = func.arity
      variadic = false
    else
      fixed_arity = func.arity.abs - 1
      variadic = true
    end
    if fixed_arity > args.length || (!variadic && fixed_arity > args.length)
      raise ArgumentError, "(#{func}) wrong number of arguments " +
                           "(#{args.length} for #{fixed_arity})"
    end
    lambda { |s| func.call(*args).call(s) }
  end

  def succeed
    lambda { |s| s }
  end

  def fail
    lambda { |s| nil }
  end

  # fresh { |q| all(
  #   eq(q, true),
  #   fresh { |q|
  #     eq(q, false) } ) }
  #
  # x = fresh
  # x, y = fresh(2)
  def fresh(n = -1, &block)
    if block.nil?
      if n == -1
        Var.new
      else
        vars = []
        for i in 1..n
          vars << Var.new
        end
        vars
      end
    else
      vars = []
      for i in 1..block.arity
        vars << Var.new
      end
      block.call(*vars)
    end
  end

  # infer(var, goal0, goal*, ...)
  # infer(n, var, goal0, goal*, ...)
  def infer(*args)
    if args[1].instance_of?(Proc)
      n, v, *goals = false, *args
    else
      n, v, *goals = args
    end

    if goals.length == 1
      goal = goals[0]
    else
      goal = all(*goals)
    end

    ss = MiniKanren::take(n, lambda { goal.call(Subst.new) })
    ss.map! { |s| MiniKanren::reify(v, s) }
  end
end

if $0 == __FILE__
  eval DATA.read, nil, $0, __LINE__+4
end

__END__

require 'test/unit'

class TC_MiniKanren < Test::Unit::TestCase
  include MiniKanren

  def test_all
    q = fresh
    assert_equal(infer(q, fail), [])
    assert_equal(infer(q, eq(true, q)), [true])
    assert_equal(infer(q, all(fail, eq(true, q))), [])
    assert_equal(infer(q, all(succeed, eq(true, q))), [true])
    assert_equal(infer(q, all(succeed, eq(:corn, q))), [:corn])
    assert_equal(infer(q, all(fail, eq(:corn, q))), [])
    assert_equal(infer(q, all(succeed, eq(false, q))), [false])
    x = fresh
    assert_equal(infer(q, all(eq(true, x), eq(true, q))), [true])
    assert_equal(infer(q, all(eq(x, true), eq(true, q))), [true])

    assert_equal(infer(q, succeed), ["_.0"])
    assert_equal(infer(q, succeed), ["_.0"])

    x, y = fresh(2)
    assert_equal(infer(q, eq([x, y], q)), [["_.0", "_.1"]])
    t, u = fresh(2)
    assert_equal(infer(q, eq([t, u], q)), [["_.0", "_.1"]])

    x = fresh
    y = x
    x = fresh
    assert_equal(infer(q, eq([y, x, y], q)), [["_.0", "_.1", "_.0"]])

    assert_equal(infer(q, all(eq(false, q), eq(true, q))), [])
    assert_equal(infer(q, all(eq(false, q), eq(false, q))), [false])

    x = q
    assert_equal(infer(q, eq(true, x)), [true])

    x = fresh
    assert_equal(infer(q, eq(x, q)), ["_.0"])
    assert_equal(infer(q, all(eq(true, x), eq(x, q))), [true])
    assert_equal(infer(q, all(eq(x, q), eq(true, x))), [true])
  end

  def test_any
    q, x = fresh(2)

    assert_equal(infer(q, eq(x == q, q)), [false])

    assert_equal(infer(q, any(
                              all(fail, succeed),
                              all(succeed, fail))),
                 [])
    assert_equal(infer(q, any(
                              all(fail, fail),
                              all(succeed, succeed))),
                 ["_.0"])
    assert_equal(infer(q, any(
                              all(succeed, succeed),
                              all(fail, fail))),
                 ["_.0"])
    assert_equal(infer(q, any(
                              all(eq(:olive, q), succeed),
                              all(eq(:oil, q), succeed))),
                 [:olive, :oil])
    assert_equal(infer(1, q, any(
                                 all(eq(:olive, q), succeed),
                                 all(eq(:oil, q), succeed))),
                 [:olive])
    assert_equal(infer(q, any(
                              all(eq(:virgin, q), fail),
                              all(eq(:olive, q), succeed),
                              all(succeed, succeed),
                              all(eq(:oil, q), succeed))),
                 [:olive, "_.0", :oil])
    assert_equal(infer(q, any(
                              all(eq(:olive, q), succeed),
                              all(succeed, succeed),
                              all(eq(:oil, q), succeed))),
                 [:olive, "_.0", :oil])
    assert_equal(infer(2, q, any(
                              all(eq(:extra, q), succeed),
                              all(eq(:virgin, q), fail),
                              all(eq(:olive, q), succeed),
                              all(eq(:oil, q), succeed))),
                 [:extra, :olive])
    x, y = fresh(2)
    assert_equal(infer(q, all(
                              eq(:split, x),
                              eq(:pea, y),
                              eq([x, y], q))),
                 [[:split, :pea]])
    assert_equal(infer(q, all(
                              any(
                                all(eq(:split, x), eq(:pea, y)),
                                all(eq(:navy, x), eq(:bean, y))),
                              eq([x, y], q))),
                 [[:split, :pea], [:navy, :bean]])
    assert_equal(infer(q, all(
                              any(
                                all(eq(:split, x), eq(:pea, y)),
                                all(eq(:navy, x), eq(:bean, y))),
                              eq([x, y, :soup], q))),
                 [[:split, :pea, :soup], [:navy, :bean, :soup]])

    def teacupo(x)
      any(
        all(eq(:tea, x), succeed),
        all(eq(:cup, x), succeed))
    end

    assert_equal(infer(q, teacupo(q)), [:tea, :cup])

    assert_equal(infer(q, all(
                              any(
                                all(teacupo(x), eq(true, y), succeed),
                                all(eq(false, x), eq(true, y))),
                              eq([x, y], q))),
                 [[false, true], [:tea, true], [:cup, true]])

    x, y, z = fresh(3)
    x_ = fresh
    assert_equal(infer(q, all(
                              any(
                                all(eq(y, x), eq(z, x_)),
                                all(eq(y, x_), eq(z, x))),
                              eq([y, z], q))),
                 [["_.0", "_.1"], ["_.0", "_.1"]])

    assert_equal(infer(q, all(
                              any(
                                all(eq(y, x), eq(z, x_)),
                                all(eq(y, x_), eq(z, x))),
                              eq(false, x),
                              eq([y, z], q))),
                 [[false, "_.0"], ["_.0", false]])

    a = eq(true, q)
    b = eq(false, q)
    assert_equal(infer(q, b), [false])

    x = fresh
    b = all(
          eq(x, q),
          eq(false, x))
    assert_equal(infer(q, b), [false])

    x, y = fresh(2)
    assert_equal(infer(q, eq([x, y], q)), [["_.0", "_.1"]])

    v, w = fresh(2)
    x, y = v, w
    assert_equal(infer(q, eq([x, y], q)), [["_.0", "_.1"]])
  end

  def test_functions
    q = fresh

    def nullo(l)
      eq(l, [])
    end

    def conso(a, d, p)
      eq([a, d], p)
    end

    def pairo(p)
      a, d = fresh(2)
      conso(a, d, p)
    end

    def cdro(p, d)
      a = fresh
      conso(a, d, p)
    end

    def caro(p, a)
      d = fresh
      conso(a, d, p)
    end

    assert_equal(infer(q, all(pairo([q, q]), eq(true, q))), [true])
    assert_equal(infer(q, all(pairo([]), eq(true, q))), [])

    def listo(l)
      d = fresh
      any(
        all(nullo(l), succeed),
        all(pairo(l), cdro(l, d), defer(method(:listo), d)))
    end

    assert_equal(infer(q, listo([:a, [:b, [q, [:d, []]]]])), ["_.0"])

    assert_equal(infer(5, q, listo([:a, [:b, [:c, q]]])),
                 [[],
                  ["_.0", []],
                  ["_.0", ["_.1", []]],
                  ["_.0", ["_.1", ["_.2", []]]],
                  ["_.0", ["_.1", ["_.2", ["_.3", []]]]]])
  end

  def test_nesting
    fresh { |q|
      assert_equal(infer(q, fresh { |q| eq(q, false) }), ["_.0"])
    }
  end
end