import math

from rpython.rlib.jit import JitDriver, isconstant, we_are_jitted
from rpython.rlib.rbigint import rbigint


class TypeFail(Exception):
    def __init__(self, reason):
        self.reason = reason

class Element(object):
    _immutable_ = True

    def b(self):
        raise TypeFail("not b")

    def f(self):
        raise TypeFail("not f")

    def l(self):
        raise TypeFail("not list")

    def first(self):
        raise TypeFail("not a pair")

    def second(self):
        raise TypeFail("not a pair")

    def tagged(self, f, g):
        raise TypeFail("not tag")

    def apply(self, y):
        raise TypeFail("not hom")

class T(Element):
    _immutable_ = True
    def asStr(self): return "*"

class B(Element):
    _immutable_ = True
    def __init__(self, b): self._b = b
    def b(self): return self._b
    def asStr(self): return "true" if self._b else "false"

class N(Element):
    _immutable_ = True
    def __init__(self, bi): self._bi = bi
    def n(self): return self._bi
    def asStr(self): return self._bi.str()

class F(Element):
    _immutable_ = True

    def __init__(self, f):
        if not we_are_jitted() and math.isnan(f):
            raise TypeFail("runtime NaN")
        self._f = f

    def f(self):
        return self._f

    def asStr(self): return str(self._f)

class P(Element):
    _immutable_ = True

    def __init__(self, x, y):
        self._x = x
        self._y = y

    def first(self):
        return self._x

    def second(self):
        return self._y

    def asStr(self): return "(%s, %s)" % (self._x.asStr(), self._y.asStr())

class L(Element):
    _immutable_ = True

    def __init__(self, x):
        self._x = x

    def tagged(self, f, g):
        return f.run(self._x)

    def asStr(self):
        return "L(%s)" % self._x.asStr()

class R(Element):
    _immutable_ = True

    def __init__(self, x):
        self._x = x

    def tagged(self, f, g):
        return g.run(self._x)

    def asStr(self):
        return "R(%s)" % self._x.asStr()

class H(Element):
    _immutable_ = True
    def __init__(self, f, x):
        self._f = f
        self._x = x
    def apply(self, y):
        return self._f.run(P(self._x, y))
    def asStr(self):
        return "%s @ %s" % (self._f, self._x.asStr())

class Xs(Element):
    _immutable_ = True
    _immutable_fields_ = "xs[*]",
    def __init__(self, xs):
        self.xs = xs
    def l(self): return self.xs
    def asStr(self):
        return "[%s]" % ", ".join([x.asStr() for x in self.xs])

class Arrow(object):
    _immutable_ = True

    def boundVars(self):
        return self.domain.boundVars() | self.codomain.boundVars()

class Id(Arrow):
    _immutable_ = True
    def run(self, x): return x
    def types(self, cs):
        rv = cs.fresh()
        return rv, rv

class Comp(Arrow):
    _immutable_ = True
    def __init__(self, f, g):
        self.f = f
        self.g = g
    def run(self, x): return self.g.run(self.f.run(x))
    def types(self, cs):
        fdom, fcod = self.f.types(cs)
        gdom, gcod = self.g.types(cs)
        cs.unify(fcod, gdom)
        return fdom, gcod

class Ignore(Arrow):
    _immutable_ = True
    def run(self, x): return T()
    def types(self, cs): return cs.fresh(), cs.concrete("1")

class First(Arrow):
    _immutable_ = True
    def run(self, x): return x.first()
    def types(self, cs):
        left = cs.fresh()
        pair = cs.functor("pair", [left, cs.fresh()])
        return pair, left

class Second(Arrow):
    _immutable_ = True
    def run(self, x): return x.second()
    def types(self, cs):
        right = cs.fresh()
        pair = cs.functor("pair", [cs.fresh(), right])
        return pair, right

class Pair(Arrow):
    _immutable_ = True
    def __init__(self, f, g):
        self.f = f
        self.g = g
    def run(self, x): return P(self.f.run(x), self.g.run(x))
    def types(self, cs):
        fdom, fcod = self.f.types(cs)
        gdom, gcod = self.g.types(cs)
        cs.unify(fdom, gdom)
        return fdom, cs.functor("pair", [fcod, gcod])

class Left(Arrow):
    _immutable_ = True
    def run(self, x): return L(x)
    def types(self, cs):
        rv = cs.fresh()
        sum = cs.functor("sum", [rv, cs.fresh()])
        return rv, sum

class Right(Arrow):
    _immutable_ = True
    def run(self, x): return R(x)
    def types(self, cs):
        rv = cs.fresh()
        sum = cs.functor("sum", [cs.fresh(), rv])
        return rv, sum

class Case(Arrow):
    _immutable_ = True
    def __init__(self, f, g):
        self.f = f
        self.g = g
    def run(self, x): return x.tagged(self.f, self.g)
    def types(self, cs):
        fdom, fcod = self.f.types(cs)
        gdom, gcod = self.g.types(cs)
        cs.unify(fcod, gcod)
        return cs.functor("sum", [fdom, gdom]), fcod

class Curry(Arrow):
    _immutable_ = True
    def __init__(self, f): self._f = f
    def run(self, x): return H(self._f, x)
    def types(self, cs):
        fdom, fcod = self._f.types(cs)
        x = cs.fresh()
        y = cs.fresh()
        cs.unify(fdom, cs.functor("pair", [x, y]))
        return x, cs.functor("hom", [y, fcod])

class Uncurry(Arrow):
    _immutable_ = True
    def __init__(self, f): self._f = f
    def run(self, x): return self._f.run(x.first()).apply(x.second())
    def types(self, cs):
        fdom, fcod = self._f.types(cs)
        x = cs.fresh()
        y = cs.fresh()
        cs.unify(fcod, cs.functor("hom", [x, y]))
        return cs.functor("pair", [fdom, x]), y

class Either(Arrow):
    _immutable_ = True
    def run(self, x): return L(T()) if x.b() else R(T())
    def types(self, cs):
        one = cs.concrete("1")
        return cs.concrete("2"), cs.functor("sum", [one, one])

class TrueArr(Arrow):
    _immutable_ = True
    def run(self, x): return B(True)
    def types(self, cs): return cs.concrete("1"), cs.concrete("2")

class FalseArr(Arrow):
    _immutable_ = True
    def run(self, x): return B(False)
    def types(self, cs): return cs.concrete("1"), cs.concrete("2")

class NotArr(Arrow):
    _immutable_ = True
    def run(self, x): return B(not x.b())
    def types(self, cs): return cs.concrete("2"), cs.concrete("2")

class Conj(Arrow):
    _immutable_ = True
    def run(self, x): return B(x.first().b() and x.second().b())
    def types(self, cs):
        two = cs.concrete("2")
        return cs.functor("pair", [two, two]), two

class Disj(Arrow):
    _immutable_ = True
    def run(self, x): return B(x.first().b() or x.second().b())
    def types(self, cs):
        two = cs.concrete("2")
        return cs.functor("pair", [two, two]), two

class Zero(Arrow):
    _immutable_ = True
    def run(self, x): return N(rbigint.fromint(0))
    def types(self, cs): return cs.concrete("1"), cs.concrete("N")

class Succ(Arrow):
    _immutable_ = True
    def run(self, x): return N(x.n().int_add(1))
    def types(self, cs): return cs.concrete("N"), cs.concrete("N")

pr_driver = JitDriver(name="pr",
        greens=["pr"], reds=["n", "rv"],
        is_recursive=True)

class PrimRec(Arrow):
    _immutable_ = True
    def __init__(self, x, f):
        self._x = x
        self._f = f

    def run(self, x):
        n = x.n()
        rv = self._x.run(T())
        while n.tobool():
            pr_driver.jit_merge_point(pr=self, n=n, rv=rv)
            n = n.int_sub(1)
            rv = self._f.run(rv)
        return rv

    def types(self, cs):
        xdom, xcod = self._x.types(cs)
        fdom, fcod = self._f.types(cs)
        cs.unify(xdom, cs.concrete("1"))
        cs.unify(xcod, fcod)
        cs.unify(fdom, fcod)
        return cs.concrete("N"), fcod

class Nil(Arrow):
    _immutable_ = True
    def run(self, x): return Xs([])
    def types(self, cs):
        return cs.concrete("1"), cs.functor("list", [cs.fresh()])

class Cons(Arrow):
    _immutable_ = True
    def run(self, x): return Xs([x.first()] + x.second().l())
    def types(self, cs):
        x = cs.fresh()
        xs = cs.functor("list", [x])
        return cs.functor("pair", [x, xs]), xs

fold_driver = JitDriver(name="fold",
        greens=["fold"], reds=["element"],
        is_recursive=True)

def driveFold(fold, element):
    fold_driver.jit_merge_point(fold=fold, element=element)
    return fold.run(element)

class Fold(Arrow):
    _immutable_ = True
    def __init__(self, n, c):
        self._n = n
        self._c = c
    def run(self, x):
        rv = self._n.run(T())
        for e in x.l():
            rv = driveFold(self._c, P(e, rv))
        return rv
    def types(self, cs):
        ndom, ncod = self._n.types(cs)
        cdom, ccod = self._c.types(cs)
        cs.unify(ndom, cs.concrete("1"))
        x = cs.fresh()
        cs.unify(cdom, cs.functor("pair", [x, ccod]))
        cs.unify(ncod, ccod)
        return cs.functor("list", [x]), ccod

class FZero(Arrow):
    _immutable_ = True
    def run(self, x): return F(0.0)
    def types(self, cs): return cs.concrete("1"), cs.concrete("F")

class FOne(Arrow):
    _immutable_ = True
    def run(self, x): return F(1.0)
    def types(self, cs): return cs.concrete("1"), cs.concrete("F")

class FPi(Arrow):
    _immutable_ = True
    def run(self, x): return F(math.pi)
    def types(self, cs): return cs.concrete("1"), cs.concrete("F")

def sign(f): return math.copysign(1.0, f) > 0.0
class FSign(Arrow):
    _immutable_ = True
    def run(self, x): return B(sign(x.f()))
    def types(self, cs): return cs.concrete("F"), cs.concrete("2")

class FFloor(Arrow):
    _immutable_ = True
    def run(self, x):
        try:
            return L(F(float(math.floor(x.f()))))
        except (ValueError, OverflowError):
            return R(T())
    def types(self, cs):
        return (cs.concrete("F"),
                cs.functor("sum", [cs.concrete("F"), cs.concrete("1")]))

class FNegate(Arrow):
    _immutable_ = True
    def run(self, x): return F(-x.f())
    def types(self, cs): return cs.concrete("F"), cs.concrete("F")

INF = float("inf")
class FRecip(Arrow):
    _immutable_ = True
    def run(self, x):
        # Same logic as Typhon. ~ C.
        f = x.f()
        if f == 0.0:
            return F(math.copysign(INF, f))
        else:
            return F(1.0 / f)
    def types(self, cs): return cs.concrete("F"), cs.concrete("F")

class FLT(Arrow):
    _immutable_ = True
    def run(self, x):
        f1 = x.first().f()
        f2 = x.second().f()
        return B(True) if f1 == -0.0 and f2 == 0.0 else B(f1 < f2)
    def types(self, cs):
        f = cs.concrete("F")
        return cs.functor("pair", [f, f]), cs.concrete("2")

class FAdd(Arrow):
    _immutable_ = True
    def run(self, x):
        y = x.first().f()
        z = x.second().f()
        # The only time addition can NaN is Infinity - Infinity, which is 0.0.
        # Otherwise, those cases can't happen, so we can skip the NaN check in
        # the JIT.
        if we_are_jitted() and (isconstant(y) and not math.isinf(y) or
                                isconstant(z) and not math.isinf(z)):
            rv = y + z
        else:
            rv = y + z
            if math.isnan(rv):
                rv = 0.0
        return F(rv)
    def types(self, cs):
        f = cs.concrete("F")
        return cs.functor("pair", [f, f]), f

MUL_SPECIALS = 0.0, -0.0, INF, -INF
class FMul(Arrow):
    _immutable_ = True
    def run(self, x):
        y = x.first().f()
        z = x.second().f()
        # The only time multiplication can NaN is 0.0 * Infinity. We define it
        # to be 0.0, but must respect source signs. Otherwise, same logic as
        # in FAdd.
        if we_are_jitted() and (isconstant(y) and y not in MUL_SPECIALS or
                                isconstant(z) and z not in MUL_SPECIALS):
            rv = y * z
        else:
            rv = y * z
            if math.isnan(rv):
                rv = 0.0 * math.copysign(1.0, y) * math.copysign(1.0, z)
        return F(rv)
    def types(self, cs):
        f = cs.concrete("F")
        return cs.functor("pair", [f, f]), f

class FSqrt(Arrow):
    _immutable_ = True
    def run(self, x):
        f = x.f()
        return L(F(math.sqrt(f))) if sign(f) else R(T())
    def types(self, cs):
        f = cs.concrete("F")
        return f, cs.functor("sum", [f, cs.concrete("1")])

class FSin(Arrow):
    _immutable_ = True
    def run(self, x): return F(math.sin(x.f()))
    def types(self, cs): return cs.concrete("F"), cs.concrete("F")

class FCos(Arrow):
    _immutable_ = True
    def run(self, x): return F(math.cos(x.f()))
    def types(self, cs): return cs.concrete("F"), cs.concrete("F")

class FATan2(Arrow):
    _immutable_ = True
    def run(self, x): return F(math.atan2(x.first().f(), x.second().f()))
    def types(self, cs):
        f = cs.concrete("F")
        return cs.functor("pair", [f, f]), f


class BuildProblem(Exception):
    def __init__(self, message):
        self.message = message

unaryFunctors = {
    "id": Id(),
    "ignore": Ignore(),
    "fst": First(),
    "snd": Second(),
    "left": Left(),
    "right": Right(),
    "either": Either(),
    "t": TrueArr(),
    "f": FalseArr(),
    "not": NotArr(),
    "conj": Conj(),
    "disj": Disj(),
    "zero": Zero(),
    "succ": Succ(),
    "nil": Nil(),
    "cons": Cons(),
    "f-zero": FZero(),
    "f-one": FOne(),
    "f-pi": FPi(),
    "f-sign": FSign(),
    "f-floor": FFloor(),
    "f-negate": FNegate(),
    "f-recip": FRecip(),
    "f-lt": FLT(),
    "f-add": FAdd(),
    "f-mul": FMul(),
    "f-sqrt": FSqrt(),
    "f-sin": FSin(),
    "f-cos": FCos(),
    "f-atan2": FATan2(),
}

def buildUnary(name):
    if name in unaryFunctors:
        return unaryFunctors[name]
    else:
        raise BuildProblem("Invalid unary functor: " + name)

def buildCompound(name, args):
    if name == "comp" and len(args) == 2:
        return Comp(args[0], args[1])
    elif name == "pair" and len(args) == 2:
        return Pair(args[0], args[1])
    elif name == "case" and len(args) == 2:
        return Case(args[0], args[1])
    elif name == "curry" and len(args) == 1:
        return Curry(args[0])
    elif name == "uncurry" and len(args) == 1:
        return Uncurry(args[0])
    elif name == "pr" and len(args) == 2:
        return PrimRec(args[0], args[1])
    elif name == "fold" and len(args) == 2:
        return Fold(args[0], args[1])
    else:
        raise BuildProblem("Invalid compound functor: " + name)


class Given(Arrow):
    """
    A formal parameter for a function.

    An given arrow is not executable, but it can still be typechecked.
    """
    _immutable_ = True
    def __init__(self, index): self.index = index
    def run(self, x): raise BuildProblem("given arrow cannot be run")
    def types(self, cs): return cs.givens(self.index)
