An implementation of Euler's partition function in Python
This post looks at various ways to implement Euler's Partition Function P(n) (outlined here on Wolfram's MathWorld in Figure 11) in Python. The description of Euler's function here is taken entirely from the MathWorld reference above, this post will simply look at a Python implementation and how to test both the correctness and performance of our implementations.
From MathWorld:
P(n) gives the number of ways to sum positive integers to n. For example, 4 can be written as the following sums
$4 = 4$
$\ \ = 3 + 1$
$\ \ = 2 + 2$
$\ \ = 2 + 1 + 1$
$\ \ = 1 + 1 + 1 +1$
This gives 5 ways to sum positive integers to 4, thus P(4) = 5.
Testing Our Partition Function for Correctness¶
The Online Encyclopedia of Integer Sequences (OEIS A000041) gives P(n) for n=0..49 which we will use for testing our implementation of P(n).
Our test function will take a function, P(n), which in turn takes a single integer argument. The test function will iterate over the list of known P(n) from OEIS above and assert that calling our function returns the correct value.
from typing import Callable
def test_pn(pn_implementation: Callable[[int], int]) -> None:
pn_truths = [
1, 1, 2, 3, 5, 7, 11, 15, 22, 30, 42, 56, 77, 101, 135, 176, 231,
297, 385, 490, 627, 792, 1002, 1255, 1575, 1958, 2436, 3010,
3718, 4565, 5604, 6842, 8349, 10143, 12310, 14883, 17977,
21637, 26015, 31185, 37338, 44583, 53174, 63261, 75175,
89134, 105558, 124754, 147273, 173525
]
for n, pn_truth in enumerate(pn_truths):
assert pn_implementation(n) == pn_truth, (
f"Error: Implementation gave P({n}) == {pn_implementation(n)}. "
f"Correct P({n}) == {pn_truth}"
)
Implementing P(n)¶
Figure 11 in the MathWorld reference on Partition Function P gives P(n) as
$P(n) = \sum_{k=1}^{n} (-1)^{k+1}[P(n - \frac{1}{2}k(3k-1))+P(n - \frac{1}{2}k(3k+1))]$
This recurrence equation is naturally expressed as a recursive function with a base case of P(0) = 1
Let's look first at an implementation that most closely resembles the function above
def P(n: int) -> int:
if n == 0:
return 1
total = 0
for k in range(1, n+1):
total += pow(-1, k+1) * (P(n - k*(3*k-1)//2) + P(n - k*(3*k+1)//2))
return total
This implementation, while simple, is awfully slow. Even our test suite to check all cases for n < 50 will not run in a reasonable amount of time. Instead, we will spot check our results with a few known P(n)
assert P(0) == 1
assert P(1) == 1
assert P(5) == 7
assert P(10) == 42
To get a baseline for measuring our function's performance, we will use the Jupyter cell magic %%time
to measure the execution time of P(30)
%%time
P(30)
Our implementation of P(n), for even moderately large values of n, takes on the order of tens of seconds! To get an understanding of what is taking up so much of our function's execution time, we can profile it with the cell magic %%prun
%%prun
P(30)
As we can see, our function P(n) was called a whopping 51,580,377 times! A major reason for the slow computation time is that the same P(n) will be repeatedly calculated within the recursive calls. To resolve this, we can cache computed values of P(n) to avoid repeated calculations.
To do this, we can create a class that generates a callable object which stores computed values of P(n) in a member variable.
class PWithCache:
def __init__(self):
self.computed_pn = {0: 1}
def __call__(self, n: int) -> int:
if n in self.computed_pn:
return self.computed_pn[n]
total = 0
for k in range(1, n+1):
# note: the call to self() here actually invokes self.__call__ (i.e. this function)
# so this function is also recursive, but with self.computed_pn stored between
# calls
total += pow(-1, k+1) * (self(n - k*(3*k-1)//2) + self(n - k*(3*k+1)//2))
self.computed_pn[n] = total
return total
P = PWithCache()
Our first call begins without any P(n) computed, but the caching between recursive calls will immediately result in a dramatic speed-up, with P(30) now taking on the order of milliseconds.
%%time
assert P(30) == 5604
Of course, on subsequent calls, P(30) has already been computed which means the call will simply result in a dictionary look-up which happens on the order of hundreds of nanoseconds.
%%timeit
assert P(30) == 5604
Finally, we'll run our implementation over our test suite
test_pn(P)
Additional Speed-Ups¶
Let us profile P(n) once again for n = 1000 to characterize its performance.
%%timeit
P = PWithCache()
P(1000)
P = PWithCache()
%prun P(1000)
One interesting thing we will notice right away is that P(n) is called 1,001,001 times while pow() is called only 500,500 times even though pow() should be called n times for each call to P(n). This is because the terms $n - \frac{1}{2}k(3k-1)$ and $n - \frac{1}{2}k(3k+1)$ frequently evaluate to negative numbers, in which case the summation loop is bypassed and 0 is returned.
Because function calls in Python are expensive, it would be better to simply replace the term with 0 rather than call P(n) with n < 0. The implementation below first calculates the terms $n - \frac{1}{2}k(3k-1)$ and $n - \frac{1}{2}k(3k+1)$ and replaces the respective P(n) terms with 0 if they are negative.
class PWithCache:
def __init__(self):
self.computed_pn = {0: 1}
def __call__(self, n: int) -> int:
if n in self.computed_pn:
return self.computed_pn[n]
total = 0
for k in range(1, n+1):
minus_one_term = n - k*(3*k-1)//2
plus_one_term = n - k*(3*k+1)//2
first_term = 0 if minus_one_term < 0 else self(minus_one_term)
second_term = 0 if plus_one_term < 0 else self(plus_one_term)
total += pow(-1, k+1) * (first_term + second_term)
self.computed_pn[n] = total
return total
P = PWithCache()
test_pn(P)
%%timeit
P = PWithCache()
P(1000)
P = PWithCache()
%prun P(1000)
We can see that by cleverly avoiding unnecessary calls to P(n), we have decreased the number of times it's called from 1,001,001 to a mere 33,476 times resulting in a roughly 2x speed-up!
Examining Calls to pow()¶
From our profiling metrics above, we see that our calls to pow() are now taking up nearly half of our execution time. Re-examining our function, we can take a second look at the term $(-1)^{k+1}$
$P(n) = \sum_{k=1}^{n} (-1)^{k+1}[P(n - \frac{1}{2}k(3k-1))+P(n - \frac{1}{2}k(3k+1))]$
Careful consideration will show us that the purpose of the term $(-1)^{k+1}$ is that we will subtract the intermediate result when k is even and add the intermediate result when k is odd. That is, $(-1)^{k+1}$ evaluates to -1 when k is even and to +1 when k is odd.
If we simply want to know whether k is odd or even in order to add or subtract our intermediate result, we can do this more quickly with the modulus operator than by computing $(-1)^{k+1}$. Let us analyze this by taking a random positive integer under 1000 and comparing the time to calculate pow(-1, random_num)
and random_num % 2
.
from random import randint
random_num = randint(1, 1000)
%%timeit
pow(-1, random_num)
%%timeit
random_num % 2
As we can see, random_num % 2
is roughly an order of magnitude faster than pow(-1, random_num)
. Let us go back and replace the term $(-1)^{k+1}$ with the use of a modulus operator to govern whether we add or subtract the intermediate result.
class PWithCache:
def __init__(self):
self.computed_pn = {0: 1}
def __call__(self, n: int) -> int:
if n in self.computed_pn:
return self.computed_pn[n]
total = 0
for k in range(1, n+1):
minus_one_term = n - k*(3*k-1)//2
plus_one_term = n - k*(3*k+1)//2
first_term = 0 if minus_one_term < 0 else self(minus_one_term)
second_term = 0 if plus_one_term < 0 else self(plus_one_term)
if k % 2:
total += first_term + second_term
else:
total -= first_term + second_term
self.computed_pn[n] = total
return total
P = PWithCache()
test_pn(P)
%%timeit
P = PWithCache()
P(1000)
P = PWithCache()
%prun P(1000)
Once again, we see a nearly 2x speed-up by eliminating the 500,500 function calls to pow()
, with all remaining calls now being the recursive calls to P(n).
An Analysis of Results¶
Our final implementatation has grown significantly in lines of code and no longer closely resembles the mathematical form of our function. We've seen signficant speed-ups but at a cost of additional complexity and lines of code. The purpose of this post is to look at how profiling can help us to understand (and potentially improve upon) the performance of our code. Whether optimization is a worthwhile endeavor is situationally dependent and the old adage that "premature optimization is the root of all evil" is always worth keeping in mind.