Skip to content

Commit

Permalink
Merge pull request #450 from QBatista/invalid_inputs_brent_max
Browse files Browse the repository at this point in the history
ENH: Add errors for invalid inputs for `brent_max`
  • Loading branch information
mmcky authored Dec 11, 2018
2 parents 083e003 + cde2f7b commit ab0c261
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
12 changes: 10 additions & 2 deletions quantecon/optimize/scalar_maximization.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def brent_max(func, a, b, args=(), xtol=1e-5, maxiter=500):
info : tuple
A tuple of the form (status_flag, num_iter). Here status_flag
indicates whether or not the maximum number of function calls was
attained. A value of 0 implies that the maximum was not hit.
attained. A value of 0 implies that the maximum was not hit.
The value `num_iter` is the number of function calls.
Example
Expand All @@ -49,7 +49,15 @@ def f(x):
```
"""

if not np.isfinite(a):
raise ValueError("a must be finite.")

if not np.isfinite(b):
raise ValueError("b must be finite.")

if not a < b:
raise ValueError("a must be less than b.")

maxfun = maxiter
status_flag = 0

Expand Down
30 changes: 24 additions & 6 deletions quantecon/optimize/tests/test_scalar_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,43 @@
"""
import numpy as np
from numpy.testing import assert_almost_equal
from nose.tools import raises
from numba import njit

from quantecon.optimize import brent_max


@njit
def f(x):
"""
A function for testing on.
"""
return -(x + 2.0)**2 + 1.0


def test_brent_max():
"""
Uses the function f defined above to test the scalar maximization
Uses the function f defined above to test the scalar maximization
routine.
"""
true_fval = 1.0
true_xf = -2.0
xf, fval, info = brent_max(f, -2, 2)
assert_almost_equal(true_fval, fval, decimal=4)
assert_almost_equal(true_xf, xf, decimal=4)



@njit
def g(x, y):
"""
A multivariate function for testing on.
"""
return -x**2 + y



def test_brent_max():
"""
Uses the function f defined above to test the scalar maximization
Uses the function f defined above to test the scalar maximization
routine.
"""
y = 5
Expand All @@ -46,6 +51,21 @@ def test_brent_max():
assert_almost_equal(true_xf, xf, decimal=4)


@raises(ValueError)
def test_invalid_a_brent_max():
brent_max(f, -np.inf, 2)


@raises(ValueError)
def test_invalid_b_brent_max():
brent_max(f, -2, np.inf)


@raises(ValueError)
def test_invalid_a_b_brent_max():
brent_max(f, 1, 0)


if __name__ == '__main__':
import sys
import nose
Expand All @@ -54,5 +74,3 @@ def test_brent_max():
argv.append('--verbose')
argv.append('--nocapture')
nose.main(argv=argv, defaultTest=__file__)


0 comments on commit ab0c261

Please sign in to comment.