Skip to content

Commit

Permalink
ENH: Add LP solution method to DiscreteDP
Browse files Browse the repository at this point in the history
  • Loading branch information
oyamad committed Jun 19, 2021
1 parent da02c6c commit 545b89b
Show file tree
Hide file tree
Showing 3 changed files with 277 additions and 12 deletions.
188 changes: 188 additions & 0 deletions quantecon/markov/_ddp_linprog_simplex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import numpy as np
from numba import jit
from .utilities import _find_indices
from ..optimize.linprog_simplex import solve_tableau, PivOptions
from ..optimize.pivoting import _pivoting


@jit(nopython=True, cache=True)
def ddp_linprog_simplex(R, Q, beta, a_indices, a_indptr, sigma,
max_iter=10**6, piv_options=PivOptions(),
tableau=None, basis=None, v=None):
r"""
Numba jit complied function to solve a discrete dynamic program via
linear programming, using `optimize.linprog_simplex` routines. The
problem has to be represented in state-action pair form with 1-dim
reward ndarray `R` of shape (n,), 2-dim transition probability
ndarray `Q` of shapce (L, n), and disount factor `beta`, where n is
the number of states and L is the number of feasible state-action
pairs.
The approach exploits the fact that the optimal value function is
the smallest function that satisfies :math:`v \geq T v`, where
:math:`T` is the Bellman operator, and hence it is a (unique)
solution to the linear program:
minimize::
\sum_{s \in S} v(s)
subject to ::
v(s) \geq r(s, a) + \beta \sum_{s' \in S} q(s'|s, a) v(s')
\quad ((s, a) \in \mathit{SA}).
This function solves its dual problem:
maximize::
\sum_{(s, a) \in \mathit{SA}} r(s, a) y(s, a)
subject to::
\sum_{a: (s', a) \in \mathit{SA}} y(s', a) -
\sum_{(s, a) \in \mathit{SA}} \beta q(s'|s, a) y(s, a) = 1
\quad (s' \in S),
y(s, a) \geq 0 \quad ((s, a) \in \mathit{SA}),
where the optimal value function is obtained as an optimal dual
solution and an optimal policy as an optimal basis.
Parameters
----------
R : ndarray(float, ndim=1)
Reward ndarray, of shape (n,).
Q : ndarray(float, ndim=2)
Transition probability ndarray, of shape (L, n).
beta : scalar(float)
Discount factor. Must be in [0, 1).
a_indices : ndarray(int, ndim=1)
Action index ndarray, of shape (L,).
a_indptr : ndarray(int, ndim=1)
Action index pointer ndarray, of shape (n+1,).
sigma : ndarray(int, ndim=1)
ndarray containing the initial feasible policy, of shape (n,).
To be modified in place to store the output optimal policy.
max_iter : int, optional(default=10**6)
Maximum number of iteration in the linear programming solver.
piv_options : PivOptions, optional
PivOptions namedtuple to set tolerance values used in the linear
programming solver.
tableau : ndarray(float, ndim=2), optional
Temporary ndarray of shape (n+1, L+n+1) to store the tableau.
Modified in place.
basis : ndarray(int, ndim=1), optional
Temporary ndarray of shape (n,) to store the basic variables.
Modified in place.
v : ndarray(float, ndim=1), optional
Output ndarray of shape (n,) to store the optimal value
function. Modified in place.
Returns
-------
success : bool
True if the algorithm succeeded in finding an optimal solution.
num_iter : int
The number of iterations performed.
v : ndarray(float, ndim=1)
Optimal value function (view to `v` if supplied).
sigma : ndarray(int, ndim=1)
Optimal policy (view to `sigma`).
"""
L, n = Q.shape

if tableau is None:
tableau = np.empty((n+1, L+n+1))
if basis is None:
basis = np.empty(n, dtype=np.int_)
if v is None:
v = np.empty(n)

_initialize_tableau(R, Q, beta, a_indptr, tableau)
_find_indices(a_indices, a_indptr, sigma, out=basis)

# Phase 1
for i in range(n):
_pivoting(tableau, basis[i], i)

# Phase 2
success, status, num_iter = \
solve_tableau(tableau, basis, max_iter-n, skip_aux=True,
piv_options=piv_options)

# Obtain solution
for i in range(n):
v[i] = tableau[-1, L+i] * (-1)

for i in range(n):
sigma[i] = a_indices[basis[i]]

return success, num_iter+n, v, sigma


@jit(nopython=True, cache=True)
def _initialize_tableau(R, Q, beta, a_indptr, tableau):
"""
Initialize the `tableau` array.
Parameters
----------
R : ndarray(float, ndim=1)
Reward ndarray, of shape (n,).
Q : ndarray(float, ndim=2)
Transition probability ndarray, of shape (L, n).
beta : scalar(float)
Discount factor. Must be in [0, 1).
a_indptr : ndarray(int, ndim=1)
Action index pointer ndarray, of shape (n+1,).
tableau : ndarray(float, ndim=2)
Empty ndarray of shape (n+1, L+n+1) to store the tableau.
Modified in place.
Returns
-------
tableau : ndarray(float, ndim=2)
View to `tableau`.
"""
L, n = Q.shape

for j in range(L):
for i in range(n):
tableau[i, j] = Q[j, i] * (-beta)

for i in range(n):
for j in range(a_indptr[i], a_indptr[i+1]):
tableau[i, j] += 1

tableau[:n, L:-1] = 0

for i in range(n):
tableau[i, L+i] = 1
tableau[i, -1] = 1

for j in range(L):
tableau[-1, j] = R[j]

tableau[-1, L:] = 0

return tableau
66 changes: 59 additions & 7 deletions quantecon/markov/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@
* value iteration;
* policy iteration;
* modified policy iteration.
* modified policy iteration;
* linear programming.
Policy iteration computes an exact optimal policy in finitely many
iterations, while value iteration and modified policy iteration return
Expand All @@ -97,6 +98,10 @@
:math:`\mathrm{span}(T v - v) < [(1 - \beta) / \beta] \varepsilon` is
satisfied, where :math:`\mathrm{span}(z) = \max(z) - \min(z)`.
The linear programming method solves the problem as a linear program by
the simplex method with `optimize.linprog_simplex` routines (implemented
only for dense matrix formulation).
References
----------
Expand All @@ -107,9 +112,9 @@
import warnings
import numpy as np
import scipy.sparse as sp
from numba import jit

from .core import MarkovChain
from ._ddp_linprog_simplex import ddp_linprog_simplex
from .utilities import (
_fill_dense_Q, _s_wise_max_argmax, _s_wise_max, _find_indices,
_has_sorted_sa_indices, _generate_a_indptr
Expand Down Expand Up @@ -280,6 +285,16 @@ class DiscreteDP:
>>> res.num_iter # Number of iterations
3
*Linear programming*
>>> res = ddp.solve(method='linear_programming', v_init=[0, 0])
>>> res.sigma # Optimal policy function
array([0, 0])
>>> res.v # Optimal value function
array([ -8.57142857, -20. ])
>>> res.num_iter # Number of iterations (within the LP solver)
4
"""
def __init__(self, R, Q, beta, s_indices=None, a_indices=None):
self._sa_pair = False
Expand Down Expand Up @@ -710,14 +725,14 @@ def solve(self, method='policy_iteration',
method : str, optinal(default='policy_iteration')
Solution method, str in {'value_iteration', 'vi',
'policy_iteration', 'pi', 'modified_policy_iteration',
'mpi'}.
'mpi', 'linear_programming', 'lp'}.
v_init : array_like(float, ndim=1), optional(default=None)
Initial value function, of length n. If None, `v_init` is
set such that v_init(s) = max_a r(s, a) for value iteration
and policy iteration; for modified policy iteration,
v_init(s) = min_(s_next, a) r(s_next, a)/(1 - beta) to guarantee
convergence.
set such that v_init(s) = max_a r(s, a) for value iteration,
policy iteration, and linear programming; for modified
policy iteration, v_init(s) = min_(s_next, a)
r(s_next, a)/(1 - beta) to guarantee convergence.
epsilon : scalar(float), optional(default=None)
Value for epsilon-optimality. If None, the value stored in
Expand Down Expand Up @@ -750,6 +765,9 @@ def solve(self, method='policy_iteration',
epsilon=epsilon,
max_iter=max_iter,
k=k)
elif method in ['linear_programming', 'lp']:
res = self.linprog_simplex(v_init=v_init,
max_iter=max_iter)
else:
raise ValueError('invalid method')

Expand Down Expand Up @@ -896,6 +914,40 @@ def midrange(z):

return res

def linprog_simplex(self, v_init=None, max_iter=None):
if self.beta == 1:
raise NotImplementedError(self._error_msg_no_discounting)

if self._sparse:
raise NotImplementedError('method invalid for sparse formulation')

if max_iter is None:
max_iter = self.max_iter * self.num_states

# What for initial condition?
if v_init is None:
v_init = self.s_wise_max(self.R)
v_init = np.asarray(v_init)

sigma = self.compute_greedy(v_init)

ddp_sa = self.to_sa_pair_form(sparse=False)
R, Q = ddp_sa.R, ddp_sa.Q
a_indices, a_indptr = ddp_sa.a_indices, ddp_sa.a_indptr

_, num_iter, v, sigma = ddp_linprog_simplex(
R, Q, self.beta, a_indices, a_indptr, sigma, max_iter=max_iter
)

res = DPSolveResult(v=v,
sigma=sigma,
num_iter=num_iter,
mc=self.controlled_mc(sigma),
method='linear programming',
max_iter=max_iter)

return res

def controlled_mc(self, sigma):
"""
Returns the controlled Markov chain for a given policy `sigma`.
Expand Down
35 changes: 30 additions & 5 deletions quantecon/markov/tests/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,26 @@ def test_modified_policy_iteration_k0(self):
# Check sigma == sigma_star
assert_array_equal(res.sigma, self.sigma_star)

def test_linear_programming(self):
for ddp in self.ddps:
if ddp._sparse:
assert_raises(NotImplementedError, ddp.solve,
method='linear_programming')
else:
res = ddp.solve(method='linear_programming')

v_init = [0, 1]
res_init = ddp.solve(method='linear_programming',
v_init=v_init)

# Check v == v_star
assert_allclose(res.v, self.v_star)
assert_allclose(res_init.v, self.v_star)

# Check sigma == sigma_star
assert_array_equal(res.sigma, self.sigma_star)
assert_array_equal(res_init.sigma, self.sigma_star)


def test_ddp_beta_0():
n, m = 3, 2
Expand All @@ -122,13 +142,18 @@ def test_ddp_beta_0():

ddp0 = DiscreteDP(R, Q, beta)
ddp1 = ddp0.to_sa_pair_form()
methods = ['vi', 'pi', 'mpi']
ddp2 = ddp0.to_sa_pair_form(sparse=False)
methods = ['vi', 'pi', 'mpi', 'lp']

for ddp in [ddp0, ddp1]:
for ddp in [ddp0, ddp1, ddp2]:
for method in methods:
res = ddp.solve(method=method, v_init=v_init)
assert_array_equal(res.sigma, sigma_star)
assert_array_equal(res.v, v_star)
if method == 'lp' and ddp._sparse:
assert_raises(NotImplementedError, ddp.solve,
method=method, v_init=v_init)
else:
res = ddp.solve(method=method, v_init=v_init)
assert_array_equal(res.sigma, sigma_star)
assert_array_equal(res.v, v_star)


def test_ddp_sorting():
Expand Down

0 comments on commit 545b89b

Please sign in to comment.