Spaces:
Sleeping
Sleeping
| """ | |
| Utility functions for integer math. | |
| TODO: rename, cleanup, perhaps move the gmpy wrapper code | |
| here from settings.py | |
| """ | |
| import math | |
| from bisect import bisect | |
| from .backend import xrange | |
| from .backend import BACKEND, gmpy, sage, sage_utils, MPZ, MPZ_ONE, MPZ_ZERO | |
| small_trailing = [0] * 256 | |
| for j in range(1,8): | |
| small_trailing[1<<j::1<<(j+1)] = [j] * (1<<(7-j)) | |
| def giant_steps(start, target, n=2): | |
| """ | |
| Return a list of integers ~= | |
| [start, n*start, ..., target/n^2, target/n, target] | |
| but conservatively rounded so that the quotient between two | |
| successive elements is actually slightly less than n. | |
| With n = 2, this describes suitable precision steps for a | |
| quadratically convergent algorithm such as Newton's method; | |
| with n = 3 steps for cubic convergence (Halley's method), etc. | |
| >>> giant_steps(50,1000) | |
| [66, 128, 253, 502, 1000] | |
| >>> giant_steps(50,1000,4) | |
| [65, 252, 1000] | |
| """ | |
| L = [target] | |
| while L[-1] > start*n: | |
| L = L + [L[-1]//n + 2] | |
| return L[::-1] | |
| def rshift(x, n): | |
| """For an integer x, calculate x >> n with the fastest (floor) | |
| rounding. Unlike the plain Python expression (x >> n), n is | |
| allowed to be negative, in which case a left shift is performed.""" | |
| if n >= 0: return x >> n | |
| else: return x << (-n) | |
| def lshift(x, n): | |
| """For an integer x, calculate x << n. Unlike the plain Python | |
| expression (x << n), n is allowed to be negative, in which case a | |
| right shift with default (floor) rounding is performed.""" | |
| if n >= 0: return x << n | |
| else: return x >> (-n) | |
| if BACKEND == 'sage': | |
| import operator | |
| rshift = operator.rshift | |
| lshift = operator.lshift | |
| def python_trailing(n): | |
| """Count the number of trailing zero bits in abs(n).""" | |
| if not n: | |
| return 0 | |
| low_byte = n & 0xff | |
| if low_byte: | |
| return small_trailing[low_byte] | |
| t = 8 | |
| n >>= 8 | |
| while not n & 0xff: | |
| n >>= 8 | |
| t += 8 | |
| return t + small_trailing[n & 0xff] | |
| if BACKEND == 'gmpy': | |
| if gmpy.version() >= '2': | |
| def gmpy_trailing(n): | |
| """Count the number of trailing zero bits in abs(n) using gmpy.""" | |
| if n: return MPZ(n).bit_scan1() | |
| else: return 0 | |
| else: | |
| def gmpy_trailing(n): | |
| """Count the number of trailing zero bits in abs(n) using gmpy.""" | |
| if n: return MPZ(n).scan1() | |
| else: return 0 | |
| # Small powers of 2 | |
| powers = [1<<_ for _ in range(300)] | |
| def python_bitcount(n): | |
| """Calculate bit size of the nonnegative integer n.""" | |
| bc = bisect(powers, n) | |
| if bc != 300: | |
| return bc | |
| bc = int(math.log(n, 2)) - 4 | |
| return bc + bctable[n>>bc] | |
| def gmpy_bitcount(n): | |
| """Calculate bit size of the nonnegative integer n.""" | |
| if n: return MPZ(n).numdigits(2) | |
| else: return 0 | |
| #def sage_bitcount(n): | |
| # if n: return MPZ(n).nbits() | |
| # else: return 0 | |
| def sage_trailing(n): | |
| return MPZ(n).trailing_zero_bits() | |
| if BACKEND == 'gmpy': | |
| bitcount = gmpy_bitcount | |
| trailing = gmpy_trailing | |
| elif BACKEND == 'sage': | |
| sage_bitcount = sage_utils.bitcount | |
| bitcount = sage_bitcount | |
| trailing = sage_trailing | |
| else: | |
| bitcount = python_bitcount | |
| trailing = python_trailing | |
| if BACKEND == 'gmpy' and 'bit_length' in dir(gmpy): | |
| bitcount = gmpy.bit_length | |
| # Used to avoid slow function calls as far as possible | |
| trailtable = [trailing(n) for n in range(256)] | |
| bctable = [bitcount(n) for n in range(1024)] | |
| # TODO: speed up for bases 2, 4, 8, 16, ... | |
| def bin_to_radix(x, xbits, base, bdigits): | |
| """Changes radix of a fixed-point number; i.e., converts | |
| x * 2**xbits to floor(x * 10**bdigits).""" | |
| return x * (MPZ(base)**bdigits) >> xbits | |
| stddigits = '0123456789abcdefghijklmnopqrstuvwxyz' | |
| def small_numeral(n, base=10, digits=stddigits): | |
| """Return the string numeral of a positive integer in an arbitrary | |
| base. Most efficient for small input.""" | |
| if base == 10: | |
| return str(n) | |
| digs = [] | |
| while n: | |
| n, digit = divmod(n, base) | |
| digs.append(digits[digit]) | |
| return "".join(digs[::-1]) | |
| def numeral_python(n, base=10, size=0, digits=stddigits): | |
| """Represent the integer n as a string of digits in the given base. | |
| Recursive division is used to make this function about 3x faster | |
| than Python's str() for converting integers to decimal strings. | |
| The 'size' parameters specifies the number of digits in n; this | |
| number is only used to determine splitting points and need not be | |
| exact.""" | |
| if n <= 0: | |
| if not n: | |
| return "0" | |
| return "-" + numeral(-n, base, size, digits) | |
| # Fast enough to do directly | |
| if size < 250: | |
| return small_numeral(n, base, digits) | |
| # Divide in half | |
| half = (size // 2) + (size & 1) | |
| A, B = divmod(n, base**half) | |
| ad = numeral(A, base, half, digits) | |
| bd = numeral(B, base, half, digits).rjust(half, "0") | |
| return ad + bd | |
| def numeral_gmpy(n, base=10, size=0, digits=stddigits): | |
| """Represent the integer n as a string of digits in the given base. | |
| Recursive division is used to make this function about 3x faster | |
| than Python's str() for converting integers to decimal strings. | |
| The 'size' parameters specifies the number of digits in n; this | |
| number is only used to determine splitting points and need not be | |
| exact.""" | |
| if n < 0: | |
| return "-" + numeral(-n, base, size, digits) | |
| # gmpy.digits() may cause a segmentation fault when trying to convert | |
| # extremely large values to a string. The size limit may need to be | |
| # adjusted on some platforms, but 1500000 works on Windows and Linux. | |
| if size < 1500000: | |
| return gmpy.digits(n, base) | |
| # Divide in half | |
| half = (size // 2) + (size & 1) | |
| A, B = divmod(n, MPZ(base)**half) | |
| ad = numeral(A, base, half, digits) | |
| bd = numeral(B, base, half, digits).rjust(half, "0") | |
| return ad + bd | |
| if BACKEND == "gmpy": | |
| numeral = numeral_gmpy | |
| else: | |
| numeral = numeral_python | |
| _1_800 = 1<<800 | |
| _1_600 = 1<<600 | |
| _1_400 = 1<<400 | |
| _1_200 = 1<<200 | |
| _1_100 = 1<<100 | |
| _1_50 = 1<<50 | |
| def isqrt_small_python(x): | |
| """ | |
| Correctly (floor) rounded integer square root, using | |
| division. Fast up to ~200 digits. | |
| """ | |
| if not x: | |
| return x | |
| if x < _1_800: | |
| # Exact with IEEE double precision arithmetic | |
| if x < _1_50: | |
| return int(x**0.5) | |
| # Initial estimate can be any integer >= the true root; round up | |
| r = int(x**0.5 * 1.00000000000001) + 1 | |
| else: | |
| bc = bitcount(x) | |
| n = bc//2 | |
| r = int((x>>(2*n-100))**0.5+2)<<(n-50) # +2 is to round up | |
| # The following iteration now precisely computes floor(sqrt(x)) | |
| # See e.g. Crandall & Pomerance, "Prime Numbers: A Computational | |
| # Perspective" | |
| while 1: | |
| y = (r+x//r)>>1 | |
| if y >= r: | |
| return r | |
| r = y | |
| def isqrt_fast_python(x): | |
| """ | |
| Fast approximate integer square root, computed using division-free | |
| Newton iteration for large x. For random integers the result is almost | |
| always correct (floor(sqrt(x))), but is 1 ulp too small with a roughly | |
| 0.1% probability. If x is very close to an exact square, the answer is | |
| 1 ulp wrong with high probability. | |
| With 0 guard bits, the largest error over a set of 10^5 random | |
| inputs of size 1-10^5 bits was 3 ulp. The use of 10 guard bits | |
| almost certainly guarantees a max 1 ulp error. | |
| """ | |
| # Use direct division-based iteration if sqrt(x) < 2^400 | |
| # Assume floating-point square root accurate to within 1 ulp, then: | |
| # 0 Newton iterations good to 52 bits | |
| # 1 Newton iterations good to 104 bits | |
| # 2 Newton iterations good to 208 bits | |
| # 3 Newton iterations good to 416 bits | |
| if x < _1_800: | |
| y = int(x**0.5) | |
| if x >= _1_100: | |
| y = (y + x//y) >> 1 | |
| if x >= _1_200: | |
| y = (y + x//y) >> 1 | |
| if x >= _1_400: | |
| y = (y + x//y) >> 1 | |
| return y | |
| bc = bitcount(x) | |
| guard_bits = 10 | |
| x <<= 2*guard_bits | |
| bc += 2*guard_bits | |
| bc += (bc&1) | |
| hbc = bc//2 | |
| startprec = min(50, hbc) | |
| # Newton iteration for 1/sqrt(x), with floating-point starting value | |
| r = int(2.0**(2*startprec) * (x >> (bc-2*startprec)) ** -0.5) | |
| pp = startprec | |
| for p in giant_steps(startprec, hbc): | |
| # r**2, scaled from real size 2**(-bc) to 2**p | |
| r2 = (r*r) >> (2*pp - p) | |
| # x*r**2, scaled from real size ~1.0 to 2**p | |
| xr2 = ((x >> (bc-p)) * r2) >> p | |
| # New value of r, scaled from real size 2**(-bc/2) to 2**p | |
| r = (r * ((3<<p) - xr2)) >> (pp+1) | |
| pp = p | |
| # (1/sqrt(x))*x = sqrt(x) | |
| return (r*(x>>hbc)) >> (p+guard_bits) | |
| def sqrtrem_python(x): | |
| """Correctly rounded integer (floor) square root with remainder.""" | |
| # to check cutoff: | |
| # plot(lambda x: timing(isqrt, 2**int(x)), [0,2000]) | |
| if x < _1_600: | |
| y = isqrt_small_python(x) | |
| return y, x - y*y | |
| y = isqrt_fast_python(x) + 1 | |
| rem = x - y*y | |
| # Correct remainder | |
| while rem < 0: | |
| y -= 1 | |
| rem += (1+2*y) | |
| else: | |
| if rem: | |
| while rem > 2*(1+y): | |
| y += 1 | |
| rem -= (1+2*y) | |
| return y, rem | |
| def isqrt_python(x): | |
| """Integer square root with correct (floor) rounding.""" | |
| return sqrtrem_python(x)[0] | |
| def sqrt_fixed(x, prec): | |
| return isqrt_fast(x<<prec) | |
| sqrt_fixed2 = sqrt_fixed | |
| if BACKEND == 'gmpy': | |
| if gmpy.version() >= '2': | |
| isqrt_small = isqrt_fast = isqrt = gmpy.isqrt | |
| sqrtrem = gmpy.isqrt_rem | |
| else: | |
| isqrt_small = isqrt_fast = isqrt = gmpy.sqrt | |
| sqrtrem = gmpy.sqrtrem | |
| elif BACKEND == 'sage': | |
| isqrt_small = isqrt_fast = isqrt = \ | |
| getattr(sage_utils, "isqrt", lambda n: MPZ(n).isqrt()) | |
| sqrtrem = lambda n: MPZ(n).sqrtrem() | |
| else: | |
| isqrt_small = isqrt_small_python | |
| isqrt_fast = isqrt_fast_python | |
| isqrt = isqrt_python | |
| sqrtrem = sqrtrem_python | |
| def ifib(n, _cache={}): | |
| """Computes the nth Fibonacci number as an integer, for | |
| integer n.""" | |
| if n < 0: | |
| return (-1)**(-n+1) * ifib(-n) | |
| if n in _cache: | |
| return _cache[n] | |
| m = n | |
| # Use Dijkstra's logarithmic algorithm | |
| # The following implementation is basically equivalent to | |
| # http://en.literateprograms.org/Fibonacci_numbers_(Scheme) | |
| a, b, p, q = MPZ_ONE, MPZ_ZERO, MPZ_ZERO, MPZ_ONE | |
| while n: | |
| if n & 1: | |
| aq = a*q | |
| a, b = b*q+aq+a*p, b*p+aq | |
| n -= 1 | |
| else: | |
| qq = q*q | |
| p, q = p*p+qq, qq+2*p*q | |
| n >>= 1 | |
| if m < 250: | |
| _cache[m] = b | |
| return b | |
| MAX_FACTORIAL_CACHE = 1000 | |
| def ifac(n, memo={0:1, 1:1}): | |
| """Return n factorial (for integers n >= 0 only).""" | |
| f = memo.get(n) | |
| if f: | |
| return f | |
| k = len(memo) | |
| p = memo[k-1] | |
| MAX = MAX_FACTORIAL_CACHE | |
| while k <= n: | |
| p *= k | |
| if k <= MAX: | |
| memo[k] = p | |
| k += 1 | |
| return p | |
| def ifac2(n, memo_pair=[{0:1}, {1:1}]): | |
| """Return n!! (double factorial), integers n >= 0 only.""" | |
| memo = memo_pair[n&1] | |
| f = memo.get(n) | |
| if f: | |
| return f | |
| k = max(memo) | |
| p = memo[k] | |
| MAX = MAX_FACTORIAL_CACHE | |
| while k < n: | |
| k += 2 | |
| p *= k | |
| if k <= MAX: | |
| memo[k] = p | |
| return p | |
| if BACKEND == 'gmpy': | |
| ifac = gmpy.fac | |
| elif BACKEND == 'sage': | |
| ifac = lambda n: int(sage.factorial(n)) | |
| ifib = sage.fibonacci | |
| def list_primes(n): | |
| n = n + 1 | |
| sieve = list(xrange(n)) | |
| sieve[:2] = [0, 0] | |
| for i in xrange(2, int(n**0.5)+1): | |
| if sieve[i]: | |
| for j in xrange(i**2, n, i): | |
| sieve[j] = 0 | |
| return [p for p in sieve if p] | |
| if BACKEND == 'sage': | |
| # Note: it is *VERY* important for performance that we convert | |
| # the list to Python ints. | |
| def list_primes(n): | |
| return [int(_) for _ in sage.primes(n+1)] | |
| small_odd_primes = (3,5,7,11,13,17,19,23,29,31,37,41,43,47) | |
| small_odd_primes_set = set(small_odd_primes) | |
| def isprime(n): | |
| """ | |
| Determines whether n is a prime number. A probabilistic test is | |
| performed if n is very large. No special trick is used for detecting | |
| perfect powers. | |
| >>> sum(list_primes(100000)) | |
| 454396537 | |
| >>> sum(n*isprime(n) for n in range(100000)) | |
| 454396537 | |
| """ | |
| n = int(n) | |
| if not n & 1: | |
| return n == 2 | |
| if n < 50: | |
| return n in small_odd_primes_set | |
| for p in small_odd_primes: | |
| if not n % p: | |
| return False | |
| m = n-1 | |
| s = trailing(m) | |
| d = m >> s | |
| def test(a): | |
| x = pow(a,d,n) | |
| if x == 1 or x == m: | |
| return True | |
| for r in xrange(1,s): | |
| x = x**2 % n | |
| if x == m: | |
| return True | |
| return False | |
| # See http://primes.utm.edu/prove/prove2_3.html | |
| if n < 1373653: | |
| witnesses = [2,3] | |
| elif n < 341550071728321: | |
| witnesses = [2,3,5,7,11,13,17] | |
| else: | |
| witnesses = small_odd_primes | |
| for a in witnesses: | |
| if not test(a): | |
| return False | |
| return True | |
| def moebius(n): | |
| """ | |
| Evaluates the Moebius function which is `mu(n) = (-1)^k` if `n` | |
| is a product of `k` distinct primes and `mu(n) = 0` otherwise. | |
| TODO: speed up using factorization | |
| """ | |
| n = abs(int(n)) | |
| if n < 2: | |
| return n | |
| factors = [] | |
| for p in xrange(2, n+1): | |
| if not (n % p): | |
| if not (n % p**2): | |
| return 0 | |
| if not sum(p % f for f in factors): | |
| factors.append(p) | |
| return (-1)**len(factors) | |
| def gcd(*args): | |
| a = 0 | |
| for b in args: | |
| if a: | |
| while b: | |
| a, b = b, a % b | |
| else: | |
| a = b | |
| return a | |
| # Comment by Juan Arias de Reyna: | |
| # | |
| # I learn this method to compute EulerE[2n] from van de Lune. | |
| # | |
| # We apply the formula EulerE[2n] = (-1)^n 2**(-2n) sum_{j=0}^n a(2n,2j+1) | |
| # | |
| # where the numbers a(n,j) vanish for j > n+1 or j <= -1 and satisfies | |
| # | |
| # a(0,-1) = a(0,0) = 0; a(0,1)= 1; a(0,2) = a(0,3) = 0 | |
| # | |
| # a(n,j) = a(n-1,j) when n+j is even | |
| # a(n,j) = (j-1) a(n-1,j-1) + (j+1) a(n-1,j+1) when n+j is odd | |
| # | |
| # | |
| # But we can use only one array unidimensional a(j) since to compute | |
| # a(n,j) we only need to know a(n-1,k) where k and j are of different parity | |
| # and we have not to conserve the used values. | |
| # | |
| # We cached up the values of Euler numbers to sufficiently high order. | |
| # | |
| # Important Observation: If we pretend to use the numbers | |
| # EulerE[1], EulerE[2], ... , EulerE[n] | |
| # it is convenient to compute first EulerE[n], since the algorithm | |
| # computes first all | |
| # the previous ones, and keeps them in the CACHE | |
| MAX_EULER_CACHE = 500 | |
| def eulernum(m, _cache={0:MPZ_ONE}): | |
| r""" | |
| Computes the Euler numbers `E(n)`, which can be defined as | |
| coefficients of the Taylor expansion of `1/cosh x`: | |
| .. math :: | |
| \frac{1}{\cosh x} = \sum_{n=0}^\infty \frac{E_n}{n!} x^n | |
| Example:: | |
| >>> [int(eulernum(n)) for n in range(11)] | |
| [1, 0, -1, 0, 5, 0, -61, 0, 1385, 0, -50521] | |
| >>> [int(eulernum(n)) for n in range(11)] # test cache | |
| [1, 0, -1, 0, 5, 0, -61, 0, 1385, 0, -50521] | |
| """ | |
| # for odd m > 1, the Euler numbers are zero | |
| if m & 1: | |
| return MPZ_ZERO | |
| f = _cache.get(m) | |
| if f: | |
| return f | |
| MAX = MAX_EULER_CACHE | |
| n = m | |
| a = [MPZ(_) for _ in [0,0,1,0,0,0]] | |
| for n in range(1, m+1): | |
| for j in range(n+1, -1, -2): | |
| a[j+1] = (j-1)*a[j] + (j+1)*a[j+2] | |
| a.append(0) | |
| suma = 0 | |
| for k in range(n+1, -1, -2): | |
| suma += a[k+1] | |
| if n <= MAX: | |
| _cache[n] = ((-1)**(n//2))*(suma // 2**n) | |
| if n == m: | |
| return ((-1)**(n//2))*suma // 2**n | |
| def stirling1(n, k): | |
| """ | |
| Stirling number of the first kind. | |
| """ | |
| if n < 0 or k < 0: | |
| raise ValueError | |
| if k >= n: | |
| return MPZ(n == k) | |
| if k < 1: | |
| return MPZ_ZERO | |
| L = [MPZ_ZERO] * (k+1) | |
| L[1] = MPZ_ONE | |
| for m in xrange(2, n+1): | |
| for j in xrange(min(k, m), 0, -1): | |
| L[j] = (m-1) * L[j] + L[j-1] | |
| return (-1)**(n+k) * L[k] | |
| def stirling2(n, k): | |
| """ | |
| Stirling number of the second kind. | |
| """ | |
| if n < 0 or k < 0: | |
| raise ValueError | |
| if k >= n: | |
| return MPZ(n == k) | |
| if k <= 1: | |
| return MPZ(k == 1) | |
| s = MPZ_ZERO | |
| t = MPZ_ONE | |
| for j in xrange(k+1): | |
| if (k + j) & 1: | |
| s -= t * MPZ(j)**n | |
| else: | |
| s += t * MPZ(j)**n | |
| t = t * (k - j) // (j + 1) | |
| return s // ifac(k) | |