###############################################################################
# Copyright 2019 StarkWare Industries Ltd.                                    #
#                                                                             #
# Licensed under the Apache License, Version 2.0 (the "License").             #
# You may not use this file except in compliance with the License.            #
# You may obtain a copy of the License at                                     #
#                                                                             #
# https://www.starkware.co/open-source-license/                               #
#                                                                             #
# Unless required by applicable law or agreed to in writing,                  #
# software distributed under the License is distributed on an "AS IS" BASIS,  #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.    #
# See the License for the specific language governing permissions             #
# and limitations under the License.                                          #
###############################################################################


"""
An implementation of field elements from F_(3 * 2**30 + 1).
"""

from random import randint

class FieldElement:
    """
    Represents an element of F_(3 * 2**30 + 1).
    """
    k_modulus = 3 * 2**30 + 1
    generator_val = 5

    def __init__(self, val):
        self.val = val % FieldElement.k_modulus

    @staticmethod
    def zero():
        """
        Obtains the zero element of the field.
        """
        return FieldElement(0)

    @staticmethod
    def one():
        """
        Obtains the unit element of the field.
        """
        return FieldElement(1)

    def __repr__(self):
        # Choose the shorter representation between the positive and negative values of the element.
        return repr((self.val + self.k_modulus//2) % self.k_modulus - self.k_modulus//2)

    def __eq__(self, other):
        if isinstance(other, int):
            other = FieldElement(other)
        return isinstance(other, FieldElement) and self.val == other.val

    def __hash__(self):
        return hash(self.val)

    @staticmethod
    def generator():
        return FieldElement(FieldElement.generator_val)

    @staticmethod
    def typecast(other):
        if isinstance(other, int):
            return FieldElement(other)
        assert isinstance(other, FieldElement), f'Type mismatch: FieldElement and {type(other)}.'
        return other

    def __neg__(self):
        return self.zero() - self

    def __add__(self, other):
        try:
            other = FieldElement.typecast(other)
        except AssertionError:
            return NotImplemented
        return FieldElement((self.val + other.val) % FieldElement.k_modulus)

    __radd__ = __add__

    def __sub__(self, other):
        try:
            other = FieldElement.typecast(other)
        except AssertionError:
            return NotImplemented
        return FieldElement((self.val - other.val) % FieldElement.k_modulus)

    def __rsub__(self, other):
        return -(self - other)

    def __mul__(self, other):
        try:
            other = FieldElement.typecast(other)
        except AssertionError:
            return NotImplemented
        return FieldElement((self.val * other.val) % FieldElement.k_modulus)

    __rmul__ = __mul__

    def __truediv__(self, other):
        other = FieldElement.typecast(other)
        return self * other.inverse()

    def __pow__(self, n):
        assert n >= 0
        cur_pow = self
        res = FieldElement(1)
        while n > 0:
            if n % 2 != 0:
                res *= cur_pow
            n = n // 2
            cur_pow *= cur_pow
        return res

    def inverse(self):
        t, new_t = 0, 1
        r, new_r = FieldElement.k_modulus, self.val
        while new_r != 0:
            quotient = r // new_r
            t, new_t = new_t, (t - (quotient * new_t))
            r, new_r = new_r, r - quotient * new_r
        assert r == 1
        return FieldElement(t)

    def is_order(self, n):
        """
        Naively checks that the element is of order n by raising it to all powers up to n, checking
        that the element to the n-th power is the unit, but not so for any k<n.
        """
        assert n >= 1
        h = FieldElement(1)
        for _ in range(1, n):
            h *= self
            if h == FieldElement(1):
                return False
        return h * self == FieldElement(1)

    def _serialize_(self):
        return repr(self.val)

    @staticmethod
    def random_element(exclude_elements=[]):
        fe = FieldElement(randint(0, FieldElement.k_modulus - 1))
        while fe in exclude_elements:
            fe = FieldElement(randint(0, FieldElement.k_modulus - 1))
        return fe