#!/usr/bin/python
#Thanks go to Dr. Ofer Hadas for kindly reviewing some of the code.


#-----------------------------------------------
#zpint.py - computations over Zp
#Copyright (c) 2007, Imri Goldberg
#All rights reserved.
#
#Redistribution and use in source and binary forms,
#with or without modification, are permitted provided
#that the following conditions are met:
#
#    * Redistributions of source code must retain the
#        above copyright notice, this list of conditions
#        and the following disclaimer.
#    * Redistributions in binary form must reproduce the
#        above copyright notice, this list of conditions
#        and the following disclaimer in the documentation
#        and/or other materials provided with the distribution.
#    * Neither the name of Algorithm.co.il nor the names of
#        its contributors may be used to endorse or promote products
#        derived from this software without specific prior written permission.
#
#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
#LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#-----------------------------------------------

import itertools

def create_zpint(p):
    """Create an int class over Zp (integers modulo p)"""

    class zp_int(object):
        div_table = dict( (((a*b)%p,a),b) for a in range(p) for b in range(p))
        n = p
        def __init__(self, x):

            if isinstance(x,int):
                self.value = x%p
            else:
                self.value = x.value%p

        def __int__(self):
            return self.value

        def __add__(self, other):
            if isinstance(other, (int, zp_int)):
                return zp_int(self.value + zp_int(other).value)
            else:
                return NotImplemented

        __radd__ = __add__

        def __sub__(self, other):
            if isinstance(other, (int, zp_int)):
                return zp_int(self.value - zp_int(other).value)
            else:
                return NotImplemented

        def __rsub__(self, other):
            return zp_int(zp_int(other).value - self.value)
        def __mul__(self, other):
            if isinstance(other, (int, zp_int)):
                return zp_int(self.value*zp_int(other).value)
            else:
                return NotImplemented

        __rmul__ = __mul__

        def __div__(self, other):
            if isinstance(other, (int, zp_int)):
                try:
                    return zp_int(self.div_table[self.value, zp_int(other).value])
                except KeyError:
                    raise ZeroDivisionError()
            else:
                return NotImplemented

        def __rdiv__(self, other):
            try:
                return zp_int(self.div_table[zp_int(other).value, self.value])
            except KeyError:
                raise ZeroDivisionError()

        def __repr__(self):
            return "%dmod%d" % (self.value, p)
        def __pow__(self, other):
            if other == 0:
                return zp_int(1)
            if other == 1:
                return zp_int(self)
            if other == 2:
                return self*self
            if other % 2 == 0:
                return (self**(other/2))**2
            else:
                return self**(other-1)*self

        def __neg__(self):
            return zp_int(-self.value)

        def __eq__(self, other):
            if isinstance(other, (int, zp_int)):
                return self.value == zp_int(other).value
            else:
                return NotImplemented

        def __ne__(self, other):
            return not (self == other)

        def __nonzero__(self):
            return self.value != 0

    return zp_int

z2int = create_zpint(2)
z3int = create_zpint(3)
z5int = create_zpint(5)
z7int = create_zpint(7)


def create_zp_polynomial(number_class = int):
    """Create an int class over the given number class
    The number class may be int, float, one returned by
    create_zpint, or any other number-like class.
    """
    if isinstance(number_class, int):
        int_class = create_zpint(number_class)
    else:
        int_class = number_class


    class Polynomial(object):
        def __init__(self, seq):
            if isinstance(seq, (int_class, int)):
                self.seq = [int_class(seq)]
            else:
                self.seq = [int_class(x) for x in seq]
            self._remove_top_zeros()

        def __nonzero__(self):
            return len(self.seq)>0

        def _remove_top_zeros(self):
            """remove leading zeroes from the high coefficients.
            1+x+0*x^2 will be changed to just 1+x
            """
            while len(self.seq) > 0 and self.seq[-1] == 0:
                del self.seq[-1]

        def deg(self):
            if len(self.seq) == 0:
                return None
            return len(self.seq) - 1

        def __iter__(self):
            for i in self.seq:
                yield i

        def __getitem__(self, i):
            return self.seq[i]

        def __add__(self, other):
            if isinstance(other, (int_class,int)):
                return self + Polynomial([other])

            other = Polynomial(other)
            seq1 = self.seq
            seq2 = other.seq
            if len(seq1) < len(seq2):
                seq2,seq1 = seq1,seq2
            return Polynomial(x+y for x,y in zip(seq1, itertools.chain(seq2,itertools.repeat(int_class(0),len(seq1)-len(seq2)))))

        #we assume addition is commutative
        __radd__ = __add__

        def __neg__(self):
            return Polynomial(-x for x in self.seq)

        def __sub__(self, other):
            return self + (-other)

        def __rsub__(self, other):
            return other + (-self)

        def __mul__(self, other):
            if isinstance(other, (int_class,int)):
                return Polynomial(x*other for x in self.seq)
            other = Polynomial(other)
            seq1 = self.seq
            seq2 = other.seq
            n = self.deg()
            m = other.deg()
            if m is None or n is None:
                return Polynomial([])
            seq1 = list(itertools.chain(seq1, itertools.repeat(int_class(0),m)))
            seq2 = list(itertools.chain(seq2, itertools.repeat(int_class(0),n)))

            return Polynomial(int_class(sum((seq1[i]*seq2[j-i] for i in range(j+1)),int_class(0))) for j in range(m+n+1))

        def __rmul__(self, other):

            if isinstance(other, (int_class,int)):
                return Polynomial(other*x for x in self.seq)
            other = Polynomial(other)
            seq1 = self.seq
            seq2 = other.seq
            n = self.deg()
            m = other.deg()
            if m is None or n is None:
                return Polynomial([])
            seq1 = list(itertools.chain(seq1, itertools.repeat(int_class(0),m)))
            seq2 = list(itertools.chain(seq2, itertools.repeat(int_class(0),n)))

            #we can't assume multiplication is commutative!
            seq1,seq2 = seq2, seq1
            return Polynomial(int_class(sum((seq1[i]*seq2[j-i] for i in range(j+1)),int_class(0))) for j in range(m+n+1))

        def __divmod__(self, other):
            if isinstance(other, (int_class,int)):
                return Polynomial(x/other for x in self.seq)
            if len(self.seq) == 0:
                return Polynomial([0])
            if len(other.seq) == 0:
                raise ZeroDivisionError("can't divide by the zero polynomial")
            result = 0
            remainder = self
            while (remainder.deg() is not None) and (remainder.deg() >= other.deg()):
                msb = remainder.deg()
                new_element = (remainder.seq[msb]/other.seq[-1])*(Polynomial([0,1])**(msb-other.deg()))
                result += new_element
                remainder -= new_element*other
            return result, remainder

        def __div__(self, other):
            return divmod(self, other)[0]

        def __mod__(self, other):
            return divmod(self, other)[1]

        def __rdiv__(self, other):
            return Polynomial([other]) / self

        def __rmod__(self, other):
            return Polynomial([other]) % self

        def __pow__(self, other):
            if other < 0:
                raise ValueError("can't raise to negative powers")
            if other == 0:
                return Polynomial(1)
            if other == 1:
                return Polynomial(self)
            if other == 2:
                return self*self
            if other % 2 == 0:
                return (self**(other/2))**2
            else:
                return self**(other-1)*self

        def __eq__(self, other):
            return self.seq == Polynomial(other).seq
        def __ne__(self, other):
            return not (self == other)

        def __repr__(self):
            if len(self.seq) == 0:
                return str(int_class(0))
            return " + ".join("%s*x^%d" % (str(a),i) for i,a in enumerate(self.seq))

        def differentiate(self):
            return Polynomial(int_class(idx)*x for idx,x in list(enumerate(self.seq))[1:])

        def evaluate(self, x):
            """compute the value of the polynomial at x"""
            x = int_class(x)
            result = 0
            if len(self.seq) == 0:
                return 0
            for a in self.seq[:0:-1]:
                result = x*(a+result)
            return result + self.seq[0]
        
        def evaluate_slow(self, x):
            """deprecated, but still kept. compute the value of the polynomial at x"""
            x = int_class(x)
            return sum((a*(x**i) for i,a in enumerate(self.seq)),int_class(0))

    return Polynomial

z2poly = create_zp_polynomial(z2int)
z3poly = create_zp_polynomial(z3int)
z5poly = create_zp_polynomial(z5int)
z7poly = create_zp_polynomial(z7int)
zpoly = create_zp_polynomial(int)



def solve_fixed_qudratic(int_class, a0, a1, a2, a3):
    """this function will find a,b,c,d such that (x^2 + ax + b)(x^2 + cx + d) = (x^4 + a3x^3 + a2x^2 + a1x + a0).
    values will be computed using given int_class (which is assumed to be Zp for some p)"""
    if a0 != 0:
        for d in (int_class(x) for x in range(1,int_class.p)):
            for c in (int_class(x) for x in range(int_class.p)):
                a = int_class(a3-c)
                b = a0/int_class(d)
                if (a*d + b*c) == a1 and (a*c + b + d) == a2:
                    return (a,b,c,d)
    else:
        raise NotImplementedError()

def gcd(a,b):
    prev = a
    r = b
    while r != 0:
        prev,r = r, prev%r
    return prev
    
def inverse(a, b):
    """return a^-1 mod b
    This is an efficient computation using the gcd algorithm.
    (Taken straight out of "Cryptography Theory and Practice")
    """
    b0 = b
    a0 = a
    t0 = 0
    t = 1
    q = b0/a0
    r = b0-q*a0
    while r>0:
        temp = (t0 - q*t)%b
        t0 = t
        t = temp
        b0 = a0
        a0 = r
        q = b0/a0
        r = b0-q*a0
    if a0 != 1:
        raise ZeroDivisionError()
    return t
