# Copyright 2018 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TransformDiagonal bijector."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow.compat.v2 as tf
from tensorflow_probability.python.bijectors import bijector

__all__ = [
    'TransformDiagonal',
]


class TransformDiagonal(bijector.Bijector):
  """Applies a Bijector to the diagonal of a matrix.

  #### Example

  ```python
  b = tfb.TransformDiagonal(diag_bijector=tfb.Exp())

  b.forward([[1., 0.],
             [0., 1.]])
  # ==> [[2.718, 0.],
         [0., 2.718]]
  ```

  """

  def __init__(self,
               diag_bijector,
               validate_args=False,
               name='transform_diagonal'):
    """Instantiates the `TransformDiagonal` bijector.

    Args:
      diag_bijector: `Bijector` instance used to transform the diagonal.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      self._diag_bijector = diag_bijector
      super(TransformDiagonal, self).__init__(
          forward_min_event_ndims=2,
          inverse_min_event_ndims=2,
          is_constant_jacobian=diag_bijector.is_constant_jacobian,
          validate_args=validate_args,
          dtype=diag_bijector.dtype,
          parameters=parameters,
          name=name)

  @property
  def diag_bijector(self):
    return self._diag_bijector

  def _forward(self, x):
    diag = self.diag_bijector.forward(tf.linalg.diag_part(x))
    return tf.linalg.set_diag(x, diag)

  def _inverse(self, y):
    diag = self.diag_bijector.inverse(tf.linalg.diag_part(y))
    return tf.linalg.set_diag(y, diag)

  def _forward_log_det_jacobian(self, x):
    # We formulate the Jacobian with respect to the flattened matrices
    # `vec(x)` and `vec(y)`. Suppose for notational convenience that
    # the first `n` entries of `vec(x)` are the diagonal of `x`, and
    # the remaining `n**2-n` entries are the off-diagonals in
    # arbitrary order. Then the Jacobian is a block-diagonal matrix,
    # with the Jacobian of the diagonal bijector in the first block,
    # and the identity Jacobian for the remaining entries (since this
    # bijector acts as the identity on non-diagonal entries):
    #
    # J_vec(x) (vec(y)) =
    # -------------------------------
    # | J_diag(x) (diag(y))      0  | n entries
    # |                             |
    # | 0                        I  | n**2-n entries
    # -------------------------------
    #   n                     n**2-n
    #
    # Since the log-det of the second (identity) block is zero, the
    # overall log-det-jacobian is just the log-det of first block,
    # from the diagonal bijector.
    #
    # Note that for elementwise operations (exp, softplus, etc) the
    # first block of the Jacobian will itself be a diagonal matrix,
    # but our implementation does not require this to be true.
    return self.diag_bijector.forward_log_det_jacobian(
        tf.linalg.diag_part(x), event_ndims=1)

  def _inverse_log_det_jacobian(self, y):
    return self.diag_bijector.inverse_log_det_jacobian(
        tf.linalg.diag_part(y), event_ndims=1)
