Nektar++
optimize.py
Go to the documentation of this file.
1#!/usr/bin/env python
2
3# We want 1/2==0.5
4from __future__ import division
5
6"""Copyright (c) 2005-2016, University of Oxford.
7All rights reserved.
8
9University of Oxford means the Chancellor, Masters and Scholars of the
10University of Oxford, having an administrative office at Wellington
11Square, Oxford OX1 2JD, UK.
12
13This file is part of Chaste.
14
15Redistribution and use in source and binary forms, with or without
16modification, are permitted provided that the following conditions are met:
17 * Redistributions of source code must retain the above copyright notice,
18 this list of conditions and the following disclaimer.
19 * Redistributions in binary form must reproduce the above copyright notice,
20 this list of conditions and the following disclaimer in the documentation
21 and/or other materials provided with the distribution.
22 * Neither the name of the University of Oxford nor the names of its
23 contributors may be used to endorse or promote products derived from this
24 software without specific prior written permission.
25
26THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
27AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
28IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
29ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
30LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
31CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
32GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
33HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
34LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
35OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36"""
37
38"""
39This part of PyCml applies various optimising transformations to CellML
40models, in particular partial evaluation and the use of lookup tables.
41"""
42
43import operator
44
45# Common CellML processing stuff
46import pycml
47from pycml import * # Put contents in the local namespace as well
48
49__version__ = "$Revision: 25790 $"[11:-2]
50
51
52
53######################################################################
54# Partial Evaluation #
55######################################################################
56
57class PartialEvaluator(object):
58 """Perform partial evaluation of a CellML model."""
59 def _debug(self, *args):
60 """Output debug info from the PE process."""
61 logger = logging.getLogger('partial-evaluator')
62 logger.debug(' '.join(map(str, args)))
63
64 def _expr_lhs(self, expr):
65 """Display the LHS of this expression."""
66 lhs = expr.assigned_variable()
67 if isinstance(lhs, cellml_variable):
68 return lhs.fullname()
69 else:
70 return lhs[0].fullname() + u'/' + lhs[1].fullname()
71
72 def _describe_expr(self, expr):
73 """Describe this expression for debug info."""
74 if isinstance(expr, mathml_apply):
75 if expr.is_assignment() or expr.is_ode():
76 return self._expr_lhs(expr)
77 else:
78 return '[nested apply]'
79 elif isinstance(expr, mathml_ci):
80 return expr.variable.fullname()
81 elif isinstance(expr, mathml_cn):
82 return u'cn[' + unicode(expr) + u']'
83 else:
84 return '[unknown]'
85
86 def _process_ci_elts(self, elt, func):
87 """Apply func to all ci elements in the tree rooted at elt."""
88 if isinstance(elt, mathml_ci):
89 func(elt)
90 else:
91 for e in elt.xml_element_children():
92 self._process_ci_elts(e, func)
93
94 def _rename_var(self, elt):
95 """Change this ci element to use a canonical name."""
96 if elt.xml_parent.localName == u'bvar':
97 # The free variable in a derivative must refer directly to the ultimate source,
98 # since this is assumed in later stages and in code generation.
99 elt._set_variable_obj(elt.variable.get_source_variable(recurse=True))
100 elt._rename()
101 self._debug("Using canonical name", unicode(elt))
102
103 def _do_reduce_eval_loop(self, expr_source):
104 """Do the reduce/evaluate loop.
105
106 expr_source is a callable that returns an iterable over expressions.
107 """
108 while True:
109 self.doc.model._pe_repeat = u'no'
110 for expr in list(expr_source()):
112 if self.doc.model._pe_repeat == u'no':
113 break
114 self._debug("----- looping -----")
115 del self.doc.model._pe_repeat
116
118 """Reduce or evaluate a single expression."""
119 if hasattr(expr, '_pe_process'):
120 # This expression has been reduced or evaluated already, but needs further
121 # processing later so hasn't been removed yet.
122 return
123 if expr._get_binding_time() is BINDING_TIMES.static:
124 # Evaluate
125 try:
126 value = expr.evaluate() # Needed here for the is_assignment case
127 except:
128 print "Error evaluating", self._describe_expr(expr)
129 raise
130 self._debug("Evaluated", self._describe_expr(expr), "to", value)
131 if isinstance(expr, mathml_apply):
132 if expr.is_ode():
133 # Replace the RHS with a <cn> element giving the value
134 rhs = expr.eq.rhs
135 new_elt = expr._eval_self()
136 expr.xml_insert_after(rhs, new_elt)
137 expr.xml_remove_child(rhs)
138 elif expr.is_assignment():
139 # The variable assigned to will have its initial_value set,
140 # so we don't need the expression any more. Flag it for removal.
141 expr._pe_process = u'remove'
142 else:
143 # Replace the expression with a <cn> element giving the value
144 new_elt = expr._eval_self()
145 expr.xml_parent.xml_insert_after(expr, new_elt)
146 expr.xml_parent.xml_remove_child(expr)
147 else:
148 # Replace the expression with a <cn> element giving the value
149 expr._reduce()
150 # Update variable usage counts for the top-level apply case
151 if isinstance(expr, mathml_apply):
152 if expr.is_ode() or expr.is_assignment():
153 expr._update_usage_counts(expr.eq.rhs, remove=True)
154 else:
155 expr._update_usage_counts(expr, remove=True)
156 else:
157 # Reduce
158 expr._reduce()
159
160 def _get_assignment_exprs(self, skip_solver_info=True):
161 """Get an iterable over all assignments in the model that are mathml_apply instances."""
162 if not skip_solver_info:
163 skip = set()
164 else:
165 skip = set(self.solver_info.get_modifiable_mathematics())
166 for e in self.doc.model.get_assignments():
167 if isinstance(e, mathml_apply) and e not in skip:
168 assert e.is_ode() or e.is_assignment()
169 yield e
170
171 def is_instantiable(self, expr):
172 """Determine whether special conditions mean that this assignment can be instantiated.
173
174 Normally an assignment can only be instantiated if the assigned-to variable is used only
175 once, in order to avoid code duplication.
176 However, if the definition under consideration for instantiation is a function only of a
177 single LT keying variable (and we will do LT) then code duplication doesn't really matter,
178 since the whole expression will be converted to a table anyway. So we should instantiate
179 regardless of multiple uses in this case.
180
181 Note: this does check that only a single keying variable appears, but doesn't check for
182 the presence of expensive functions. Of course, if there aren't any expensive functions,
183 the code duplication isn't that worrying.
184 """
185 instantiate = False
186 if self.lookup_tables_analyser and self.doc.model.get_option('pe_instantiate_tables'):
187 keying_vars = set()
188 all_keying = [True]
189 def func(ci_elt):
190 if self.lookup_tables_analyser.is_keying_var(ci_elt.variable):
191 keying_vars.add(ci_elt.variable)
192 else:
193 all_keying[0] = False
194 self._process_ci_elts(expr.eq.rhs, func)
195 instantiate = len(keying_vars) == 1 and all_keying[0]
196 return instantiate
197
198 def _is_source_of(self, v1, v2):
199 """Test if v1 is a source of the mapped variable v2."""
200 if v1 is v2:
201 return True
202 elif v2.get_type() == VarTypes.Mapped:
203 return self._is_source_of(v1, v2.get_source_variable())
204 else:
205 return False
206
207 def _check_retargetting(self, ci_elt):
208 """Check if this new variable reference means a retarget needs to change.
209
210 A retarget occurs when a kept dynamic mapped variable is changed to computed
211 because its source variable(s) are only used once and are not kept. But new
212 mathematics may use one or more of those sources, in which case we need to
213 revert the retarget.
214 """
215 # Is this a retarget?
216 var = ci_elt.variable
217 root_defn = var.get_source_variable(recurse=True).get_dependencies()
218 if root_defn:
219 root_defn = root_defn[0]
220 else:
221 return
222 if (isinstance(root_defn, mathml_apply)
223 and hasattr(root_defn, '_pe_process')
224 and root_defn._pe_process == 'retarget'):
225 # Are we a source of the 'new' assignee?
226 assignee = root_defn._cml_assigns_to
227 if not self._is_source_of(assignee, var):
228 if var.get_type() == VarTypes.Computed:
229 # We were the original source variable; stop retargetting
230 self._debug('Ceasing re-target of', root_defn, 'to', assignee)
231 del root_defn._pe_process
232 root_defn._cml_assigns_to = var
233 else:
234 # Re-target to var instead
235 self._debug('Changing re-target of', root_defn, 'from', assignee, 'to', var)
236 var._cml_var_type = VarTypes.Computed
237 root_defn._cml_assigns_to = var
238 var._cml_depends_on = [root_defn]
239 var._cml_source_var = None
240 # assignee should now map to var
241 assignee._cml_var_type = VarTypes.Mapped
242 assignee._cml_source_var = var
243 assignee._cml_depends_on = [var]
244
245 def parteval(self, doc, solver_info, lookup_tables_analyser=None):
246 """Do the partial evaluation."""
247 self.doc = doc
248 self.solver_info = solver_info
249 self.lookup_tables_analyser = lookup_tables_analyser
250 if lookup_tables_analyser:
251 lookup_tables_analyser.doc = doc
252 doc.partial_evaluator = self
253 # Do BTA and reduce/eval of main model
254 doc.model.do_binding_time_analysis()
256
257 if solver_info.has_modifiable_mathematics():
258 # Do BTA and reduce/eval of solver info section
259 for expr in solver_info.get_modifiable_mathematics():
260 if not (isinstance(expr, mathml_apply) and expr.is_top_level()):
261 self._process_ci_elts(expr, lambda ci: ci.variable._used())
263 solver_info.do_binding_time_analysis()
264 self._do_reduce_eval_loop(solver_info.get_modifiable_mathematics)
265
266 # Process flagged expressions
267 for expr in list(self._get_assignment_exprs()):
268 if hasattr(expr, '_pe_process'):
269 if expr._pe_process == u'remove':
270 if (expr._get_binding_time() == BINDING_TIMES.dynamic and
271 isinstance(expr._cml_assigns_to, cellml_variable) and
272 expr._cml_assigns_to.get_usage_count() > 1):
273 self._debug("Keeping", repr(expr), "due to SolverInfo")
274 continue
275 expr.xml_parent.xml_remove_child(expr)
276 self.doc.model._remove_assignment(expr)
277 elif expr._pe_process == u'retarget':
278 lhs = expr.eq.lhs
279 var = expr._cml_assigns_to
280 ci = mathml_ci.create_new(lhs, var.fullname(cellml=True))
281 self._debug('Re-targetting', lhs, var, ci)
282 ci._set_variable_obj(var)
283 lhs.xml_parent.xml_insert_after(lhs, ci)
284 lhs.xml_parent.xml_remove_child(lhs)
285
286 # Use canonical variable names in all ci elements
287 for expr in list(self._get_assignment_exprs(False)):
288 # If the assigned-to variable isn't used or kept, remove the assignment
289 if isinstance(expr.eq.lhs, mathml_ci):
290 var = expr.eq.lhs.variable
291 if not (var.get_usage_count() or var.pe_keep):
292 expr.xml_parent.xml_remove_child(expr)
293 doc.model._remove_assignment(expr)
294 continue
295 self._process_ci_elts(expr, self._rename_var)
296 for expr in solver_info.get_modifiable_mathematics():
297 self._process_ci_elts(expr, self._rename_var)
298 solver_info.use_canonical_variable_names()
299
300 # Tidy up kept variables, in case they aren't referenced in an eq'n.
301 for var in doc.model.get_all_variables():
302 if var.pe_keep:
303 var._reduce()
304
305 # Remove unused variables
306 for var in list(doc.model.get_all_variables()):
307 assert var.get_usage_count() >= 0
308 if var.get_usage_count() == 0 and not var.pe_keep:
309 var.xml_parent._del_variable(var)
310
311 # Collapse into a single component
312 new_comp = cellml_component.create_new(doc, u'c')
313 new_comp._cml_created_by_pe = True
314 old_comps = list(getattr(doc.model, u'component', []))
315 doc.model._add_component(new_comp)
316 # We iterate over a copy of the component list so we can delete components
317 # from the model in this loop, and so the new component exists in the model
318 # so we can add content to it.
319 for comp in old_comps:
320 # Move relevant contents into new_comp
321 for units in list(getattr(comp, u'units', [])):
322 # Copy all <units> elements
323 # TODO: Just generate the ones we need, using _ensure_units_exist
324 comp.xml_remove_child(units)
325 new_comp.xml_append(units)
326 for var in list(getattr(comp, u'variable', [])):
327 # Only move used source variables
328 self._debug('Variable', var.fullname(), 'usage', var.get_usage_count(),
329 'type', var.get_type(), 'kept', var.pe_keep)
330 if (var.get_usage_count() and var.get_type() != VarTypes.Mapped) or var.pe_keep:
331 self._debug('Moving variable', var.fullname(cellml=True))
332 # Remove from where it was
333 comp._del_variable(var, keep_annotations=True)
334 # Set name to canonical version
335 var.name = var.fullname(cellml=True)
336 # Place in new component
337 new_comp._add_variable(var)
338 # Don't copy reactions
339 for math in list(getattr(comp, u'math', [])):
340 # Copy all <math> elements with content
341 if math.xml_children:
342 comp.xml_remove_child(math)
343 new_comp.xml_append(math)
344 # Invalidate cached links
345 math._unset_cached_links()
346 doc.model._del_component(comp)
347 # Remove groups & connections
348 for group in list(getattr(doc.model, u'group', [])):
349 doc.model.xml_remove_child(group)
350 for conn in list(getattr(doc.model, u'connection', [])):
351 doc.model.xml_remove_child(conn)
352
353 # Remove unused variable assignments from the list
354 vs = [v for v in doc.model.get_assignments() if isinstance(v, cellml_variable)]
355 for v in vs:
356 if not v.xml_parent is new_comp:
357 doc.model._remove_assignment(v)
358
359 # Remove interface attributes from variables
360 for v in new_comp.variable:
361 for iface in [u'public', u'private']:
362 try:
363 delattr(v, iface+u'_interface')
364 except AttributeError:
365 pass
366
367 # Refresh expression dependency lists
368 for expr in self._get_assignment_exprs(False):
369 expr._cml_depends_on = list(expr.vars_in(expr.eq.rhs))
370 if expr.is_ode():
371 # Add dependency on the independent variable
372 indep_var = expr.eq.lhs.diff.independent_variable
373 if not indep_var in expr._cml_depends_on:
374 expr._cml_depends_on.append(indep_var)
375 # Update ODE definition dependency if needed
376 expr.eq.lhs.diff.dependent_variable._update_ode_dependency(indep_var, expr)
377 return
378
379
380######################################################################
381# Lookup table analysis #
382######################################################################
383
385 """
386 Analyses & annotates a CellML model to indicate where lookup
387 tables can be used.
388 """
389
390 def __init__(self):
391 """Create an analyser."""
392 # No model to analyse yet
393 self.doc = None
394 # Set default parameter values
395 self.set_params()
396
397 @property
398 def config(self):
399 """Get the current document's configuration store."""
400 return getattr(self.doc, '_cml_config', None)
401
403 """Determine if the given variable represents the transmembrane potential.
404
405 This method takes an instance of cellml_variable and returns a boolean.
406 """
407 return (var.name in [u'V', u'membrane__V'] and
408 var.get_type(follow_maps=True) == VarTypes.State)
409
410 def is_allowed_variable(self, var):
411 """Return True iff the given variable is allowed in a lookup table.
412
413 This method uses the config store in the document to check the variable object.
414 """
415 var = var.get_source_variable(recurse=True)
416 allowed = (var in self.config.lut_config or
417 (self.config.options.include_dt_in_tables and
418 var is self.solver_info.get_dt().get_source_variable(recurse=True)))
419 return allowed
420
421 def is_keying_var(self, var):
422 """Return True iff the given variable can be used as a table key.
423
424 Will check the config store if it exists. If not, the variable name must match self.table_var.
425 """
426 if self.config:
427 return var.get_source_variable(recurse=True) in self.config.lut_config
428 else:
429 return var.name == self.table_var
430
431 _LT_DEFAULTS = {'table_min': u'-100.0001',
432 'table_max': u'49.9999',
433 'table_step': u'0.01',
434 'table_var': u'V'}
435 def set_params(self, **kw):
436 """Set parameters controlling lookup table generation.
437
438 Keyword parameters control the lookup table settings, which are
439 stored as attributes on suitable expressions.
440 table_min - minimum table entry (unicode) -> lut:min
441 table_max - maximum table entry (unicode) -> lut:max
442 table_step - table step size (unicode) -> lut:step
443 table_var - the name of the variable indexing the table (unicode) -> lut:var
444 """
445 defaults = self._LT_DEFAULTS
446 for attr in defaults:
447 if attr in kw:
448 setattr(self, attr, kw[attr])
449 else:
450 setattr(self, attr, getattr(self, attr, defaults[attr]))
451 return
452
453 def get_param(self, param_name, table_var):
454 """Get the value of the lookup table parameter.
455
456 table_var is the variable object being used to key this table.
457
458 If the document has a config store, lookup the value there.
459 If that doesn't give us a value, use that given using set_params.
460 """
461 try:
462 val = self.config.lut_config[
463 table_var.get_source_variable(recurse=True)][param_name]
464 except AttributeError, KeyError:
465 val = getattr(self, param_name)
466 return val
467
468 # One of these functions is required for a lookup table to be worthwhile
469 lut_expensive_funcs = frozenset(('exp', 'log', 'ln', 'root',
470 'sin', 'cos', 'tan',
471 'sec', 'csc', 'cot',
472 'sinh', 'cosh', 'tanh',
473 'sech', 'csch', 'coth',
474 'arcsin', 'arccos', 'arctan',
475 'arcsinh', 'arccosh', 'arctanh',
476 'arcsec', 'arccsc', 'arccot',
477 'arcsech', 'arccsch', 'arccoth'))
478
479 class LUTState(object):
480 """Represents the state for lookup table analysis."""
481 def __init__(self):
482 """Set the initial state.
483
484 We assume at first a lookup table would not be suitable.
485 """
486 self.has_var = False
487 self.bad_vars = set()
488 self.has_func = False
489 self.table_var = None
490
491 def update(self, res):
492 """Update the state with the results of a recursive call.
493
494 res is the result of checking a sub-expression for suitability,
495 and should be another instance of this class.
496 """
497 self.has_var = (self.has_var or res.has_var) and \
498 (not (self.table_var and res.table_var) or
499 self.table_var.get_source_variable(recurse=True) is
500 res.table_var.get_source_variable(recurse=True))
501 # The second condition above specifies that the keying variables must be the same if they both exist
502 self.bad_vars.update(res.bad_vars)
503 self.has_func = self.has_func or res.has_func
504 self.table_var = self.table_var or res.table_var
505 if not self.has_var and self.table_var:
506 # Two sub-expressions have different keying variables, so consider them as bad variables
507 self.bad_vars.add(self.table_var.name)
508 self.bad_vars.add(res.table_var.name)
509 self.table_var = None
510
511 def suitable(self):
512 """Return True iff this state indicates a suitable expression for replacement with a lookup table."""
513 return (self.has_var and
514 not self.bad_vars and
515 self.has_func)
516
517 def reason(self):
518 """
519 Return a unicode string describing why this state indicates the
520 expression is not suitable for replacement with a lookup table.
521
522 This can be:
523 'no_var' - doesn't contain the table variable
524 'bad_var <vname>' - contains a variable which isn't permitted
525 'no_func' - doesn't contain an expensive function
526 or a comma separated combination of the above.
527 """
528 r = []
529 if not self.has_var:
530 r.append(u'no_var')
531 if not self.has_func:
532 r.append(u'no_func')
533 for vname in self.bad_vars:
534 r.append(u'bad_var ' + vname)
535 return u','.join(r)
536
538 """Create a LUTState instance from an already annotated expression."""
539 state = self.LUTState()
540 possible = expr.getAttributeNS(NSS['lut'], u'possible', '')
541 if possible == u'yes':
542 varname = expr.getAttributeNS(NSS['lut'], u'var')
543 state.table_var = expr.component.get_variable_by_name(varname)
544 elif possible == u'no':
545 reason = expr.getAttributeNS(NSS['lut'], u'reason', '')
546 reasons = reason.split(u',')
547 for reason in reasons:
548 if reason == u'no_var':
549 state.has_var = False
550 elif reason == u'no_func':
551 state.has_func = False
552 elif reason.startswith(u'bad_var '):
553 state.bad_vars.add(reason[8:])
554 return state
555
556 def analyse_for_lut(self, expr, var_checker_fn):
557 """Check if the given expression can be replaced by a lookup table.
558
559 The first argument is the expression to check; the second is a
560 function which takes a variable object and returns True iff this
561 variable is permitted within a lookup table expression.
562
563 If self.annotate_failures is True then annotate <apply> and
564 <piecewise> expressions which don't qualify with the reason
565 why they do not.
566 This can be:
567 'no_var' - doesn't contain the table variable
568 'bad_var <vname>' - contains a variable which isn't permitted
569 'no_func' - doesn't contain an expensive function
570 or a comma separated combination of the above.
571 The annotation is stored as the lut:reason attribute.
572
573 If self.annotate_outermost_only is True then only annotate the
574 outermost qualifying expression, rather than also annotating
575 qualifying subexpressions.
576 """
577 # If this is a cloned expression, then just copy any annotations on the original.
578 if isinstance(expr, mathml):
579 source_expr = expr.get_original_of_clone()
580 else:
581 source_expr = None
582 if source_expr and source_expr.getAttributeNS(NSS['lut'], u'possible', '') != '':
583 LookupTableAnalyser.copy_lut_annotations(source_expr, expr)
584 state = self.create_state_from_annotations(source_expr)
585 DEBUG('lookup-tables', "No need to analyse clone", expr.xml(), state.suitable(), state.reason())
586 else:
587 # Initialise the indicators
588 state = self.LUTState()
589 # Process current node
590 if isinstance(expr, mathml_ci):
591 # Variable reference
592 if var_checker_fn(expr.variable):
593 # Could be a permitted var that isn't a keying var
594 if self.is_keying_var(expr.variable):
595 assert state.table_var is None # Sanity check
596 state.has_var = True
597 state.table_var = expr.variable
598 else:
599 state.bad_vars.add(expr.variable.name)
600 elif isinstance(expr, mathml_piecewise):
601 # Recurse into pieces & otherwise options
602 if hasattr(expr, u'otherwise'):
603 r = self.analyse_for_lut(child_i(expr.otherwise, 1),
604 var_checker_fn)
605 state.update(r)
606 for piece in getattr(expr, u'piece', []):
607 r = self.analyse_for_lut(child_i(piece, 1), var_checker_fn)
608 state.update(r)
609 r = self.analyse_for_lut(child_i(piece, 2), var_checker_fn)
610 state.update(r)
611 elif isinstance(expr, mathml_apply):
612 # Check function
613 if (not state.has_func and
614 expr.operator().localName in self.lut_expensive_funcs):
615 state.has_func = True
616 # Check operands
617 operand_states = {}
618 for operand in expr.operands():
619 r = self.analyse_for_lut(operand, var_checker_fn)
620 state.update(r)
621 operand_states[id(operand)] = r
622 # Check qualifiers
623 for qual in expr.qualifiers():
624 r = self.analyse_for_lut(qual, var_checker_fn)
625 state.update(r)
626 # Special case additional optimisations
627 if self.config and self.config.options.combine_commutative_tables and not state.suitable():
628 if isinstance(expr.operator(), reduce_commutative_nary):
629 self.check_commutative_tables(expr, operand_states)
630 elif isinstance(expr.operator(), mathml_divide):
631 self.check_divide_by_table(expr, operand_states)
632 else:
633 # Just recurse into children
634 for e in expr.xml_children:
635 if getattr(e, 'nodeType', None) == Node.ELEMENT_NODE:
636 r = self.analyse_for_lut(e, var_checker_fn)
637 state.update(r)
638 # Annotate the expression if appropriate
639 if isinstance(expr, (mathml_apply, mathml_piecewise)):
640 if state.suitable():
641 self.annotate_as_suitable(expr, state.table_var)
642 else:
643 if self.annotate_failures:
644 expr.xml_set_attribute((u'lut:reason', NSS['lut']), state.reason())
645 return state
646
647 def check_divide_by_table(self, expr, operand_states):
648 """Convert division by a table into multiplication.
649
650 This is called if expr, a division, cannot be replaced as a whole by a lookup table.
651 If the denominator can be replaced by a table, then convert expr into a multiplication
652 by the reciprocal, moving the division into the table.
653 """
654 numer, denom = list(expr.operands())
655 state = operand_states[id(denom)]
656 if state.suitable():
657 expr.safe_remove_child(numer)
658 expr.safe_remove_child(denom)
659 recip = mathml_apply.create_new(expr, u'divide', [(u'1', u'dimensionless'), denom])
660 times = mathml_apply.create_new(expr, u'times', [numer, recip])
661 expr.replace_child(expr, times, expr.xml_parent)
662 self.annotate_as_suitable(recip, state.table_var)
663
664 def check_commutative_tables(self, expr, operand_states):
665 """Check whether we can combine suitable operands into a new expression.
666
667 If expr has a commutative (and associative) n-ary operator, but is not suitable as a
668 whole to become a lookup table (checked by caller) then we might still be able to
669 do slightly better than just analysing its operands. If multiple operands can be
670 replaced by tables keyed on the same variable, these can be combined into a new
671 application of the same operator as expr, which can then be replaced as a whole
672 by a single lookup table, and made an operand of expr.
673
674 Alternatively, if at least one operand can be replaced by a table, and a subset of
675 other operands do not contain other variables, then they can be included in the single
676 table.
677 """
678 # Operands that can be replaced by tables
679 table_operands = filter(lambda op: operand_states[id(op)].suitable(), expr.operands())
680 if not table_operands:
681 return
682 # Sort these suitable operands by table_var (map var id to var & operand list, respectively)
683 table_vars, table_var_operands = {}, {}
684 for oper in table_operands:
685 table_var = operand_states[id(oper)].table_var
686 table_var_id = id(table_var)
687 if not table_var_id in table_vars:
688 table_vars[table_var_id] = table_var
689 table_var_operands[table_var_id] = []
690 table_var_operands[table_var_id].append(oper)
691 # Figure out if any operands aren't suitable by themselves but could be included in a table
692 potential_operands = {id(None): []}
693 for table_var_id in table_vars.keys():
694 potential_operands[table_var_id] = []
695 for op in expr.operands():
696 state = operand_states[id(op)]
697 if not state.suitable() and not state.bad_vars:
698 if not state.table_var in potential_operands:
699 potential_operands[id(state.table_var)] = []
700 potential_operands[id(state.table_var)].append(op)
701 # Do any combining
702 for table_var_id in table_vars.keys():
703 suitable_opers = table_var_operands[table_var_id] + potential_operands[table_var_id] + potential_operands[id(None)]
704 if len(suitable_opers) > 1:
705 # Create new sub-expression with the suitable operands
706 for oper in suitable_opers:
707 expr.safe_remove_child(oper)
708 new_expr = mathml_apply.create_new(expr, expr.operator().localName, suitable_opers)
709 expr.xml_append(new_expr)
710 self.annotate_as_suitable(new_expr, table_vars[table_var_id])
711 # Remove the operands with no table_var from consideration with other keying vars
712 potential_operands[id(None)] = []
713
714 def annotate_as_suitable(self, expr, table_var):
715 """Annotate the given expression as being suitable for a lookup table."""
717 # Remove annotations from (expr and) child expressions
718 self.remove_lut_annotations(expr)
719 for param in ['min', 'max', 'step']:
720 expr.xml_set_attribute((u'lut:' + param, NSS['lut']),
721 self.get_param('table_' + param, table_var))
722 expr.xml_set_attribute((u'lut:var', NSS['lut']), table_var.name)
723 expr.xml_set_attribute((u'lut:possible', NSS['lut']), u'yes')
724 self.doc.lookup_tables[expr] = True
725
726 @staticmethod
727 def copy_lut_annotations(from_expr, to_expr):
728 """Copy any lookup table annotations from one expression to another."""
729 for pyname, fullname in from_expr.xml_attributes.iteritems():
730 if fullname[1] == NSS['lut']:
731 to_expr.xml_set_attribute(fullname, getattr(from_expr, pyname))
732
733 def remove_lut_annotations(self, expr, remove_reason=False):
734 """Remove lookup table annotations from the given expression.
735
736 By default this will only remove annotations from expressions
737 (and sub-expressions) that can be converted to use lookup tables.
738 If remove_reason is True, then the lut:reason attributes on
739 non-qualifying expressions will also be removed.
740 """
741 # Remove from this expression
742 delete_table = False
743 for pyname in getattr(expr, 'xml_attributes', {}).keys():
744 fullname = expr.xml_attributes[pyname]
745 if fullname[1] == NSS['lut']:
746 if remove_reason or fullname[0] != u'lut:reason':
747 expr.__delattr__(pyname)
748 if fullname[0] != u'lut:reason':
749 delete_table = True
750 # Delete expr from list of lookup tables?
751 if delete_table:
752 del self.doc.lookup_tables[expr]
753 # Recurse into children
754 for e in expr.xml_children:
755 if getattr(e, 'nodeType', None) == Node.ELEMENT_NODE:
756 self.remove_lut_annotations(e, remove_reason)
757
758 def analyse_model(self, doc, solver_info,
759 annotate_failures=True,
760 annotate_outermost_only=True):
761 """Analyse the given document.
762
763 This method checks all expressions (and subexpressions)
764 in the given document for whether they could be converted to
765 use a lookup table, and annotates them appropriately.
766
767 By default expressions which don't qualify will be annotated
768 to indicate why; set annotate_failures to False to suppress
769 this.
770
771 Also by default only the outermost suitable expression in any
772 given tree will be annotated; if you want to annotate suitable
773 subexpressions of a suitable expression then pass
774 annotate_outermost_only as False.
775 """
776 self.doc = doc
777 self.solver_info = solver_info
778 self.annotate_failures = annotate_failures
779 self.annotate_outermost_only = annotate_outermost_only
780 doc.lookup_tables = {}
781 # How to check for allowed variables
782 if hasattr(doc, '_cml_config'):
783 checker_fn = self.is_allowed_variable
784 else:
785 checker_fn = self.var_is_membrane_potential
786
787 # Check all expressions
788 for expr in (e for e in doc.model.get_assignments()
789 if isinstance(e, mathml_apply)):
790 ops = expr.operands()
791 ops.next()
792 e = ops.next()
793 self.analyse_for_lut(e, checker_fn)
794 for expr in solver_info.get_modifiable_mathematics():
795 self.analyse_for_lut(expr, checker_fn)
796
797 if solver_info.has_modifiable_mathematics():
799
800 # Assign names (numbers) to the lookup tables found.
801 # Also work out which ones can share index variables into the table.
802 doc.lookup_tables = doc.lookup_tables.keys()
803 doc.lookup_tables.sort(cmp=element_path_cmp)
804 doc.lookup_table_indexes, n = {}, 0
805 for i, expr in enumerate(doc.lookup_tables):
806 expr.xml_set_attribute((u'lut:table_name', NSS['lut']), unicode(i))
807 comp = expr.get_component()
808 var = comp.get_variable_by_name(expr.var).get_source_variable(recurse=True)
809 key = (expr.min, expr.max, expr.step, var)
810 if not key in doc.lookup_table_indexes:
811 doc.lookup_table_indexes[key] = unicode(n)
812 n += 1
813 expr.xml_set_attribute((u'lut:table_index', NSS['lut']), doc.lookup_table_indexes[key])
814
815 if solver_info.has_modifiable_mathematics():
817
818 # Re-do dependency analysis so that an expression using lookup
819 # tables only depends on the keying variable.
820 for expr in (e for e in doc.model.get_assignments()
821 if isinstance(e, mathml_apply)):
822 expr.classify_variables(root=True,
823 dependencies_only=True,
824 needs_special_treatment=self.calculate_dependencies)
825
826 def _find_tables(self, expr, table_dict):
827 """Helper method for _determine_unneeded_tables."""
828 if expr.getAttributeNS(NSS['lut'], u'possible', '') == u'yes':
829 table_dict[id(expr)] = expr
830 else:
831 for e in self.doc.model.xml_element_children(expr):
832 self._find_tables(e, table_dict)
833
835 """Determine whether some expressions identified as lookup tables aren't actually used.
836
837 This occurs if some ODEs have been linearised, in which case the original definitions
838 will have been analysed for lookup tables, but aren't actually used.
839
840 TODO: The original definitions might be used for computing derived quantities...
841 """
842 original_tables = {}
843 new_tables = {}
844 def f(exprs, table_dict):
845 exprs = filter(lambda n: isinstance(n, (mathml_ci, mathml_apply, mathml_piecewise)), exprs)
846 for node in self.doc.model.calculate_extended_dependencies(exprs):
847 if isinstance(node, mathml_apply):
848 self._find_tables(node, table_dict)
849 for u, t, eqns in self.solver_info.get_linearised_odes():
850 original_defn = u.get_ode_dependency(t)
851 f([original_defn], original_tables)
852 f(eqns, new_tables)
853 for id_ in set(original_tables.keys()) - set(new_tables.keys()):
854 expr = original_tables[id_]
855 self.remove_lut_annotations(expr)
856 expr.xml_set_attribute((u'lut:reason', NSS['lut']),
857 u'Expression will not be used in generated code.')
858 DEBUG('lookup-tables', 'Not annotating probably unused expression', expr)
859
861 """Determine whether we have multiple tables for the same expression.
862
863 Any expression that is identical to a previous table will be re-annotated to refer to the
864 previous table, instead of declaring a new one.
865
866 This is a temporary measure until we have proper sub-expression elimination for the Jacobian
867 and residual calculations.
868 """
869 uniq_tables = []
870 for expr in self.doc.lookup_tables:
871 for table in uniq_tables:
872 if expr.same_tree(table):
873 lt_name = table.getAttributeNS(NSS['lut'], u'table_name', u'')
874 # Need to remove old name before we can set a new one (grr amara)
875 del expr.table_name
876 expr.xml_set_attribute((u'lut:table_name', NSS['lut']), lt_name)
877 break
878 else:
879 uniq_tables.append(expr)
880
881 def calculate_dependencies(self, expr):
882 """Determine the dependencies of an expression that might use a lookup table.
883
884 This method is suitable for use as the needs_special_treatment function in
885 mathml_apply.classify_variables. It is used to override the default recursion
886 into sub-trees. It takes a single sub-tree as argument, and returns either
887 the dependency set for that sub-tree, or None to use the default recursion.
888
889 Expressions that can use a lookup table only depend on the keying variable.
890 """
891 if expr.getAttributeNS(NSS['lut'], u'possible', '') == u'yes':
892 key_var_name = expr.getAttributeNS(NSS['lut'], u'var')
893 key_var = expr.component.get_variable_by_name(key_var_name).get_source_variable(recurse=True)
894 return set([key_var])
895 # If not a table, use default behaviour
896 return None
897
898
899######################################################################
900# Jacobian analysis #
901######################################################################
902
903class LinearityAnalyser(object):
904 """Analyse linearity aspects of a model.
905
906 This class performs analyses to determine which ODEs have a linear
907 dependence on their state variable, discounting the presence of
908 the transmembrane potential.
909
910 This can be used to decouple the ODEs for more efficient solution,
911 especially in a multi-cellular context.
912
913 analyse_for_jacobian(doc) must be called before
915 """
916 LINEAR_KINDS = Enum('None', 'Linear', 'Nonlinear')
917
918 def analyse_for_jacobian(self, doc, V=None):
919 """Analyse the model for computing a symbolic Jacobian.
920
921 Determines automatically which variables will need to be solved
922 for using Newton's method, and stores their names in
923 doc.model._cml_nonlinear_system_variables, as a list of variable
924 objects.
925
926 Also stores doc.model._cml_linear_vars, doc.model._cml_free_var,
927 doc.model._cml_transmembrane_potential.
928 """
929 # TODO: Add error checking and tidy
930 stvs = doc.model.find_state_vars()
931 if V is None:
932 Vcname, Vvname = 'membrane', 'V'
933 V = doc.model.get_variable_by_name(Vcname, Vvname)
934 V = V.get_source_variable(recurse=True)
935 doc.model._cml_transmembrane_potential = V
936 free_var = doc.model.find_free_vars()[0]
937 lvs = self.find_linear_odes(stvs, V, free_var)
938 # Next 3 lines for benefit of rearrange_linear_odes(doc)
939 lvs.sort(key=lambda v: v.fullname())
940 doc.model._cml_linear_vars = lvs
941 doc.model._cml_free_var = free_var
942 # Store nonlinear vars in canonical order
943 nonlinear_vars = list(set(stvs) - set([V]) - set(lvs))
944 nonlinear_vars.sort(key=lambda v: v.fullname())
945 doc.model._cml_nonlinear_system_variables = nonlinear_vars
946 # Debugging
947 f = lambda var: var.fullname()
948 DEBUG('linearity-analyser', 'V=', V.fullname(), '; free var=',
949 free_var.fullname(), '; linear vars=', map(f, lvs),
950 '; nonlinear vars=', map(f, nonlinear_vars))
951 return
952
953 def _get_rhs(self, expr):
954 """Return the RHS of an assignment expression."""
955 ops = expr.operands()
956 ops.next()
957 return ops.next()
958
959 def _check_expr(self, expr, state_var, bad_vars):
960 """The actual linearity checking function.
961
962 Recursively determine the type of dependence expr has on
963 state_var. The presence of any members of bad_vars indicates
964 a non-linear dependency.
965
966 Return a member of the self.LINEAR_KINDS enum.
967 """
968 kind = self.LINEAR_KINDS
969 result = None
970 if isinstance(expr, mathml_ci):
971 var = expr.variable.get_source_variable(recurse=True)
972 if var is state_var:
973 result = kind.Linear
974 elif var in bad_vars:
975 result = kind.Nonlinear
976 elif var.get_type(follow_maps=True) == VarTypes.Computed:
977 # Recurse into defining expression
978 src_var = var.get_source_variable(recurse=True)
979 src_expr = self._get_rhs(src_var.get_dependencies()[0])
980 DEBUG('find-linear-deps', "--recurse for", src_var.name,
981 "to", src_expr)
982 result = self._check_expr(src_expr, state_var, bad_vars)
983 else:
984 result = kind.None
985 # Record the kind of this variable, for later use when
986 # rearranging linear ODEs
987 var._cml_linear_kind = result
988 elif isinstance(expr, mathml_cn):
989 result = kind.None
990 elif isinstance(expr, mathml_piecewise):
991 # If any conditions have a dependence, then we're
992 # nonlinear. Otherwise, all the pieces must be the same
993 # (and that's what we are) or we're nonlinear.
994 pieces = getattr(expr, u'piece', [])
995 conds = map(lambda p: child_i(p, 2), pieces)
996 chld_exprs = map(lambda p: child_i(p, 1), pieces)
997 if hasattr(expr, u'otherwise'):
998 chld_exprs.append(child_i(expr.otherwise, 1))
999 for cond in conds:
1000 if self._check_expr(cond, state_var, bad_vars) != kind.None:
1001 result = kind.Nonlinear
1002 break
1003 else:
1004 # Conditions all OK
1005 for e in chld_exprs:
1006 res = self._check_expr(e, state_var, bad_vars)
1007 if result is not None and res != result:
1008 # We have a difference
1009 result = kind.Nonlinear
1010 break
1011 result = res
1012 elif isinstance(expr, mathml_apply):
1013 # Behaviour depends on the operator
1014 operator = expr.operator().localName
1015 operands = expr.operands()
1016 if operator in ['plus', 'minus']:
1017 # Linear if any operand linear, and none non-linear
1018 op_kinds = map(lambda op: self._check_expr(op, state_var,
1019 bad_vars),
1020 operands)
1021 result = max(op_kinds)
1022 elif operator == 'divide':
1023 # Linear iff only numerator linear
1024 numer = operands.next()
1025 denom = operands.next()
1026 if self._check_expr(denom, state_var, bad_vars) != kind.None:
1027 result = kind.Nonlinear
1028 else:
1029 result = self._check_expr(numer, state_var, bad_vars)
1030 elif operator == 'times':
1031 # Linear iff only 1 linear operand
1032 op_kinds = map(lambda op: self._check_expr(op, state_var,
1033 bad_vars),
1034 operands)
1035 lin, nonlin = 0, 0
1036 for res in op_kinds:
1037 if res == kind.Linear: lin += 1
1038 elif res == kind.Nonlinear: nonlin += 1
1039 if nonlin > 0 or lin > 1:
1040 result = kind.Nonlinear
1041 elif lin == 1:
1042 result = kind.Linear
1043 else:
1044 result = kind.None
1045 else:
1046 # Cannot be linear; may be no dependence at all
1047 result = max(map(lambda op: self._check_expr(op, state_var,
1048 bad_vars),
1049 operands))
1050 if result == kind.Linear:
1051 result = kind.Nonlinear
1052 else:
1053 # Either a straightforward container element
1054 try:
1055 child = child_i(expr, 1)
1056 except ValueError:
1057 # Assume it's just a constant
1058 result = kind.None
1059 else:
1060 result = self._check_expr(child, state_var, bad_vars)
1061 DEBUG('find-linear-deps', "Expression", expr, "gives result", result)
1062 return result
1063
1064 def find_linear_odes(self, state_vars, V, free_var):
1065 """Identify linear ODEs.
1066
1067 For each ODE (except that for V), determine whether it has a
1068 linear dependence on the dependent variable, and thus can be
1069 updated directly, without using Newton's method.
1070
1071 We also require it to not depend on any other state variable,
1072 except for V.
1073 """
1074 kind = self.LINEAR_KINDS
1075 candidates = set(state_vars) - set([V])
1076 linear_vars = []
1077 for var in candidates:
1078 ode_expr = var.get_ode_dependency(free_var)
1079 if self._check_expr(self._get_rhs(ode_expr), var,
1080 candidates - set([var])) == kind.Linear:
1081 linear_vars.append(var)
1082 return linear_vars
1083
1084 def _clone(self, expr):
1085 """Properly clone a MathML sub-expression."""
1086 if isinstance(expr, mathml):
1087 clone = expr.clone_self(register=True)
1088 else:
1089 clone = mathml.clone(expr)
1090 return clone
1091
1092 def _make_apply(self, operator, ghs, i, filter_none=True,
1093 preserve=False):
1094 """Construct a new apply expression for g or h.
1095
1096 ghs is an iterable of (g,h) pairs for operands.
1097
1098 i indicates whether to construct g (0) or h (1).
1099
1100 filter_none indicates the behaviour of 0 under this operator.
1101 If True, it's an additive zero, otherwise it's a
1102 multiplicative zero.
1103 """
1104 # Find g or h operands
1105 ghs_i = map(lambda gh: gh[i], ghs)
1106 if not filter_none and None in ghs_i:
1107 # Whole expr is None
1108 new_expr = None
1109 else:
1110 # Filter out None subexprs
1111 if operator == u'minus':
1112 # Do we need to retain a unary minus?
1113 if len(ghs_i) == 1 or ghs_i[0] is None:
1114 # Original was -a or 0-a
1115 retain_unary_minus = True
1116 else:
1117 # Original was a-0 or a-b
1118 retain_unary_minus = False
1119 else:
1120 # Only retain if we're told to preserve as-is
1121 retain_unary_minus = preserve
1122 ghs_i = filter(None, ghs_i)
1123 if ghs_i:
1124 if len(ghs_i) > 1 or retain_unary_minus:
1125 new_expr = mathml_apply.create_new(
1126 self.__expr, operator, ghs_i)
1127 else:
1128 new_expr = self._clone(ghs_i[0]) # Clone may be unneeded
1129 else:
1130 new_expr = None
1131 return new_expr
1132
1133 def _transfer_lut(self, expr, gh, var):
1134 """Transfer lookup table annotations from expr to gh.
1135
1136 gh is a pair (g, h) s.t. expr = g + h*var.
1137
1138 If expr can be represented by a lookup table, then the lookup
1139 variable cannot be var, since if it were, then expr would have
1140 a non-linear dependence on var. Hence h must be 0, since
1141 otherwise expr would contain a (state) variable other than the
1142 lookup variable, and hence not be a candidate for a table.
1143 Thus expr=g, so we transfer the annotations to g.
1144 """
1145 if expr.getAttributeNS(NSS['lut'], u'possible', '') != u'yes':
1146 return
1147 # Paranoia check that our reasoning is correct
1148 g, h = gh
1149 assert h is None
1150 # Transfer the annotations into g
1151 LookupTableAnalyser.copy_lut_annotations(expr, g)
1152 # Make sure g has a reference to its component, for use by code generation.
1153 g._cml_component = expr.component
1154 return
1155
1156 def _rearrange_expr(self, expr, var):
1157 """Rearrange an expression into the form g + h*var.
1158
1159 Performs a post-order traversal of this expression's tree,
1160 and returns a pair (g, h)
1161 """
1162# import inspect
1163# depth = len(inspect.stack())
1164# print ' '*depth, "_rearrange_expr", prid(expr, True), var.name, expr
1165 gh = None
1166 if isinstance(expr, mathml_ci):
1167 # Variable
1168 ci_var = expr.variable.get_source_variable(recurse=True)
1169 if var is ci_var:
1170 gh = (None, mathml_cn.create_new(expr,
1171 u'1', u'dimensionless'))
1172 else:
1173 if ci_var._cml_linear_kind == self.LINEAR_KINDS.None:
1174 # Just put the <ci> in g, but with full name
1175 gh = (mathml_ci.create_new(expr, ci_var.fullname()), None)
1176 gh[0]._set_variable_obj(ci_var)
1177 else:
1178 # ci_var is a linear function of var, so rearrange
1179 # its definition
1180 if not hasattr(ci_var, '_cml_linear_split'):
1181 ci_defn = ci_var.get_dependencies()[0]
1182 ci_var._cml_linear_split = self._rearrange_expr(
1183 self._get_rhs(ci_defn), var)
1184 gh = ci_var._cml_linear_split
1185 elif isinstance(expr, mathml_piecewise):
1186 # The tests have to move into both components of gh:
1187 # "if C1 then (a1,b1) elif C2 then (a2,b2) else (a0,b0)"
1188 # maps to "(if C1 then a1 elif C2 then a2 else a0,
1189 # if C1 then b1 elif C2 then b2 else b0)"
1190 # Note that no test is a function of var.
1191 # First rearrange child expressions
1192 pieces = getattr(expr, u'piece', [])
1193 cases = map(lambda p: child_i(p, 1), pieces)
1194 cases_ghs = map(lambda c: self._rearrange_expr(c, var), cases)
1195 if hasattr(expr, u'otherwise'):
1196 ow_gh = self._rearrange_expr(child_i(expr.otherwise, 1), var)
1197 else:
1198 ow_gh = (None, None)
1199 # Now construct the new expression
1200 conds = map(lambda p: self._clone(child_i(p, 2)), pieces)
1201 def piecewise_branch(i):
1202 pieces_i = zip(map(lambda gh: gh[i], cases_ghs),
1203 conds)
1204 pieces_i = filter(lambda p: p[0] is not None,
1205 pieces_i) # Remove cases that are None
1206 ow = ow_gh[i]
1207 if pieces_i:
1208 new_expr = mathml_piecewise.create_new(
1209 expr, pieces_i, ow)
1210 elif ow:
1211 new_expr = ow
1212 else:
1213 new_expr = None
1214 return new_expr
1215 gh = (piecewise_branch(0), piecewise_branch(1))
1216 self._transfer_lut(expr, gh, var)
1217 elif isinstance(expr, mathml_apply):
1218 # Behaviour depends on the operator
1219 operator = expr.operator().localName
1220 operands = expr.operands()
1221 self.__expr = expr # For self._make_apply
1222 if operator in ['plus', 'minus']:
1223 # Just split the operation into each component
1224 operand_ghs = map(lambda op: self._rearrange_expr(op, var),
1225 operands)
1226 g = self._make_apply(operator, operand_ghs, 0)
1227 h = self._make_apply(operator, operand_ghs, 1)
1228 gh = (g, h)
1229 elif operator == 'divide':
1230 # (a, b) / (c, 0) = (a/c, b/c)
1231 numer = self._rearrange_expr(operands.next(), var)
1232 denom = self._rearrange_expr(operands.next(), var)
1233 assert denom[1] is None
1234 denom_g = denom[0]
1235 g = h = None
1236 if numer[0]:
1237 g = mathml_apply.create_new(expr, operator,
1238 [numer[0], denom_g])
1239 if numer[1]:
1240 if g:
1241 denom_g = self._clone(denom_g)
1242 h = mathml_apply.create_new(expr, operator,
1243 [numer[1], denom_g])
1244 gh = (g, h)
1245 elif operator == 'times':
1246 # (a1,b1)*(a2,b2) = (a1*a2, b1*a2 or a1*b2 or None)
1247 # Similarly for the nary case - at most one b_i is not None
1248 operand_ghs = map(lambda op: self._rearrange_expr(op, var),
1249 operands)
1250 g = self._make_apply(operator, operand_ghs, 0,
1251 filter_none=False)
1252 # Find non-None b_i, if any
1253 for i, ab in enumerate(operand_ghs):
1254 if ab[1] is not None:
1255 operand_ghs[i] = (ab[1], None)
1256 # Clone the a_i to avoid objects having 2 parents
1257 for j, ab in enumerate(operand_ghs):
1258 if j != i:
1259 operand_ghs[j] = (self._clone(operand_ghs[j][0]), None)
1260 h = self._make_apply(operator, operand_ghs, 0, filter_none=False)
1261 break
1262 else:
1263 h = None
1264 gh = (g, h)
1265 else:
1266 # (a, None) op (b, None) = (a op b, None)
1267 operand_ghs = map(lambda op: self._rearrange_expr(op, var),
1268 operands)
1269 g = self._make_apply(operator, operand_ghs, 0, preserve=True)
1270 gh = (g, None)
1271 self._transfer_lut(expr, gh, var)
1272 else:
1273 # Since this expression is linear, there can't be any
1274 # occurrence of var in it; all possible such cases are covered
1275 # above. So just clone it into g.
1276 gh = (self._clone(expr), None)
1277# print ' '*depth, "Re-arranged", prid(expr, True), "to", prid(gh[0], True), ",", prid(gh[1], True)
1278 return gh
1279
1281 """Rearrange the linear ODEs so they can be updated directly
1282 on solving.
1283
1284 Each ODE du/dt = f(u, t) can be written in the form
1285 du/dt = g(t) + h(t)u.
1286 A backward Euler update step is then as simple as
1287 u_n = (u_{n-1} + g(t)dt) / (1 - h(t)dt)
1288 (assuming that the transmembrane potential has already been
1289 solved for at t_n.
1290
1291 Stores the results in doc.model._cml_linear_update_exprs, a
1292 mapping from variable object u to pair (g, h).
1293 """
1294
1295 odes = map(lambda v: v.get_ode_dependency(doc.model._cml_free_var),
1296 doc.model._cml_linear_vars)
1297 result = {}
1298 for var, ode in itertools.izip(doc.model._cml_linear_vars, odes):
1299 # Do the re-arrangement for this variable.
1300 # We do this by a post-order traversal of its defining ODE,
1301 # constructing a pair (g, h) for each subexpression recursively.
1302 # First, get the RHS of the ODE
1303 rhs = self._get_rhs(ode)
1304 # And traverse
1305 result[var] = self._rearrange_expr(rhs, var)
1306 # Store result in model
1307 doc.model._cml_linear_update_exprs = result
1308 return result
1309
1310 def show(self, d):
1311 """Print out a more readable report on a rearrangement,
1312 as given by self.rearrange_linear_odes."""
1313 for var, expr in d.iteritems():
1314 print var.fullname()
1315 print "G:", expr[0].xml()
1316 print "H:", expr[1].xml()
1317 print "ODE:", var.get_ode_dependency(
1318 var.model._cml_free_var).xml()
1319 print
1320
1321
1322######################################################################
1323# Rush-Larsen analysis #
1324######################################################################
1325
1326class ExpressionMatcher(object):
1327 """Test whether a MathML expression matches a given tree pattern.
1328
1329 Patterns are instances of the nested Pattern class, or more specifically
1330 one of its subclasses. The static method match on this class checks an
1331 expression against a pattern, returning True iff there is a match.
1332 """
1333
1334 class Pattern(object):
1335 """Abstract base class for tree patterns."""
1336 def match(self, expr):
1337 """
1338 Method implemented by concrete subclasses to test a given expression.
1339 Returns True iff there is a match.
1340 """
1341 raise NotImplementedError
1342
1343 class A(Pattern):
1344 """An apply expression."""
1345 def __init__(self, operator, operands):
1346 self.operator = operator
1347 self.operands = operands
1348
1349 def match(self, expr):
1350 matched = False
1351 if isinstance(expr, mathml_apply):
1352 if expr.operator().localName == self.operator:
1353 expr_operands = list(expr.operands())
1354 if len(expr_operands) == len(self.operands):
1355 matched = reduce(operator.and_,
1356 map(lambda (pat, op): pat.match(op),
1357 zip(self.operands, expr_operands)))
1358 return matched
1359
1360 class V(Pattern):
1361 """A variable reference."""
1362 def __init__(self, var=None):
1363 self.set_variable(var)
1364
1365 def set_variable(self, var):
1366 if var:
1367 self.var = var.get_source_variable(recurse=True)
1368 else:
1369 self.var = var
1370
1371 def match(self, expr):
1372 matched = False
1373 if isinstance(expr, mathml_ci) and self.var is expr.variable.get_source_variable(recurse=True):
1374 matched = True
1375 return matched
1376
1377 class N(Pattern):
1378 """A constant number, optionally with the value specified."""
1379 def __init__(self, value=None):
1380 self.value = value
1381
1382 def match(self, expr):
1383 matched = False
1384 if isinstance(expr, mathml_cn):
1385 value = expr.evaluate()
1386 if self.value is None:
1387 self.value = value
1388 if self.value == value:
1389 matched = True
1390 return matched
1391
1392 class X(Pattern):
1393 """A placeholder, matching anything (and noting what was matched)."""
1394 def __init__(self):
1395 self.matched = None
1396
1397 def match(self, expr):
1398 self.matched = expr
1399 return True
1400
1401 class R(Pattern):
1402 """A container that matches any number of levels of indirection/recursion.
1403
1404 This can be used to wrap a pattern where we wish to allow for variable mappings
1405 or equations such as "var1 = var2" before we reach the 'interesting' equation.
1406 If the expression we're matching is a ci element we recursively find the
1407 ultimate non-ci defining expression and match our sub-pattern against that. If
1408 the expression isn't a ci, or the ultimate definition isn't an expression, we
1409 match our sub-pattern against it directly.
1410 """
1411 def __init__(self, pat):
1412 self.sub_pattern = pat
1413
1414 def match(self, expr):
1415 while isinstance(expr, mathml_ci):
1416 # Find this variable's defining expression, if it is an equation
1417 var = expr.variable.get_source_variable(recurse=True)
1418 defn = var.get_dependencies()
1419 if defn and isinstance(defn[0], mathml_apply):
1420 expr = defn[0].eq.rhs
1421 return self.sub_pattern.match(expr)
1422
1423 @staticmethod
1424 def match(pattern, expression):
1425 """Test for a match."""
1426 return pattern.match(expression)
1427
1429 """Analyse a model to identify Hodgkin-Huxley style gating variable equations.
1430
1431 We look for ODEs whose definition matches "alpha*(1-x) - beta*x" (where x is
1432 the state variable, and alpha & beta are any expression). Alternatively for
1433 models which already have tau & inf variables, we match against "(inf-x)/tau".
1434
1435 To allow for when units conversions have been performed, we chase 'simple'
1436 assignments and (the semantically equivalent) variable mappings until reaching
1437 an 'interesting' defining equation. We also need to allow the whole RHS to be
1438 multiplied by a constant. If this occurs, the constant conversion factor is
1439 also stored; otherwise we store None.
1440
1441 Stores a dictionary on the document root mapping cellml_variable instances to
1442 4-tuples. The tuple is either ('ab', alpha, beta, conv) or ('ti', tau, inf, conv)
1443 depending on which formulation has been used. Note that the expressions in these
1444 are not cloned copies - they are the original objects still embedded within the
1445 relevant ODE. The units conversion factor 'conv' is stored as a Python double.
1446 """
1447 def __init__(self):
1448 """Create the patterns to match against."""
1449 em = ExpressionMatcher
1450 self._var = em.V()
1451 self._conv = em.N()
1452 # Alpha/beta form
1453 self._alpha = em.X()
1454 self._beta = em.X()
1455 self._ab_pattern = em.R(em.A('minus', [em.A('times', [self._alpha,
1456 em.A('minus', [em.N(1), self._var])]),
1457 em.A('times', [self._beta, self._var])]))
1458 self._alt_ab_pattern = em.R(em.A('times', [self._conv, self._ab_pattern]))
1459 # Tau/inf form
1460 self._tau = em.X()
1461 self._inf = em.X()
1462 self._ti_pattern = em.R(em.A('divide', [em.A('minus', [self._inf, self._var]),
1463 self._tau]))
1464 self._alt_ti_pattern = em.R(em.A('times', [self._conv, self._ti_pattern]))
1465
1466 def analyse_model(self, doc):
1467 # First, find linear ODEs that have the potential to be gating variables
1468 la = LinearityAnalyser()
1469 V = doc._cml_config.V_variable
1470 state_vars = doc.model.find_state_vars()
1471 free_var = doc.model.find_free_vars()[0]
1472 linear_vars = la.find_linear_odes(state_vars, V, free_var)
1473 # Next, check they match dn/dt = a (1-n) - b n
1474 doc._cml_rush_larsen = {}
1475 for var in linear_vars:
1476 ode_expr = var.get_ode_dependency(free_var)
1477 self._check_var(var, ode_expr, doc._cml_rush_larsen)
1478
1479 def _check_var(self, var, ode_expr, mapping):
1480 rhs = ode_expr.eq.rhs
1481 self._var.set_variable(ode_expr.eq.lhs.diff.dependent_variable)
1482 if self._ab_pattern.match(rhs):
1483 mapping[var] = ('ab', self._alpha.matched, self._beta.matched, None)
1484 elif self._alt_ab_pattern.match(rhs):
1485 mapping[var] = ('ab', self._alpha.matched, self._beta.matched, self._conv.value)
1486 elif self._ti_pattern.match(rhs):
1487 mapping[var] = ('ti', self._tau.matched, self._inf.matched, None)
1488 elif self._alt_ti_pattern.match(rhs):
1489 mapping[var] = ('ti', self._tau.matched, self._inf.matched, self._conv.value)
def __init__(self, operator, operands)
Definition: optimize.py:1345
def match(pattern, expression)
Definition: optimize.py:1424
def _make_apply(self, operator, ghs, i, filter_none=True, preserve=False)
Definition: optimize.py:1093
def find_linear_odes(self, state_vars, V, free_var)
Definition: optimize.py:1064
def _transfer_lut(self, expr, gh, var)
Definition: optimize.py:1133
def _check_expr(self, expr, state_var, bad_vars)
Definition: optimize.py:959
def analyse_for_jacobian(self, doc, V=None)
Definition: optimize.py:918
def analyse_model(self, doc, solver_info, annotate_failures=True, annotate_outermost_only=True)
Definition: optimize.py:760
def annotate_as_suitable(self, expr, table_var)
Definition: optimize.py:714
def get_param(self, param_name, table_var)
Definition: optimize.py:453
def remove_lut_annotations(self, expr, remove_reason=False)
Definition: optimize.py:733
def analyse_for_lut(self, expr, var_checker_fn)
Definition: optimize.py:556
def check_divide_by_table(self, expr, operand_states)
Definition: optimize.py:647
def check_commutative_tables(self, expr, operand_states)
Definition: optimize.py:664
def copy_lut_annotations(from_expr, to_expr)
Definition: optimize.py:727
def _find_tables(self, expr, table_dict)
Definition: optimize.py:826
def _get_assignment_exprs(self, skip_solver_info=True)
Definition: optimize.py:160
def parteval(self, doc, solver_info, lookup_tables_analyser=None)
Definition: optimize.py:245
def _process_ci_elts(self, elt, func)
Definition: optimize.py:86
def _do_reduce_eval_loop(self, expr_source)
Definition: optimize.py:103
def _check_var(self, var, ode_expr, mapping)
Definition: optimize.py:1479
def Enum(*names)
Definition: enum.py:8
def child_i(elt, i)
Definition: pycml.py:3449
def DEBUG(facility, *args)
Definition: utilities.py:95