-
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #379 from QuantEcon/k_array
Re-implement `next_k_array`; add `k_array_rank`
- Loading branch information
Showing
3 changed files
with
195 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
""" | ||
Useful routines for combinatorics | ||
""" | ||
from scipy.special import comb | ||
from numba import jit | ||
|
||
from .numba import comb_jit | ||
|
||
|
||
@jit(nopython=True, cache=True) | ||
def next_k_array(a): | ||
""" | ||
Given an array `a` of k distinct nonnegative integers, sorted in | ||
ascending order, return the next k-array in the lexicographic | ||
ordering of the descending sequences of the elements [1]_. `a` is | ||
modified in place. | ||
Parameters | ||
---------- | ||
a : ndarray(int, ndim=1) | ||
Array of length k. | ||
Returns | ||
------- | ||
a : ndarray(int, ndim=1) | ||
View of `a`. | ||
Examples | ||
-------- | ||
Enumerate all the subsets with k elements of the set {0, ..., n-1}. | ||
>>> n, k = 4, 2 | ||
>>> a = np.arange(k) | ||
>>> while a[-1] < n: | ||
... print(a) | ||
... a = next_k_array(a) | ||
... | ||
[0 1] | ||
[0 2] | ||
[1 2] | ||
[0 3] | ||
[1 3] | ||
[2 3] | ||
References | ||
---------- | ||
.. [1] `Combinatorial number system | ||
<https://en.wikipedia.org/wiki/Combinatorial_number_system>`_, | ||
Wikipedia. | ||
""" | ||
# Logic taken from Algotirhm T in D. Knuth, The Art of Computer | ||
# Programming, Section 7.2.1.3 "Generating All Combinations". | ||
k = len(a) | ||
if k == 1 or a[0] + 1 < a[1]: | ||
a[0] += 1 | ||
return a | ||
|
||
a[0] = 0 | ||
i = 1 | ||
x = a[i] + 1 | ||
|
||
while i < k-1 and x == a[i+1]: | ||
i += 1 | ||
a[i-1] = i - 1 | ||
x = a[i] + 1 | ||
a[i] = x | ||
|
||
return a | ||
|
||
|
||
def k_array_rank(a): | ||
""" | ||
Given an array `a` of k distinct nonnegative integers, sorted in | ||
ascending order, return its ranking in the lexicographic ordering of | ||
the descending sequences of the elements [1]_. | ||
Parameters | ||
---------- | ||
a : ndarray(int, ndim=1) | ||
Array of length k. | ||
Returns | ||
------- | ||
idx : scalar(int) | ||
Ranking of `a`. | ||
References | ||
---------- | ||
.. [1] `Combinatorial number system | ||
<https://en.wikipedia.org/wiki/Combinatorial_number_system>`_, | ||
Wikipedia. | ||
""" | ||
k = len(a) | ||
idx = int(a[0]) # Convert to Python int | ||
for i in range(1, k): | ||
idx += comb(a[i], i+1, exact=True) | ||
return idx | ||
|
||
|
||
@jit(nopython=True, cache=True) | ||
def k_array_rank_jit(a): | ||
""" | ||
Numba jit version of `k_array_rank`. | ||
Notes | ||
----- | ||
An incorrect value will be returned without warning or error if | ||
overflow occurs during the computation. It is the user's | ||
responsibility to ensure that the rank of the input array fits | ||
within the range of possible values of `np.intp`; a sufficient | ||
condition for it is `scipy.special.comb(a[-1]+1, len(a), exact=True) | ||
<= np.iinfo(np.intp).max`. | ||
""" | ||
k = len(a) | ||
idx = a[0] | ||
for i in range(1, k): | ||
idx += comb_jit(a[i], i+1) | ||
return idx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
""" | ||
Tests for util/combinatorics.py | ||
""" | ||
import numpy as np | ||
from numpy.testing import assert_array_equal | ||
from nose.tools import eq_ | ||
import scipy.special | ||
from quantecon.util.combinatorics import ( | ||
next_k_array, k_array_rank, k_array_rank_jit | ||
) | ||
|
||
|
||
class TestKArray: | ||
def setUp(self): | ||
self.k_arrays = np.array( | ||
[[0, 1, 2], | ||
[0, 1, 3], | ||
[0, 2, 3], | ||
[1, 2, 3], | ||
[0, 1, 4], | ||
[0, 2, 4], | ||
[1, 2, 4], | ||
[0, 3, 4], | ||
[1, 3, 4], | ||
[2, 3, 4], | ||
[0, 1, 5], | ||
[0, 2, 5], | ||
[1, 2, 5], | ||
[0, 3, 5], | ||
[1, 3, 5], | ||
[2, 3, 5], | ||
[0, 4, 5], | ||
[1, 4, 5], | ||
[2, 4, 5], | ||
[3, 4, 5]] | ||
) | ||
self.L, self.k = self.k_arrays.shape | ||
|
||
def test_next_k_array(self): | ||
k_arrays_computed = np.empty((self.L, self.k), dtype=int) | ||
k_arrays_computed[0] = np.arange(self.k) | ||
for i in range(1, self.L): | ||
k_arrays_computed[i] = k_arrays_computed[i-1] | ||
next_k_array(k_arrays_computed[i]) | ||
assert_array_equal(k_arrays_computed, self.k_arrays) | ||
|
||
def test_k_array_rank(self): | ||
for i in range(self.L): | ||
eq_(k_array_rank(self.k_arrays[i]), i) | ||
|
||
def test_k_array_rank_jit(self): | ||
for i in range(self.L): | ||
eq_(k_array_rank_jit(self.k_arrays[i]), i) | ||
|
||
|
||
def test_k_array_rank_arbitrary_precision(): | ||
n, k = 100, 50 | ||
a = np.arange(n-k, n) | ||
eq_(k_array_rank(a), scipy.special.comb(n, k, exact=True)-1) | ||
|
||
|
||
if __name__ == '__main__': | ||
import sys | ||
import nose | ||
|
||
argv = sys.argv[:] | ||
argv.append('--verbose') | ||
argv.append('--nocapture') | ||
nose.main(argv=argv, defaultTest=__file__) |