from cammylib.sexp import Atom, Functor, Hole, unhole

class UnificationFailed(Exception):
    def __init__(self, message):
        self.message = message

def occursCheck(index, var):
    "If an index occurs in a variable, raise an error."
    if var.occurs(index):
        raise UnificationFailed("Occurs check: %d in %s" %
                (index, var.asStr()))

class ConstraintStore(object):
    "A Kanren-style constraint store."

    def __init__(self):
        self.i = 0
        self.constraints = []
        self.knownGivens = []

    def givens(self, index):
        while len(self.knownGivens) <= index:
            self.knownGivens.append((self.fresh(), self.fresh()))
        return self.knownGivens[index]

    def fresh(self):
        rv = self.i
        self.i += 1
        self.constraints.append(Hole(rv))
        return rv

    def concrete(self, symbol):
        rv = self.i
        self.i += 1
        self.constraints.append(Atom(symbol))
        return rv

    def functor(self, constructor, arguments):
        rv = self.i
        self.i += 1
        args = [Hole(arg) for arg in arguments]
        self.constraints.append(Functor(constructor, args))
        return rv

    def walk(self, i):
        var = self.constraints[i]
        if isinstance(var, Hole) and var.index != i:
            return self.walk(var.index)
        else:
            return var

    def unify(self, i, j):
        vi = self.walk(i)
        vj = self.walk(j)
        # print "unifying", vi.asStr(), vj.asStr()
        if isinstance(vi, Hole):
            if isinstance(vj, Hole) and vi.index == vj.index:
                return
            occursCheck(vi.index, vj)
            self.constraints[vi.index] = vj
        elif isinstance(vj, Hole):
            occursCheck(vj.index, vi)
            self.constraints[vj.index] = vi
        elif isinstance(vi, Atom) and isinstance(vj, Atom):
            if vi.symbol != vj.symbol:
                raise UnificationFailed(
                    "Can't unify constant types: %s vs. %s" %
                    (vi.symbol, vj.symbol))
        elif isinstance(vi, Functor) and isinstance(vj, Functor):
            if vi.constructor != vj.constructor:
                raise UnificationFailed(
                    "Can't unify compound types: %s vs. %s" %
                    (vi.constructor, vj.constructor))
            if len(vi.arguments) != len(vj.arguments):
                raise UnificationFailed("Compound types have different arity?")
            for i, argi in enumerate(vi.arguments):
                self.unify(unhole(argi), unhole(vj.arguments[i]))
        else:
            raise UnificationFailed("Quite different types: %s. vs %s" %
                    (vi.asStr(), vj.asStr()))


LETTERS = "XYZWSTPQ"

class TypeExtractor(object):
    def __init__(self, cs):
        self.cs = cs
        self.d = {}
        self.seen = []

    def addTypeAlias(self, index):
        self.d[index] = LETTERS[len(self.d)]

    def findTypeAlias(self, index):
        if index not in self.d:
            self.addTypeAlias(index)
        return self.d[index]

    def extract(self, var):
        if var in self.seen:
            # XXX split exceptions?
            raise UnificationFailed("tried to extract infinite type")
        self.seen.append(var)

        sexp = self.cs.walk(var)
        rv = sexp.extractType(self)

        self.seen.pop()
        return rv
