# -*- coding: utf-8 -*-
"This module provides functionality for compilation of strings as dolfin Expressions."

# Copyright (C) 2008-2016 Martin Sandve Alnæs
#
# This file is part of DOLFIN.
#
# DOLFIN is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# DOLFIN is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with DOLFIN. If not, see <http://www.gnu.org/licenses/>.
#
# Modified by Johan Hake 2008-2009

from __future__ import print_function
from six import string_types
import re
import hashlib

from ufl.utils.sequences import product

from dolfin.compilemodules.compilemodule import (compile_extension_module,
                                                 expression_to_code_fragments,
                                                 math_header)

__all__ = ["compile_expressions"]

_expression_template = """class %(classname)s: public Expression
{
public:
%(members)s
  %(classname)s():Expression()
  {
%(value_shape)s
%(constructor)s
  }

  void eval(dolfin::Array<double>& values, const dolfin::Array<double>& x,
            const ufc::cell& cell) const
  {
%(evalcode_cell)s
  }

  void eval(dolfin::Array<double>& values, const dolfin::Array<double>& x) const
  {
%(evalcode)s
  }
};
"""

def flatten_and_check_expression(expr):
    # Convert expr to a flat tuple of strings
    # and return value_shape and geometrical dimensions
    if isinstance(expr, string_types):
        return (expr,), ()
    elif isinstance(expr, (tuple,list)):
        if all(isinstance(e,tuple) for e in expr):
            shape = (len(expr),len(expr[0]))
            expr = sum(expr, ())
        else:
            shape = (len(expr),)
        if all(isinstance(e, string_types) for e in expr):
            return expr, shape
    raise TypeError("Wrong type of expressions. Provide a 'str', a 'tuple' of 'str' or a 'tuple' of 'tuple' of 'str': %s" % str(expr))


def expression_to_dolfin_expression(expr, generic_function_members, mesh_function_members):
    "Generates code for a dolfin::Expression subclass for a single expression."

    # TODO: Make this configurable through global dolfin 'debug mode' parameter?
    add_runtime_checks = True

    # Check and flattern provided expression
    expr, expr_shape = flatten_and_check_expression(expr)

    # Extract code fragments from the expr
    generic_function_member_names = [item[0] for item in generic_function_members]
    fragments, members = expression_to_code_fragments(
        expr, ["values", "x"],
        generic_function_member_names,
        mesh_function_members)

    # Generate code for value_rank
    value_shape_code = ["    _value_shape.push_back(%d);" % value_dim
                        for value_dim in expr_shape]

    evalcode = []

    # Runtime checks (TODO: Better to check these when updating the value instead...)
    if add_runtime_checks:
        for name, shape in generic_function_members:
            dim = product(shape)
            evalcode.append("    if (shared_%s->value_size() != %d)" % (name, dim))
            evalcode.append("      dolfin_error(\"generated code\",")
            evalcode.append("                   \"calling eval\", ")
            evalcode.append("                   \"Expecting value size %d for parameter \\'%s\\'\");" % (dim, name))
            evalcode.append("    if (shared_%s.get() == this)" % name)
            evalcode.append("      dolfin_error(\"generated code\",")
            evalcode.append("                   \"calling eval\",")
            evalcode.append("                   \"Circular eval call detected. Cannot use itself as parameter \\'%s\\' within eval\");" % name)
            evalcode.append("")

    # Generate code for evaluating genericfunction members
    for name, shape in generic_function_members:
        dim = product(shape)

        # Setup output array and call eval
        evalcode.append("    double %s__data_[%d];" % (name, dim))
        evalcode.append("    Array<double> %s__array_(%d, %s__data_);" % (name, dim, name))
        evalcode.append("    shared_%s->eval(%s__array_, x);" % (name, name))

        # Ensure const access through userdefined name
        if shape:
            # Vector valued result
            evalcode.append("    const Array<double> & %s = %s__array_;" % (name, name))
        else:
            # Scalar valued result
            evalcode.append("    const double %s = %s__array_[0];" % (name, name))
        evalcode.append("")

    # Lookup in MeshFunction<typename>(mesh, tdim)
    for name, typename in mesh_function_members:
        evalcode.append("    const %s %s = (*shared_%s)[cell.index];" % (typename, name, name))

    # Generate code for the actual expression evaluation
    evalcode.extend("    values[%d] = %s;" % (i, c) for i, c in enumerate(expr))

    # Adapt evalcode to with/without cell argument if possible
    evalcode = "\n".join(evalcode)
    evalcode_cell = evalcode.replace(
        "__array_, x", "__array_, x, cell")
    if mesh_function_members:
        # TODO: Reuse code in Function::eval(values, x) which looks up cell from x?
        evalcode = []
        evalcode.append("      dolfin_error(\"generated code\",")
        evalcode.append("                   \"calling eval\", ")
        evalcode.append("                   \"Need cell to evaluate this Expression\");")
        evalcode = "\n".join(evalcode)

    # Connect the code fragments using the expression template code
    fragments["evalcode"]  = evalcode
    fragments["evalcode_cell"]  = evalcode_cell
    fragments["value_shape"] = "\n".join(value_shape_code)

    # Assign classname
    classname = "Expression_" + hashlib.sha1(
        fragments["evalcode"].encode("utf-8")).hexdigest()
    fragments["classname"] = classname

    # Produce the C++ code for the expression class
    code = _expression_template % fragments
    return classname, code, members


def compile_expression_code(code, classnames=None, module_name=None,
                            additional_declarations=None, mpi_comm=None):

    additional_declarations = additional_declarations or ""

    # Autodetect classnames:
    _classnames = re.findall(r"class[ ]+([\w]+).*", code)

    # Just a little assertion for safety:
    if classnames is None:
        classnames = _classnames
    else:
        assert all(a == b for (a,b) in zip(classnames, _classnames))

    # Complete the code
    code = "%s\n%s"%(math_header, code)

    # Compile the extension module
    compiled_module = compile_extension_module(
        code, additional_declarations=additional_declarations,
        mpi_comm=mpi_comm)

    # Get the compiled class
    expression_classes = [getattr(compiled_module, name) for name in classnames]
    return expression_classes


def compile_expressions(cppargs,
                        generic_function_members=None,
                        mesh_function_members=None,
                        mpi_comm=None):
    """
    Compiles a list of either C++ expressions of full subclasses of
    dolfin::Expression class.

    The expression can either be a str in which case it is
    interpreted as a scalar expression and a scalar Expression is generated.

    If the expression is a tuple consisting of more than one str it is
    interpreted as a vector expression, and a rank 1 Expression is generated.

    A tuple of tuples of str objects is interpreted as a matrix
    expression, and a rank 2 Expression is generated.

    If an expression string contains a name, it is assumed to be a scalar
    parameter name, and is added as a public member of the generated expression.
    The names of these parameters are then returned in a list together with the
    compiled expression class.

    If 'cppargs' include a class definition it is interpreted as c++ code with complete
    implementations of a subclasses of dolfin::Expression.

    kwargs propagates the default member values for any generated parameter.
    """
    #, which contains:
    #    %s
    #""" % "\n".join("        " + b for b in _builtins)
    # FIXME: Hook up this to a more general debug mechanism
    assert isinstance(cppargs, list)

    generic_function_members_list = (generic_function_members or
                                    [[] for i in range(len(cppargs))])
    mesh_function_members_list = (mesh_function_members or
                                 [[] for i in range(len(cppargs))])

    # Collect code and classnames
    code_snippets = []; classnames = []; all_members = []; additional_declarations = [];

    for cpparg, generic_function_members, mesh_function_members in \
            zip(cppargs, generic_function_members_list, mesh_function_members_list):
        assert isinstance(cpparg, string_types + (tuple, list))
        # If the cpparg includes the word 'class' and 'Expression',
        # assume it is a c++ code snippet

        if isinstance(cpparg, string_types) and "class" in cpparg and "Expression" in cpparg:
            # Assume that a code snippet is passed as cpparg
            code = cpparg

            # Get the class name
            classname = re.findall(r"class[ ]+([\w]+).*", code)[0]
            members = []

            # FIXME: Check for passed dimension?
        else:
            classname, code, members = \
                expression_to_dolfin_expression(cpparg,
                    generic_function_members, mesh_function_members)

            additional_declarations.extend(
                "%%rename(%s) dolfin::%s::shared_%s;" % (name, classname, name)
                for name, shape in generic_function_members)

            additional_declarations.extend(
                "%%rename(%s) dolfin::%s::shared_%s;" % (name, classname, name)
                for name, typename in mesh_function_members)

        all_members.append(members)
        code_snippets.append(code)
        classnames.append(classname)

    expression_classes = compile_expression_code(
        "\n\n".join(code_snippets), classnames,
        additional_declarations="\n".join(additional_declarations),
        mpi_comm=mpi_comm)

    return expression_classes, all_members

