Nektar++
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
optimize.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 
3 # We want 1/2==0.5
4 from __future__ import division
5 
6 """Copyright (c) 2005-2016, University of Oxford.
7 All rights reserved.
8 
9 University of Oxford means the Chancellor, Masters and Scholars of the
10 University of Oxford, having an administrative office at Wellington
11 Square, Oxford OX1 2JD, UK.
12 
13 This file is part of Chaste.
14 
15 Redistribution and use in source and binary forms, with or without
16 modification, are permitted provided that the following conditions are met:
17  * Redistributions of source code must retain the above copyright notice,
18  this list of conditions and the following disclaimer.
19  * Redistributions in binary form must reproduce the above copyright notice,
20  this list of conditions and the following disclaimer in the documentation
21  and/or other materials provided with the distribution.
22  * Neither the name of the University of Oxford nor the names of its
23  contributors may be used to endorse or promote products derived from this
24  software without specific prior written permission.
25 
26 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
27 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
28 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
29 ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
30 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
31 CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
32 GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
33 HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
34 LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
35 OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36 """
37 
38 """
39 This part of PyCml applies various optimising transformations to CellML
40 models, in particular partial evaluation and the use of lookup tables.
41 """
42 
43 import operator
44 
45 # Common CellML processing stuff
46 import pycml
47 from pycml import * # Put contents in the local namespace as well
48 
49 __version__ = "$Revision: 25790 $"[11:-2]
50 
51 
52 
53 ######################################################################
54 # Partial Evaluation #
55 ######################################################################
56 
57 class PartialEvaluator(object):
58  """Perform partial evaluation of a CellML model."""
59  def _debug(self, *args):
60  """Output debug info from the PE process."""
61  logger = logging.getLogger('partial-evaluator')
62  logger.debug(' '.join(map(str, args)))
63 
64  def _expr_lhs(self, expr):
65  """Display the LHS of this expression."""
66  lhs = expr.assigned_variable()
67  if isinstance(lhs, cellml_variable):
68  return lhs.fullname()
69  else:
70  return lhs[0].fullname() + u'/' + lhs[1].fullname()
71 
72  def _describe_expr(self, expr):
73  """Describe this expression for debug info."""
74  if isinstance(expr, mathml_apply):
75  if expr.is_assignment() or expr.is_ode():
76  return self._expr_lhs(expr)
77  else:
78  return '[nested apply]'
79  elif isinstance(expr, mathml_ci):
80  return expr.variable.fullname()
81  elif isinstance(expr, mathml_cn):
82  return u'cn[' + unicode(expr) + u']'
83  else:
84  return '[unknown]'
85 
86  def _process_ci_elts(self, elt, func):
87  """Apply func to all ci elements in the tree rooted at elt."""
88  if isinstance(elt, mathml_ci):
89  func(elt)
90  else:
91  for e in elt.xml_element_children():
92  self._process_ci_elts(e, func)
93 
94  def _rename_var(self, elt):
95  """Change this ci element to use a canonical name."""
96  if elt.xml_parent.localName == u'bvar':
97  # The free variable in a derivative must refer directly to the ultimate source,
98  # since this is assumed in later stages and in code generation.
99  elt._set_variable_obj(elt.variable.get_source_variable(recurse=True))
100  elt._rename()
101  self._debug("Using canonical name", unicode(elt))
102 
103  def _do_reduce_eval_loop(self, expr_source):
104  """Do the reduce/evaluate loop.
105 
106  expr_source is a callable that returns an iterable over expressions.
107  """
108  while True:
109  self.doc.model._pe_repeat = u'no'
110  for expr in list(expr_source()):
111  self._reduce_evaluate_expression(expr)
112  if self.doc.model._pe_repeat == u'no':
113  break
114  self._debug("----- looping -----")
115  del self.doc.model._pe_repeat
116 
118  """Reduce or evaluate a single expression."""
119  if hasattr(expr, '_pe_process'):
120  # This expression has been reduced or evaluated already, but needs further
121  # processing later so hasn't been removed yet.
122  return
123  if expr._get_binding_time() is BINDING_TIMES.static:
124  # Evaluate
125  try:
126  value = expr.evaluate() # Needed here for the is_assignment case
127  except:
128  print "Error evaluating", self._describe_expr(expr)
129  raise
130  self._debug("Evaluated", self._describe_expr(expr), "to", value)
131  if isinstance(expr, mathml_apply):
132  if expr.is_ode():
133  # Replace the RHS with a <cn> element giving the value
134  rhs = expr.eq.rhs
135  new_elt = expr._eval_self()
136  expr.xml_insert_after(rhs, new_elt)
137  expr.xml_remove_child(rhs)
138  elif expr.is_assignment():
139  # The variable assigned to will have its initial_value set,
140  # so we don't need the expression any more. Flag it for removal.
141  expr._pe_process = u'remove'
142  else:
143  # Replace the expression with a <cn> element giving the value
144  new_elt = expr._eval_self()
145  expr.xml_parent.xml_insert_after(expr, new_elt)
146  expr.xml_parent.xml_remove_child(expr)
147  else:
148  # Replace the expression with a <cn> element giving the value
149  expr._reduce()
150  # Update variable usage counts for the top-level apply case
151  if isinstance(expr, mathml_apply):
152  if expr.is_ode() or expr.is_assignment():
153  expr._update_usage_counts(expr.eq.rhs, remove=True)
154  else:
155  expr._update_usage_counts(expr, remove=True)
156  else:
157  # Reduce
158  expr._reduce()
159 
160  def _get_assignment_exprs(self, skip_solver_info=True):
161  """Get an iterable over all assignments in the model that are mathml_apply instances."""
162  if not skip_solver_info:
163  skip = set()
164  else:
165  skip = set(self.solver_info.get_modifiable_mathematics())
166  for e in self.doc.model.get_assignments():
167  if isinstance(e, mathml_apply) and e not in skip:
168  assert e.is_ode() or e.is_assignment()
169  yield e
170 
171  def is_instantiable(self, expr):
172  """Determine whether special conditions mean that this assignment can be instantiated.
173 
174  Normally an assignment can only be instantiated if the assigned-to variable is used only
175  once, in order to avoid code duplication.
176  However, if the definition under consideration for instantiation is a function only of a
177  single LT keying variable (and we will do LT) then code duplication doesn't really matter,
178  since the whole expression will be converted to a table anyway. So we should instantiate
179  regardless of multiple uses in this case.
180 
181  Note: this does check that only a single keying variable appears, but doesn't check for
182  the presence of expensive functions. Of course, if there aren't any expensive functions,
183  the code duplication isn't that worrying.
184  """
185  instantiate = False
186  if self.lookup_tables_analyser and self.doc.model.get_option('pe_instantiate_tables'):
187  keying_vars = set()
188  all_keying = [True]
189  def func(ci_elt):
190  if self.lookup_tables_analyser.is_keying_var(ci_elt.variable):
191  keying_vars.add(ci_elt.variable)
192  else:
193  all_keying[0] = False
194  self._process_ci_elts(expr.eq.rhs, func)
195  instantiate = len(keying_vars) == 1 and all_keying[0]
196  return instantiate
197 
198  def _is_source_of(self, v1, v2):
199  """Test if v1 is a source of the mapped variable v2."""
200  if v1 is v2:
201  return True
202  elif v2.get_type() == VarTypes.Mapped:
203  return self._is_source_of(v1, v2.get_source_variable())
204  else:
205  return False
206 
207  def _check_retargetting(self, ci_elt):
208  """Check if this new variable reference means a retarget needs to change.
209 
210  A retarget occurs when a kept dynamic mapped variable is changed to computed
211  because its source variable(s) are only used once and are not kept. But new
212  mathematics may use one or more of those sources, in which case we need to
213  revert the retarget.
214  """
215  # Is this a retarget?
216  var = ci_elt.variable
217  root_defn = var.get_source_variable(recurse=True).get_dependencies()
218  if root_defn:
219  root_defn = root_defn[0]
220  else:
221  return
222  if (isinstance(root_defn, mathml_apply)
223  and hasattr(root_defn, '_pe_process')
224  and root_defn._pe_process == 'retarget'):
225  # Are we a source of the 'new' assignee?
226  assignee = root_defn._cml_assigns_to
227  if not self._is_source_of(assignee, var):
228  if var.get_type() == VarTypes.Computed:
229  # We were the original source variable; stop retargetting
230  self._debug('Ceasing re-target of', root_defn, 'to', assignee)
231  del root_defn._pe_process
232  root_defn._cml_assigns_to = var
233  else:
234  # Re-target to var instead
235  self._debug('Changing re-target of', root_defn, 'from', assignee, 'to', var)
236  var._cml_var_type = VarTypes.Computed
237  root_defn._cml_assigns_to = var
238  var._cml_depends_on = [root_defn]
239  var._cml_source_var = None
240  # assignee should now map to var
241  assignee._cml_var_type = VarTypes.Mapped
242  assignee._cml_source_var = var
243  assignee._cml_depends_on = [var]
244 
245  def parteval(self, doc, solver_info, lookup_tables_analyser=None):
246  """Do the partial evaluation."""
247  self.doc = doc
248  self.solver_info = solver_info
249  self.lookup_tables_analyser = lookup_tables_analyser
250  if lookup_tables_analyser:
251  lookup_tables_analyser.doc = doc
252  doc.partial_evaluator = self
253  # Do BTA and reduce/eval of main model
254  doc.model.do_binding_time_analysis()
256 
257  if solver_info.has_modifiable_mathematics():
258  # Do BTA and reduce/eval of solver info section
259  for expr in solver_info.get_modifiable_mathematics():
260  if not (isinstance(expr, mathml_apply) and expr.is_top_level()):
261  self._process_ci_elts(expr, lambda ci: ci.variable._used())
262  self._process_ci_elts(expr, self._check_retargetting)
263  solver_info.do_binding_time_analysis()
264  self._do_reduce_eval_loop(solver_info.get_modifiable_mathematics)
265 
266  # Process flagged expressions
267  for expr in list(self._get_assignment_exprs()):
268  if hasattr(expr, '_pe_process'):
269  if expr._pe_process == u'remove':
270  if (expr._get_binding_time() == BINDING_TIMES.dynamic and
271  isinstance(expr._cml_assigns_to, cellml_variable) and
272  expr._cml_assigns_to.get_usage_count() > 1):
273  self._debug("Keeping", repr(expr), "due to SolverInfo")
274  continue
275  expr.xml_parent.xml_remove_child(expr)
276  self.doc.model._remove_assignment(expr)
277  elif expr._pe_process == u'retarget':
278  lhs = expr.eq.lhs
279  var = expr._cml_assigns_to
280  ci = mathml_ci.create_new(lhs, var.fullname(cellml=True))
281  self._debug('Re-targetting', lhs, var, ci)
282  ci._set_variable_obj(var)
283  lhs.xml_parent.xml_insert_after(lhs, ci)
284  lhs.xml_parent.xml_remove_child(lhs)
285 
286  # Use canonical variable names in all ci elements
287  for expr in list(self._get_assignment_exprs(False)):
288  # If the assigned-to variable isn't used or kept, remove the assignment
289  if isinstance(expr.eq.lhs, mathml_ci):
290  var = expr.eq.lhs.variable
291  if not (var.get_usage_count() or var.pe_keep):
292  expr.xml_parent.xml_remove_child(expr)
293  doc.model._remove_assignment(expr)
294  continue
295  self._process_ci_elts(expr, self._rename_var)
296  for expr in solver_info.get_modifiable_mathematics():
297  self._process_ci_elts(expr, self._rename_var)
298  solver_info.use_canonical_variable_names()
299 
300  # Tidy up kept variables, in case they aren't referenced in an eq'n.
301  for var in doc.model.get_all_variables():
302  if var.pe_keep:
303  var._reduce()
304 
305  # Remove unused variables
306  for var in list(doc.model.get_all_variables()):
307  assert var.get_usage_count() >= 0
308  if var.get_usage_count() == 0 and not var.pe_keep:
309  var.xml_parent._del_variable(var)
310 
311  # Collapse into a single component
312  new_comp = cellml_component.create_new(doc, u'c')
313  new_comp._cml_created_by_pe = True
314  old_comps = list(getattr(doc.model, u'component', []))
315  doc.model._add_component(new_comp)
316  # We iterate over a copy of the component list so we can delete components
317  # from the model in this loop, and so the new component exists in the model
318  # so we can add content to it.
319  for comp in old_comps:
320  # Move relevant contents into new_comp
321  for units in list(getattr(comp, u'units', [])):
322  # Copy all <units> elements
323  # TODO: Just generate the ones we need, using _ensure_units_exist
324  comp.xml_remove_child(units)
325  new_comp.xml_append(units)
326  for var in list(getattr(comp, u'variable', [])):
327  # Only move used source variables
328  self._debug('Variable', var.fullname(), 'usage', var.get_usage_count(),
329  'type', var.get_type(), 'kept', var.pe_keep)
330  if (var.get_usage_count() and var.get_type() != VarTypes.Mapped) or var.pe_keep:
331  self._debug('Moving variable', var.fullname(cellml=True))
332  # Remove from where it was
333  comp._del_variable(var, keep_annotations=True)
334  # Set name to canonical version
335  var.name = var.fullname(cellml=True)
336  # Place in new component
337  new_comp._add_variable(var)
338  # Don't copy reactions
339  for math in list(getattr(comp, u'math', [])):
340  # Copy all <math> elements with content
341  if math.xml_children:
342  comp.xml_remove_child(math)
343  new_comp.xml_append(math)
344  # Invalidate cached links
345  math._unset_cached_links()
346  doc.model._del_component(comp)
347  # Remove groups & connections
348  for group in list(getattr(doc.model, u'group', [])):
349  doc.model.xml_remove_child(group)
350  for conn in list(getattr(doc.model, u'connection', [])):
351  doc.model.xml_remove_child(conn)
352 
353  # Remove unused variable assignments from the list
354  vs = [v for v in doc.model.get_assignments() if isinstance(v, cellml_variable)]
355  for v in vs:
356  if not v.xml_parent is new_comp:
357  doc.model._remove_assignment(v)
358 
359  # Remove interface attributes from variables
360  for v in new_comp.variable:
361  for iface in [u'public', u'private']:
362  try:
363  delattr(v, iface+u'_interface')
364  except AttributeError:
365  pass
366 
367  # Refresh expression dependency lists
368  for expr in self._get_assignment_exprs(False):
369  expr._cml_depends_on = list(expr.vars_in(expr.eq.rhs))
370  if expr.is_ode():
371  # Add dependency on the independent variable
372  indep_var = expr.eq.lhs.diff.independent_variable
373  if not indep_var in expr._cml_depends_on:
374  expr._cml_depends_on.append(indep_var)
375  # Update ODE definition dependency if needed
376  expr.eq.lhs.diff.dependent_variable._update_ode_dependency(indep_var, expr)
377  return
378 
379 
380 ######################################################################
381 # Lookup table analysis #
382 ######################################################################
383 
384 class LookupTableAnalyser(object):
385  """
386  Analyses & annotates a CellML model to indicate where lookup
387  tables can be used.
388  """
389 
390  def __init__(self):
391  """Create an analyser."""
392  # No model to analyse yet
393  self.doc = None
394  # Set default parameter values
395  self.set_params()
396 
397  @property
398  def config(self):
399  """Get the current document's configuration store."""
400  return getattr(self.doc, '_cml_config', None)
401 
403  """Determine if the given variable represents the transmembrane potential.
404 
405  This method takes an instance of cellml_variable and returns a boolean.
406  """
407  return (var.name in [u'V', u'membrane__V'] and
408  var.get_type(follow_maps=True) == VarTypes.State)
409 
410  def is_allowed_variable(self, var):
411  """Return True iff the given variable is allowed in a lookup table.
412 
413  This method uses the config store in the document to check the variable object.
414  """
415  var = var.get_source_variable(recurse=True)
416  allowed = (var in self.config.lut_config or
417  (self.config.options.include_dt_in_tables and
418  var is self.solver_info.get_dt().get_source_variable(recurse=True)))
419  return allowed
420 
421  def is_keying_var(self, var):
422  """Return True iff the given variable can be used as a table key.
423 
424  Will check the config store if it exists. If not, the variable name must match self.table_var.
425  """
426  if self.config:
427  return var.get_source_variable(recurse=True) in self.config.lut_config
428  else:
429  return var.name == self.table_var
430 
431  _LT_DEFAULTS = {'table_min': u'-100.0001',
432  'table_max': u'49.9999',
433  'table_step': u'0.01',
434  'table_var': u'V'}
435  def set_params(self, **kw):
436  """Set parameters controlling lookup table generation.
437 
438  Keyword parameters control the lookup table settings, which are
439  stored as attributes on suitable expressions.
440  table_min - minimum table entry (unicode) -> lut:min
441  table_max - maximum table entry (unicode) -> lut:max
442  table_step - table step size (unicode) -> lut:step
443  table_var - the name of the variable indexing the table (unicode) -> lut:var
444  """
445  defaults = self._LT_DEFAULTS
446  for attr in defaults:
447  if attr in kw:
448  setattr(self, attr, kw[attr])
449  else:
450  setattr(self, attr, getattr(self, attr, defaults[attr]))
451  return
452 
453  def get_param(self, param_name, table_var):
454  """Get the value of the lookup table parameter.
455 
456  table_var is the variable object being used to key this table.
457 
458  If the document has a config store, lookup the value there.
459  If that doesn't give us a value, use that given using set_params.
460  """
461  try:
462  val = self.config.lut_config[
463  table_var.get_source_variable(recurse=True)][param_name]
464  except AttributeError, KeyError:
465  val = getattr(self, param_name)
466  return val
467 
468  # One of these functions is required for a lookup table to be worthwhile
469  lut_expensive_funcs = frozenset(('exp', 'log', 'ln', 'root',
470  'sin', 'cos', 'tan',
471  'sec', 'csc', 'cot',
472  'sinh', 'cosh', 'tanh',
473  'sech', 'csch', 'coth',
474  'arcsin', 'arccos', 'arctan',
475  'arcsinh', 'arccosh', 'arctanh',
476  'arcsec', 'arccsc', 'arccot',
477  'arcsech', 'arccsch', 'arccoth'))
478 
479  class LUTState(object):
480  """Represents the state for lookup table analysis."""
481  def __init__(self):
482  """Set the initial state.
483 
484  We assume at first a lookup table would not be suitable.
485  """
486  self.has_var = False
487  self.bad_vars = set()
488  self.has_func = False
489  self.table_var = None
490 
491  def update(self, res):
492  """Update the state with the results of a recursive call.
493 
494  res is the result of checking a sub-expression for suitability,
495  and should be another instance of this class.
496  """
497  self.has_var = (self.has_var or res.has_var) and \
498  (not (self.table_var and res.table_var) or
499  self.table_var.get_source_variable(recurse=True) is
500  res.table_var.get_source_variable(recurse=True))
501  # The second condition above specifies that the keying variables must be the same if they both exist
502  self.bad_vars.update(res.bad_vars)
503  self.has_func = self.has_func or res.has_func
504  self.table_var = self.table_var or res.table_var
505  if not self.has_var and self.table_var:
506  # Two sub-expressions have different keying variables, so consider them as bad variables
507  self.bad_vars.add(self.table_var.name)
508  self.bad_vars.add(res.table_var.name)
509  self.table_var = None
510 
511  def suitable(self):
512  """Return True iff this state indicates a suitable expression for replacement with a lookup table."""
513  return (self.has_var and
514  not self.bad_vars and
515  self.has_func)
516 
517  def reason(self):
518  """
519  Return a unicode string describing why this state indicates the
520  expression is not suitable for replacement with a lookup table.
521 
522  This can be:
523  'no_var' - doesn't contain the table variable
524  'bad_var <vname>' - contains a variable which isn't permitted
525  'no_func' - doesn't contain an expensive function
526  or a comma separated combination of the above.
527  """
528  r = []
529  if not self.has_var:
530  r.append(u'no_var')
531  if not self.has_func:
532  r.append(u'no_func')
533  for vname in self.bad_vars:
534  r.append(u'bad_var ' + vname)
535  return u','.join(r)
536 
538  """Create a LUTState instance from an already annotated expression."""
539  state = self.LUTState()
540  possible = expr.getAttributeNS(NSS['lut'], u'possible', '')
541  if possible == u'yes':
542  varname = expr.getAttributeNS(NSS['lut'], u'var')
543  state.table_var = expr.component.get_variable_by_name(varname)
544  elif possible == u'no':
545  reason = expr.getAttributeNS(NSS['lut'], u'reason', '')
546  reasons = reason.split(u',')
547  for reason in reasons:
548  if reason == u'no_var':
549  state.has_var = False
550  elif reason == u'no_func':
551  state.has_func = False
552  elif reason.startswith(u'bad_var '):
553  state.bad_vars.add(reason[8:])
554  return state
555 
556  def analyse_for_lut(self, expr, var_checker_fn):
557  """Check if the given expression can be replaced by a lookup table.
558 
559  The first argument is the expression to check; the second is a
560  function which takes a variable object and returns True iff this
561  variable is permitted within a lookup table expression.
562 
563  If self.annotate_failures is True then annotate <apply> and
564  <piecewise> expressions which don't qualify with the reason
565  why they do not.
566  This can be:
567  'no_var' - doesn't contain the table variable
568  'bad_var <vname>' - contains a variable which isn't permitted
569  'no_func' - doesn't contain an expensive function
570  or a comma separated combination of the above.
571  The annotation is stored as the lut:reason attribute.
572 
573  If self.annotate_outermost_only is True then only annotate the
574  outermost qualifying expression, rather than also annotating
575  qualifying subexpressions.
576  """
577  # If this is a cloned expression, then just copy any annotations on the original.
578  if isinstance(expr, mathml):
579  source_expr = expr.get_original_of_clone()
580  else:
581  source_expr = None
582  if source_expr and source_expr.getAttributeNS(NSS['lut'], u'possible', '') != '':
583  LookupTableAnalyser.copy_lut_annotations(source_expr, expr)
584  state = self.create_state_from_annotations(source_expr)
585  DEBUG('lookup-tables', "No need to analyse clone", expr.xml(), state.suitable(), state.reason())
586  else:
587  # Initialise the indicators
588  state = self.LUTState()
589  # Process current node
590  if isinstance(expr, mathml_ci):
591  # Variable reference
592  if var_checker_fn(expr.variable):
593  # Could be a permitted var that isn't a keying var
594  if self.is_keying_var(expr.variable):
595  assert state.table_var is None # Sanity check
596  state.has_var = True
597  state.table_var = expr.variable
598  else:
599  state.bad_vars.add(expr.variable.name)
600  elif isinstance(expr, mathml_piecewise):
601  # Recurse into pieces & otherwise options
602  if hasattr(expr, u'otherwise'):
603  r = self.analyse_for_lut(child_i(expr.otherwise, 1),
604  var_checker_fn)
605  state.update(r)
606  for piece in getattr(expr, u'piece', []):
607  r = self.analyse_for_lut(child_i(piece, 1), var_checker_fn)
608  state.update(r)
609  r = self.analyse_for_lut(child_i(piece, 2), var_checker_fn)
610  state.update(r)
611  elif isinstance(expr, mathml_apply):
612  # Check function
613  if (not state.has_func and
614  expr.operator().localName in self.lut_expensive_funcs):
615  state.has_func = True
616  # Check operands
617  operand_states = {}
618  for operand in expr.operands():
619  r = self.analyse_for_lut(operand, var_checker_fn)
620  state.update(r)
621  operand_states[id(operand)] = r
622  # Check qualifiers
623  for qual in expr.qualifiers():
624  r = self.analyse_for_lut(qual, var_checker_fn)
625  state.update(r)
626  # Special case additional optimisations
627  if self.config and self.config.options.combine_commutative_tables and not state.suitable():
628  if isinstance(expr.operator(), reduce_commutative_nary):
629  self.check_commutative_tables(expr, operand_states)
630  elif isinstance(expr.operator(), mathml_divide):
631  self.check_divide_by_table(expr, operand_states)
632  else:
633  # Just recurse into children
634  for e in expr.xml_children:
635  if getattr(e, 'nodeType', None) == Node.ELEMENT_NODE:
636  r = self.analyse_for_lut(e, var_checker_fn)
637  state.update(r)
638  # Annotate the expression if appropriate
639  if isinstance(expr, (mathml_apply, mathml_piecewise)):
640  if state.suitable():
641  self.annotate_as_suitable(expr, state.table_var)
642  else:
643  if self.annotate_failures:
644  expr.xml_set_attribute((u'lut:reason', NSS['lut']), state.reason())
645  return state
646 
647  def check_divide_by_table(self, expr, operand_states):
648  """Convert division by a table into multiplication.
649 
650  This is called if expr, a division, cannot be replaced as a whole by a lookup table.
651  If the denominator can be replaced by a table, then convert expr into a multiplication
652  by the reciprocal, moving the division into the table.
653  """
654  numer, denom = list(expr.operands())
655  state = operand_states[id(denom)]
656  if state.suitable():
657  expr.safe_remove_child(numer)
658  expr.safe_remove_child(denom)
659  recip = mathml_apply.create_new(expr, u'divide', [(u'1', u'dimensionless'), denom])
660  times = mathml_apply.create_new(expr, u'times', [numer, recip])
661  expr.replace_child(expr, times, expr.xml_parent)
662  self.annotate_as_suitable(recip, state.table_var)
663 
664  def check_commutative_tables(self, expr, operand_states):
665  """Check whether we can combine suitable operands into a new expression.
666 
667  If expr has a commutative (and associative) n-ary operator, but is not suitable as a
668  whole to become a lookup table (checked by caller) then we might still be able to
669  do slightly better than just analysing its operands. If multiple operands can be
670  replaced by tables keyed on the same variable, these can be combined into a new
671  application of the same operator as expr, which can then be replaced as a whole
672  by a single lookup table, and made an operand of expr.
673 
674  Alternatively, if at least one operand can be replaced by a table, and a subset of
675  other operands do not contain other variables, then they can be included in the single
676  table.
677  """
678  # Operands that can be replaced by tables
679  table_operands = filter(lambda op: operand_states[id(op)].suitable(), expr.operands())
680  if not table_operands:
681  return
682  # Sort these suitable operands by table_var (map var id to var & operand list, respectively)
683  table_vars, table_var_operands = {}, {}
684  for oper in table_operands:
685  table_var = operand_states[id(oper)].table_var
686  table_var_id = id(table_var)
687  if not table_var_id in table_vars:
688  table_vars[table_var_id] = table_var
689  table_var_operands[table_var_id] = []
690  table_var_operands[table_var_id].append(oper)
691  # Figure out if any operands aren't suitable by themselves but could be included in a table
692  potential_operands = {id(None): []}
693  for table_var_id in table_vars.keys():
694  potential_operands[table_var_id] = []
695  for op in expr.operands():
696  state = operand_states[id(op)]
697  if not state.suitable() and not state.bad_vars:
698  if not state.table_var in potential_operands:
699  potential_operands[id(state.table_var)] = []
700  potential_operands[id(state.table_var)].append(op)
701  # Do any combining
702  for table_var_id in table_vars.keys():
703  suitable_opers = table_var_operands[table_var_id] + potential_operands[table_var_id] + potential_operands[id(None)]
704  if len(suitable_opers) > 1:
705  # Create new sub-expression with the suitable operands
706  for oper in suitable_opers:
707  expr.safe_remove_child(oper)
708  new_expr = mathml_apply.create_new(expr, expr.operator().localName, suitable_opers)
709  expr.xml_append(new_expr)
710  self.annotate_as_suitable(new_expr, table_vars[table_var_id])
711  # Remove the operands with no table_var from consideration with other keying vars
712  potential_operands[id(None)] = []
713 
714  def annotate_as_suitable(self, expr, table_var):
715  """Annotate the given expression as being suitable for a lookup table."""
716  if self.annotate_outermost_only:
717  # Remove annotations from (expr and) child expressions
718  self.remove_lut_annotations(expr)
719  for param in ['min', 'max', 'step']:
720  expr.xml_set_attribute((u'lut:' + param, NSS['lut']),
721  self.get_param('table_' + param, table_var))
722  expr.xml_set_attribute((u'lut:var', NSS['lut']), table_var.name)
723  expr.xml_set_attribute((u'lut:possible', NSS['lut']), u'yes')
724  self.doc.lookup_tables[expr] = True
725 
726  @staticmethod
727  def copy_lut_annotations(from_expr, to_expr):
728  """Copy any lookup table annotations from one expression to another."""
729  for pyname, fullname in from_expr.xml_attributes.iteritems():
730  if fullname[1] == NSS['lut']:
731  to_expr.xml_set_attribute(fullname, getattr(from_expr, pyname))
732 
733  def remove_lut_annotations(self, expr, remove_reason=False):
734  """Remove lookup table annotations from the given expression.
735 
736  By default this will only remove annotations from expressions
737  (and sub-expressions) that can be converted to use lookup tables.
738  If remove_reason is True, then the lut:reason attributes on
739  non-qualifying expressions will also be removed.
740  """
741  # Remove from this expression
742  delete_table = False
743  for pyname in getattr(expr, 'xml_attributes', {}).keys():
744  fullname = expr.xml_attributes[pyname]
745  if fullname[1] == NSS['lut']:
746  if remove_reason or fullname[0] != u'lut:reason':
747  expr.__delattr__(pyname)
748  if fullname[0] != u'lut:reason':
749  delete_table = True
750  # Delete expr from list of lookup tables?
751  if delete_table:
752  del self.doc.lookup_tables[expr]
753  # Recurse into children
754  for e in expr.xml_children:
755  if getattr(e, 'nodeType', None) == Node.ELEMENT_NODE:
756  self.remove_lut_annotations(e, remove_reason)
757 
758  def analyse_model(self, doc, solver_info,
759  annotate_failures=True,
760  annotate_outermost_only=True):
761  """Analyse the given document.
762 
763  This method checks all expressions (and subexpressions)
764  in the given document for whether they could be converted to
765  use a lookup table, and annotates them appropriately.
766 
767  By default expressions which don't qualify will be annotated
768  to indicate why; set annotate_failures to False to suppress
769  this.
770 
771  Also by default only the outermost suitable expression in any
772  given tree will be annotated; if you want to annotate suitable
773  subexpressions of a suitable expression then pass
774  annotate_outermost_only as False.
775  """
776  self.doc = doc
777  self.solver_info = solver_info
778  self.annotate_failures = annotate_failures
779  self.annotate_outermost_only = annotate_outermost_only
780  doc.lookup_tables = {}
781  # How to check for allowed variables
782  if hasattr(doc, '_cml_config'):
783  checker_fn = self.is_allowed_variable
784  else:
785  checker_fn = self.var_is_membrane_potential
786 
787  # Check all expressions
788  for expr in (e for e in doc.model.get_assignments()
789  if isinstance(e, mathml_apply)):
790  ops = expr.operands()
791  ops.next()
792  e = ops.next()
793  self.analyse_for_lut(e, checker_fn)
794  for expr in solver_info.get_modifiable_mathematics():
795  self.analyse_for_lut(expr, checker_fn)
796 
797  if solver_info.has_modifiable_mathematics():
799 
800  # Assign names (numbers) to the lookup tables found.
801  # Also work out which ones can share index variables into the table.
802  doc.lookup_tables = doc.lookup_tables.keys()
803  doc.lookup_tables.sort(cmp=element_path_cmp)
804  doc.lookup_table_indexes, n = {}, 0
805  for i, expr in enumerate(doc.lookup_tables):
806  expr.xml_set_attribute((u'lut:table_name', NSS['lut']), unicode(i))
807  comp = expr.get_component()
808  var = comp.get_variable_by_name(expr.var).get_source_variable(recurse=True)
809  key = (expr.min, expr.max, expr.step, var)
810  if not key in doc.lookup_table_indexes:
811  doc.lookup_table_indexes[key] = unicode(n)
812  n += 1
813  expr.xml_set_attribute((u'lut:table_index', NSS['lut']), doc.lookup_table_indexes[key])
814 
815  if solver_info.has_modifiable_mathematics():
817 
818  # Re-do dependency analysis so that an expression using lookup
819  # tables only depends on the keying variable.
820  for expr in (e for e in doc.model.get_assignments()
821  if isinstance(e, mathml_apply)):
822  expr.classify_variables(root=True,
823  dependencies_only=True,
824  needs_special_treatment=self.calculate_dependencies)
825 
826  def _find_tables(self, expr, table_dict):
827  """Helper method for _determine_unneeded_tables."""
828  if expr.getAttributeNS(NSS['lut'], u'possible', '') == u'yes':
829  table_dict[id(expr)] = expr
830  else:
831  for e in self.doc.model.xml_element_children(expr):
832  self._find_tables(e, table_dict)
833 
835  """Determine whether some expressions identified as lookup tables aren't actually used.
836 
837  This occurs if some ODEs have been linearised, in which case the original definitions
838  will have been analysed for lookup tables, but aren't actually used.
839 
840  TODO: The original definitions might be used for computing derived quantities...
841  """
842  original_tables = {}
843  new_tables = {}
844  def f(exprs, table_dict):
845  exprs = filter(lambda n: isinstance(n, (mathml_ci, mathml_apply, mathml_piecewise)), exprs)
846  for node in self.doc.model.calculate_extended_dependencies(exprs):
847  if isinstance(node, mathml_apply):
848  self._find_tables(node, table_dict)
849  for u, t, eqns in self.solver_info.get_linearised_odes():
850  original_defn = u.get_ode_dependency(t)
851  f([original_defn], original_tables)
852  f(eqns, new_tables)
853  for id_ in set(original_tables.keys()) - set(new_tables.keys()):
854  expr = original_tables[id_]
855  self.remove_lut_annotations(expr)
856  expr.xml_set_attribute((u'lut:reason', NSS['lut']),
857  u'Expression will not be used in generated code.')
858  DEBUG('lookup-tables', 'Not annotating probably unused expression', expr)
859 
861  """Determine whether we have multiple tables for the same expression.
862 
863  Any expression that is identical to a previous table will be re-annotated to refer to the
864  previous table, instead of declaring a new one.
865 
866  This is a temporary measure until we have proper sub-expression elimination for the Jacobian
867  and residual calculations.
868  """
869  uniq_tables = []
870  for expr in self.doc.lookup_tables:
871  for table in uniq_tables:
872  if expr.same_tree(table):
873  lt_name = table.getAttributeNS(NSS['lut'], u'table_name', u'')
874  # Need to remove old name before we can set a new one (grr amara)
875  del expr.table_name
876  expr.xml_set_attribute((u'lut:table_name', NSS['lut']), lt_name)
877  break
878  else:
879  uniq_tables.append(expr)
880 
881  def calculate_dependencies(self, expr):
882  """Determine the dependencies of an expression that might use a lookup table.
883 
884  This method is suitable for use as the needs_special_treatment function in
885  mathml_apply.classify_variables. It is used to override the default recursion
886  into sub-trees. It takes a single sub-tree as argument, and returns either
887  the dependency set for that sub-tree, or None to use the default recursion.
888 
889  Expressions that can use a lookup table only depend on the keying variable.
890  """
891  if expr.getAttributeNS(NSS['lut'], u'possible', '') == u'yes':
892  key_var_name = expr.getAttributeNS(NSS['lut'], u'var')
893  key_var = expr.component.get_variable_by_name(key_var_name).get_source_variable(recurse=True)
894  return set([key_var])
895  # If not a table, use default behaviour
896  return None
897 
898 
899 ######################################################################
900 # Jacobian analysis #
901 ######################################################################
902 
903 class LinearityAnalyser(object):
904  """Analyse linearity aspects of a model.
905 
906  This class performs analyses to determine which ODEs have a linear
907  dependence on their state variable, discounting the presence of
908  the transmembrane potential.
909 
910  This can be used to decouple the ODEs for more efficient solution,
911  especially in a multi-cellular context.
912 
913  analyse_for_jacobian(doc) must be called before
914  rearrange_linear_odes(doc).
915  """
916  LINEAR_KINDS = Enum('None', 'Linear', 'Nonlinear')
917 
918  def analyse_for_jacobian(self, doc, V=None):
919  """Analyse the model for computing a symbolic Jacobian.
920 
921  Determines automatically which variables will need to be solved
922  for using Newton's method, and stores their names in
923  doc.model._cml_nonlinear_system_variables, as a list of variable
924  objects.
925 
926  Also stores doc.model._cml_linear_vars, doc.model._cml_free_var,
927  doc.model._cml_transmembrane_potential.
928  """
929  # TODO: Add error checking and tidy
930  stvs = doc.model.find_state_vars()
931  if V is None:
932  Vcname, Vvname = 'membrane', 'V'
933  V = doc.model.get_variable_by_name(Vcname, Vvname)
934  V = V.get_source_variable(recurse=True)
935  doc.model._cml_transmembrane_potential = V
936  free_var = doc.model.find_free_vars()[0]
937  lvs = self.find_linear_odes(stvs, V, free_var)
938  # Next 3 lines for benefit of rearrange_linear_odes(doc)
939  lvs.sort(key=lambda v: v.fullname())
940  doc.model._cml_linear_vars = lvs
941  doc.model._cml_free_var = free_var
942  # Store nonlinear vars in canonical order
943  nonlinear_vars = list(set(stvs) - set([V]) - set(lvs))
944  nonlinear_vars.sort(key=lambda v: v.fullname())
945  doc.model._cml_nonlinear_system_variables = nonlinear_vars
946  # Debugging
947  f = lambda var: var.fullname()
948  DEBUG('linearity-analyser', 'V=', V.fullname(), '; free var=',
949  free_var.fullname(), '; linear vars=', map(f, lvs),
950  '; nonlinear vars=', map(f, nonlinear_vars))
951  return
952 
953  def _get_rhs(self, expr):
954  """Return the RHS of an assignment expression."""
955  ops = expr.operands()
956  ops.next()
957  return ops.next()
958 
959  def _check_expr(self, expr, state_var, bad_vars):
960  """The actual linearity checking function.
961 
962  Recursively determine the type of dependence expr has on
963  state_var. The presence of any members of bad_vars indicates
964  a non-linear dependency.
965 
966  Return a member of the self.LINEAR_KINDS enum.
967  """
968  kind = self.LINEAR_KINDS
969  result = None
970  if isinstance(expr, mathml_ci):
971  var = expr.variable.get_source_variable(recurse=True)
972  if var is state_var:
973  result = kind.Linear
974  elif var in bad_vars:
975  result = kind.Nonlinear
976  elif var.get_type(follow_maps=True) == VarTypes.Computed:
977  # Recurse into defining expression
978  src_var = var.get_source_variable(recurse=True)
979  src_expr = self._get_rhs(src_var.get_dependencies()[0])
980  DEBUG('find-linear-deps', "--recurse for", src_var.name,
981  "to", src_expr)
982  result = self._check_expr(src_expr, state_var, bad_vars)
983  else:
984  result = kind.None
985  # Record the kind of this variable, for later use when
986  # rearranging linear ODEs
987  var._cml_linear_kind = result
988  elif isinstance(expr, mathml_cn):
989  result = kind.None
990  elif isinstance(expr, mathml_piecewise):
991  # If any conditions have a dependence, then we're
992  # nonlinear. Otherwise, all the pieces must be the same
993  # (and that's what we are) or we're nonlinear.
994  pieces = getattr(expr, u'piece', [])
995  conds = map(lambda p: child_i(p, 2), pieces)
996  chld_exprs = map(lambda p: child_i(p, 1), pieces)
997  if hasattr(expr, u'otherwise'):
998  chld_exprs.append(child_i(expr.otherwise, 1))
999  for cond in conds:
1000  if self._check_expr(cond, state_var, bad_vars) != kind.None:
1001  result = kind.Nonlinear
1002  break
1003  else:
1004  # Conditions all OK
1005  for e in chld_exprs:
1006  res = self._check_expr(e, state_var, bad_vars)
1007  if result is not None and res != result:
1008  # We have a difference
1009  result = kind.Nonlinear
1010  break
1011  result = res
1012  elif isinstance(expr, mathml_apply):
1013  # Behaviour depends on the operator
1014  operator = expr.operator().localName
1015  operands = expr.operands()
1016  if operator in ['plus', 'minus']:
1017  # Linear if any operand linear, and none non-linear
1018  op_kinds = map(lambda op: self._check_expr(op, state_var,
1019  bad_vars),
1020  operands)
1021  result = max(op_kinds)
1022  elif operator == 'divide':
1023  # Linear iff only numerator linear
1024  numer = operands.next()
1025  denom = operands.next()
1026  if self._check_expr(denom, state_var, bad_vars) != kind.None:
1027  result = kind.Nonlinear
1028  else:
1029  result = self._check_expr(numer, state_var, bad_vars)
1030  elif operator == 'times':
1031  # Linear iff only 1 linear operand
1032  op_kinds = map(lambda op: self._check_expr(op, state_var,
1033  bad_vars),
1034  operands)
1035  lin, nonlin = 0, 0
1036  for res in op_kinds:
1037  if res == kind.Linear: lin += 1
1038  elif res == kind.Nonlinear: nonlin += 1
1039  if nonlin > 0 or lin > 1:
1040  result = kind.Nonlinear
1041  elif lin == 1:
1042  result = kind.Linear
1043  else:
1044  result = kind.None
1045  else:
1046  # Cannot be linear; may be no dependence at all
1047  result = max(map(lambda op: self._check_expr(op, state_var,
1048  bad_vars),
1049  operands))
1050  if result == kind.Linear:
1051  result = kind.Nonlinear
1052  else:
1053  # Either a straightforward container element
1054  try:
1055  child = child_i(expr, 1)
1056  except ValueError:
1057  # Assume it's just a constant
1058  result = kind.None
1059  else:
1060  result = self._check_expr(child, state_var, bad_vars)
1061  DEBUG('find-linear-deps', "Expression", expr, "gives result", result)
1062  return result
1063 
1064  def find_linear_odes(self, state_vars, V, free_var):
1065  """Identify linear ODEs.
1066 
1067  For each ODE (except that for V), determine whether it has a
1068  linear dependence on the dependent variable, and thus can be
1069  updated directly, without using Newton's method.
1070 
1071  We also require it to not depend on any other state variable,
1072  except for V.
1073  """
1074  kind = self.LINEAR_KINDS
1075  candidates = set(state_vars) - set([V])
1076  linear_vars = []
1077  for var in candidates:
1078  ode_expr = var.get_ode_dependency(free_var)
1079  if self._check_expr(self._get_rhs(ode_expr), var,
1080  candidates - set([var])) == kind.Linear:
1081  linear_vars.append(var)
1082  return linear_vars
1083 
1084  def _clone(self, expr):
1085  """Properly clone a MathML sub-expression."""
1086  if isinstance(expr, mathml):
1087  clone = expr.clone_self(register=True)
1088  else:
1089  clone = mathml.clone(expr)
1090  return clone
1091 
1092  def _make_apply(self, operator, ghs, i, filter_none=True,
1093  preserve=False):
1094  """Construct a new apply expression for g or h.
1095 
1096  ghs is an iterable of (g,h) pairs for operands.
1097 
1098  i indicates whether to construct g (0) or h (1).
1099 
1100  filter_none indicates the behaviour of 0 under this operator.
1101  If True, it's an additive zero, otherwise it's a
1102  multiplicative zero.
1103  """
1104  # Find g or h operands
1105  ghs_i = map(lambda gh: gh[i], ghs)
1106  if not filter_none and None in ghs_i:
1107  # Whole expr is None
1108  new_expr = None
1109  else:
1110  # Filter out None subexprs
1111  if operator == u'minus':
1112  # Do we need to retain a unary minus?
1113  if len(ghs_i) == 1 or ghs_i[0] is None:
1114  # Original was -a or 0-a
1115  retain_unary_minus = True
1116  else:
1117  # Original was a-0 or a-b
1118  retain_unary_minus = False
1119  else:
1120  # Only retain if we're told to preserve as-is
1121  retain_unary_minus = preserve
1122  ghs_i = filter(None, ghs_i)
1123  if ghs_i:
1124  if len(ghs_i) > 1 or retain_unary_minus:
1125  new_expr = mathml_apply.create_new(
1126  self.__expr, operator, ghs_i)
1127  else:
1128  new_expr = self._clone(ghs_i[0]) # Clone may be unneeded
1129  else:
1130  new_expr = None
1131  return new_expr
1132 
1133  def _transfer_lut(self, expr, gh, var):
1134  """Transfer lookup table annotations from expr to gh.
1135 
1136  gh is a pair (g, h) s.t. expr = g + h*var.
1137 
1138  If expr can be represented by a lookup table, then the lookup
1139  variable cannot be var, since if it were, then expr would have
1140  a non-linear dependence on var. Hence h must be 0, since
1141  otherwise expr would contain a (state) variable other than the
1142  lookup variable, and hence not be a candidate for a table.
1143  Thus expr=g, so we transfer the annotations to g.
1144  """
1145  if expr.getAttributeNS(NSS['lut'], u'possible', '') != u'yes':
1146  return
1147  # Paranoia check that our reasoning is correct
1148  g, h = gh
1149  assert h is None
1150  # Transfer the annotations into g
1151  LookupTableAnalyser.copy_lut_annotations(expr, g)
1152  # Make sure g has a reference to its component, for use by code generation.
1153  g._cml_component = expr.component
1154  return
1155 
1156  def _rearrange_expr(self, expr, var):
1157  """Rearrange an expression into the form g + h*var.
1158 
1159  Performs a post-order traversal of this expression's tree,
1160  and returns a pair (g, h)
1161  """
1162 # import inspect
1163 # depth = len(inspect.stack())
1164 # print ' '*depth, "_rearrange_expr", prid(expr, True), var.name, expr
1165  gh = None
1166  if isinstance(expr, mathml_ci):
1167  # Variable
1168  ci_var = expr.variable.get_source_variable(recurse=True)
1169  if var is ci_var:
1170  gh = (None, mathml_cn.create_new(expr,
1171  u'1', u'dimensionless'))
1172  else:
1173  if ci_var._cml_linear_kind == self.LINEAR_KINDS.None:
1174  # Just put the <ci> in g, but with full name
1175  gh = (mathml_ci.create_new(expr, ci_var.fullname()), None)
1176  gh[0]._set_variable_obj(ci_var)
1177  else:
1178  # ci_var is a linear function of var, so rearrange
1179  # its definition
1180  if not hasattr(ci_var, '_cml_linear_split'):
1181  ci_defn = ci_var.get_dependencies()[0]
1182  ci_var._cml_linear_split = self._rearrange_expr(
1183  self._get_rhs(ci_defn), var)
1184  gh = ci_var._cml_linear_split
1185  elif isinstance(expr, mathml_piecewise):
1186  # The tests have to move into both components of gh:
1187  # "if C1 then (a1,b1) elif C2 then (a2,b2) else (a0,b0)"
1188  # maps to "(if C1 then a1 elif C2 then a2 else a0,
1189  # if C1 then b1 elif C2 then b2 else b0)"
1190  # Note that no test is a function of var.
1191  # First rearrange child expressions
1192  pieces = getattr(expr, u'piece', [])
1193  cases = map(lambda p: child_i(p, 1), pieces)
1194  cases_ghs = map(lambda c: self._rearrange_expr(c, var), cases)
1195  if hasattr(expr, u'otherwise'):
1196  ow_gh = self._rearrange_expr(child_i(expr.otherwise, 1), var)
1197  else:
1198  ow_gh = (None, None)
1199  # Now construct the new expression
1200  conds = map(lambda p: self._clone(child_i(p, 2)), pieces)
1201  def piecewise_branch(i):
1202  pieces_i = zip(map(lambda gh: gh[i], cases_ghs),
1203  conds)
1204  pieces_i = filter(lambda p: p[0] is not None,
1205  pieces_i) # Remove cases that are None
1206  ow = ow_gh[i]
1207  if pieces_i:
1208  new_expr = mathml_piecewise.create_new(
1209  expr, pieces_i, ow)
1210  elif ow:
1211  new_expr = ow
1212  else:
1213  new_expr = None
1214  return new_expr
1215  gh = (piecewise_branch(0), piecewise_branch(1))
1216  self._transfer_lut(expr, gh, var)
1217  elif isinstance(expr, mathml_apply):
1218  # Behaviour depends on the operator
1219  operator = expr.operator().localName
1220  operands = expr.operands()
1221  self.__expr = expr # For self._make_apply
1222  if operator in ['plus', 'minus']:
1223  # Just split the operation into each component
1224  operand_ghs = map(lambda op: self._rearrange_expr(op, var),
1225  operands)
1226  g = self._make_apply(operator, operand_ghs, 0)
1227  h = self._make_apply(operator, operand_ghs, 1)
1228  gh = (g, h)
1229  elif operator == 'divide':
1230  # (a, b) / (c, 0) = (a/c, b/c)
1231  numer = self._rearrange_expr(operands.next(), var)
1232  denom = self._rearrange_expr(operands.next(), var)
1233  assert denom[1] is None
1234  denom_g = denom[0]
1235  g = h = None
1236  if numer[0]:
1237  g = mathml_apply.create_new(expr, operator,
1238  [numer[0], denom_g])
1239  if numer[1]:
1240  if g:
1241  denom_g = self._clone(denom_g)
1242  h = mathml_apply.create_new(expr, operator,
1243  [numer[1], denom_g])
1244  gh = (g, h)
1245  elif operator == 'times':
1246  # (a1,b1)*(a2,b2) = (a1*a2, b1*a2 or a1*b2 or None)
1247  # Similarly for the nary case - at most one b_i is not None
1248  operand_ghs = map(lambda op: self._rearrange_expr(op, var),
1249  operands)
1250  g = self._make_apply(operator, operand_ghs, 0,
1251  filter_none=False)
1252  # Find non-None b_i, if any
1253  for i, ab in enumerate(operand_ghs):
1254  if ab[1] is not None:
1255  operand_ghs[i] = (ab[1], None)
1256  # Clone the a_i to avoid objects having 2 parents
1257  for j, ab in enumerate(operand_ghs):
1258  if j != i:
1259  operand_ghs[j] = (self._clone(operand_ghs[j][0]), None)
1260  h = self._make_apply(operator, operand_ghs, 0, filter_none=False)
1261  break
1262  else:
1263  h = None
1264  gh = (g, h)
1265  else:
1266  # (a, None) op (b, None) = (a op b, None)
1267  operand_ghs = map(lambda op: self._rearrange_expr(op, var),
1268  operands)
1269  g = self._make_apply(operator, operand_ghs, 0, preserve=True)
1270  gh = (g, None)
1271  self._transfer_lut(expr, gh, var)
1272  else:
1273  # Since this expression is linear, there can't be any
1274  # occurrence of var in it; all possible such cases are covered
1275  # above. So just clone it into g.
1276  gh = (self._clone(expr), None)
1277 # print ' '*depth, "Re-arranged", prid(expr, True), "to", prid(gh[0], True), ",", prid(gh[1], True)
1278  return gh
1279 
1280  def rearrange_linear_odes(self, doc):
1281  """Rearrange the linear ODEs so they can be updated directly
1282  on solving.
1283 
1284  Each ODE du/dt = f(u, t) can be written in the form
1285  du/dt = g(t) + h(t)u.
1286  A backward Euler update step is then as simple as
1287  u_n = (u_{n-1} + g(t)dt) / (1 - h(t)dt)
1288  (assuming that the transmembrane potential has already been
1289  solved for at t_n.
1290 
1291  Stores the results in doc.model._cml_linear_update_exprs, a
1292  mapping from variable object u to pair (g, h).
1293  """
1294 
1295  odes = map(lambda v: v.get_ode_dependency(doc.model._cml_free_var),
1296  doc.model._cml_linear_vars)
1297  result = {}
1298  for var, ode in itertools.izip(doc.model._cml_linear_vars, odes):
1299  # Do the re-arrangement for this variable.
1300  # We do this by a post-order traversal of its defining ODE,
1301  # constructing a pair (g, h) for each subexpression recursively.
1302  # First, get the RHS of the ODE
1303  rhs = self._get_rhs(ode)
1304  # And traverse
1305  result[var] = self._rearrange_expr(rhs, var)
1306  # Store result in model
1307  doc.model._cml_linear_update_exprs = result
1308  return result
1309 
1310  def show(self, d):
1311  """Print out a more readable report on a rearrangement,
1312  as given by self.rearrange_linear_odes."""
1313  for var, expr in d.iteritems():
1314  print var.fullname()
1315  print "G:", expr[0].xml()
1316  print "H:", expr[1].xml()
1317  print "ODE:", var.get_ode_dependency(
1318  var.model._cml_free_var).xml()
1319  print
1320 
1321 
1322 ######################################################################
1323 # Rush-Larsen analysis #
1324 ######################################################################
1325 
1326 class ExpressionMatcher(object):
1327  """Test whether a MathML expression matches a given tree pattern.
1328 
1329  Patterns are instances of the nested Pattern class, or more specifically
1330  one of its subclasses. The static method match on this class checks an
1331  expression against a pattern, returning True iff there is a match.
1332  """
1333 
1334  class Pattern(object):
1335  """Abstract base class for tree patterns."""
1336  def match(self, expr):
1337  """
1338  Method implemented by concrete subclasses to test a given expression.
1339  Returns True iff there is a match.
1340  """
1341  raise NotImplementedError
1342 
1343  class A(Pattern):
1344  """An apply expression."""
1345  def __init__(self, operator, operands):
1346  self.operator = operator
1347  self.operands = operands
1348 
1349  def match(self, expr):
1350  matched = False
1351  if isinstance(expr, mathml_apply):
1352  if expr.operator().localName == self.operator:
1353  expr_operands = list(expr.operands())
1354  if len(expr_operands) == len(self.operands):
1355  matched = reduce(operator.and_,
1356  map(lambda (pat, op): pat.match(op),
1357  zip(self.operands, expr_operands)))
1358  return matched
1359 
1360  class V(Pattern):
1361  """A variable reference."""
1362  def __init__(self, var=None):
1363  self.set_variable(var)
1364 
1365  def set_variable(self, var):
1366  if var:
1367  self.var = var.get_source_variable(recurse=True)
1368  else:
1369  self.var = var
1370 
1371  def match(self, expr):
1372  matched = False
1373  if isinstance(expr, mathml_ci) and self.var is expr.variable.get_source_variable(recurse=True):
1374  matched = True
1375  return matched
1376 
1377  class N(Pattern):
1378  """A constant number, optionally with the value specified."""
1379  def __init__(self, value=None):
1380  self.value = value
1381 
1382  def match(self, expr):
1383  matched = False
1384  if isinstance(expr, mathml_cn):
1385  value = expr.evaluate()
1386  if self.value is None:
1387  self.value = value
1388  if self.value == value:
1389  matched = True
1390  return matched
1391 
1392  class X(Pattern):
1393  """A placeholder, matching anything (and noting what was matched)."""
1394  def __init__(self):
1395  self.matched = None
1396 
1397  def match(self, expr):
1398  self.matched = expr
1399  return True
1400 
1401  class R(Pattern):
1402  """A container that matches any number of levels of indirection/recursion.
1403 
1404  This can be used to wrap a pattern where we wish to allow for variable mappings
1405  or equations such as "var1 = var2" before we reach the 'interesting' equation.
1406  If the expression we're matching is a ci element we recursively find the
1407  ultimate non-ci defining expression and match our sub-pattern against that. If
1408  the expression isn't a ci, or the ultimate definition isn't an expression, we
1409  match our sub-pattern against it directly.
1410  """
1411  def __init__(self, pat):
1412  self.sub_pattern = pat
1413 
1414  def match(self, expr):
1415  while isinstance(expr, mathml_ci):
1416  # Find this variable's defining expression, if it is an equation
1417  var = expr.variable.get_source_variable(recurse=True)
1418  defn = var.get_dependencies()
1419  if defn and isinstance(defn[0], mathml_apply):
1420  expr = defn[0].eq.rhs
1421  return self.sub_pattern.match(expr)
1422 
1423  @staticmethod
1424  def match(pattern, expression):
1425  """Test for a match."""
1426  return pattern.match(expression)
1427 
1428 class RushLarsenAnalyser(object):
1429  """Analyse a model to identify Hodgkin-Huxley style gating variable equations.
1430 
1431  We look for ODEs whose definition matches "alpha*(1-x) - beta*x" (where x is
1432  the state variable, and alpha & beta are any expression). Alternatively for
1433  models which already have tau & inf variables, we match against "(inf-x)/tau".
1434 
1435  To allow for when units conversions have been performed, we chase 'simple'
1436  assignments and (the semantically equivalent) variable mappings until reaching
1437  an 'interesting' defining equation. We also need to allow the whole RHS to be
1438  multiplied by a constant. If this occurs, the constant conversion factor is
1439  also stored; otherwise we store None.
1440 
1441  Stores a dictionary on the document root mapping cellml_variable instances to
1442  4-tuples. The tuple is either ('ab', alpha, beta, conv) or ('ti', tau, inf, conv)
1443  depending on which formulation has been used. Note that the expressions in these
1444  are not cloned copies - they are the original objects still embedded within the
1445  relevant ODE. The units conversion factor 'conv' is stored as a Python double.
1446  """
1447  def __init__(self):
1448  """Create the patterns to match against."""
1449  em = ExpressionMatcher
1450  self._var = em.V()
1451  self._conv = em.N()
1452  # Alpha/beta form
1453  self._alpha = em.X()
1454  self._beta = em.X()
1455  self._ab_pattern = em.R(em.A('minus', [em.A('times', [self._alpha,
1456  em.A('minus', [em.N(1), self._var])]),
1457  em.A('times', [self._beta, self._var])]))
1458  self._alt_ab_pattern = em.R(em.A('times', [self._conv, self._ab_pattern]))
1459  # Tau/inf form
1460  self._tau = em.X()
1461  self._inf = em.X()
1462  self._ti_pattern = em.R(em.A('divide', [em.A('minus', [self._inf, self._var]),
1463  self._tau]))
1464  self._alt_ti_pattern = em.R(em.A('times', [self._conv, self._ti_pattern]))
1465 
1466  def analyse_model(self, doc):
1467  # First, find linear ODEs that have the potential to be gating variables
1468  la = LinearityAnalyser()
1469  V = doc._cml_config.V_variable
1470  state_vars = doc.model.find_state_vars()
1471  free_var = doc.model.find_free_vars()[0]
1472  linear_vars = la.find_linear_odes(state_vars, V, free_var)
1473  # Next, check they match dn/dt = a (1-n) - b n
1474  doc._cml_rush_larsen = {}
1475  for var in linear_vars:
1476  ode_expr = var.get_ode_dependency(free_var)
1477  self._check_var(var, ode_expr, doc._cml_rush_larsen)
1478 
1479  def _check_var(self, var, ode_expr, mapping):
1480  rhs = ode_expr.eq.rhs
1481  self._var.set_variable(ode_expr.eq.lhs.diff.dependent_variable)
1482  if self._ab_pattern.match(rhs):
1483  mapping[var] = ('ab', self._alpha.matched, self._beta.matched, None)
1484  elif self._alt_ab_pattern.match(rhs):
1485  mapping[var] = ('ab', self._alpha.matched, self._beta.matched, self._conv.value)
1486  elif self._ti_pattern.match(rhs):
1487  mapping[var] = ('ti', self._tau.matched, self._inf.matched, None)
1488  elif self._alt_ti_pattern.match(rhs):
1489  mapping[var] = ('ti', self._tau.matched, self._inf.matched, self._conv.value)