Generating Prime Numbers

How to know if a number is prime? How to get a list of prime numbers in Python? In this post we analyze and provide a simple and efficient method. If you want to skip the explanation and just copy the function for generating prime numbers, the full code is at the bottom.

We are going to build two functions. The first one aims to obtain all the prime numbers from 2 to a certain natural number and return them in a list. The second tells us if a supplied number is prime by calling the above function and checking if it is found in the list.

For the first we will use the algorithm known as the sieve of Eratosthenes, which despite being formulated more than 2000 years ago is one of the most efficient algorithms. Let's see what Wikipedia says about the selection process.

A table is formed with all the natural numbers between 2 and n, and the numbers that are not primes are crossed out as follows: Starting with 2, all its multiples are crossed out. Starting over, when an integer is found that has not been crossed out, that number is declared prime, and all its multiples are crossed out, and so on. The process ends when the square of the largest confirmed prime number is greater than n.

(This quote is actually taken from the Spanish Wikipedia article, which seems to me much better explained than the English one.)

Here is an illustration of what is depicted above:

/images/generating-prime-numbers/sieve-of-eratosthenes.gif

Based on this description, let's start writing down the function step by step. The first thing to do is the function definition, which should receive a natural number n that we will call max_number.

def get_prime_numbers(max_number):

Next, we create the table of natural numbers between 2 and max_number. A simple Python list will suffice.

    numbers = [True, True] + [True] * (max_number-1)

A true value means a number that has not been crossed out, while a false value means a crossed-out number. The make sure each number matches the list index, the list begins with two placeholder elements ([True, True]) that fill positions 0 and 1.

We continue by defining two variables. The first one contains the last obtained prime number (starting with 2), and the second one a multiple number of the previous one.

    last_prime_number = 2
    i = last_prime_number

Recall that the selection process ends when the square of the last obtained prime number (last_prime_number) is greater than the argument. So:

    while last_prime_number**2 <= max_number:

Once inside the selection process, we must cross out (that is, assign False in the list) the multiples of the last obtained prime number. Let's do that:

        i += last_prime_number
        while i <= max_number:
            numbers[i] = False
            i += last_prime_number

Once this is done, we must repeat this method for the number not crossed out that follows the last obtained prime number. Let's use a loop to fetch that number:

        j = last_prime_number + 1
        while j < max_number:
            if numbers[j]:
                last_prime_number = j
                break
            j += 1
        i = last_prime_number

When the main loop (line 5) finishes, the prime numbers are those within the numbers list that are not crossed out (i.e., containing True). Through the following code we filter the list and return the prime numbers.

    return [i + 2 for i, not_crossed in enumerate(numbers[2:]) if not_crossed]

(The [2:] syntax drops the placeholders defined at the beginning of the list.)

The first part being ready (see the full code below), it only remais to determine if a certain number is in the previous list:

def is_prime(n):
    return n in get_prime_numbers(n)

Now let's test these functions:

print(get_prime_numbers(20))  # [2, 3, 5, 7, 11, 13, 17, 19]
print(is_prime(3))  # True

Cython Optimization

CPU intensive code like the one just implemented can be highly optimized by compiling it with Cython. In our tests, the get_prime_numbers() function runs between 7 and 8 times faster after cythonization. The process is simple.

Create the cprime.pyx file and put the following code in it:

def get_prime_numbers(int max_number):
    cdef int i, j, last_prime_number, not_crossed
    numbers = [True, True] + [True] * (max_number-1)
    last_prime_number = 2
    i = last_prime_number
    while last_prime_number**2 <= max_number:
        i += last_prime_number
        while i <= max_number:
            numbers[i] = False
            i += last_prime_number
        j = last_prime_number + 1
        while j < max_number:
            if numbers[j]:
                last_prime_number = j
                break
            j += 1
        i = last_prime_number
    return [i + 2 for i, not_crossed in enumerate(numbers[2:]) if not_crossed]

This is our former code barely tweaked. We just made explicit the types of our variables so that Cython can convert them to native C variables.

Then create the setup.py file that will let us compile the code.

from distutils.core import setup
from Cython.Build import cythonize
setup(
    name = "CPrime",
    ext_modules = cythonize("cprime.pyx"),
)

And run the following command to compile:

python setup.py build_ext --inplace

Now we can import the function as follows.

from cprime import get_prime_numbers

We can calculate the execution time and compare both functions with the following code.

def get_prime_numbers(max_number):
    numbers = [True, True] + [True] * (max_number-1)
    last_prime_number = 2
    i = last_prime_number
    while last_prime_number**2 <= max_number:
        i += last_prime_number
        while i <= max_number:
            numbers[i] = False
            i += last_prime_number
        j = last_prime_number + 1
        while j < max_number:
            if numbers[j]:
                last_prime_number = j
                break
            j += 1
        i = last_prime_number
    return [i + 2 for i, not_crossed in enumerate(numbers[2:]) if not_crossed]
if __name__ == "__main__":
    import timeit
    print(
        "Python",
        timeit.timeit(
            "get_prime_numbers(20)",
            number=100000,
            globals=globals()
        )
    )
    from cprime import get_prime_numbers as get_prime_numbers_fast
    print(
        "Cython",
        timeit.timeit(
            "get_prime_numbers_fast(20)",
            number=100000,
            globals=globals()
        )
    )

The result is:

Python 0.39204310000059195
Cython 0.05220530000224244

Source Code

Python implementation:

def get_prime_numbers(max_number):
    # Create a list containing the state (crossed/not-crossed)
    # of each number from 2 to max_number.
    numbers = [True, True] + [True] * (max_number-1)
    # Start with 2. This variable always has a prime number.
    last_prime_number = 2
    # This variable contains the current number in the list,
    # which is always a multiple of last_prime_number.
    i = last_prime_number
    # Proceed as long as the square of last_prime_number (ie,
    # the last returned prime number) is less than or equal to max_number.
    while last_prime_number**2 <= max_number:
        # Cross out all multiples of the last obtained prime number.
        i += last_prime_number
        while i <= max_number:
            numbers[i] = False
            i += last_prime_number
        # Get the number immediately following the last
        # obtained prime number (last_prime_number) that is not crossed out.
        j = last_prime_number + 1
        while j < max_number:
            if numbers[j]:
                last_prime_number = j
                break
            j += 1
        i = last_prime_number
    # Return numbers in the list that are not crossed out.
    return [i + 2 for i, not_crossed in enumerate(numbers[2:]) if not_crossed]

Cython optimization (7 to 8 times faster):

def get_prime_numbers(int max_number):
    cdef int i, j, last_prime_number, not_crossed
    numbers = [True, True] + [True] * (max_number-1)
    last_prime_number = 2
    i = last_prime_number
    while last_prime_number**2 <= max_number:
        i += last_prime_number
        while i <= max_number:
            numbers[i] = False
            i += last_prime_number
        j = last_prime_number + 1
        while j < max_number:
            if numbers[j]:
                last_prime_number = j
                break
            j += 1
        i = last_prime_number
    return [i + 2 for i, not_crossed in enumerate(numbers[2:]) if not_crossed]

And to check whether a number is prime or not:

def is_prime(n):
    return n in get_prime_numbers(n)

Some examples:

>>> get_prime_numbers(70)
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67]
>>> is_prime(4)
False
>>> is_prime(7)
True
>>> is_prime(20)
False
>>> is_prime(53)
True