Skip to content

Commit

Permalink
Merge pull request #356 from moorepants/system-generator-pass-kwargs
Browse files Browse the repository at this point in the history
kwargs can be passed from System.generate_ode_function to the matrix generator
  • Loading branch information
jbm950 authored Jul 24, 2016
2 parents b473f4c + 4a85a06 commit ab85da4
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 5 deletions.
16 changes: 13 additions & 3 deletions pydy/codegen/ode_function_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,16 +641,26 @@ class CythonODEFunctionGenerator(ODEFunctionGenerator):

def __init__(self, *args, **kwargs):

self._options = {'tmp_dir': None,
'prefix': 'pydy_codegen',
'cse': True,
'verbose': False}
for k, v in self._options.items():
self._options[k] = kwargs.pop(k, v)

if Cython is None:
raise ImportError('Cython must be installed to use this class.')
else:
super(CythonODEFunctionGenerator, self).__init__(*args, **kwargs)

__init__.__doc__ = ODEFunctionGenerator.__init__.__doc__

@staticmethod
def _cythonize(outputs, inputs):
return CythonMatrixGenerator(inputs, outputs).compile()
def _cythonize(self, outputs, inputs):
g = CythonMatrixGenerator(inputs, outputs,
prefix=self._options['prefix'],
cse=self._options['cse'])
return g.compile(tmp_dir=self._options['tmp_dir'],
verbose=self._options['verbose'])

def _set_eval_array(self, f):

Expand Down
4 changes: 2 additions & 2 deletions pydy/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,8 @@ def _kwargs_for_gen_ode_func(self):

def generate_ode_function(self, **kwargs):
"""Calls ``pydy.codegen.ode_function_generators.generate_ode_function``
with the appropriate arguments, and sets the
``evaluate_ode_function`` attribute to the resulting function.
with the appropriate arguments, and sets the ``evaluate_ode_function``
attribute to the resulting function.
Parameters
----------
Expand Down
19 changes: 19 additions & 0 deletions pydy/tests/test_system.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
#!/usr/bin/env python

import os
import warnings
import tempfile
import shutil

import numpy as np
from numpy import testing
import sympy as sm
import sympy.physics.mechanics as me
from scipy.integrate import odeint
theano = sm.external.import_module('theano')
Cython = sm.external.import_module('Cython')

from ..system import System
from ..models import multi_mass_spring_damper, n_link_pendulum_on_cart
Expand Down Expand Up @@ -424,6 +428,21 @@ def test_integrate(self):
with testing.assert_raises(NotImplementedError):
sys.generate_ode_function(generator='made-up')

# Test pass kwargs to the generators.
if Cython:
self.tempdirpath = tempfile.mkdtemp()
prefix = 'my_test_file'
self.sys.generate_ode_function(generator='cython',
prefix=prefix,
tmp_dir=self.tempdirpath)
assert [True for f in os.listdir(self.tempdirpath)
if f.startswith(prefix)]
else:
warnings.warn("Cython was not found so the related tests are being"
" skipped.", PyDyImportWarning)
def cleanup(self):
shutil.rmtree(self.tempdirpath)


def test_specifying_coordinate_issue_339():
"""This test ensures that you can use derivatives as specified values."""
Expand Down

0 comments on commit ab85da4

Please sign in to comment.