# -*- coding: utf-8 -*-
# Copyright (C) 2011-2017 Martin Sandve Alnæs
#
# This file is part of UFLACS.
#
# UFLACS is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# UFLACS is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with UFLACS. If not, see <http://www.gnu.org/licenses/>
"""Algorithms for factorizing argument dependent monomials."""
import numpy
from itertools import chain
from ufl import as_ufl, conditional
from ufl.classes import Argument
from ufl.classes import Division
from ufl.classes import Product
from ufl.classes import Sum
from ufl.classes import Conditional
from ufl.classes import Zero
from ufl.algorithms import extract_type
from ffc.log import error
from ffc.uflacs.analysis.dependencies import compute_dependencies
from ffc.uflacs.analysis.modified_terminals import analyse_modified_terminal, strip_modified_terminal
def _build_arg_sets(V):
"Build arg_sets = { argument number: set(j for j where V[j] is a modified Argument with this number) }"
arg_sets = {}
for i, v in enumerate(V):
arg = strip_modified_terminal(v)
if not isinstance(arg, Argument):
continue
num = arg.number()
arg_set = arg_sets.get(num)
if arg_set is None:
arg_set = {}
arg_sets[num] = arg_set
arg_set[i] = v
return arg_sets
def _build_argument_indices_from_arg_sets(V, arg_sets):
"Build ordered list of indices to modified arguments."
# Build set of all indices of V referring to modified arguments
arg_indices = set()
for js in arg_sets.values():
arg_indices.update(js)
# Make a canonical ordering of vertex indices for modified arguments
def arg_ordering_key(i):
"Return a key for sorting argument vertex indices based on the properties of the modified terminal."
mt = analyse_modified_terminal(arg_ordering_key.V[i])
return mt.argument_ordering_key()
arg_ordering_key.V = V
ordered_arg_indices = sorted(arg_indices, key=arg_ordering_key)
return ordered_arg_indices
[docs]def build_argument_indices(V):
"Build ordered list of indices to modified arguments."
arg_sets = _build_arg_sets(V)
ordered_arg_indices = _build_argument_indices_from_arg_sets(V, arg_sets)
return ordered_arg_indices
[docs]def build_argument_dependencies(dependencies, arg_indices):
"Preliminary algorithm: build list of argument vertex indices each vertex (indirectly) depends on."
n = len(dependencies)
A = numpy.empty(n, dtype=object)
for i, deps in enumerate(dependencies):
argdeps = []
for j in deps:
if j in arg_indices:
argdeps.append(j)
else:
argdeps.extend(A[j])
A[i] = sorted(argdeps)
return A
[docs]class Factors(object): # TODO: Refactor code in this file by using a class like this
def __init__(self):
self.FV = []
self.e2fi = {}
add_to_fv(expr, self.FV, self.e2fi)
[docs]def add_to_fv(expr, FV, e2fi):
"Add expression expr to factor vector FV and expr->FVindex mapping e2fi."
fi = e2fi.get(expr)
if fi is None:
fi = len(e2fi)
FV.append(expr)
e2fi[expr] = fi
return fi
# Reuse these empty objects where appropriate to save memory
noargs = {}
[docs]def handle_sum(v, si, deps, SV_factors, FV, sv2fv, e2fi):
if len(deps) != 2:
error("Assuming binary sum here. This can be fixed if needed.")
fac0 = SV_factors[deps[0]]
fac1 = SV_factors[deps[1]]
argkeys = set(fac0) | set(fac1)
if argkeys: # f*arg + g*arg = (f+g)*arg
argkeys = sorted(argkeys)
keylen = len(argkeys[0])
factors = {}
for argkey in argkeys:
if len(argkey) != keylen:
error("Expecting equal argument rank terms among summands.")
fi0 = fac0.get(argkey)
fi1 = fac1.get(argkey)
if fi0 is None:
fisum = fi1
elif fi1 is None:
fisum = fi0
else:
f0 = FV[fi0]
f1 = FV[fi1]
fisum = add_to_fv(f0 + f1, FV, e2fi)
factors[argkey] = fisum
else: # non-arg + non-arg
factors = noargs
sv2fv[si] = add_to_fv(v, FV, e2fi)
return factors
[docs]def handle_product(v, si, deps, SV_factors, FV, sv2fv, e2fi):
if len(deps) != 2:
error("Assuming binary product here. This can be fixed if needed.")
fac0 = SV_factors[deps[0]]
fac1 = SV_factors[deps[1]]
if not fac0 and not fac1: # non-arg * non-arg
# Record non-argument product
factors = noargs
f0 = FV[sv2fv[deps[0]]]
f1 = FV[sv2fv[deps[1]]]
assert f1 * f0 == v
sv2fv[si] = add_to_fv(v, FV, e2fi)
assert FV[sv2fv[si]] == v
elif not fac0: # non-arg * arg
# Record products of non-arg operand with each factor of arg-dependent operand
f0 = FV[sv2fv[deps[0]]]
factors = {}
for k1 in sorted(fac1):
fi1 = fac1[k1]
factors[k1] = add_to_fv(f0 * FV[fi1], FV, e2fi)
elif not fac1: # arg * non-arg
# Record products of non-arg operand with each factor of arg-dependent operand
f1 = FV[sv2fv[deps[1]]]
factors = {}
for k0 in sorted(fac0):
f0 = FV[fac0[k0]]
factors[k0] = add_to_fv(f1 * f0, FV, e2fi)
else: # arg * arg
# Record products of each factor of arg-dependent operand
factors = {}
for k0 in sorted(fac0):
f0 = FV[fac0[k0]]
for k1 in sorted(fac1):
f1 = FV[fac1[k1]]
argkey = tuple(sorted(k0 + k1)) # sort key for canonical representation
factors[argkey] = add_to_fv(f0 * f1, FV, e2fi)
return factors
[docs]def handle_division(v, si, deps, SV_factors, FV, sv2fv, e2fi):
fac0 = SV_factors[deps[0]]
fac1 = SV_factors[deps[1]]
assert not fac1, "Cannot divide by arguments."
if fac0: # arg / non-arg
# Record products of non-arg operand with each factor of arg-dependent operand
f1 = FV[sv2fv[deps[1]]]
factors = {}
for k0 in sorted(fac0):
f0 = FV[fac0[k0]]
factors[k0] = add_to_fv(f0 / f1, FV, e2fi)
else: # non-arg / non-arg
# Record non-argument subexpression
factors = noargs
sv2fv[si] = add_to_fv(v, FV, e2fi)
return factors
[docs]def handle_conditional(v, si, deps, SV_factors, FV, sv2fv, e2fi):
fac0 = SV_factors[deps[0]]
fac1 = SV_factors[deps[1]]
fac2 = SV_factors[deps[2]]
assert not fac0, "Cannot have argument in condition."
if not (fac1 or fac2): # non-arg ? non-arg : non-arg
# Record non-argument subexpression
sv2fv[si] = add_to_fv(v, FV, e2fi)
factors = noargs
else:
f0 = FV[sv2fv[deps[0]]]
f1 = FV[sv2fv[deps[1]]]
f2 = FV[sv2fv[deps[2]]]
# Term conditional(c, argument, non-argument) is not legal unless non-argument is 0.0
assert fac1 or isinstance(f1, Zero)
assert fac2 or isinstance(f2, Zero)
assert () not in fac1
assert () not in fac2
z = as_ufl(0.0)
# In general, can decompose like this:
# conditional(c, sum_i fi*ui, sum_j fj*uj) -> sum_i conditional(c, fi, 0)*ui + sum_j conditional(c, 0, fj)*uj
mas = sorted(set(fac1.keys()) | set(fac2.keys()))
factors = {}
for k in mas:
fi1 = fac1.get(k)
fi2 = fac2.get(k)
f1 = z if fi1 is None else FV[fi1]
f2 = z if fi2 is None else FV[fi2]
factors[k] = add_to_fv(conditional(f0, f1, f2), FV, e2fi)
return factors
[docs]def handle_operator(v, si, deps, SV_factors, FV, sv2fv, e2fi):
# Error checking
if any(SV_factors[d] for d in deps):
error("Assuming that a {0} cannot be applied to arguments. If this is wrong please report a bug.".format(type(v)))
# Record non-argument subexpression
sv2fv[si] = add_to_fv(v, FV, e2fi)
factors = noargs
return factors
[docs]def compute_argument_factorization(SV, SV_deps, SV_targets, rank):
"""Factorizes a scalar expression graph w.r.t. scalar Argument
components.
The result is a triplet (AV, FV, IM):
- The scalar argument component subgraph:
AV[ai] = v
with the property
SV[arg_indices] == AV[:]
- An expression graph vertex list with all non-argument factors:
FV[fi] = f
with the property that none of the expressions depend on Arguments.
- A dict representation of the final integrand of rank r:
IM = { (ai1_1, ..., ai1_r): fi1, (ai2_1, ..., ai2_r): fi2, }
This mapping represents the factorization of SV[-1] w.r.t. Arguments s.t.:
SV[-1] := sum(FV[fik] * product(AV[ai] for ai in aik) for aik, fik in IM.items())
where := means equivalence in the mathematical sense,
of course in a different technical representation.
"""
# Extract argument component subgraph
arg_indices = build_argument_indices(SV)
#A = build_argument_dependencies(SV_deps, arg_indices)
AV = [SV[si] for si in arg_indices]
#av2sv = arg_indices
sv2av = { si: ai for ai, si in enumerate(arg_indices) }
assert all(AV[ai] == SV[si] for ai, si in enumerate(arg_indices))
assert all(AV[ai] == SV[si] for si, ai in sv2av.items())
# Data structure for building non-argument factors
FV = []
e2fi = {}
# Adding 0.0 as an expression to fix issue in conditional
zero_index = add_to_fv(as_ufl(0.0), FV, e2fi)
# Adding 1.0 as an expression allows avoiding special representation
# of arguments when first visited by representing "v" as "1*v"
one_index = add_to_fv(as_ufl(1.0), FV, e2fi)
# Adding 2 as an expression fixes an issue with FV entries that change K*K -> K**2
two_index = add_to_fv(as_ufl(2), FV, e2fi)
# Intermediate factorization for each vertex in SV on the format
# SV_factors[si] = None # if SV[si] does not depend on arguments
# SV_factors[si] = { argkey: fi } # if SV[si] does depend on arguments, where:
# FV[fi] is the expression SV[si] with arguments factored out
# argkey is a tuple with indices into SV for each of the argument components SV[si] depends on
# SV_factors[si] = { argkey1: fi1, argkey2: fi2, ... } # if SV[si] is a linear combination of multiple argkey configurations
SV_factors = numpy.empty(len(SV), dtype=object)
si2fi = numpy.zeros(len(SV), dtype=int)
# Factorize each subexpression in order:
for si, v in enumerate(SV):
deps = SV_deps[si]
# These handlers insert values in si2fi and SV_factors
if not len(deps):
if si in arg_indices:
# v is a modified Argument
factors = { (si,): one_index }
else:
# v is a modified non-Argument terminal
si2fi[si] = add_to_fv(v, FV, e2fi)
factors = noargs
else:
# These quantities could be better input args to handlers:
#facs = [SV_factors[d] for d in deps]
#fs = [FV[sv2fv[d]] for d in deps]
if isinstance(v, Sum):
handler = handle_sum
elif isinstance(v, Product):
handler = handle_product
elif isinstance(v, Division):
handler = handle_division
elif isinstance(v, Conditional):
handler = handle_conditional
else: # All other operators
handler = handle_operator
factors = handler(v, si, deps, SV_factors, FV, si2fi, e2fi)
SV_factors[si] = factors
assert not noargs, "This dict was not supposed to be filled with anything!"
# Throw away superfluous items in array
# FV = FV[:len(e2fi)]
assert len(FV) == len(e2fi)
# Get the factorizations of the target values
IMs = []
for si in SV_targets:
if SV_factors[si] == {}:
if rank == 0:
# Functionals and expressions: store as no args * factor
factors = { (): si2fi[si] }
else:
# Zero form of arity 1 or higher: make factors empty
factors = {}
else:
# Forms of arity 1 or higher:
# Map argkeys from indices into SV to indices into AV,
# and resort keys for canonical representation
factors = { tuple(sorted(sv2av[si] for si in argkey)): fi
for argkey, fi in SV_factors[si].items() }
# Expecting all term keys to have length == rank
# (this assumption will eventually have to change if we
# implement joint bilinear+linear form factorization here)
assert all(len(k) == rank for k in factors)
IMs.append(factors)
# Recompute dependencies in FV
FV_deps = compute_dependencies(e2fi, FV)
# Indices into FV that are needed for final result
FV_targets = list(chain(sorted(IM.values())
for IM in IMs))
return IMs, AV, FV, FV_deps, FV_targets
[docs]def rebuild_scalar_graph_from_factorization(AV, FV, IM):
# TODO: What about multiple target_variables?
# Build initial graph
SV = []
SV.extend(AV)
SV.extend(FV)
se2i = dict((s, i) for i, s in enumerate(SV))
def add_vertex(h):
# Avoid adding vertices twice
i = se2i.get(h)
if i is None:
se2i[h] = len(SV)
SV.append(h)
# Add factorization monomials
argkeys = sorted(IM.keys())
fs = []
for argkey in argkeys:
# Start with coefficients
f = FV[IM[argkey]]
# f = 1
# Add binary products with each argument in order
for argindex in argkey:
f = f * AV[argindex]
add_vertex(f)
# Add product with coefficients last
# f = f*FV[IM[argkey]]
# add_vertex(f)
# f is now the full monomial, store it as a term for sum below
fs.append(f)
# Add sum of factorization monomials
g = 0
for f in fs:
g = g + f
add_vertex(g)
# Rebuild dependencies
dependencies = compute_dependencies(se2i, SV)
return SV, se2i, dependencies