import numpy as np
[docs]
def forward(f,P):
"""
Forward operator.
Parameters
----------
f : complex np.array
f = B + iφ parametrizes the phase shifts φ and attenuation B, position the derivative is to be computed
P : complex np.array
Fresnel propagator.
Returns
-------
forward : real np.array
Intensity image
"""
field_detector = np.zeros_like(P, dtype='complex_')
# wave_FT = np.fft.fft2(np.exp(-1j * f))
wave_FT = np.fft.fft2(np.exp(-f))
for d in range(len(P)):
field_detector[d] = np.fft.ifft2(wave_FT * P[d])
forward = np.abs(field_detector)**2
return forward
[docs]
def frechet_derivative(f,P,epsilon):
"""
Frechet derivative of forward operator.
Parameters
----------
f : complex np.array
f = B + iφ parametrizes the phase shifts φ and attenuation B, position the derivative is to be computed
P : complex np.array
Fresnel propagator.
epsilon : complex np.array
Input of Frechet derivative
Returns
-------
frechet : real np.array
Frechet derivative with respect to f at epsilon
"""
# wave = np.exp(-1j*f)
# t1 = epsilon * np.fft.ifft2( np.fft.fft2(wave) * P )
# t2 = np.fft.ifft2( np.fft.fft2(wave) * P )
# frechet = np.imag(-np.conj(t1) * t2 )
wave = np.exp(-f)
t1 = epsilon * np.fft.ifft2( np.fft.fft2(wave) * P )
t2 = np.fft.ifft2( np.fft.fft2(wave) * P )
frechet = np.real(t1 * np.conj(t2) )
return frechet
[docs]
def adjoint_frechet_derivative(f,P,epsilon,gram=[]):
"""
Frechet derivative of forward operator.
Parameters
----------
f : complex np.array
f = B + iφ parametrizes the phase shifts φ and attenuation B, position the derivative is to be computed
P : complex np.array
Fresnel propagator.
epsilon : complex np.array
Input of Frechet derivative
gram : real np.array
Gramian matrix of regularization term
Returns
-------
adjoint_frechet : complex np.array
Adjoint of the Frechet derivative with respect to f at epsilon
"""
if gram != []:
epsilon = np.fft.ifft2((1/gram) * np.fft.fft2(epsilon))
# wave = np.exp(-1j*f)
# wave_FT = np.fft.fft2(wave)
# t1 = np.fft.ifft2(wave_FT * P)
# t1 = 1j * epsilon * t1
# t2 = np.fft.ifft2( np.conj(P) * np.fft.fft2(t1) )
# adjoint_frechet = np.conj(wave) * t2
wave = np.exp(-f)
wave_FT = np.fft.fft2(wave)
t1 = np.fft.ifft2(wave_FT * P)
t1 = epsilon * t1
t2 = np.fft.ifft2( np.conj(P) * np.fft.fft2(t1) )
adjoint_frechet = np.conj(wave) * t2
return -adjoint_frechet
[docs]
def grad(M):
"""
Gradient operator.
Parameters
----------
M : complex np.array
Image whose gradient is to be computed.
Returns
-------
gradient : complex np.array
Gradient of image M.
"""
nx = np.shape(M)[0]
fx = M[np.hstack((np.arange(1,nx),[nx-1])),:] - M
ny = np.shape(M)[1]
fy = M[:,np.hstack((np.arange(1,ny),[ny-1]))] - M
gradient = np.concatenate((fx[np.newaxis,:,:],fy[np.newaxis,:,:]), axis=0)
return gradient
[docs]
def div(P):
"""
Divergence operator.
Parameters
----------
P : complex np.array
Image whose divergence is to be computed.
Returns
-------
divergence : complex np.array
Divergence of image P.
"""
Py = P[1,:,:]
Px = P[0,:,:]
nx = np.shape(Px)[0]
fx = Px - Px[np.hstack(([0],np.arange(0,nx-1))),:]
fx[0,:] = Px[0,:] # boundary
fx[nx-1,:] = -Px[nx-2,:]
ny = np.shape(Py)[1]
fy = Py - Py[:,np.hstack(([0],np.arange(0,ny-1)))]
fy[:,0] = Py[:,0] # boundary
fy[:,ny-1] = -Py[:,ny-2]
divergence = fx+fy
return divergence
[docs]
def shrinkage(u, kappa):
"""
Shrinkage operation
Parameters
----------
u : real np.array
Input image.
kappa : float
Returns
-------
u : real np.array
max(0, |u| - kappa) * sign(u)
"""
u = np.maximum(0, u - kappa) - np.maximum(0, -u - kappa)
return u
[docs]
def amplitude_constraint(wavefront, amplitude, propagator, mask=[]):
"""Apply amplitude constraint.
Parameters
----------
wavefront : complex np.array
Wavefront to constrain.
amplitude : np.array
Amplitude to impose.
propagator : complex np.array
Propagator corresponding to effective distance of amplitude.
mask : np.array, optional
Zone to apply constraint.
Returns
-------
wavefront_constrained : complex np.array
Wavefront after applied constraint.
"""
#TODO: Proper handling of padding
if mask == []:
mask = np.ones_like(amplitude)
wavefront_aux = np.fft.ifft2(np.fft.fft2(wavefront) * propagator) # TODO: Should be done with propagator instead
wavefront_aux = np.where(mask != 0, amplitude * np.exp(1j * np.angle(wavefront_aux)), wavefront_aux) # Apply amplitude constraint
wavefront_constrained = np.fft.ifft2(np.fft.fft2(wavefront_aux) * np.conj(propagator))
return wavefront_constrained
[docs]
def error_estimate(wavefront, amplitude, propagator, mask=[]):
"""
Estimate fit to data.
Parameters
----------
wavefront : complex np.array
Wavefront for estimation.
amplitude : np.array
Amplitude from measured image.
propagator : complex np.array
Fresnel propagator corresponding to effective propagation distance
in measured image.
mask : np.array
Restrict estimate to a region of interest.
Returns
-------
error : float
MSE calculated and measured amplitude.
"""
wavefront_aux = np.fft.ifft2(np.fft.fft2(wavefront) * propagator)
aux = (np.abs(wavefront_aux) - amplitude) ** 2
error = np.sum(aux[:]) / (propagator.shape[0]*propagator.shape[1])
return error