Tech News
← Back to articles

Jax: Fast Combinations Calculation

read original related products more articles

Combinadics

A fast combinations calculation in jax.

Idea of combinadic implementation is from https://jamesmccaffrey.wordpress.com/2022/06/28/generating-the-mth-lexicographical-element-of-a-combination-using-the-combinadic and some useful information can be found here: https://en.wikipedia.org/wiki/Combinatorial_number_system. Below I copied and aggregated some of the details.

Introduction

The following code demostrates the combinations calculation in numpy and via combinadics:

# setup n = 4 k = 3 totalcount = math . comb ( n , k ) # numpy print ( f"Calculate combinations \" { n } choose { k } \" in numpy:" ) for comb in itertools . combinations ( np . arange ( start = 0 , stop = n , dtype = jnp . int32 ), k ): print ( comb ) # combinadics print ( "Calculate via combinadics:" ) actual = n - 1 - calculateMth ( n , k , totalcount - 1 - jnp . arange ( start = 0 , stop = n , dtype = jnp . int32 ),) for comb in actual : print ( comb )

And the output from execution of the code is:

Calculate combinations "4 choose 3" in numpy: (0, 1, 2) (0, 1, 3) (0, 2, 3) (1, 2, 3) Calculate via combinadics: [0 1 2] [0 1 3] [0 2 3] [1 2 3]

A bit of theory

You can think of a combinadic as an alternate representation of an integer. Consider the integer $859$ . It can be represented as the sum of powers of $10$ as

... continue reading