dubfi.linalg.mpi_worker_helpers =============================== .. py:module:: dubfi.linalg.mpi_worker_helpers .. autoapi-nested-parse:: Helper functions for MPI linear algebra worker (private module). This module contains alternative implementations of helper functions used by :mod:`dubfi.linalg.mpi_worker`. Functions --------- .. autoapisummary:: dubfi.linalg.mpi_worker_helpers._trace_product_self_parallel dubfi.linalg.mpi_worker_helpers._trace_product_self_ensopt dubfi.linalg.mpi_worker_helpers._trace_product_self_ens Module Contents --------------- .. py:function:: _trace_product_self_parallel(s, e, hens, x, loc, chainT) Compute part of trace product for :class:`MpiGradOpWorker` (special case). Helper for :meth:`MpiGradOpWorker._trace_product` in commonly used case. Runtime: O(n_state * n_obs**3) + O(n_state * n_ens * n_obs**2) + O(n_state**2 * n_ens * n_obs) Required memory (theoretically): 16 * n_obs**2 Bytes per numba thread :param s: start index in reduced observation dimension :type s: int :param e: end index in reduced observation dimension :type e: int :param hens: centered ensemble observation operator :type hens: array, shape (n_ens, n_state, n_obs) :param x: ensemble of states :type x: array, shape (n_ens, n_obs) :param loc: localization, real-symmetric matrix :type loc: array, shape (n_obs, n_obs) :param chainT: chained array, real-symmetric matrix :type chainT: array, shape (n_obs, n_obs) .. py:function:: _trace_product_self_ensopt(s, e, hens, x, loc, chain) Compute part of trace product for :class:`MpiGradOpWorker` (special case). Helper for :meth:`MpiGradOpWorker._trace_product` in commonly used case, assuming that loc and chain are symmetric. This function can be faster than :func:`_trace_product_self_parallel` if n_state > 2*n_obs**2. Runtime: O(n_ens**2 * n_obs**3) + ... Required memory: 16 * (n_ens + 1) * n_obs**2 Bytes Define: s[k, i1, j2] = (hens[m1, k, i1] x[m1, j1] + i1 <-> j1) loc[i1, j1] chain[j1, j2] (sum over m1, j1) Compute: result[k, l] = s[k, i1, j2] s[l, j2, i1] (sum over i1=s:e, j2=0:n_obs) Consider fixed m1, m2. Then result consists of two parts: A: hens[m1, k, i1] x[m1, j1] loc[i1, j1] chain[j1, j2] hens[m2, l, j2] x[m2, i2] loc[i2, j2] chain[i2, i1] B: hens[m1, k, i1] x[m1, j1] loc[i1, j1] chain[j1, j2] hens[m2, l, i2] x[m2, j2] loc[i2, j2] chain[i2, i1] .. py:function:: _trace_product_self_ens(s, e, hens, x, loc, chain) Compute part of trace product for :class:`MpiGradOpWorker` (special case). Helper for :meth:`MpiGradOpWorker._trace_product` in commonly used case, assuming that loc and chain are symmetric. Runtime: O(n_ens**2 * n_obs**3) + ... Define: s[k, i1, j2] = (hens[m1, k, i1] x[m1, j1] + i1 <-> j1) loc[i1, j1] chain[j1, j2] (sum over m1, j1) Compute: result[k, l] = s[k, i1, j2] s[l, j2, i1] (sum over i1=s:e, j2=0:n_obs) Consider fixed m1, m2. Then result consists of two parts: A: hens[m1, k, i1] x[m1, j1] loc[i1, j1] chain[j1, j2] hens[m2, l, j2] x[m2, i2] loc[i2, j2] chain[i2, i1] B: hens[m1, k, i1] x[m1, j1] loc[i1, j1] chain[j1, j2] hens[m2, l, i2] x[m2, j2] loc[i2, j2] chain[i2, i1]