Main SCICO Classes¶
BlockArray¶
The class BlockArray
provides a way to combine arrays of
different shapes into a single object for use with other SCICO classes.
A BlockArray
consists of a list of jax.Array
objects,
which we refer to as blocks. A BlockArray
differs from a list in
that, whenever possible, BlockArray
properties and methods
(including unary and binary operators like +, -, *, …) automatically
map along the blocks, returning another BlockArray
or tuple as
appropriate. For example,
>>> x = snp.blockarray((
... [[1, 3, 7],
... [2, 2, 1]],
... [2, 4, 8]
... ))
>>> x.shape # returns tuple
((2, 3), (3,))
>>> x * 2 # returns BlockArray
BlockArray([...Array([[ 2, 6, 14],
[ 4, 4, 2]], dtype=...), ...Array([ 4, 8, 16], dtype=...)])
>>> y = snp.blockarray((
... [[.2],
... [.3]],
... [.4]
... ))
>>> x + y # returns BlockArray
BlockArray([...Array([[1.2, 3.2, 7.2],
[2.3, 2.3, 1.3]], dtype=...), ...Array([2.4, 4.4, 8.4], dtype=...)])
NumPy and SciPy Functions¶
scico.numpy
, scico.numpy.testing
, and
scico.scipy.special
provide wrappers around jax.numpy
,
numpy.testing
and jax.scipy.special
where many of the
functions have been extended to work with instances of BlockArray
.
In particular:
When a tuple of tuples is passed as the shape argument to an array creation routine, a
BlockArray
is created.When a
BlockArray
is passed to a reduction function, the blocks are ravelled (i.e., reshaped to be 1D) and concatenated before the reduction is applied. This behavior may be prevented by passing the axis argument, in which case the function is mapped over the blocks.When one or more
BlockArray
instances are passed to a mathematical function that is not a reduction, the function is mapped over (corresponding) blocks.
For a list of array creation routines, see
>>> scico.numpy.creation_routines
('empty', ...)
For a list of reduction functions, see
>>> scico.numpy.reduction_functions
('sum', ...)
For lists of the remaining wrapped functions, see
>>> scico.numpy.mathematical_functions
('sin', ...)
>>> scico.numpy.testing_functions
('testing.assert_allclose', ...)
>>> import scico.scipy
>>> scico.scipy.special.functions
('betainc', ...)
Note that:
Both
scico.numpy.ravel
andBlockArray.ravel
return aBlockArray
with ravelled blocks rather than the concatenation of these blocks as a single array.The functional and method versions of the “same” function differ in their behavior, with the method version only applying the reduction within each block, and the function version applying the reduction across all blocks. For example,
scico.numpy.sum
applied to aBlockArray
with two blocks returns a scalar value, whileBlockArray.sum
returns aBlockArray
two scalar blocks.
Motivating Example¶
The discrete differences of a two-dimensional array, \(\mb{x} \in \mbb{R}^{n \times m}\), in the horizontal and vertical directions can be represented by the arrays \(\mb{x}_h \in \mbb{R}^{n \times (m-1)}\) and \(\mb{x}_v \in \mbb{R}^{(n-1) \times m}\) respectively. While it is usually useful to consider the output of a difference operator as a single entity, we cannot combine these two arrays into a single array since they have different shapes. We could vectorize each array and concatenate the resulting vectors, leading to \(\mb{\bar{x}} \in \mbb{R}^{n(m-1) + m(n-1)}\), which can be stored as a one-dimensional array, but this makes it hard to access the individual components \(\mb{x}_h\) and \(\mb{x}_v\).
Instead, we can construct a BlockArray
, \(\mb{x}_B =
[\mb{x}_h, \mb{x}_v]\):
>>> n = 32
>>> m = 16
>>> x_h, key = scico.random.randn((n, m-1))
>>> x_v, _ = scico.random.randn((n-1, m), key=key)
# Form the blockarray
>>> x_B = snp.blockarray([x_h, x_v])
# The blockarray shape is a tuple of tuples
>>> x_B.shape
((32, 15), (31, 16))
# Each block component can be easily accessed
>>> x_B[0].shape
(32, 15)
>>> x_B[1].shape
(31, 16)
Constructing a BlockArray¶
The recommended way to construct a BlockArray
is by using the
blockarray
function.
>>> import scico.numpy as snp
>>> x0, key = scico.random.randn((32, 32))
>>> x1, _ = scico.random.randn((16,), key=key)
>>> X = snp.blockarray((x0, x1))
>>> X.shape
((32, 32), (16,))
>>> X.size
(1024, 16)
>>> len(X)
2
While blockarray
will accept arguments of type
ndarray
or Array
, arguments of type ndarray
will be converted to Array
type.
Operating on a BlockArray¶
Indexing¶
BlockArray
indexing works just like indexing a list.
Multiplication Between BlockArray and LinearOperator¶
The Operator
and LinearOperator
classes are designed
to work on instances of BlockArray
in addition to instances of
Array
. For example
>>> x, key = scico.random.randn((3, 4))
>>> A_1 = scico.linop.Identity(x.shape)
>>> A_1.shape # array -> array
((3, 4), (3, 4))
>>> A_2 = scico.linop.FiniteDifference(x.shape)
>>> A_2.shape # array -> BlockArray
(((2, 4), (3, 3)), (3, 4))
>>> diag = snp.blockarray([np.array(1.0), np.array(2.0)])
>>> A_3 = scico.linop.Diagonal(diag, input_shape=(A_2.output_shape))
>>> A_3.shape # BlockArray -> BlockArray
(((2, 4), (3, 3)), ((2, 4), (3, 3)))
Operators¶
An operator is a map from \(\mathbb{R}^n\) or \(\mathbb{C}^n\)
to \(\mathbb{R}^m\) or \(\mathbb{C}^m\). In SCICO, operators
are primarily used to represent imaging systems and provide
regularization. SCICO operators are represented by instances of the
Operator
class.
SCICO Operator
objects extend the notion of “shape” and
“size” from the usual NumPy ndarray
class. Each
Operator
object has an input_shape
and output_shape
;
these shapes can be either tuples or a tuple of tuples (in the case of
a BlockArray
). The matrix_shape
attribute describes the
shape of the LinearOperator
if it were to act on vectorized,
or flattened, inputs.
For example, consider a two-dimensional array \(\mb{x} \in
\mathbb{R}^{n \times m}\). We compute the discrete differences of
\(\mb{x}\) in the horizontal and vertical directions, generating
two new arrays: \(\mb{x}_h \in \mathbb{R}^{n \times (m-1)}\) and
\(\mb{x}_v \in \mathbb{R}^{(n-1) \times m}\). We represent this
linear operator by \(\mb{A} : \mathbb{R}^{n \times m} \to
\mathbb{R}^{n \times (m-1)} \otimes \mathbb{R}^{(n-1) \times m}\). In
SCICO, this linear operator will return a BlockArray
with
the horizontal and vertical differences stored as blocks. Letting
\(y = \mb{A} x\), we have y.shape = ((n, m-1), (n-1, m))
and
A.input_shape = (n, m)
A.output_shape = ((n, m-1), (n-1, m)], (n, m))
A.shape = ( ((n, m-1), (n-1, m)), (n, m)) # (output_shape, input_shape)
A.input_size = n*m
A.output_size = n*(n-1)*m*(m-1)
A.matrix_shape = (n*(n-1)*m*(m-1), n*m) # (output_size, input_size)
Operator Calculus¶
SCICO supports a variety of operator calculus rules, allowing new operators to be defined in terms of old ones. The following table summarizes the available operations.
Operation |
Result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Defining A New Operator¶
To define a new operator, pass a callable to the Operator
constructor:
A = Operator(input_shape=(32,), eval_fn = lambda x: 2 * x)
Or use subclassing:
>>> from scico.operator import Operator
>>> class MyOp(Operator):
...
... def _eval(self, x):
... return 2 * x
>>> A = MyOp(input_shape=(32,))
At a minimum, the _eval
function must be overridden. If either
output_shape
or output_dtype
are unspecified, they are
determined by evaluating the operator on an input of appropriate shape
and dtype.
Linear Operators¶
Linear operators are those for which
SCICO represents linear operators as instances of the class
LinearOperator
. While finite-dimensional linear operators
can always be associated with a matrix, it is often useful to
represent them in a matrix-free manner. Most of SCICO’s linear
operators are implemented matrix-free.
Using A LinearOperator¶
We implement two ways to evaluate a LinearOperator
. The
first is using standard callable syntax: A(x)
. The second mimics
the NumPy matrix multiplication syntax: A @ x
. Both methods
perform shape and type checks to validate the input before ultimately
either calling A._eval or generating a new LinearOperator
.
For linear operators that map real-valued inputs to real-valued
outputs, there are two ways to apply the adjoint: A.adj(y)
and
A.T @ y
.
For complex-valued linear operators, there are three ways to apply the
adjoint A.adj(y)
, A.H @ y
, and A.conj().T @ y
. Note that
in this case, A.T
returns the non-conjugated transpose of the
LinearOperator
.
While the cost of evaluating the linear operator is virtually
identical for A(x)
and A @ x
, the A.H
and A.conj().T
methods are somewhat slower; especially the latter. This is because
two intermediate linear operators must be created before the function
is evaluated. Evaluating A.conj().T @ y
is equivalent to:
def f(y):
B = A.conj() # New LinearOperator #1
C = B.T # New LinearOperator #2
return C @ y
Note: the speed differences between these methods vanish if applied inside of a jit-ed function. For instance:
f = jax.jit(lambda x: A.conj().T @ x)
Public Method |
Private Method |
|
|
|
|
|
|
The public methods perform shape and type checking to validate the input before either calling the corresponding private method or returning a composite LinearOperator.
Linear Operator Calculus¶
SCICO supports several linear operator calculus rules.
Given
A
and B
of class LinearOperator
and of appropriate shape,
x
an array of appropriate shape,
c
a scalar, and
O
an Operator
,
we have
Operation |
Result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Defining A New Linear Operator¶
To define a new linear operator, pass a callable to the
LinearOperator
constructor
>>> from scico.linop import LinearOperator
>>> A = LinearOperator(input_shape=(32,),
... eval_fn = lambda x: 2 * x)
Or, use subclassing:
>>> class MyLinearOperator(LinearOperator):
... def _eval(self, x):
... return 2 * x
>>> A = MyLinearOperator(input_shape=(32,))
At a minimum, the _eval
method must be overridden. If the
_adj
method is not overriden, the adjoint is determined using
scico.linear_adjoint
. If either output_shape
or
output_dtype
are unspecified, they are determined by evaluating
the Operator on an input of appropriate shape and dtype.
Functionals¶
A functional is
a mapping from \(\mathbb{R}^n\) or \(\mathbb{C}^n\) to \(\mathbb{R}\).
In SCICO, functionals are
primarily used to represent a cost to be minimized
and are represented by instances of the Functional
class.
An instance of Functional
, f
, may provide three core operations.
- Evaluation
f(x)
returns the value of the functional evaluated at the pointx
.A functional that can be evaluated has the attribute
f.has_eval == True
.Not all functionals can be evaluated: see Plug-and-Play.
- Gradient
f.grad(x)
returns the gradient of the functional evaluated atx
.Gradients are calculated using JAX reverse-mode automatic differentiation, exposed through
scico.grad
.Note: The gradient of a functional
f
can be evaluated even if that functional is not smooth. All that is required is that the functional can be evaluated,f.has_eval == True
. However, the result may not be a valid gradient (or subgradient) for all inputs.
- Proximal operator
f.prox(v, lam)
returns the result of the scaled proximal operator off
, i.e., the proximal operator oflambda x: lam * f(x)
, evaluated at the pointv
.The proximal operator of a functional \(f : \mathbb{R}^n \to \mathbb{R}\) is the mapping \(\mathrm{prox}_f : \mathbb{R}^n \to \mathbb{R}^n\) defined as
\[\mathrm{prox}_f (\mb{v}) = \argmin_{\mb{x}} f(\mb{x}) + \frac{1}{2} \norm{\mb{v} - \mb{x}}_2^2\;.\]
Plug-and-Play¶
For the plug-and-play framework [48],
we encapsulate generic denoisers including CNNs
in Functional
objects that cannot be evaluated.
The denoiser is applied via the the proximal operator.
For examples, see Usage Examples.
Proximal Calculus¶
We support a limited subset of proximal calculus rules:
Scaled Functionals¶
Given a scalar c
and a functional f
with a defined proximal method, we can
determine the proximal method of c * f
as
Note that we have made no assumptions regarding homogeneity of f
;
rather, only that the proximal method of f
is given
in the parameterized form \(\mathrm{prox}_{c f}\).
In SCICO, multiplying a Functional
by a scalar
will return a ScaledFunctional
.
This ScaledFunctional
retains the has_eval
and has_prox
attributes
from the original Functional
,
but the proximal method is modified to accomodate the additional scalar.
Separable Functionals¶
A separable functional \(f : \mathbb{C}^N \to \mathbb{R}\) can be written as the sum of functionals \(f_i : \mathbb{C}^{N_i} \to \mathbb{R}\) with \(\sum_i N_i = N\). In particular,
The proximal operator of a separable \(f\) can be written in terms of the proximal operators of the \(f_i\) (see Theorem 6.6 of [7]):
Separable Functionals are implemented in the SeparableFunctional
class. Separable functionals naturally accept BlockArray
inputs and return the prox as a BlockArray
.
Adding New Functionals¶
To add a new functional, create a class which
inherits from base
Functional
;has
has_eval
andhas_prox
flags;has
_eval
andprox
methods, as necessary.
For example,
class MyFunctional(scico.functional.Functional):
has_eval = True
has_prox = True
def _eval(self, x: JaxArray) -> float:
return snp.sum(x)
def prox(self, x: JaxArray, lam : float) -> JaxArray:
return x - lam
Losses¶
In SCICO, a loss is a special type of functional
where \(\alpha\) is a scaling parameter,
\(l\) is a functional,
\(\mb{y}\) is a set of measurements,
and \(A\) is an operator.
SCICO uses the class Loss
to represent losses.
Loss functionals commonly arrise in the context of solving
inverse problems in scientific imaging,
where they are used to represent the mismatch
between predicted measurements \(A(\mb{x})\)
and actual ones \(\mb{y}\).
Optimization Algorithms¶
ADMM¶
The Alternating Direction Method of Multipliers (ADMM) [25] [23] is an algorithm for minimizing problems of the form
where \(f\) and \(g\) are convex (but not necessarily smooth) functionals, \(\acute{A}\) and \(\acute{B}\) are linear operators, and \(\mb{c}\) is a constant vector. (For a thorough introduction and overview, see [12].)
The SCICO ADMM solver, ADMM
, solves problems of the form
where \(f\) and the \(g_i\) are instances of Functional
,
and the \(C_i\) are LinearOperator
, by defining
in (1), corresponding to defining
In ADMM
, \(f\) is a Functional
, typically a
Loss
, corresponding to the forward model of an imaging
problem, and the \(g_i\) are Functional
, typically
corresponding to a regularization term or constraint. Each of the
\(g_i\) must have a proximal operator defined. It is also possible
to set f = None
, which corresponds to defining \(f = 0\),
i.e. the zero function.
Subproblem Solvers¶
The most computational expensive component of the ADMM iterations is typically the \(\mb{x}\)-update,
The available solvers for this problem are:
-
This is the default subproblem solver as it is applicable in all cases. It it is only suitable for relatively small-scale problems as it makes use of
solver.minimize
, which wrapsscipy.optimize.minimize
. -
This subproblem solver can be used when \(f\) takes the form \(\norm{\mb{A} \mb{x} - \mb{y}}^2_W\). It makes use of the conjugate gradient method, and is significantly more efficient than
admm.GenericSubproblemSolver
when it can be used. -
This subproblem solver can be used when \(f\) takes the form \(\norm{\mb{A} \mb{x} - \mb{y}}^2_W\), and \(A\) and all of the \(C_i\) are diagonal (
Diagonal
) or matrix operators (MatrixOperator
). It exploits a pre-computed matrix factorization for a significantly more efficient solution than conjugate gradient. -
This subproblem solver can be used when \(f\) takes the form \(\norm{\mb{A} \mb{x} - \mb{y}}^2_W\) and \(\mb{A}\) and all the \(C_i\) s are circulant (i.e., diagonalized by the DFT).
admm.FBlockCircularConvolveSolver
andadmm.G0BlockCircularConvolveSolver
These subproblem solvers can be used when the primary linear operator is block-circulant (i.e. an operator with blocks that are diagonalied by the DFT).
For more details of these solvers and how to specify them, see the API
reference page for scico.optimize.admm
.
Proximal ADMM¶
Proximal ADMM [18] is an algorithm for solving problems of the form
where \(f\) and \(g\) are are convex (but not necessarily smooth) functionals and \(A\) and \(B\) are linear operators. Although convergence per iteration is typically somewhat worse than that of ADMM, the iterations can be much cheaper than that of ADMM, giving Proximal ADMM competitive time convergence performance.
The SCICO Proximal ADMM solver, ProximalADMM
, requires
\(f\) and \(g\) to be instances of Functional
, and
to have a proximal operator defined (Functional.prox
), and
\(A\) and \(B\) are required to be an instance of
LinearOperator
.
Non-Linear Proximal ADMM¶
Non-Linear Proximal ADMM [11] is an algorithm for solving problems of the form
where \(f\) and \(g\) are are convex (but not necessarily smooth) functionals and \(H\) is a function of two vector variables.
The SCICO Non-Linear Proximal ADMM solver, NonLinearPADMM
, requires
\(f\) and \(g\) to be instances of Functional
, and
to have a proximal operator defined (Functional.prox
), and
\(H\) is required to be an instance of Function
.
Linearized ADMM¶
Linearized ADMM [56] [42] (Sec. 4.4.2) is an algorithm for solving problems of the form
where \(f\) and \(g\) are are convex (but not necessarily smooth) functionals. Although convergence per iteration is typically significantly worse than that of ADMM, the \(\mb{x}\)-update, can be much cheaper than that of ADMM, giving Linearized ADMM competitive time convergence performance.
The SCICO Linearized ADMM solver, LinearizedADMM
,
requires \(f\) and \(g\) to be instances of Functional
,
and to have a proximal operator defined (Functional.prox
), and
\(C\) is required to be an instance of LinearOperator
.
PDHG¶
The Primal–Dual Hybrid Gradient (PDHG) algorithm [21] [14] [43] solves problems of the form
where \(f\) and \(g\) are are convex (but not necessarily smooth) functionals. The algorithm has similar advantages over ADMM to those of Linearized ADMM, but typically exhibits better convergence properties.
The SCICO PDHG solver, PDHG
,
requires \(f\) and \(g\) to be instances of Functional
,
and to have a proximal operator defined (Functional.prox
), and
\(C\) is required to be an instance of Operator
or LinearOperator
.
PGM¶
The Proximal Gradient Method (PGM) [17] [10] and Accelerated Proximal Gradient Method (AcceleratedPGM) [8] are algorithms for minimizing problems of the form
where \(g\) is convex and \(f\) is smooth and convex. The
corresponding SCICO solvers are PGM
and AcceleratedPGM
respectively. In most cases AcceleratedPGM
is expected to provide
faster convergence. In both of these classes, \(f\) and \(g\) are
both of type Functional
, where \(f\) must be differentiable,
and \(g\) must have a proximal operator defined.
While ADMM provides significantly more flexibility than PGM, and often converges faster, the latter is preferred when solving the ADMM \(\mb{x}\)-step is very computationally expensive, such as in the case of \(f(\mb{x}) = \norm{\mb{A} \mb{x} - \mb{y}}^2_W\) where \(A\) is large and does not have any special structure that would allow an efficient solution of (2).
Step Size Options¶
The step size (usually referred to in terms of its reciprocal,
\(L\)) for the gradient descent in PGM
can be adapted via
Barzilai-Borwein methods (also called spectral methods) and iterative
line search methods.
The available step size policy classes are:
-
This implements the step size adaptation based on the Barzilai-Borwein method [6]. The step size \(\alpha\) is estimated as
\[\begin{split}\mb{\Delta x} = \mb{x}_k - \mb{x}_{k-1} \; \\ \mb{\Delta g} = \nabla f(\mb{x}_k) - \nabla f (\mb{x}_{k-1}) \; \\ \alpha = \frac{\mb{\Delta x}^T \mb{\Delta g}}{\mb{\Delta g}^T \mb{\Delta g}} \;.\end{split}\]Since the PGM solver uses the reciprocal of the step size, the value \(L = 1 / \alpha\) is returned.
-
This implements the adaptive Barzilai-Borwein method as introduced in [60]. The adaptive step size rule computes
\[\begin{split}\mb{\Delta x} = \mb{x}_k - \mb{x}_{k-1} \; \\ \mb{\Delta g} = \nabla f(\mb{x}_k) - \nabla f (\mb{x}_{k-1}) \; \\ \alpha^{\mathrm{BB1}} = \frac{\mb{\Delta x}^T \mb{\Delta x}} {\mb{\Delta x}^T \mb{\Delta g}} \; \\ \alpha^{\mathrm{BB2}} = \frac{\mb{\Delta x}^T \mb{\Delta g}} {\mb{\Delta g}^T \mb{\Delta g}} \;.\end{split}\]The determination of the new step size is made via the rule
\[\begin{split}\alpha = \left\{ \begin{array}{ll} \alpha^{\mathrm{BB2}} & \mathrm{~if~} \alpha^{\mathrm{BB2}} / \alpha^{\mathrm{BB1}} < \kappa \; \\ \alpha^{\mathrm{BB1}} & \mathrm{~otherwise} \end{array} \right . \;,\end{split}\]with \(\kappa \in (0, 1)\).
Since the PGM solver uses the reciprocal of the step size, the value \(L = 1 / \alpha\) is returned.
-
This implements the line search strategy described in [8]. This strategy estimates \(L\) such that \(f(\mb{x}) \leq \hat{f}_{L}(\mb{x})\) is satisfied with \(\hat{f}_{L}\) a quadratic approximation to \(f\) defined as
\[\hat{f}_{L}(\mb{x}, \mb{y}) = f(\mb{y}) + \nabla f(\mb{y})^H (\mb{x} - \mb{y}) + \frac{L}{2} \left\| \mb{x} - \mb{y} \right\|_2^2 \;,\]with \(\mb{x}\) the potential new update and \(\mb{y}\) the current solution or current extrapolation (if using
AcceleratedPGM
). -
This implements the robust line search strategy described in [22]. This strategy estimates \(L\) such that \(f(\mb{x}) \leq \hat{f}_{L}(\mb{x})\) is satisfied with \(\hat{f}_{L}\) a quadratic approximation to \(f\) defined as
\[\hat{f}_{L}(\mb{x}, \mb{y}) = f(\mb{y}) + \nabla f(\mb{y})^H (\mb{x} - \mb{y}) + \frac{L}{2} \left\| \mb{x} - \mb{y} \right\|_2^2 \;,\]with \(\mb{x}\) the potential new update and \(\mb{y}\) the auxiliary extrapolation state. Note that this should only be used with
AcceleratedPGM
.
For more details of these step size managers and how to specify them, see
the API reference page for scico.optimize.pgm
.
Learned Models¶
In SCICO, neural network models are used to represent imaging problems and provide different modes of data-driven regularization. The models are implemented in Flax, and constitute a representative sample of frequently used networks.
FlaxMap¶
SCICO interfaces with the implemented models via FlaxMap
. This provides a standardized access to all trained models via the model definiton and the learned parameters. Further specialized functionality, such as learned denoisers, are built on top of FlaxMap
. The specific models that have been implemented are described below.
DnCNN¶
The denoiser convolutional neural network model (DnCNN) [59], implemented as DnCNNNet
, is used to denoise images that have been corrupted with additive Gaussian noise.
ODP¶
The unrolled optimization with deep priors (ODP) [19], implemented as ODPNet
, is used to solve inverse problems in imaging by adapting classical iterative methods into an end-to-end framework that incorporates deep networks as well as knowledge of the image formation model.
The framework aims to solve the optimization problem
where \(A\) represents a linear forward model and \(r\) a regularization function encoding prior information, by unrolling the iterative solution method into a network where each iteration corresponds to a different stage in the ODP network. Different iterative solutions produce different unrolled optimization algorithms which, in turn, produce different ODP networks. The ones implemented in SCICO are described below.
Proximal Map¶
This algorithm corresponds to solving
with \(k\) corresponding to the index of the iteration, which translates to an index of the stage of the network, \(f(A \mb{x}, \mb{y})\) a fidelity term, usually an \(\ell_2\) norm, and \(\mb{x}^{k+1/2}\) a regularization representing \(\mathrm{prox}_r (\mb{x}^k)\) and usually implemented as a convolutional neural network (CNN). This proximal map representation is used when minimization problem (3) can be solved in a computationally efficient manner.
ODPProxDnBlock
uses this formulation to solve a denoising problem, which, according to [19], can be solved by
where \(A\) corresponds to the identity operator and is therefore omitted, \(\mb{y}\) is the noisy signal, \(\alpha_k > 0\) is a learned stage-wise parameter weighting the contribution of the fidelity term and \(\mb{x}^k + \mb{x}^{k+1/2}\) is the regularization, usually represented by a residual CNN.
ODPProxDblrBlock
uses this formulation to solve a deblurring problem, which, according to [19], can be solved by
where \(A\) is the blurring operator, \(K\) is the blurring kernel, \(\mb{y}\) is the blurred signal, \(\mathcal{F}\) is the DFT, \(\alpha_k > 0\) is a learned stage-wise parameter weighting the contribution of the fidelity term and \(\mb{x}^k + \mb{x}^{k+1/2}\) is the regularization represented by a residual CNN.
Gradient Descent¶
When the solution of the optimization problem in (3) can not be simply represented by an analytical step, a formulation based on a gradient descent iteration is preferred. This yields
where \(\mb{x}^{k+1/2}\) represents \(\nabla r(\mb{x}^k)\).
ODPGrDescBlock
uses this formulation to solve a generic problem with \(\ell_2\) fidelity as
with \(\mb{y}\) the measured signal and \(\mb{x} + \mb{x}^{k+1/2}\) a residual CNN.
MoDL¶
The model-based deep learning (MoDL) [1], implemented as MoDLNet
, is used to solve inverse problems in imaging also by adapting classical iterative methods into an end-to-end deep learning framework, but, in contrast to ODP, it solves the optimization problem
by directly computing the update
via conjugate gradient. The regularization \(\mb{z}^k = \mathrm{D}_w(\mb{x}^{k})\) incorporates prior information, usually in the form of a denoiser model. In this case, the denoiser \(\mathrm{D}_w\) is shared between all the stages of the network requiring relatively less memory than other unrolling methods. This also allows for deploying a different number of iterations in testing than the ones used in training.