Python 中的整数平方根
- 2025-01-14 08:50:00
- admin 原创
- 163
问题描述:
在 Python 或标准库中是否有整数平方根?我希望它是精确的(即返回一个整数),并且如果输入不是完全平方,则引发异常。
我尝试使用这个代码:
def isqrt(n):
i = int(math.sqrt(n) + 0.5)
if i**2 == n:
return i
raise ValueError('input was not a perfect square')
但它很丑陋,我不太相信它能处理大整数。我可以遍历平方,如果超出了值就放弃,但我认为这样做会有点慢。另外,这肯定已经在某个地方实现了?
另请参阅:检查一个数字是否为完全平方数。
解决方案 1:
注意:现在math.isqrt
在 stdlib 中,自 Python 3.8 起可用。
牛顿法对于整数非常有效:
def isqrt(n):
x = n
y = (x + 1) // 2
while y < x:
x = y
y = (x + n // x) // 2
return x
这将返回x x不超过n的最大整数x。如果要检查结果是否恰好是平方根,只需执行乘法即可检查n*是否为完全平方数。
我在我的博客上讨论了该算法和其他三种计算平方根的算法。
解决方案 2:
更新: Python 3.8math.isqrt
在标准库中有一个函数!
我使用小输入(0…2 22)和大输入(2 50001 )对此处的每个(正确)函数进行了基准测试。在这两种情况下, gmpy2.isqrt
mathmandan 建议的明显赢家排在第一,其次是 Python 3.8 math.isqrt
,第三是NPE 链接的 ActiveState 配方。ActiveState 配方有一堆可以用移位代替的除法,这使其速度更快(但仍落后于本机函数):
def isqrt(n):
if n > 0:
x = 1 << (n.bit_length() + 1 >> 1)
while True:
y = (x + n // x) >> 1
if y >= x:
return x
x = y
elif n == 0:
return 0
else:
raise ValueError("square root not defined for negative numbers")
基准测试结果:
gmpy2.isqrt()
(mathmandan):小 0.08 µs,大 0.07 msint(gmpy2.isqrt())
*: 小 0.3 µs,大 0.07 msPython 3.8
math.isqrt
:小 0.13 µs,大 0.9 msActiveState(如上所述优化):小 0.6 µs,大 17.0 ms
ActiveState (NPE):小 1.0 µs,大 17.3 ms
castlebravo 长针:小 4 µs,大 80 ms
mathmandan 改进:小 2.7 µs,大 120 ms
martineau(经过此修正):小 2.3 µs,大 140 ms
nibot : 8 µs 小,1000 ms 大
mathmandan : 小 1.8 µs,大 2200 ms
castlebravo 牛顿法:小 1.5 µs,大 19000 ms
user448810 : 小 1.4 µs,大 20000 ms
(* 由于gmpy2.isqrt
返回一个gmpy2.mpz
对象,其行为大部分类似于但不完全类似于int
,您可能需要将其转换回以int
满足某些用途。)
解决方案 3:
抱歉回复这么晚;我只是偶然发现了这个页面。如果以后有人访问这个页面,python 模块 gmpy2 旨在处理非常大的输入,其中包括一个整数平方根函数。
例子:
>>> import gmpy2
>>> gmpy2.isqrt((10**100+1)**2)
mpz(10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001L)
>>> gmpy2.isqrt((10**100+1)**2 - 1)
mpz(10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000L)
当然,所有内容都会有“mpz”标签,但 mpz 与 int 兼容:
>>> gmpy2.mpz(3)*4
mpz(12)
>>> int(gmpy2.mpz(12))
12
请参阅我的其他答案,了解有关此方法相对于该问题的其他答案的性能的讨论。
下载:https://code.google.com/p/gmpy/
解决方案 4:
这是一个非常简单的实现:
def i_sqrt(n):
i = n.bit_length() >> 1 # i = floor( (1 + floor(log_2(n))) / 2 )
m = 1 << i # m = 2^i
#
# Fact: (2^(i + 1))^2 > n, so m has at least as many bits
# as the floor of the square root of n.
#
# Proof: (2^(i+1))^2 = 2^(2i + 2) >= 2^(floor(log_2(n)) + 2)
# >= 2^(ceil(log_2(n) + 1) >= 2^(log_2(n) + 1) > 2^(log_2(n)) = n. QED.
#
while m*m > n:
m >>= 1
i -= 1
for k in xrange(i-1, -1, -1):
x = m | (1 << k)
if x*x <= n:
m = x
return m
这只是一个二分查找。将值初始化m
为不超过平方根的最大 2 的幂,然后检查是否可以设置每个较小的位,同时保持结果不大于平方根。(按降序逐位检查。)
对于相当大的值n
(例如,大约10**6000
,或大约20000
位),这似乎是:
比user448810 描述的牛顿法实现更快。
比我的其他答案
gmpy2
中的内置方法慢得多。与 nibot 描述的长手写平方根类似,但速度稍慢。
所有这些方法都可以成功处理这种大小的输入,但在我的计算机上,此函数大约需要 1.5 秒,而 @Nibot 的函数大约需要 0.9 秒,@user448810 的函数大约需要 19 秒,而 gmpy2 内置方法则需要不到一毫秒(!)。示例:
>>> import random
>>> import timeit
>>> import gmpy2
>>> r = random.getrandbits
>>> t = timeit.timeit
>>> t('i_sqrt(r(20000))', 'from __main__ import *', number = 5)/5. # This function
1.5102493192883117
>>> t('exact_sqrt(r(20000))', 'from __main__ import *', number = 5)/5. # Nibot
0.8952787937686366
>>> t('isqrt(r(20000))', 'from __main__ import *', number = 5)/5. # user448810
19.326695976676184
>>> t('gmpy2.isqrt(r(20000))', 'from __main__ import *', number = 5)/5. # gmpy2
0.0003599147067689046
>>> all(i_sqrt(n)==isqrt(n)==exact_sqrt(n)[0]==int(gmpy2.isqrt(n)) for n in (r(1500) for i in xrange(1500)))
True
这个函数可以很容易地推广,尽管它不是那么好,因为我对的初始猜测不太精确m
:
def i_root(num, root, report_exactness = True):
i = num.bit_length() / root
m = 1 << i
while m ** root < num:
m <<= 1
i += 1
while m ** root > num:
m >>= 1
i -= 1
for k in xrange(i-1, -1, -1):
x = m | (1 << k)
if x ** root <= num:
m = x
if report_exactness:
return m, m ** root == num
return m
不过,请注意gmpy2
还有一种i_root
方法。
事实上,这种方法可以适用于任何(非负、递增)函数,f
以确定“的整数逆f
”。但是,要选择有效的初始值,m
您仍然需要了解有关的信息f
。
编辑:感谢@Greggo 指出i_sqrt
可以重写该函数以避免使用任何乘法。这带来了令人印象深刻的性能提升!
def improved_i_sqrt(n):
assert n >= 0
if n == 0:
return 0
i = n.bit_length() >> 1 # i = floor( (1 + floor(log_2(n))) / 2 )
m = 1 << i # m = 2^i
#
# Fact: (2^(i + 1))^2 > n, so m has at least as many bits
# as the floor of the square root of n.
#
# Proof: (2^(i+1))^2 = 2^(2i + 2) >= 2^(floor(log_2(n)) + 2)
# >= 2^(ceil(log_2(n) + 1) >= 2^(log_2(n) + 1) > 2^(log_2(n)) = n. QED.
#
while (m << i) > n: # (m<<i) = m*(2^i) = m*m
m >>= 1
i -= 1
d = n - (m << i) # d = n-m^2
for k in xrange(i-1, -1, -1):
j = 1 << k
new_diff = d - (((m<<1) | j) << k) # n-(m+2^k)^2 = n-m^2-2*m*2^k-2^(2k)
if new_diff >= 0:
d = new_diff
m |= j
return m
请注意,根据构造,k
的第位m << 1
未设置,因此可以使用按位或来实现的加法(m<<1) + (1<<k)
。最终我将(2*m*(2**k) + 2**(2*k))
写为(((m<<1) | (1<<k)) << k)
,因此它需要三次移位和一次按位或(然后进行减法以得到new_diff
)。也许还有更有效的方法来得到这个?无论如何,它比乘法要好得多m*m
!与上面比较:
>>> t('improved_i_sqrt(r(20000))', 'from __main__ import *', number = 5)/5.
0.10908999762373242
>>> all(improved_i_sqrt(n) == i_sqrt(n) for n in xrange(10**6))
True
解决方案 5:
长手平方根算法
事实证明,有一种可以手工计算平方根的算法,类似于长除法。该算法的每次迭代都会产生结果平方根的一位数字,同时消耗您要寻找的平方根的数字的两位数字。虽然该算法的“长手”版本以十进制指定,但它可以在任何进制中工作,二进制是最简单的实现,并且可能是执行速度最快的(取决于底层大数表示)。
由于该算法对数字进行逐位运算,因此它可以对任意大的完全平方数产生精确的结果,对于非完全平方数,可以根据需要产生任意多的精度数字(小数点右边)。
“Dr. Math”网站上有两篇很好的文章解释了该算法:
二进制平方根
手写平方根
以下是 Python 中的实现:
def exact_sqrt(x):
"""Calculate the square root of an arbitrarily large integer.
The result of exact_sqrt(x) is a tuple (a, r) such that a**2 + r = x, where
a is the largest integer such that a**2 <= x, and r is the "remainder". If
x is a perfect square, then r will be zero.
The algorithm used is the "long-hand square root" algorithm, as described at
http://mathforum.org/library/drmath/view/52656.html
Tobin Fricke 2014-04-23
Max Planck Institute for Gravitational Physics
Hannover, Germany
"""
N = 0 # Problem so far
a = 0 # Solution so far
# We'll process the number two bits at a time, starting at the MSB
L = x.bit_length()
L += (L % 2) # Round up to the next even number
for i in xrange(L, -1, -1):
# Get the next group of two bits
n = (x >> (2*i)) & 0b11
# Check whether we can reduce the remainder
if ((N - a*a) << 2) + n >= (a<<2) + 1:
b = 1
else:
b = 0
a = (a << 1) | b # Concatenate the next bit of the solution
N = (N << 2) | n # Concatenate the next bit of the problem
return (a, N-a*a)
您可以轻松修改此函数以进行额外的迭代来计算平方根的小数部分。我最感兴趣的是计算大完全平方的根。
我不确定这与“整数牛顿法”算法相比如何。我怀疑牛顿法更快,因为它原则上可以在一次迭代中生成多位解决方案,而“长手”算法每次迭代只会生成一位解决方案。
源代码仓库:https ://gist.github.com/tobin/11233492
解决方案 6:
一种选择是使用decimal
模块,并以足够精确的浮点数执行此操作:
import decimal
def isqrt(n):
nd = decimal.Decimal(n)
with decimal.localcontext() as ctx:
ctx.prec = n.bit_length()
i = int(nd.sqrt())
if i**2 != n:
raise ValueError('input was not a perfect square')
return i
我认为这应该有效:
>>> isqrt(1)
1
>>> isqrt(7**14) == 7**7
True
>>> isqrt(11**1000) == 11**500
True
>>> isqrt(11**1000+1)
Traceback (most recent call last):
File "<ipython-input-121-e80953fb4d8e>", line 1, in <module>
isqrt(11**1000+1)
File "<ipython-input-100-dd91f704e2bd>", line 10, in isqrt
raise ValueError('input was not a perfect square')
ValueError: input was not a perfect square
解决方案 7:
Python的默认math
库有一个整数平方根函数:
math.isqrt(n)
返回非负整数n的整数平方根。这是n的精确平方根的下限,或者等效地是满足a² ≤ n 的最大整数 a 。
解决方案 8:
好像你可以这样检查:
if int(math.sqrt(n))**2 == n:
print n, 'is a perfect square'
更新:
正如您所指出的,上述方法在较大的 值时会失败。对于这些,下面的方法看起来很有希望,这是 Martin Guy @ UKC 于 1985 年 6 月编写的示例 C 代码的改编版,用于维基百科文章“计算平方根的方法n
”中提到的相对简单的二进制数字逐位计算方法:
from math import ceil, log
def isqrt(n):
res = 0
bit = 4**int(ceil(log(n, 4))) if n else 0 # smallest power of 4 >= the argument
while bit:
if n >= res + bit:
n -= res + bit
res = (res >> 1) + bit
else:
res >>= 1
bit >>= 2
return res
if __name__ == '__main__':
from math import sqrt # for comparison purposes
for i in range(17)+[2**53, (10**100+1)**2]:
is_perfect_sq = isqrt(i)**2 == i
print '{:21,d}: math.sqrt={:12,.7G}, isqrt={:10,d} {}'.format(
i, sqrt(i), isqrt(i), '(perfect square)' if is_perfect_sq else '')
输出:
0: math.sqrt= 0, isqrt= 0 (perfect square)
1: math.sqrt= 1, isqrt= 1 (perfect square)
2: math.sqrt= 1.414214, isqrt= 1
3: math.sqrt= 1.732051, isqrt= 1
4: math.sqrt= 2, isqrt= 2 (perfect square)
5: math.sqrt= 2.236068, isqrt= 2
6: math.sqrt= 2.44949, isqrt= 2
7: math.sqrt= 2.645751, isqrt= 2
8: math.sqrt= 2.828427, isqrt= 2
9: math.sqrt= 3, isqrt= 3 (perfect square)
10: math.sqrt= 3.162278, isqrt= 3
11: math.sqrt= 3.316625, isqrt= 3
12: math.sqrt= 3.464102, isqrt= 3
13: math.sqrt= 3.605551, isqrt= 3
14: math.sqrt= 3.741657, isqrt= 3
15: math.sqrt= 3.872983, isqrt= 3
16: math.sqrt= 4, isqrt= 4 (perfect square)
9,007,199,254,740,992: math.sqrt=9.490627E+07, isqrt=94,906,265
100,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,020,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,001: math.sqrt= 1E+100, isqrt=10,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,001 (perfect square)
解决方案 9:
下面的脚本提取整数平方根。它不使用除法,只使用位移位,因此速度非常快。它使用牛顿法计算平方根的倒数,这是一项因Quake III Arena而闻名的技术,如 Wikipedia 文章“快速平方根倒数”中所述。
该算法计算的策略s = sqrt(Y)
如下。
将参数 Y 减小到 [1/4, 1) 范围内的 y,即 y = Y/B,其中 1/4 <= y < 1,其中 B 是 2 的偶数次方,因此
B = 2**(2*k)
对于某个整数 k。我们想要找到 X,其中 x = X/B,并且 x = 1 / sqrt(y)。使用二次极小极大多项式确定 X 的首次近似值。
使用牛顿法细化X。
计算
s = X*Y/(2**(3*k))
。
我们实际上并不创建分数或执行任何除法。所有算术都是用整数完成的,我们使用位移位来除以 B 的各种幂。
范围缩减让我们找到一个很好的初始近似值,以提供给牛顿法。以下是区间 [1/4, 1] 中平方根倒数的 2 次极小极大多项式近似版本:
/256)
(抱歉,我在这里颠倒了 x 和 y 的含义,以符合惯例)。此近似值的最大误差约为 0.0355 ~= 1/28。以下是显示误差的图表:
使用这个多项式,我们的初始 x 至少具有 4 或 5 位精度。牛顿法的每一轮都会使精度翻倍,因此如果我们需要的话,不需要很多轮就能获得数千位。
""" Integer square root
Uses no divisions, only shifts
"Quake" style algorithm,
i.e., Newton's method for 1 / sqrt(y)
Uses a quadratic minimax polynomial for the first approximation
Written by PM 2Ring 2022.01.23
"""
def int_sqrt(y):
if y < 0:
raise ValueError("int_sqrt arg must be >= 0, not %s" % y)
if y < 2:
return y
# print("
*", y, "*")
# Range reduction.
# Find k such that 1/4 <= y/b < 1, where b = 2 ** (k*2)
j = y.bit_length()
# Round k*2 up to the next even number
k2 = j + (j & 1)
# k and some useful multiples
k = k2 >> 1
k3 = k2 + k
k6 = k3 << 1
kd = k6 + 1
# b cubed
b3 = 1 << k6
# Minimax approximation: x/b ~= 1 / sqrt(y/b)
x = (((463 * y * y) >> k2) - (896 * y) + (698 << k2)) >> 8
# print(" ", x, h)
# Newton's method for 1 / sqrt(y/b)
epsilon = 1 << k
for i in range(1, 99):
dx = x * (b3 - y * x * x) >> kd
x += dx
# print(f" {i}: {x} {dx}")
if abs(dx) <= epsilon:
break
# s == sqrt(y)
s = x * y >> k3
# Adjust if too low
ss = s + 1
return ss if ss * ss <= y else s
def test(lo, hi, step=1):
for y in range(lo, hi, step):
s = int_sqrt(y)
ss = s + 1
s2, ss2 = s * s, ss * ss
assert s2 <= y < ss2, (y, s2, ss2)
print("ok")
test(0, 100000, 1)
这段代码肯定比和慢。其目的只是为了说明算法。如果用 C 实现,看看它有多快会很有趣……math.isqrt
`decimal.Decimal.sqrt`
这是在 SageMathCell 服务器上运行的实时版本hi
。设置<= 0 可计算并显示 中设置的单个值的结果lo
。您可以在输入框中输入表达式,例如设置hi
为 0 并设置lo
为2 * 10**100
可获得sqrt(2) * 10**50
。
解决方案 10:
受到所有答案的启发,决定用纯 C++ 实现这些答案中的几种最佳方法。众所周知,C++ 总是比 Python 快。
为了粘合 C++ 和 Python,我使用了Cython。它允许将 C++ 转换为 Python 模块,然后直接从 Python 函数调用 C++ 函数。
另外作为补充,我不仅提供了采用 Python 的代码,还提供带有测试的纯 C++。
以下是纯 C++ 测试的时间:
Test 'GMP', bits 64, time 0.000001 sec
Test 'AndersKaseorg', bits 64, time 0.000003 sec
Test 'Babylonian', bits 64, time 0.000006 sec
Test 'ChordTangent', bits 64, time 0.000018 sec
Test 'GMP', bits 50000, time 0.000118 sec
Test 'AndersKaseorg', bits 50000, time 0.002777 sec
Test 'Babylonian', bits 50000, time 0.003062 sec
Test 'ChordTangent', bits 50000, time 0.009120 sec
和相同的 C++ 函数,但采用的 Python 模块具有时间安排:
Bits 50000
math.isqrt: 2.819 ms
gmpy2.isqrt: 0.166 ms
ISqrt_GMP: 0.252 ms
ISqrt_AndersKaseorg: 3.338 ms
ISqrt_Babylonian: 3.756 ms
ISqrt_ChordTangent: 10.564 ms
从某种意义上说,我的 Cython-C++ 是一个很好的框架,适合那些想要直接从 Python 编写和测试自己的 C++ 方法的人。
正如您在上面的计时示例中注意到的那样,我使用了以下方法:
math.isqrt,来自标准库的实现。
gmpy2.isqrt,GMPY2 库的实现。
ISqrt_GMP - 与 GMPY2 相同,但使用我的 Cython 模块,在那里我
<gmpxx.h>
直接使用 C++ GMP 库()。ISqrt_AndersKaseorg,代码取自@AndersKaseorg 的回答。
ISqrt_Babylonian,取自 Wikipedia文章的方法,即所谓的巴比伦方法。据我所知,这是我自己的实现。
ISqrt_ChordTangent,这是我自己的方法,我称之为 Chord-Tangent,因为它使用弦和切线来迭代缩短搜索间隔。这种方法在我的另一篇文章中有适度详细的描述。这种方法很好,因为它不仅搜索平方根,还搜索任何 K 的 K 次方根。我画了一张小图来展示这个算法的细节。
关于编译 C++/Cython 代码,我使用了GMP库。您需要先安装它,在 Linux 下很容易通过sudo apt install libgmp-dev
。
在 Windows 下最简单的方法是安装非常棒的程序VCPKG,这是一个软件包管理器,类似于 Linux 中的 APT。VCPKG 使用Visual Studio从源代码编译所有软件包(不要忘记安装Visual Studio 的社区版本)。安装 VCPKG 后,您可以通过 安装 GMP vcpkg install gmp
。您也可以安装MPIR,这是 GMP 的替代分支,您可以通过 安装它vcpkg install mpir
。
在 Windows 下安装 GMP 后,请编辑我的 Python 代码并替换包含目录和库文件的路径。安装结束时的 VCPKG 应显示包含 GMP 库的 ZIP 文件的路径,其中包含 .lib 和 .h 文件。
您可能注意到,在 Python 代码中,我还设计了一个特别方便的cython_compile()
函数,用于将任何 C++ 代码编译成 Python 模块。这个函数非常好,因为它允许您轻松地将任何 C++ 代码插入 Python,并且可以多次重复使用。
如果您有任何问题或建议,或者您的电脑上某些功能无法运行,请在评论中写下。
下面我首先展示 Python 代码,然后展示 C++ 代码。请参阅Try it online!
上面的 C++ 代码链接,在 GodBolt 服务器上在线运行代码。这两个代码片段都是完全可从头运行的,无需对它们进行任何编辑。
def cython_compile(srcs):
import json, hashlib, os, glob, importlib, sys, shutil, tempfile
srch = hashlib.sha256(json.dumps(srcs, sort_keys = True, ensure_ascii = True).encode('utf-8')).hexdigest().upper()[:12]
pdir = 'cyimp'
if len(glob.glob(f'{pdir}/cy{srch}*')) == 0:
class ChDir:
def __init__(self, newd):
self.newd = newd
def __enter__(self):
self.curd = os.getcwd()
os.chdir(self.newd)
return self
def __exit__(self, ext, exv, tb):
os.chdir(self.curd)
os.makedirs(pdir, exist_ok = True)
with tempfile.TemporaryDirectory(dir = pdir) as td, ChDir(str(td)) as chd:
os.makedirs(pdir, exist_ok = True)
for k, v in srcs.items():
with open(f'cys{srch}_{k}', 'wb') as f:
f.write(v.replace('{srch}', srch).encode('utf-8'))
import numpy as np
from setuptools import setup, Extension
from Cython.Build import cythonize
sys.argv += ['build_ext', '--inplace']
setup(
ext_modules = cythonize(
Extension(
f'{pdir}.cy{srch}', [f'cys{srch}_{k}' for k in filter(lambda e: e[e.rfind('.') + 1:] in ['pyx', 'c', 'cpp'], srcs.keys())],
depends = [f'cys{srch}_{k}' for k in filter(lambda e: e[e.rfind('.') + 1:] not in ['pyx', 'c', 'cpp'], srcs.keys())],
extra_compile_args = ['/O2', '/std:c++latest',
'/ID:/dev/_3party/vcpkg_bin/gmp/include/',
],
),
compiler_directives = {'language_level': 3, 'embedsignature': True},
annotate = True,
),
include_dirs = [np.get_include()],
)
del sys.argv[-2:]
for f in glob.glob(f'{pdir}/cy{srch}*'):
shutil.copy(f, f'./../')
print('Cython module:', f'cy{srch}')
return importlib.import_module(f'{pdir}.cy{srch}')
def cython_import():
srcs = {
'lib.h': """
#include <cstring>
#include <cstdint>
#include <stdexcept>
#include <tuple>
#include <iostream>
#include <string>
#include <type_traits>
#include <sstream>
#include <gmpxx.h>
#pragma comment(lib, "D:/dev/_3party/vcpkg_bin/gmp/lib/gmp.lib")
#define ASSERT_MSG(cond, msg) { if (!(cond)) throw std::runtime_error("Assertion (" #cond ") failed at line " + std::to_string(__LINE__) + "! Msg '" + std::string(msg) + "'."); }
#define ASSERT(cond) ASSERT_MSG(cond, "")
#define LN { std::cout << "LN " << __LINE__ << std::endl; }
using u32 = uint32_t;
using u64 = uint64_t;
template <typename T>
size_t BitLen(T n) {
if constexpr(std::is_same_v<std::decay_t<T>, mpz_class>)
return mpz_sizeinbase(n.get_mpz_t(), 2);
else {
size_t cnt = 0;
while (n >= (1ULL << 32)) {
cnt += 32;
n >>= 32;
}
while (n >= (1 << 8)) {
cnt += 8;
n >>= 8;
}
while (n) {
++cnt;
n >>= 1;
}
return cnt;
}
}
template <typename T>
T ISqrt_Babylonian(T const & y) {
// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
if (y <= 1)
return y;
T x = T(1) << (BitLen(y) / 2), a = 0, b = 0, limit = 3;
while (true) {
size_t constexpr loops = 3;
for (size_t i = 0; i < loops; ++i) {
if (i + 1 >= loops)
a = x;
b = y;
b /= x;
x += b;
x >>= 1;
}
if (b < a)
std::swap(a, b);
if (b - a > limit)
continue;
++b;
for (size_t i = 0; a <= b; ++a, ++i)
if (a * a > y) {
if (i == 0)
break;
else
return a - 1;
}
ASSERT(false);
}
}
template <typename T>
T ISqrt_AndersKaseorg(T const & n) {
// https://stackoverflow.com/a/53983683/941531
if (n > 0) {
T y = 0, x = T(1) << ((BitLen(n) + 1) >> 1);
while (true) {
y = (x + n / x) >> 1;
if (y >= x)
return x;
x = y;
}
} else if (n == 0)
return 0;
else
ASSERT_MSG(false, "square root not defined for negative numbers");
}
template <typename T>
T ISqrt_GMP(T const & y) {
// https://gmplib.org/manual/Integer-Roots
mpz_class r, n;
bool constexpr is_mpz = std::is_same_v<std::decay_t<T>, mpz_class>;
if constexpr(is_mpz)
n = y;
else {
static_assert(sizeof(T) <= 8);
n = u32(y >> 32);
n <<= 32;
n |= u32(y);
}
mpz_sqrt(r.get_mpz_t(), n.get_mpz_t());
if constexpr(is_mpz)
return r;
else
return (u64(mpz_get_ui(mpz_class(r >> 32).get_mpz_t())) << 32) | u64(mpz_get_ui(mpz_class(r & u32(-1)).get_mpz_t()));
}
template <typename T>
T KthRoot_ChordTangent(T const & n, size_t k = 2) {
// https://i.sstatic.net/et9O0.jpg
if (n <= 1)
return n;
auto KthPow = [&](auto const & x){
T y = x * x;
for (size_t i = 2; i < k; ++i)
y *= x;
return y;
};
auto KthPowDer = [&](auto const & x){
T y = x * u32(k);
for (size_t i = 1; i + 1 < k; ++i)
y *= x;
return y;
};
size_t root_bit_len = (BitLen(n) + k - 1) / k;
T hi = T(1) << root_bit_len,
x_begin = hi >> 1, x_end = hi,
y_begin = KthPow(x_begin), y_end = KthPow(x_end),
x_mid = 0, y_mid = 0, x_n = 0, y_n = 0, tangent_x = 0, chord_x = 0;
for (size_t icycle = 0; icycle < (1 << 30); ++icycle) {
if (x_end <= x_begin + 2)
break;
if constexpr(0) { // Do Binary Search step if needed
x_mid = (x_begin + x_end) >> 1;
y_mid = KthPow(x_mid);
if (y_mid > n) {
x_end = x_mid; y_end = y_mid;
} else {
x_begin = x_mid; y_begin = y_mid;
}
}
// (y_end - y_begin) / (x_end - x_begin) = (n - y_begin) / (x_n - x_begin) ->
x_n = x_begin + (n - y_begin) * (x_end - x_begin) / (y_end - y_begin);
y_n = KthPow(x_n);
tangent_x = x_n + (n - y_n) / KthPowDer(x_n) + 1;
chord_x = x_n + (n - y_n) * (x_end - x_n) / (y_end - y_n);
//ASSERT(chord_x <= tangent_x);
x_begin = chord_x; x_end = tangent_x;
y_begin = KthPow(x_begin); y_end = KthPow(x_end);
//ASSERT(y_begin <= n);
//ASSERT(y_end > n);
}
for (size_t i = 0; x_begin <= x_end; ++x_begin, ++i)
if (x_begin * x_begin > n) {
if (i == 0)
break;
else
return x_begin - 1;
}
ASSERT(false);
return 0;
}
mpz_class FromLimbs(uint64_t * limbs, uint64_t * cnt) {
mpz_class r;
mpz_import(r.get_mpz_t(), *cnt, -1, 8, -1, 0, limbs);
return r;
}
void ToLimbs(mpz_class const & n, uint64_t * limbs, uint64_t * cnt) {
uint64_t cnt_before = *cnt;
size_t cnt_res = 0;
mpz_export(limbs, &cnt_res, -1, 8, -1, 0, n.get_mpz_t());
ASSERT(cnt_res <= cnt_before);
std::memset(limbs + cnt_res, 0, (cnt_before - cnt_res) * 8);
*cnt = cnt_res;
}
void ISqrt_GMP_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_GMP<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_AndersKaseorg_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_AndersKaseorg<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_Babylonian_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_Babylonian<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_ChordTangent_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(KthRoot_ChordTangent<mpz_class>(FromLimbs(limbs, cnt), 2), limbs, cnt);
}
""",
'main.pyx': r"""
# distutils: language = c++
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
import numpy as np
cimport numpy as np
cimport cython
from libc.stdint cimport *
cdef extern from "cys{srch}_lib.h" nogil:
void ISqrt_ChordTangent_Py(uint64_t * limbs, uint64_t * cnt);
void ISqrt_GMP_Py(uint64_t * limbs, uint64_t * cnt);
void ISqrt_AndersKaseorg_Py(uint64_t * limbs, uint64_t * cnt);
void ISqrt_Babylonian_Py(uint64_t * limbs, uint64_t * cnt);
@cython.boundscheck(False)
@cython.wraparound(False)
def ISqrt(method, n):
mask64 = (1 << 64) - 1
def ToLimbs():
return np.copy(np.frombuffer(n.to_bytes((n.bit_length() + 63) // 64 * 8, 'little'), dtype = np.uint64))
words = (n.bit_length() + 63) // 64
t = n
r = np.zeros((words,), dtype = np.uint64)
for i in range(words):
r[i] = np.uint64(t & mask64)
t >>= 64
return r
def FromLimbs(x):
return int.from_bytes(x.tobytes(), 'little')
n = 0
for i in range(x.shape[0]):
n |= int(x[i]) << (i * 64)
return n
n = ToLimbs()
cdef uint64_t[:] cn = n
cdef uint64_t ccnt = len(n)
cdef uint64_t cmethod = {'GMP': 0, 'AndersKaseorg': 1, 'Babylonian': 2, 'ChordTangent': 3}[method]
with nogil:
(ISqrt_GMP_Py if cmethod == 0 else ISqrt_AndersKaseorg_Py if cmethod == 1 else ISqrt_Babylonian_Py if cmethod == 2 else ISqrt_ChordTangent_Py)(
<uint64_t *>&cn[0], <uint64_t *>&ccnt
)
return FromLimbs(n[:ccnt])
""",
}
return cython_compile(srcs)
def main():
import math, gmpy2, timeit, random
mod = cython_import()
fs = [
('math.isqrt', math.isqrt),
('gmpy2.isqrt', gmpy2.isqrt),
('ISqrt_GMP', lambda n: mod.ISqrt('GMP', n)),
('ISqrt_AndersKaseorg', lambda n: mod.ISqrt('AndersKaseorg', n)),
('ISqrt_Babylonian', lambda n: mod.ISqrt('Babylonian', n)),
('ISqrt_ChordTangent', lambda n: mod.ISqrt('ChordTangent', n)),
]
times = [0] * len(fs)
ntests = 1 << 6
bits = 50000
for i in range(ntests):
n = random.randrange(1 << (bits - 1), 1 << bits)
ref = None
for j, (fn, f) in enumerate(fs):
timeit_cnt = 3
tim = timeit.timeit(lambda: f(n), number = timeit_cnt) / timeit_cnt
times[j] += tim
x = f(n)
if j == 0:
ref = x
else:
assert x == ref, (fn, ref, x)
print('Bits', bits)
print('
'.join([f'{fs[i][0]:>19}: {round(times[i] / ntests * 1000, 3):>7} ms' for i in range(len(fs))]))
if __name__ == '__main__':
main()
和 C++:
在线尝试一下!
#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <tuple>
#include <iostream>
#include <string>
#include <type_traits>
#include <sstream>
#include <gmpxx.h>
#define ASSERT_MSG(cond, msg) { if (!(cond)) throw std::runtime_error("Assertion (" #cond ") failed at line " + std::to_string(__LINE__) + "! Msg '" + std::string(msg) + "'."); }
#define ASSERT(cond) ASSERT_MSG(cond, "")
#define LN { std::cout << "LN " << __LINE__ << std::endl; }
using u32 = uint32_t;
using u64 = uint64_t;
template <typename T>
size_t BitLen(T n) {
if constexpr(std::is_same_v<std::decay_t<T>, mpz_class>)
return mpz_sizeinbase(n.get_mpz_t(), 2);
else {
size_t cnt = 0;
while (n >= (1ULL << 32)) {
cnt += 32;
n >>= 32;
}
while (n >= (1 << 8)) {
cnt += 8;
n >>= 8;
}
while (n) {
++cnt;
n >>= 1;
}
return cnt;
}
}
template <typename T>
T ISqrt_Babylonian(T const & y) {
// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
if (y <= 1)
return y;
T x = T(1) << (BitLen(y) / 2), a = 0, b = 0, limit = 3;
while (true) {
size_t constexpr loops = 3;
for (size_t i = 0; i < loops; ++i) {
if (i + 1 >= loops)
a = x;
b = y;
b /= x;
x += b;
x >>= 1;
}
if (b < a)
std::swap(a, b);
if (b - a > limit)
continue;
++b;
for (size_t i = 0; a <= b; ++a, ++i)
if (a * a > y) {
if (i == 0)
break;
else
return a - 1;
}
ASSERT(false);
}
}
template <typename T>
T ISqrt_AndersKaseorg(T const & n) {
// https://stackoverflow.com/a/53983683/941531
if (n > 0) {
T y = 0, x = T(1) << ((BitLen(n) + 1) >> 1);
while (true) {
y = (x + n / x) >> 1;
if (y >= x)
return x;
x = y;
}
} else if (n == 0)
return 0;
else
ASSERT_MSG(false, "square root not defined for negative numbers");
}
template <typename T>
T ISqrt_GMP(T const & y) {
// https://gmplib.org/manual/Integer-Roots
mpz_class r, n;
bool constexpr is_mpz = std::is_same_v<std::decay_t<T>, mpz_class>;
if constexpr(is_mpz)
n = y;
else {
static_assert(sizeof(T) <= 8);
n = u32(y >> 32);
n <<= 32;
n |= u32(y);
}
mpz_sqrt(r.get_mpz_t(), n.get_mpz_t());
if constexpr(is_mpz)
return r;
else
return (u64(mpz_get_ui(mpz_class(r >> 32).get_mpz_t())) << 32) | u64(mpz_get_ui(mpz_class(r & u32(-1)).get_mpz_t()));
}
template <typename T>
std::string IntToStr(T n) {
if constexpr(std::is_same_v<std::decay_t<T>, mpz_class>)
return n.get_str();
else {
std::ostringstream ss;
ss << n;
return ss.str();
}
}
template <typename T>
T KthRoot_ChordTangent(T const & n, size_t k = 2) {
// https://i.sstatic.net/et9O0.jpg
if (n <= 1)
return n;
auto KthPow = [&](auto const & x){
T y = x * x;
for (size_t i = 2; i < k; ++i)
y *= x;
return y;
};
auto KthPowDer = [&](auto const & x){
T y = x * u32(k);
for (size_t i = 1; i + 1 < k; ++i)
y *= x;
return y;
};
size_t root_bit_len = (BitLen(n) + k - 1) / k;
T hi = T(1) << root_bit_len,
x_begin = hi >> 1, x_end = hi,
y_begin = KthPow(x_begin), y_end = KthPow(x_end),
x_mid = 0, y_mid = 0, x_n = 0, y_n = 0, tangent_x = 0, chord_x = 0;
for (size_t icycle = 0; icycle < (1 << 30); ++icycle) {
//std::cout << "x_begin, x_end = " << IntToStr(x_begin) << ", " << IntToStr(x_end) << ", n " << IntToStr(n) << std::endl;
if (x_end <= x_begin + 2)
break;
if constexpr(0) { // Do Binary Search step if needed
x_mid = (x_begin + x_end) >> 1;
y_mid = KthPow(x_mid);
if (y_mid > n) {
x_end = x_mid; y_end = y_mid;
} else {
x_begin = x_mid; y_begin = y_mid;
}
}
// (y_end - y_begin) / (x_end - x_begin) = (n - y_begin) / (x_n - x_begin) ->
x_n = x_begin + (n - y_begin) * (x_end - x_begin) / (y_end - y_begin);
y_n = KthPow(x_n);
tangent_x = x_n + (n - y_n) / KthPowDer(x_n) + 1;
chord_x = x_n + (n - y_n) * (x_end - x_n) / (y_end - y_n);
//ASSERT(chord_x <= tangent_x);
x_begin = chord_x; x_end = tangent_x;
y_begin = KthPow(x_begin); y_end = KthPow(x_end);
//ASSERT(y_begin <= n);
//ASSERT(y_end > n);
}
for (size_t i = 0; x_begin <= x_end; ++x_begin, ++i)
if (x_begin * x_begin > n) {
if (i == 0)
break;
else
return x_begin - 1;
}
ASSERT(false);
return 0;
}
mpz_class FromLimbs(uint64_t * limbs, uint64_t * cnt) {
mpz_class r;
mpz_import(r.get_mpz_t(), *cnt, -1, 8, -1, 0, limbs);
return r;
}
void ToLimbs(mpz_class const & n, uint64_t * limbs, uint64_t * cnt) {
uint64_t cnt_before = *cnt;
size_t cnt_res = 0;
mpz_export(limbs, &cnt_res, -1, 8, -1, 0, n.get_mpz_t());
ASSERT(cnt_res <= cnt_before);
std::memset(limbs + cnt_res, 0, (cnt_before - cnt_res) * 8);
*cnt = cnt_res;
}
void ISqrt_ChordTangent_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(KthRoot_ChordTangent<mpz_class>(FromLimbs(limbs, cnt), 2), limbs, cnt);
}
void ISqrt_GMP_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_GMP<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_AndersKaseorg_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_AndersKaseorg<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_Babylonian_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_Babylonian<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
// Testing
#include <chrono>
#include <random>
#include <vector>
#include <iomanip>
inline double Time() {
static auto const gtb = std::chrono::high_resolution_clock::now();
return std::chrono::duration_cast<std::chrono::duration<double>>(std::chrono::high_resolution_clock::now() - gtb)
.count();
}
template <typename T, typename F>
std::vector<T> Test0(std::string const & test_name, size_t bits, size_t ntests, F && f) {
std::mt19937_64 rng{123};
std::vector<T> nums;
for (size_t i = 0; i < ntests; ++i) {
T n = 0;
for (size_t j = 0; j < bits; j += 32) {
size_t const cbits = std::min<size_t>(32, bits - j);
n <<= cbits;
n ^= u32(rng()) >> (32 - cbits);
}
nums.push_back(n);
}
auto tim = Time();
for (auto & n: nums)
n = f(n);
tim = Time() - tim;
std::cout << "Test " << std::setw(15) << ("'" + test_name + "'")
<< ", bits " << std::setw(6) << bits << ", time "
<< std::fixed << std::setprecision(6) << std::setw(9) << tim / ntests << " sec" << std::endl;
return nums;
}
void Test() {
auto f = [](auto ty, size_t bits, size_t ntests){
using T = std::decay_t<decltype(ty)>;
auto tim = Time();
auto a = Test0<T>("GMP", bits, ntests, [](auto const & x){ return ISqrt_GMP<T>(x); });
auto b = Test0<T>("AndersKaseorg", bits, ntests, [](auto const & x){ return ISqrt_AndersKaseorg<T>(x); });
ASSERT(b == a);
auto c = Test0<T>("Babylonian", bits, ntests, [](auto const & x){ return ISqrt_Babylonian<T>(x); });
ASSERT(c == a);
auto d = Test0<T>("ChordTangent", bits, ntests, [](auto const & x){ return KthRoot_ChordTangent<T>(x); });
ASSERT(d == a);
std::cout << "Bits " << bits << " nums " << ntests << " time " << std::fixed << std::setprecision(1) << (Time() - tim) << " sec" << std::endl;
};
for (auto p: std::vector<std::pair<int, int>>{{15, 1 << 19}, {30, 1 << 19}})
f(u64(), p.first, p.second);
for (auto p: std::vector<std::pair<int, int>>{{64, 1 << 15}, {8192, 1 << 10}, {50000, 1 << 5}})
f(mpz_class(), p.first, p.second);
}
int main() {
try {
Test();
return 0;
} catch (std::exception const & ex) {
std::cout << "Exception: " << ex.what() << std::endl;
return -1;
}
}
解决方案 11:
您的函数因输入过多而失败:
In [26]: isqrt((10**100+1)**2)
ValueError: input was not a perfect square
ActiveState 网站上有一个配方,希望它更可靠,因为它只使用整数数学。它基于早期的 StackOverflow 问题:编写自己的平方根函数
解决方案 12:
浮点数无法在计算机上精确表示。您可以测试所需的接近度,将 epsilon 设置为 Python 浮点数精度范围内的较小值。
def isqrt(n):
epsilon = .00000000001
i = int(n**.5 + 0.5)
if abs(i**2 - n) < epsilon:
return i
raise ValueError('input was not a perfect square')
解决方案 13:
我用循环比较了这里给出的不同方法:
for i in range (1000000): # 700 msec
r=int(123456781234567**0.5+0.5)
if r**2==123456781234567:rr=r
else:rr=-1
发现这个是最快的,不需要数学导入。太长可能会失败,但看看这个
15241576832799734552675677489**0.5 = 123456781234567.0
解决方案 14:
尝试这个条件(无需额外计算):
def isqrt(n):
i = math.sqrt(n)
if i != int(i):
raise ValueError('input was not a perfect square')
return i
如果您需要它返回一个int
(而不是float
带有尾随零的),那么可以分配第二个变量或计算int(i)
两次。
扫码咨询,免费领取项目管理大礼包!