Source code for phaseretrieval

#This file is part of the PyPhase software.
#
#Copyright (c) Max Langer (2019) 
#
#max.langer@creatis.insa-lyon.fr
#
#This software is a computer program whose purpose is to allow development,
#implementation, and deployment of phase retrieval algorihtms.
#
#This software is governed by the CeCILL  license under French law and
#abiding by the rules of distribution of free software.  You can  use, 
#modify and/ or redistribute the software under the terms of the CeCILL
#license as circulated by CEA, CNRS and INRIA at the following URL
#"http://www.cecill.info". 
#
#As a counterpart to the access to the source code and  rights to copy,
#modify and redistribute granted by the license, users are provided only
#with a limited warranty  and the software's author,  the holder of the
#economic rights,  and the successive licensors  have only  limited
#liability. 
#
#In this respect, the user's attention is drawn to the risks associated
#with loading,  using,  modifying and/or developing or reproducing the
#software by the user in light of its specific status of free software,
#that may mean  that it is complicated to manipulate,  and  that  also
#therefore means  that it is reserved for developers  and  experienced
#professionals having in-depth computer knowledge. Users are therefore
#encouraged to load and test the software's suitability as regards their
#requirements in conditions enabling the security of their systems and/or 
#data to be ensured and,  more generally, to use and operate it in the 
#same conditions as regards security. 
#
#The fact that you are presently reading this means that you have had
#knowledge of the CeCILL license and that you accept its terms.

import numpy as np
from math import *
import pyphase.parallelizer as Parallelizer
import pyphase.propagator as Propagator
import pyphase.tomography as Tomography
#from vendor.EdfFile import EdfFile #TODO: This should not be necessary here!
import matplotlib.pyplot as pyplot
import scipy.ndimage
from pyphase.config import *
from matplotlib.pyplot import pause
from scipy import interpolate
import pickle #TODO: Is there a better way to handle imports? Centralised?
from scipy import ndimage
import pyphase.dataset as Dataset

[docs]class PhaseRetrievalAlgorithm2D: """Base class for 2D phase retrieval algorithms. Parameters ---------- dataset : pyphase.Dataset, optional A Dataset type object. shape : tuple of ints, optional Size of images (ny, nx) for creation of frequency variables etc. pixel_size : float, optional In m. distance : list of floats, optional Effective propagation distances in m. energy : float, optional Effective energy in keV. alpha : tuple of floats, optional Regularisation parameters. First entry for LF, second for HF. Typically [1e-8, 1e-10]. pad : int Padding factor (default 2). Attributes ---------- nx : int Number of pixels in horizontal direction. ny : int Number of pixels in horizontal direction. pixel_size : tuple of floats Pixel size [x, y] in µm. ND : int Number of positions. energy : float Energy in keV. alpha : tuple of floats First entry for LF, second for HF. Tyically [1e-8, 1e-10]. distance : numpy array Effective propagation distances in m. padding : int Padding factor. sample_frequency : float Reciprocal of pixel size in lengthscale. nfx : int Number of samples in Fourier domain (horizontal). nfy : int Number of samples in Fourier domain (vertical). fx : numpy array Frequency variable (horizontal), calculated by frequency_variable. fy : numpy array Freuency variable (vertical). alpha_cutoff : float Cutoff frequency for regularization parameter in normalized frequency. alpha_slope : float Slope in regularization parameter. Notes ----- Takes either a Dataset type object with keyword dataset (which contains all necessary parameters), or parameters as above. """ def __init__(self, dataset=None, **kwargs): self.lengthscale=10e-6 # TODO where should this be... Makes formulae dimensionless #print(kwargs) # Check whether usage is with dataset or external images #TODO: Which parameters in constructor? if (isinstance(dataset, Dataset.Dataset) or isinstance(dataset, Dataset.ESRF)): #TODO: needs integration #dataset = args[0] self.nx=dataset.nx # TODO: Probably make into an array [x, y] self.ny=dataset.ny self.pixel_size = dataset.pixel_size*1e-6 self.ND = len(dataset.position) self.energy = dataset.energy self.distance = np.array(dataset.effective_distance) #TODO: wallah I've mixed up position, distance, effective distance. Sort out else: self.nx = kwargs["shape"][1] self.ny = kwargs["shape"][0] self.pixel_size = kwargs["pixel_size"] self.distance = np.array(kwargs["distance"]) if\ isinstance(kwargs["distance"], list) else kwargs["distance"] self.energy = np.array(kwargs["energy"]) if\ isinstance(kwargs["energy"], list) else kwargs["energy"] self.ND = len(self.distance) if\ isinstance(self.distance, np.ndarray) else len(self.energy) if 'pad' in kwargs: self.padding = kwargs['pad'] if 'alpha' in kwargs: self.alpha=kwargs['alpha'] else: self.alpha=[-8, -10] if (type(self.pixel_size) == float) or (type(self.pixel_size) == np.float64) : # If x,y pixelsizes are not given, assign them self.pixel_size = np.array([self.pixel_size, self.pixel_size]) elif type(self.pixel_size) == list: self.pixel_size = np.array(self.pixel_size) self.sample_frequency = self.lengthscale/self.pixel_size #TODO: Should be attribute self.fx, self.fy = self.frequency_variable(self.nfx, self.nfy, self.sample_frequency) self._compute_factors() self.alpha_cutoff = .5 self.alpha_cutoff_frequency=self.alpha_cutoff*self.sample_frequency # TODO: should be a property (dynamically calculated from alpha_cutoff) self.alpha_slope = .1e3 def _algorithm(self, image, positions=None): """'Pure virtual' method containing purely the algorithm code part. Should be defined by each subclass. """ def _compute_factors(self): """ Computes factors used in phase retrieval. Motivation is to save time and memory when reconstructing a series of projections on a processor (usually in parallel with others). Default are the sin and cos chirps in CTF. CTF and Mixed approaches extend this method, while TIE methods override it. """ self.coschirp = np.zeros((self.ND, self.nfy, self.nfx)) self.sinchirp = np.zeros_like(self.coschirp) for distance in range(self.ND): self.coschirp[distance] = np.cos((pi*self.Fresnel_number[distance]) * (self.fx**2) + (pi*self.Fresnel_number[distance]) * (self.fy**2)) self.sinchirp[distance] = np.sin((pi*self.Fresnel_number[distance]) * (self.fx**2) + (pi*self.Fresnel_number[distance]) * (self.fy**2)) def __getstate__(self): """ Used in parallel computing to override writing of voluminous variables when serializing using pickle by parallelizer. They are instead re-calculated by each process (see __setstate__()). """ state = self.__dict__.copy() del state['fx'], state['fy'], state['sinchirp'], state['coschirp'] return state def __setstate__(self, state): """ Used in parallel computing to re-calculate voluminous variables when serializing using pickle by parallelizer instead of saving them to disk (see __setstate__()). """ self.__dict__.update(state) self._compute_factors() @property def nfx(self): return self.padding*self.nx @property def nfy(self): return self.padding*self.ny @property def Lambda(self): """Wavelength based on energy (float)""" return 12.4e-10 / self.energy @property def Fresnel_number(self): """Fresnel number at each position, calculated from energy and distance (float)""" return self.Lambda * self.distance / (self.lengthscale**2) @property def Alpha(self): """Image implementation of regularisation parameter (np.array)""" x=np.linspace(-1,1,self.nfx) y=np.linspace(-1,1,self.nfy) xv, yv = np.meshgrid(x,y) R=np.sqrt(np.square(xv) + np.square(yv)) R=np.fft.fftshift(R) if self.alpha[0] > self.alpha[1]: # Logistic function instead of error function (to be seen) Alpha = self.alpha[0] - ((self.alpha[0] - self.alpha[1]) / (1 + np.exp(-self.alpha_slope * (R-self.alpha_cutoff)))) elif self.alpha[0] < self.alpha[1]: Alpha = self.alpha[0] + ((self.alpha[1] - self.alpha[0]) / (1 + np.exp(-self.alpha_slope * (R-self.alpha_cutoff)))) else: Alpha = self.alpha[0] * R**0 return 10**Alpha
[docs] @Parallelize def reconstruct_projections(self, *, dataset, projections): """ Reconstruct a range of projections (parallelized function). Parameters ---------- dataset : pyphase.Dataset Dataset to reconstruct. projections : list of int In the form [start, end] """ for projection in range(projections[0], projections[1]+1): #print("Projection: {}".format(projection)) self.reconstruct_projection(dataset=dataset, projection=projection)
[docs] def frequency_variable(self, nfx, nfy, sample_frequency): """ Calculate frequency variables. Parameters ---------- nfx : int Number of samples in x direction nfy : int Number of samples in y direction sample_frequency : float Reciprocal of pixel size in 1/m Returns ------- nparray Frequency variables as an array of size [nfy, nfx, 2] Notes ----- Follows numpy FFT convention. Zero frequency at [0,0], [1:n//2] contain the positive frequencies, [n//2 + 1:] n the negative frequencies in increasing order starting from the most negative frequency. """ if type(sample_frequency) == int: sample_frequency = np.array([sample_frequency, sample_frequency]) #TODO: refactor x=0 x=np.append(x, np.linspace(sample_frequency[0]/nfx, sample_frequency[0]/2, nfx//2)) x=np.append(x, np.linspace(-sample_frequency[0]/2+sample_frequency[0]/nfx, -sample_frequency[0]/nfx, int(nfx//2-1+(np.ceil(nfx/2)-nfx//2)))) y=0 y=np.append(y, np.linspace(sample_frequency[1]/nfy, sample_frequency[1]/2, nfy//2)) y=np.append(y, np.linspace(-sample_frequency[1]/2+sample_frequency[1]/nfy, -sample_frequency[1]/nfy, int(nfy//2-1+(np.ceil(nfy/2)-nfy//2)))) return np.meshgrid(x, y)
[docs] def simple_propagator(self, pxs, Lambda, z): """ Creates a Fresnel propagator. Parameters ---------- pxs : float Pixel size in µm. Lambda : float Wavelength in m. z : float Effective propagation distance in m. Returns ------- H : nparray Fresnel propagator. Notes ----- Temporary implementation by Y. Zhang. Will be integrated with the propagator module. """ # TODO: need to be replaced by the propagator class # TODO: Needs refactoring to meet coding standards # Generates the ifftshift version of the propagators x = np.arange(-np.fix(self.nfx / 2), np.ceil(self.nfx / 2)) y = np.arange(-np.fix(self.nfy / 2), np.ceil(self.nfy / 2)) fx = np.fft.ifftshift(x / (self.nfx * pxs[1])) #pxs[1] throws error: invalid index to scalar variable. Replaced with pxs fy = np.fft.ifftshift(y / (self.nfy * pxs[0])) Fx, Fy = np.meshgrid(fx, fy) f2 = Fx ** 2 + Fy ** 2 H = np.zeros([self.ND, self.nfy, self.nfx], dtype="complex_") for distance in range(self.ND): H[distance,:,:] = np.exp(-1j * np.pi * Lambda * z[distance] * f2) return H
[docs] def reconstruct_projection(self, dataset, projection=0, positions=None, pad=True): """ Reconstruct one projection from a Dataset object and saves the result. Parameters ---------- dataset : Dataset Dataset object to use (not necessarily the same as initialised). projection : int, optional Number of projection to reconstruct. positions : int or list of ints, optional Subset of positions to use for reconstruction Returns ------- phase : np.array Reconstructed phase. attenuation : np.array Reconstructed attenuation. """ ID = np.zeros((self.ND, self.nfy, self.nfx)) #dataset.padding = self.padding #TODO: needs proper sorting out... for position in range(self.ND): # if difference: # FID[position] = dataset.get_image(projection=projection, difference=True, Fourier=True, position=position) #TODO: difference not really necessary, just give the difference image as input... # else: ID[position] = dataset.get_projection(projection=projection, position=position, pad=pad) # elif image.any(): # Call algorithm part phase, attenuation = self.reconstruct_image(ID, positions=positions, pad=pad) # if difference: # dataset.write_image(phase, 'phase update', projection) # else: dataset.write_image(image=Utilities.resize(phase, [self.ny, self.nx]), projection=projection) dataset.write_image(image=Utilities.resize(attenuation, [self.ny, self.nx]), projection=projection, projection_type='attenuation') return phase, attenuation
[docs] def reconstruct_image(self, image, positions=None, pad=False): """ Template for reconstructing an image given as argument. Arguments --------- image : numpy.array A phase contrast image or an ndarray of images stacked along the first dimension. positions : int or list of ints, optional Subset of positions to use for reconstruction. Note ---- Calls _algorithm (container purely for algorithm part). Returns ------- phase : numpy.array attenuation : numpy.array """ if len(image.shape) == 2: # If only one image is given, add 3rd dimention for compatibility with loops image = image[np.newaxis] if not positions: positions = list(range(self.ND)) elif not isinstance(positions, list): positions = [ positions ] image = Utilities.resize(image, [self.nfy, self.nfx]) phase, attenuation = self._algorithm(image, positions=positions) if not pad: phase = Utilities.resize(phase, [self.ny, self.nx]) attenuation = Utilities.resize(attenuation, [self.ny, self.nx]) return phase, attenuation
[docs]class TIEHOM(PhaseRetrievalAlgorithm2D): """ Transport of Intensity Equation for homogeneous objects (or "Paganin's algorithm") [1] Parameters ---------- delta_beta : float, optional Material dependent ratio delta over beta. References ---------- [1] Paganin et al. J. Microsc. 206 (2002) 33 """ def __init__(self, dataset=None, delta_beta=500, **kwargs): self._delta_beta=delta_beta self.padding = 2 super().__init__(dataset, **kwargs) @property def delta_beta(self): """Material dependent ratio delta over beta (float).""" return self._delta_beta @delta_beta.setter def delta_beta(self, delta_beta): """Recalculates dependent factors on setting (float).""" if delta_beta != self._delta_beta: self._delta_beta = delta_beta self._compute_factors() def _compute_factors(self): """Calculate TIEHOM factors. Overrides PhaseRetrievalAlgorithm2D.""" self.TIEHOM_factor = [0 for xxx in range(self.ND)] for distance in range(self.ND): self.TIEHOM_factor[distance] = 1 + self.Fresnel_number[distance] * np.pi * self.delta_beta * ((self.fx ** 2) + (self.fy ** 2)) def _algorithm(self, image, positions=None): """Reconstruct one image or a set of images using TIEHOM.""" #TODO: Needs verification on simpler images FID=np.fft.fft2(image) numerator_TIEHOM = np.zeros((self.nfy, self.nfx)) denominator_TIEHOM = numerator_TIEHOM.copy() if len(positions) == 1: phase = 1/2 * self.delta_beta * np.log(np.real(np.fft.ifft2(FID[positions[0]] / self.TIEHOM_factor[positions[0]]))) else: for position in positions: numerator_TIEHOM = numerator_TIEHOM + (self.TIEHOM_factor[position] * FID[position]) denominator_TIEHOM = denominator_TIEHOM + (self.TIEHOM_factor[position])**2 phase = 1/2 * self.delta_beta * np.log(np.real(np.fft.ifft2(numerator_TIEHOM / denominator_TIEHOM))) attenuation = -1/(self.delta_beta) * phase return phase, attenuation
[docs]class CTFPurePhase(PhaseRetrievalAlgorithm2D): """ Contrast Transfer Function for pure phase objects [1]. References ---------- [1] Cloetens et al. J. Phys. D: Appl. Phys 29 (1996) 133 """ def __init__(self, dataset=None, **kwargs): self.padding = 2 super().__init__(dataset, **kwargs) def _algorithm(self, image, positions=None): image=np.fft.fft2(image) #Saves memory? # Generate CTF factors # TODO: should possibly be done in constructor sinCTFfactor = np.zeros((self.nfy, self.nfx)) Delta = np.zeros_like(sinCTFfactor) for position in positions: sinCTFfactor = sinCTFfactor + self.sinchirp[position]*image[position] Delta = Delta + self.sinchirp[position]*self.sinchirp[position] phase = sinCTFfactor / (2*Delta + self.Alpha) phase = np.real(np.fft.ifft2(phase)) attenuation = np.zeros_like(phase) return phase, attenuation
[docs]class CTFHOM(PhaseRetrievalAlgorithm2D): """ Contrast Transfer Function for homoegeneous objects [1]. References ---------- [1] Villanueva-Perez et al. Optics Letters 42 (2017) 1133 Parameters ---------- delta_beta : float, optional Material dependent ratio delta over beta. """ def __init__(self, dataset=None, delta_beta=100, **kwargs): self._beta_delta = 1/delta_beta self.padding = 2 super().__init__(dataset, **kwargs) def _algorithm(self, image, positions=None): FID = np.fft.fft2(image-1) # Generate CTF factors # TODO: should possibly be done in constructor CTFfactor = np.zeros((self.nfy, self.nfx)) Delta = np.zeros_like(CTFfactor) for distance in positions: factor = self.sinchirp[distance] - \ self._beta_delta * self.coschirp[distance] CTFfactor = CTFfactor + factor*FID[distance] Delta = Delta + factor*factor phase = CTFfactor/(2*Delta+self.Alpha) phase = np.real(np.fft.ifft2(phase)) attenuation = phase*self._beta_delta return phase, attenuation
[docs]class CTF(PhaseRetrievalAlgorithm2D): """ Contrast Transfer Function [1]. References ---------- [1] Cloetens et al. Appl. Phys. Lett. 75 (1999) 2912 """ def __init__(self, dataset=None, **kwargs): self.padding = 2 super().__init__(dataset, **kwargs) def _compute_factors(self): """Compute specific factors for CTF phase retrieval""" super()._compute_factors() self.A = np.zeros((self.nfy, self.nfx)) self.B = self.A.copy() self.C = self.A.copy() for distance in range(self.ND): self.A += self.sinchirp[distance] * self.coschirp[distance] self.B += self.sinchirp[distance] * self.sinchirp[distance] self.C += self.coschirp[distance] * self.coschirp[distance] self.Delta = self.B*self.C - self.A**2 def __getstate__(self): """Includes specific CTF factors""" state = super().__getstate__() del state['A'], state['B'], state['C'], state['Delta'] return state def _algorithm(self, image, positions=None): FID=np.fft.fft2(image) # Generate CTF factors # TODO: should possibly be done in constructor sinCTFfactor = np.zeros((self.nfy, self.nfx)) cosCTFfactor = np.zeros_like(sinCTFfactor) for distance in positions: sinCTFfactor = sinCTFfactor + self.sinchirp[distance]*FID[distance] cosCTFfactor = cosCTFfactor + self.coschirp[distance]*FID[distance] # TODO: The removal of the delta is not explicit in the paper # but should probably be done # s{k}(1,1) -= nf*mf; # remove 1 in real space # TODO: verify correct padding phase = (self.C*sinCTFfactor - self.A*cosCTFfactor) / (2*self.Delta + self.Alpha) attenuation = (self.A*sinCTFfactor - self.B*cosCTFfactor) / (2*self.Delta + self.Alpha) phase = np.real(np.fft.ifft2(phase)) attenuation = np.real(np.fft.ifft2(attenuation)) return phase, attenuation
[docs]class Mixed(PhaseRetrievalAlgorithm2D): """Mixed approach phase retrieval Note ---- Legacy code to be aligned with current APIxs """ # Guigay et al. Optics Letters 32, 1617, 2007, Langer et al. TIP 19, 2428, 2010 # TODO: Regularisation should somehow be separated def __init__(self, dataset): self.padding = 2 super().__init__(dataset) self.delta_beta = 500 #TODO: should there even be a default value? Where should delta_beta live? self.sumAD2 = np.zeros((self.nfy, self.nfx)) # Cf = np.zeros((self.nfy, self.nfx)) # Cg = np.zeros((self.nfy, self.nfx)) self.coschirp_dfx = [0 for x in range(self.ND)] self.coschirp_dfy = [0 for x in range(self.ND)] for distance in range(1,self.ND): self.sumAD2 = self.sumAD2 + self.sinchirp[distance]*self.sinchirp[distance] # Denominator Eq. 17 self.coschirp_dfx[distance] = self.coschirp[distance] * 1j*self.fx*self.Fresnel_number[distance] # First part of Delta self.coschirp_dfy[distance] = self.coschirp[distance] * 1j*self.fy*self.Fresnel_number[distance] # First part of Delta self.R = np.sqrt(np.square(self.fx) + np.square(self.fy)) self.LP_cutoff = 0.5 self.LP_slope = .5e3 self.LPfilter = 1 - 1/(1 + np.exp(-self.LP_slope * (self.R-self.LP_cutoff))) #Logistic filter self.sigma_I0filter = 10 self.iterations = 5 self.prior = 'homogeneous' # 'homogeneous' 'functional' TODO: put in data file
[docs] def get_prior(self, projection): """Generates a prior estimate of the phase""" if self.prior == 'forward': self.LPfilter * np.fft.fft2(np.log(I0)*I0/2) else: self.LPfilter * np.fft.fft2(np.log(I0)*I0/2)
[docs] def Lcurve(self, dataset, projection): """Calculate the L-curve (for finding regularisation parameter)""" Lcurve_min = -9 Lcurve_max = -3 Lcurve_step = 1 Lcurve_range = np.arange(Lcurve_min, Lcurve_max+1, Lcurve_step, dtype=float) alpha_HF = -10 FID = [0 for x in range(self.ND)] for distance in range(self.ND): FID[distance] = dataset.get_projection(projection, distance+1, 'Fourier') if not distance == 0: FID[distance] = FID[distance] - FID[0] #Wouldn't it be better filtered as well? FI0_filtered = scipy.ndimage.fourier_gaussian(FID[0], self.sigma_I0filter) I0_filtered = np.real(np.fft.ifft2(FI0_filtered)) dfxI0 = np.real(np.fft.ifft2(2j*np.pi*self.fx * FI0_filtered)) / I0_filtered dfyI0 = np.real(np.fft.ifft2(2j*np.pi*self.fy * FI0_filtered)) / I0_filtered I0 = np.real(np.fft.ifft2(FID[0])) #TODO: refactor doule code with reconstruct_projection if self.prior == 'forward': prior = dataset.get_image(projection, 'prior', 'Fourier') self.delta_beta = 1 #TODO: Necessary? print('forward') else: prior = np.fft.fft2(np.log(I0)*I0/2) prior = self.LPfilter * prior model_error = np.zeros(Lcurve_range.shape) regularisation_error = np.zeros(Lcurve_range.shape) for index in range(Lcurve_range.shape[0]): print('Alpha: {}, Delta/Beta: {}'.format(Lcurve_range[index], self.delta_beta)) self.alpha = np.array([10**Lcurve_range[index], 10**alpha_HF]) phase = self.reconstruct_projection(dataset, projection) # TODO: reconstruct_projection could return image? # Propagate with mixed (move to propagator) phase_dfxI0 = np.fft.fft2(phase*dfxI0) phase_dfyI0 = np.fft.fft2(phase*dfyI0) phasef = np.fft.fft2(phase) for distance in range(1, self.ND): mixed_contrast = 2*self.sinchirp[distance]*phasef + self.coschirp_dfx[distance]*phase_dfxI0 + self.coschirp_dfy[distance]*phase_dfyI0 model_difference = FID[distance] - mixed_contrast model_difference[0, 0] = 0 # Disregard offset (necessary?) model_difference_r = np.fft.ifft2(model_difference) model_difference_rc = np.real(model_difference_r[self.ny//2:-self.ny//2, self.nx//2:-self.nx//2]) model_error[index] += np.sum(np.square(model_difference_rc)) / (self.nx*self.ny) model_error[index] = model_error[index] / (self.ND-1) regularisation_difference = phasef - self.delta_beta*prior regularisation_difference = np.fft.ifft2(regularisation_difference) regularisation_difference = np.real(regularisation_difference[self.ny//2:-self.ny//2, self.nx//2:-self.nx//2]) regularisation_error[index] = np.sum(np.square(regularisation_difference)) / (self.nx*self.ny) print("ME: {} , RE: {}".format(model_error[index], regularisation_error[index])) model_error_log = np.log10(model_error) regularisation_error_log = np.log10(regularisation_error) # model_error_log = np.array([-4.2471, -4.2460, -4.2442, -4.2384, -4.2114, -4.1003, -3.8592]) # regularisation_error_log = np.array([1.477485, 0.790983, 0.313785, -0.039234, -0.403655, -0.896721, -1.472450]) # LR = np.array([-9, -8, -7, -6, -5, -4, -3]) Loversamp=10 #LR = np.log10(Lcurve_range) t = np.linspace(0, 1, len(model_error_log)) ts = np.linspace(0, 1, len(model_error_log)*Loversamp) M = interpolate.UnivariateSpline(t, model_error_log, s=0) R = interpolate.UnivariateSpline(t, regularisation_error_log, s=0) Mp = M.derivative() Rp = R.derivative() Mpp = Mp.derivative() Rpp = Rp.derivative() K = (Mp(ts)*Rpp(ts) - Rp(ts)*Mpp(ts)) / (Rp(ts)**2 + Mp(ts)**2)**1.5 # Langer 2010 eq. 25 Mts = M(ts) Rts = R(ts) # Kmax = K[Mts.argmin():Rts.argmin()].argmax() Kmax = K.argmax() Mmin = Mts.argmin() Lts = np.linspace(Lcurve_range.min(), Lcurve_range.max(), len(Lcurve_range)*Loversamp) #TODO: Plot function in display self.alpha[0]=10**Lts[Kmax] dataset.alpha = self.alpha lcurve_filename=dataset.path+'/'+dataset.name+'_/lcurve.pickle' #TODO: refactor with open(lcurve_filename, 'wb') as f: pickle.dump([Lts, model_error_log, regularisation_error_log, Mts, Rts, K, Kmax, Mmin], f, pickle.HIGHEST_PROTOCOL) dataset.WriteParameterFile() self.display_Lcurve(dataset)
[docs] def display_Lcurve(self, dataset): """Displays the L-curve""" lcurve_filename=dataset.path+'/'+dataset.name+'_/lcurve.pickle' #TODO: refactor with open(lcurve_filename, 'rb') as f: Lts, model_error_log, regularisation_error_log, Mts, Rts, K, Kmax, Mmin = pickle.load(f) ### Loversamp=10 #LR = np.log10(Lcurve_range) t = np.linspace(0, 1, len(model_error_log)) ts = np.linspace(0, 1, len(model_error_log)*Loversamp) M = interpolate.InterpolatedUnivariateSpline(t, model_error_log, k=4) R = interpolate.InterpolatedUnivariateSpline(t, regularisation_error_log, k=4) Mp = M.derivative() Rp = R.derivative() Mpp = M.derivative(2) Rpp = R.derivative(2) K = (Mp(ts)*Rpp(ts) - Rp(ts)*Mpp(ts)) / (Rp(ts)**2 + Mp(ts)**2)**1.5 # Langer 2010 eq. 25 Mts = M(ts) Rts = R(ts) ### pyplot.figure() pyplot.plot(Lts, K) pyplot.show() pyplot.figure() pyplot.plot(model_error_log, regularisation_error_log, 'rx', Mts, Rts, 'b-', Mts[Kmax], Rts[Kmax], 'go', Mts[Mmin], Rts[Mmin], 'co') pyplot.show() # pyplot.figure() # pyplot.title("Mts") # pyplot.plot(Lts,Mts) # pyplot.show() # # pyplot.figure() # pyplot.title("Rts") # pyplot.plot(Lts,Rts) # pyplot.show() # # pyplot.figure() # pyplot.title("Mp") # pyplot.plot(Lts,Mp(ts)) # pyplot.show() # # pyplot.figure() # pyplot.title("Mpp") # pyplot.plot(Lts,Mpp(ts)) # # pyplot.show() # # pyplot.figure() # pyplot.title("Rp") # pyplot.plot(Lts,Rp(ts)) # pyplot.show() # # pyplot.figure() # pyplot.title("Rpp") # pyplot.plot(Lts,Rpp(ts)) # pyplot.show() print('Maximum curvature at: {}'.format(Lts[Kmax])) print('Minimum model error at: {}'.format(Lts[Mmin])) pass
def _algorithm(self, image, positions=None): FID = np.fft.fft2(image) for distance in positions: if not distance == 0: #TODO: Handle case when position_number makes sense FID[distance] = FID[distance] - FID[0] #Wouldn't it be better filtered as well? # TODO: Need a utility function for filters # TODO: Quick and dirty implementation to be refactored I0 = np.real(np.fft.ifft2(scipy.ndimage.fourier_gaussian(FID[0], self.sigma_I0filter))) # Gradient of attenuation image. dfxI0 = np.real(np.fft.ifft2(2j*np.pi*self.fx * scipy.ndimage.fourier_gaussian(FID[0], self.sigma_I0filter))) / I0 dfyI0 = np.real(np.fft.ifft2(2j*np.pi*self.fy * scipy.ndimage.fourier_gaussian(FID[0], self.sigma_I0filter))) / I0 # TODO: I guess priors etc should be in ther own classes/functions? How to handle the Lcurve case... # Estimate of phase*absorption with delta/beta = 1 if self.prior == 'forward': prior = dataset.get_image(projection, 'prior', 'Fourier') self.delta_beta = 1 #TODO: Necessary? print('forward') else: prior = np.fft.fft2(np.log(I0)*I0/2) prior = self.LPfilter * prior phase = np.zeros((self.nfy, self.nfx)) for n in range(self.iterations): nominator_term = np.zeros((self.nfy, self.nfx)) phase_dfxI0 = np.fft.fft2(phase*dfxI0) phase_dfyI0 = np.fft.fft2(phase*dfyI0) for distance in positions[1:]: nominator_term = nominator_term + self.sinchirp[distance] * (FID[distance] - self.coschirp_dfx[distance]*phase_dfxI0 - self.coschirp_dfy[distance]*phase_dfyI0) nominator_term = nominator_term / (self.ND-1) phase_n = (nominator_term + (self.Alpha*self.delta_beta*prior)) / (self.Alpha+self.sumAD2) phase_n = np.real(np.fft.ifft2(phase_n)) print("Iteration: {} RMS: {}".format(n, np.sqrt(np.sum((phase_n-phase)*(phase_n-phase).conjugate())/(self.nfx*self.nfy)))) phase = phase_n attenuation = np.zeros_like(phase) return phase, attenuation
[docs] def create_multimaterial_prior(self, data): """Generate a multi-material prior from a tomographic reconstruction of a contact plane scan Note ---- Legacy code to be refactored. """ # Reconstruct attenuation (if not already reconstructed) #TODO: Put parameters in parameter file threshold = 2.1 delta_beta_soft = 1938 delta_beta_hard = 310 delta_beta_tmp = [2480, 711, 132] #TODO: refactor median_size=2 #TODO: How to 3d filter memmaps? I guess one could parallelize proper # Create/read prior volume print('Creating multi-material prior: db_soft: {}, db_hard: {}, threshold: {}'.format(delta_beta_soft, delta_beta_hard, threshold)) segmentation_filename = '/mntdirect/_data_id19_bones01/bones3/max/holodata/knee/lbtoKneeWTOA8weekFeb15/18_OA/18_OA_8weeks_1_slice_pag_db0250_1400_Seg.raw' #TODO: Refactor #segmentation_filename = '' attenuation_volume = np.memmap(data.attenuation_vol_filename, dtype=np.float32, mode='r') prior = np.memmap(data.prior_vol_filename, dtype=np.float32, mode='w+', shape=attenuation_volume.shape) if segmentation_filename: print('Using pre-segmented volume') segmentation_volume = np.memmap(segmentation_filename, mode='r', dtype=np.uint8, shape=attenuation_volume.shape) #prior[:] = np.piecewise(segmentation_volume[:].astype(np.float32, copy=False), [segmentation_volume < 128, segmentation_volume == 127, segmentation_volume > 127], [delta_beta_tmp[0], delta_beta_tmp[1], delta_beta_tmp[2]]) #del segmentation_volume, prior print('Median filtering attenuation') prior = np.memmap(data.prior_vol_filename, dtype=np.float32, mode='w+', shape=attenuation_volume.shape) ndimage.filters.median_filter(attenuation_volume, size=3, mode='nearest', output=prior) prior.flush() prior[:] = -prior[:] * np.piecewise(segmentation_volume[:].astype(np.float32, copy=False), [segmentation_volume < 128, segmentation_volume == 127, segmentation_volume > 127], [delta_beta_tmp[0], delta_beta_tmp[1], delta_beta_tmp[2]]) / 2 prior.flush() else: prior[:] = -attenuation_volume[:] * np.piecewise(attenuation_volume[:], [attenuation_volume < threshold, attenuation_volume >= threshold], [delta_beta_soft, delta_beta_hard]) / 2 #TODO: Interface for selecting parameters? #TODO: Read segmented volume #TODO: functional del prior, attenuation_volume
[docs]class HIO_ER(PhaseRetrievalAlgorithm2D): """ Sequence of Hybrid Input Output [1] and Error Reduction [2]. Attributes ---------- retriever : PhaseRetrievalAlgorithm2D Algorithm for initialisation iterations : int Number of global iterations iterations_hio : int Number of HIO iterations per iteration iterations_er : int Number of ER iterations per iteration step_size_phase : float Update step size for the phase step_size_attenuation : float Update step size for the attenuation References ---------- [1] Fienup Appl. Opt. 21 (1982) 2758 [2] Gerchberg & Saxton Optik 35 (1972) 237 """ #TODO: Needs refactoring to conform to coding standards def __init__(self,dataset=None, **kwargs): self.padding=1 self.retriever = CTFPurePhase(dataset, **kwargs) super().__init__(dataset, **kwargs) self.iterations = 4 # 4 in Yuhe code self.iterations_hio = 45 # 45 in Yuhe code self.iterations_er = 5 # 5 in Yuhe code self.step_size_phase = 0.2 self.step_size_attenuation = 0.2 #self.retriever=CTF(dataset=dataset) self.propagator = self.simple_propagator(self.pixel_size, self.Lambda, self.distance) # TODO: But then it's passed around anyway
[docs] def reconstruct_projection(self, dataset, projection=0, positions=None, pad=False): super().reconstruct_projection(dataset=dataset, projection=projection, positions=positions, pad=pad)
@property def retriever(self): return(self._retriever) @retriever.setter def retriever(self, retriever): self._retriever = retriever self._retriever_class = self._retriever.__class__ @property def alpha(self): return self._alpha @alpha.setter def alpha(self, value): self.retriever.alpha = value self._alpha = value def __getstate__(self): """Includes HIO specific variables """ state = self.__dict__.copy() del state['fx'], state['fy'], state['sinchirp'], state['coschirp'], state['retriever'], state['propagator'] #TODO: Needs proper separation PRAlg2d return state def __setstate__(self, state): """Includes HIO specifics (retriever for initialisation notably)""" self.__dict__.update(state) self._compute_factors() self.retriever = self._retriever_class(dataset=self.dataset) # only dataset, reconstruct_image not parallellisable ? self.propagator = self.simple_propagator(self.pixel_size, self.Lambda, self.distance) #TODO: update to use propagator module
[docs] def amplitude_constraint(self, wavefront, amplitude, propagator, mask=[]): """Apply amplitude constraint. Parameters ---------- wavefront : complex np.array Wavefront to constrain. amplitude : np.array Amplitude to impose. propagator : complex np.array Propagator corresponding to effective distance of amplitude. mask : np.array, optional Zone to apply constraint. Returns ------- wavefront_constrained : complex np.array Wavefront after applied constraint. """ #TODO: Proper handling of padding if mask == []: mask = np.ones_like(amplitude) wavefront_aux = np.fft.ifft2(np.fft.fft2(wavefront) * propagator) # TODO: Should be done with propagator instead wavefront_aux = np.where(mask != 0, amplitude * np.exp(1j * np.angle(wavefront_aux)), wavefront_aux) # Apply amplitude constraint wavefront_constrained = np.fft.ifft2(np.fft.fft2(wavefront_aux) * np.conj(propagator)) return wavefront_constrained
[docs] def hybrid_input_output(self, wavefront, initial_wavefront, support, step_size_attenuation, step_size_phase): """ One iteration of the Hybrid Input Output algorithm. Parameters ---------- wavefront : complex np.array Constrained wavefront. initial_wavefront : complex np.array Wavefront to update. support : np.array Support of object. step_size_attenuation : float Step size for attenuation update. step_size_phase : float Step size for phase update. Returns ------- wavefront_updated : complex np.array Updated wavefront. """ # TODO: Refactor spaghetti code, make step sizes members. Make methods reasonable to use from outside instead of arbitrary cut up blocks. # phase constraint # TODO: Should be split out into its own class phase = np.angle(wavefront) phase = np.where(support == 0, np.angle(initial_wavefront) - step_size_phase * np.angle(wavefront), phase) # abs constraint attenuation = np.where(np.abs(wavefront) < 1, np.abs(wavefront), 1) attenuation = np.where(np.abs(wavefront) > 0, attenuation, 0) attenuation = np.where(support == 0, np.abs(initial_wavefront) - step_size_attenuation * (np.abs(wavefront) - 1), attenuation) wavefront_updated = attenuation * np.exp(1j * phase) return wavefront_updated
[docs] def error_reduction(self, wavefront, support): """ One iteration of Error Reduction. Parameters ---------- wavefront : complex np.array wavefront to update. support : np.array Support constraint. Returns ------- wavefront_updated : complex np.array Updated wavefront. """ # phase constraint phase = np.where(np.angle(wavefront) < 0, np.angle(wavefront), 0) phase = np.where(support == 0, 0, phase) # abs constraint attenuation = np.where(np.abs(wavefront) < 1, np.abs(wavefront), 1) attenuation = np.where(np.abs(wavefront) > 0, attenuation, 0) attenuation = np.where(support == 0, 1, attenuation) wavefront_updated = attenuation * np.exp(1j * phase) return wavefront_updated
[docs] def error_estimate(self, wavefront, amplitude, propagator, mask=[]): """ Estimate fit to data. Parameters ---------- wavefront : complex np.array Wavefront for estimation. amplitude : np.array Amplitude from measured image. propagator : complex np.array Fresnel propagator corresponding to effective propagation distance in measured image. mask : np.array Restrict estimate to a region of interest. Returns ------- error : float MSE calculated and measured amplitude. """ # if mask == []: # mask = np.ones_like(amplitude) # mask = np.where(mask != 0, 1, 0) # N = np.sum(mask[:]) wavefront_aux = np.fft.ifft2(np.fft.fft2(wavefront) * propagator) aux = (np.abs(wavefront_aux) - amplitude) ** 2 error = np.sum(aux[:]) / (self.nfx*self.nfy) return error
def _algorithm(self, image, positions=None, support=None): phase, attenuation = self.retriever.reconstruct_image(image) amplitude = np.sqrt(image) #FID=np.fft.fft2(image) # phase = Utilities.resize(phase, (self.ny, self.nx)) # attenuation = Utilities.resize(attenuation, (self.ny, self.nx)) adjust_offset = True if adjust_offset: phase = phase - phase.max() if not support: support = np.ones_like(image[0]) initial_guess = np.exp(-attenuation) * np.exp(1j * phase) #mask = np.ones_like(initial_guess) amplitude = np.sqrt(image) wavefront = initial_guess # Prepare wavefront (ML: What is this supposed to do?) #mask = np.where(mask != 0, 1, 0) # Prepare mask TODO: This doesńt actually do anything? support = np.where(support != 0, 1, 0) # Prepare support reconstruction = np.empty_like(image, dtype="complex_") for distance in positions: print(F'========== processing distance {distance+1} ==========') error_count = 0 for ii in range(self.iterations): for jj in range(self.iterations_hio + self.iterations_er): initial_wavefront = wavefront # Amplitude constraint wavefront = self.amplitude_constraint(wavefront, amplitude[distance], self.propagator[distance]) if jj < self.iterations_hio: wavefront = self.hybrid_input_output(wavefront, initial_wavefront, support, self.step_size_attenuation, self.step_size_phase) else: wavefront = self.error_reduction(wavefront, support) error_count += 1 error_it = self.error_estimate(wavefront, amplitude[distance], self.propagator[distance]) print('Iteration {:04d}, error: {:0.2g}'.format(error_count,np.real(error_it))) #TODO: print every interation? # object = np.fft.fftshift(self.phi) #ML: Why fftshift?! # object = wavefront #ML: Why fftshift?! reconstruction[distance] = wavefront phase = np.average(np.angle(reconstruction),axis=0) attenuation = np.average(np.abs(reconstruction),axis=0) return phase, attenuation
[docs]class RAAR(PhaseRetrievalAlgorithm2D): """ Relaxed averaged alternating reflections [1] Attributes ---------- retriever : PhaseRetrievalAlgorithm2D Algorithm for initialisation iterations : int Number of RAAR iterations step_size_phase : float Update step size for the phase step_size_attenuation : float Update step size for the attenuation References ---------- [1] Luke Inverse Problems 21 (2005) 3750 """ def __init__(self,dataset=None, **kwargs): self.padding=1 self.retriever = TIEHOM(dataset, 100) #CTFPurePhase(dataset, **kwargs) super().__init__(dataset, **kwargs) self.iterations = 200 self.step_size_phase = 1 self.step_size_attenuation = 1 self.propagator = self.simple_propagator(self.pixel_size, self.Lambda, self.distance)
[docs] def reconstruct_projection(self, dataset, projection=0, positions=None, pad=False): super().reconstruct_projection(dataset=dataset, projection=projection, positions=positions, pad=pad)
@property def retriever(self): return(self._retriever) @retriever.setter def retriever(self, retriever): self._retriever = retriever self._retriever_class = self._retriever.__class__ @property def alpha(self): return self._alpha @alpha.setter def alpha(self, value): self.retriever.alpha = value self._alpha = value def __getstate__(self): """Includes specific variables """ state = self.__dict__.copy() del state['fx'], state['fy'], state['sinchirp'], state['coschirp'], state['retriever'], state['propagator'] return state def __setstate__(self, state): """Includes specifics (retriever for initialisation notably)""" self.__dict__.update(state) self._compute_factors() self.retriever = self._retriever_class(dataset=self.dataset) self.propagator = self.simple_propagator(self.pixel_size, self.Lambda, self.distance)
[docs] def amplitude_constraint(self, wavefront, amplitude, propagator, mask=[]): """Apply amplitude constraint. Parameters ---------- wavefront : complex np.array Wavefront to constrain. amplitude : np.array Amplitude to impose. propagator : complex np.array Propagator corresponding to effective distance of amplitude. mask : np.array, optional Zone to apply constraint. Returns ------- wavefront_constrained : complex np.array Wavefront after applied constraint. """ if mask == []: mask = np.ones_like(amplitude) wavefront_aux = np.fft.ifft2(np.fft.fft2(wavefront) * propagator) # TODO: Should be done with propagator instead wavefront_aux = np.where(mask != 0, amplitude * np.exp(1j * np.angle(wavefront_aux)), wavefront_aux) # Apply amplitude constraint wavefront_constrained = np.fft.ifft2(np.fft.fft2(wavefront_aux) * np.conj(propagator)) return wavefront_constrained
[docs] def relaxed_averaged_alternating_reflections(self, wavefront, initial_wavefront, support, step_size_attenuation, step_size_phase): """ One iteration of the Relaxed Averaged Alternating Reflections algorithm. Parameters ---------- wavefront : complex np.array Constrained wavefront. initial_wavefront : complex np.array Wavefront to update. support : np.array Support of object. step_size_attenuation : float Step size for attenuation update. step_size_phase : float Step size for phase update. Returns ------- wavefront_updated : complex np.array Updated wavefront. """ phase = np.where(np.angle(2 * wavefront - initial_wavefront) < 0, np.angle(wavefront), 0) phase = np.where(support == 0, step_size_phase * np.angle(initial_wavefront) + (1 - 2 * step_size_phase) * np.angle(wavefront), phase) attenuation = np.where(2 * (1 - np.abs(wavefront)) - (1 - np.abs(initial_wavefront)) > 0, np.abs(wavefront), 1) attenuation = np.where(support == 0, step_size_attenuation * np.abs(initial_wavefront) + (1 - 2 * step_size_attenuation) * (np.abs(wavefront) - 1), attenuation) wavefront_updated = attenuation * np.exp(1j * phase) return wavefront_updated
[docs] def error_estimate(self, wavefront, amplitude, propagator, mask=[]): """ Estimate fit to data. Parameters ---------- wavefront : complex np.array Wavefront for estimation. amplitude : np.array Amplitude from measured image. propagator : complex np.array Fresnel propagator corresponding to effective propagation distance in measured image. mask : np.array Restrict estimate to a region of interest. Returns ------- error : float MSE calculated and measured amplitude. """ # if mask == []: # mask = np.ones_like(amplitude) # mask = np.where(mask != 0, 1, 0) # N = np.sum(mask[:]) wavefront_aux = np.fft.ifft2(np.fft.fft2(wavefront) * propagator) aux = (np.abs(wavefront_aux) - amplitude) ** 2 error = np.sum(aux[:]) / (self.nfx*self.nfy) return error
def _algorithm(self, image, positions=None, support=None): phase, attenuation = self.retriever.reconstruct_image(image) ### It seems that giving an initialisation with artefacts in the low frequency range can hardly be removed ### Better to initialised with zeros? # phase, attenuation = phase * 0, attenuation * 0 amplitude = np.sqrt(image) adjust_offset = True if adjust_offset: phase = phase - phase.max() if not support: support = np.ones_like(image[0]) initial_guess = np.exp(-attenuation) * np.exp(1j * phase) #mask = np.ones_like(initial_guess) amplitude = np.sqrt(image) wavefront = initial_guess # Prepare wavefront (ML: What is this supposed to do?) #mask = np.where(mask != 0, 1, 0) # Prepare mask TODO: This doesńt actually do anything? support = np.where(support != 0, 1, 0) # Prepare support reconstruction = np.empty_like(image, dtype="complex_") for ii in range(self.iterations): error_count = 0 print(f'========== Iterations {ii+1} ==========') for distance in positions: initial_wavefront = wavefront wavefront = self.amplitude_constraint(wavefront, amplitude[distance], self.propagator[distance]) wavefront = self.relaxed_averaged_alternating_reflections(wavefront, initial_wavefront, support, self.step_size_attenuation, self.step_size_phase) error_count += 1 error_it = self.error_estimate(wavefront, amplitude[distance], self.propagator[distance]) print(f'Distance {distance} error: {np.real(error_it) :0.2g}') if ii == (self.iterations - 1): ### last iteration reconstruction[distance] = wavefront phase = np.average(np.angle(reconstruction),axis=0) attenuation = np.average(np.abs(reconstruction),axis=0) return phase, attenuation
[docs]class HPR(PhaseRetrievalAlgorithm2D): """ Hybrid Projection Reflection [1] Attributes ---------- retriever : PhaseRetrievalAlgorithm2D Algorithm for initialisation iterations : int Number of HPR iterations step_size_phase : float Update step size for the phase step_size_attenuation : float Update step size for the attenuation References ---------- [1] Bauschke & Combettes OSA 20(6) (2003) 1025-1034 """ def __init__(self,dataset=None, **kwargs): self.padding=1 self.retriever = CTFPurePhase(dataset, **kwargs) super().__init__(dataset, **kwargs) self.iterations = 200 self.step_size_phase = 0.75 self.step_size_attenuation = 0.75 self.propagator = self.simple_propagator(self.pixel_size, self.Lambda, self.distance)
[docs] def reconstruct_projection(self, dataset, projection=0, positions=None, pad=False): super().reconstruct_projection(dataset=dataset, projection=projection, positions=positions, pad=pad)
@property def retriever(self): return(self._retriever) @retriever.setter def retriever(self, retriever): self._retriever = retriever self._retriever_class = self._retriever.__class__ @property def alpha(self): return self._alpha @alpha.setter def alpha(self, value): self.retriever.alpha = value self._alpha = value def __getstate__(self): """Includes specific variables """ state = self.__dict__.copy() del state['fx'], state['fy'], state['sinchirp'], state['coschirp'], state['retriever'], state['propagator'] #TODO: Needs proper separation PRAlg2d return state def __setstate__(self, state): """Includes specifics (retriever for initialisation notably)""" self.__dict__.update(state) self._compute_factors() self.retriever = self._retriever_class(dataset=self.dataset) # only dataset, reconstruct_image not parallellisable ? self.propagator = self.simple_propagator(self.pixel_size, self.Lambda, self.distance) # TODO: But then it's passed around anyway
[docs] def amplitude_constraint(self, wavefront, amplitude, propagator, mask=[]): """Apply amplitude constraint. Parameters ---------- wavefront : complex np.array Wavefront to constrain. amplitude : np.array Amplitude to impose. propagator : complex np.array Propagator corresponding to effective distance of amplitude. mask : np.array, optional Zone to apply constraint. Returns ------- wavefront_constrained : complex np.array Wavefront after applied constraint. """ #TODO: Proper handling of padding if mask == []: mask = np.ones_like(amplitude) wavefront_aux = np.fft.ifft2(np.fft.fft2(wavefront) * propagator) # TODO: Should be done with propagator instead wavefront_aux = np.where(mask != 0, amplitude * np.exp(1j * np.angle(wavefront_aux)), wavefront_aux) # Apply amplitude constraint wavefront_constrained = np.fft.ifft2(np.fft.fft2(wavefront_aux) * np.conj(propagator)) return wavefront_constrained
[docs] def hybrid_projection_reflection(self, wavefront, initial_wavefront, support, step_size_attenuation, step_size_phase): """ One iteration of the Hybrid Projection Reflection. Parameters ---------- wavefront : complex np.array Constrained wavefront. initial_wavefront : complex np.array Wavefront to update. support : np.array Support of object. step_size_attenuation : float Step size for attenuation update. step_size_phase : float Step size for phase update. Returns ------- wavefront_updated : complex np.array Updated wavefront. """ phase_condition = 2 * np.angle(wavefront) - np.angle(initial_wavefront) - (1 - step_size_phase) * np.angle(wavefront) attenuation_condition = 2 * (np.abs(wavefront) - 1) - (np.abs(initial_wavefront) - 1) - (1 - step_size_attenuation) * (np.abs(wavefront) - 1) phase = np.where(phase_condition < 0, np.angle(wavefront), 0) phase = np.where(support == 0, np.angle(initial_wavefront) - step_size_phase * np.angle(wavefront), phase) attenuation = np.where(attenuation_condition < 0, np.abs(wavefront), 1) attenuation = np.where(support == 0, np.abs(initial_wavefront) - step_size_attenuation * (np.abs(wavefront) - 1), attenuation) wavefront_updated = attenuation * np.exp(1j * phase) return wavefront_updated
[docs] def error_estimate(self, wavefront, amplitude, propagator, mask=[]): """ Estimate fit to data. Parameters ---------- wavefront : complex np.array Wavefront for estimation. amplitude : np.array Amplitude from measured image. propagator : complex np.array Fresnel propagator corresponding to effective propagation distance in measured image. mask : np.array Restrict estimate to a region of interest. Returns ------- error : float MSE calculated and measured amplitude. """ # if mask == []: # mask = np.ones_like(amplitude) # mask = np.where(mask != 0, 1, 0) # N = np.sum(mask[:]) wavefront_aux = np.fft.ifft2(np.fft.fft2(wavefront) * propagator) aux = (np.abs(wavefront_aux) - amplitude) ** 2 error = np.sum(aux[:]) / (self.nfx*self.nfy) return error
def _algorithm(self, image, positions=None, support=None): phase, attenuation = self.retriever.reconstruct_image(image) ### It seems that giving an initialisation with artefacts in the low frequency range can hardly be removed ### Better to initialised with zeros? # phase, attenuation = phase * 0, attenuation * 0 amplitude = np.sqrt(image) adjust_offset = True if adjust_offset: phase = phase - phase.max() if not support: support = np.ones_like(image[0]) initial_guess = np.exp(-attenuation) * np.exp(1j * phase) #mask = np.ones_like(initial_guess) amplitude = np.sqrt(image) wavefront = initial_guess # Prepare wavefront (ML: What is this supposed to do?) #mask = np.where(mask != 0, 1, 0) # Prepare mask TODO: This doesńt actually do anything? support = np.where(support != 0, 1, 0) # Prepare support reconstruction = np.empty_like(image, dtype="complex_") for ii in range(self.iterations): error_count = 0 print(F'========== Iterations {ii+1} ==========') for distance in positions: initial_wavefront = wavefront wavefront = self.amplitude_constraint(wavefront, amplitude[distance], self.propagator[distance]) wavefront = self.hybrid_projection_reflection(wavefront, initial_wavefront, support, self.step_size_attenuation, self.step_size_phase) error_count += 1 error_it = self.error_estimate(wavefront, amplitude[distance], self.propagator[distance]) print(f'Distance {distance} error: {np.real(error_it) :0.2g}') if ii == (self.iterations - 1): ### last iteration reconstruction[distance] = wavefront phase = np.average(np.angle(reconstruction),axis=0) attenuation = np.average(np.abs(reconstruction),axis=0) return phase, attenuation
[docs]class GradientDescent(PhaseRetrievalAlgorithm2D): """Gradient descent algorithm. Parameters ---------- PSF: ndarray with the same shape as images Point spread function (optional) """ def __init__(self, dataset=None, PSF=[], **kwargs): self.padding=2 self.retriever = CTF(dataset, **kwargs) super().__init__(dataset, **kwargs) # self.dataset = dataset self.PSF = PSF self.step_size = 0.5 self.iterations = 20 self.propagator = self.simple_propagator(self.pixel_size[0], self.Lambda, self.distance) #TODO: iterative algorithms could maybe have a base class? @property def retriever(self): return(self._retriever) @retriever.setter def retriever(self, retriever): self._retriever = retriever self._retriever_class = self._retriever.__class__ @property def alpha(self): return self._alpha @alpha.setter def alpha(self, value): self.retriever.alpha = value self._alpha = value def __getstate__(self): """Includes HIO specific variables """ state = self.__dict__.copy() del state['fx'], state['fy'], state['sinchirp'], state['coschirp'], state['retriever'], state['propagator'] #TODO: Needs proper separation PRAlg2d return state def __setstate__(self, state): """Includes HIO specifics (retriever for initialisation notably)""" self.__dict__.update(state) self._compute_factors() self.retriever = self._retriever_class(dataset=self.dataset) # only dataset, reconstruct_image not parallellisable ? self.propagator = self.simple_propagator(self.pixel_size, self.Lambda, self.distance) #TODO: update to use propagator module def _algorithm(self, image, positions=False): # TODO: Correct handling of inoutting images # TODO: Actually, this is steepest descent, not CG #FID=np.fft.fft2(image) print('Initialising') phase, attenuation = self.retriever.reconstruct_image(image, pad=True) adjust_offset = True if adjust_offset: phase = phase - phase.max() initial_guess = np.exp(-attenuation) * np.exp(1j * phase) object_FT = np.fft.fft2(initial_guess) current_object = object_FT for iteration in range(self.iterations): error = 0 print(F'Iteration {iteration + 1} of {self.iterations}') for distance in range(self.ND): field = np.fft.ifft2(self.propagator[distance, :, :] * object_FT) intensity_calculated = np.real(field * np.conj(field)) if self.PSF != []: intensity_calculated = np.real(np.fft.ifft2(np.fft.fft2(intensity_calculated) * self.PSF)) intensity_calculated = intensity_calculated / np.sum(intensity_calculated[:]) * self.nfx * self.nfy intensity_difference = image[distance] - intensity_calculated error = error + np.std(intensity_difference * intensity_difference) / self.ND update = np.conj(self.propagator[distance]) * np.fft.fft2(intensity_difference * field) #TODO: This just can be right current_object = current_object + (self.step_size / self.ND) * update object_FT = current_object print(F'Error for iteration {iteration + 1} is {error:.4g}') field = np.fft.ifft2(object_FT) phase = np.angle(field) attenuation = np.abs(field) return phase, attenuation
[docs]class NLCG(PhaseRetrievalAlgorithm2D): """Non-Linear Conjugate Gradient algorithm. Parameters ---------- PSF: ndarray with the same shape as images Point spread function (optional) """ def __init__(self, dataset=None, PSF=[], **kwargs): self.padding=2 self.retriever = CTFPurePhase(dataset, **kwargs) super().__init__(dataset, **kwargs) # self.dataset = dataset self.PSF = PSF self.step_size = 0.5 # For fixed step size self.iterations = 10 self.propagator = self.simple_propagator(self.pixel_size[0], self.Lambda, self.distance) #TODO: iterative algorithms could maybe have a base class? @property def retriever(self): return(self._retriever) @retriever.setter def retriever(self, retriever): self._retriever = retriever self._retriever_class = self._retriever.__class__ @property def alpha(self): return self._alpha @alpha.setter def alpha(self, value): self.retriever.alpha = value self._alpha = value def __getstate__(self): """Includes HIO specific variables """ state = self.__dict__.copy() del state['fx'], state['fabs()y'], state['sinchirp'], state['coschirp'], state['retriever'], state['propagator'] #TODO: Needs proper separation PRAlg2d return state def __setstate__(self, state): """Includes HIO specifics (retriever for initialisation notably)""" self.__dict__.update(state) self._compute_factors() self.retriever = self._retriever_class(dataset=self.dataset) # only dataset, reconstruct_image not parallellisable ? self.propagator = self.simple_propagator(self.pixel_size, self.Lambda, self.distance) #TODO: update to use propagator module def _algorithm(self, image, positions=False): # TODO: Correct handling of inoutting images # TODO: Actually, this is steepest descent, not CG #FID=np.fft.fft2(image) self.step_size = 0.5 #TODO self.limit_to_fov = True # Only consider FOV steepest direction & update of conjugate self.beta = np.zeros(self.iterations) # Weight for update of conjugate direction print('Initialising') phase, attenuation = self.retriever.reconstruct_image(image, pad=True) adjust_offset = False if adjust_offset: phase = phase - phase.max() initial_guess = np.exp(-attenuation) * np.exp(1j * phase) current_solution = np.fft.fft2(initial_guess) mask = Utilities.resize(np.ones((self.ny, self.nx)), (self.nfy, self.nfx)) for iteration in range(self.iterations): objective_cost = 0 steepest_direction = np.zeros(current_solution.shape) print(F'Iteration {iteration + 1} of {self.iterations}') for position in range(self.ND): field_detector = np.fft.ifft2(self.propagator[position] * current_solution) # Propagate intensity_calculated = np.real(field_detector * np.conj(field_detector)) # Calculate intensity if self.PSF != []: intensity_calculated = np.real(np.fft.ifft2(np.fft.fft2(intensity_calculated) * self.PSF)) # TODO: Why on earth go to SD for filtering? #intensity_calculated = intensity_calculated / np.sum(intensity_calculated[:]) * self.nfx * self.nfy # TODO: WHy normalise here? ### Calculate difference in contrast plane ### intensity_difference = intensity_calculated - image[position] ### Calculate error in contrast plane ### objective_cost = objective_cost + (np.sum(Utilities.resize(intensity_difference, [self.ny, self.nx]) ** 2) / (self.ND + self.nx + self.ny)) #TODO: Only in central area i suppose? steepest_direction = steepest_direction - (1/self.ND) * np.conj(self.propagator[position]) * np.fft.fft2(intensity_difference * field_detector * mask) # Mask? steepest_direction_SD = Utilities.resize(np.fft.ifft2(steepest_direction), (self.ny, self.nx)) if not iteration == 0: #Pollack-Ribiere if self.limit_to_fov: delta_new = np.real(np.sum(steepest_direction_SD*np.conj(steepest_direction_SD))) delta_mid = np.real(np.sum(np.conj(steepest_direction_SD)*steepest_direction_old)) #TODO: Should it really be real? Or abs? else: delta_new = np.real(np.sum(steepest_direction*np.conj(steepest_direction))) delta_mid = np.real(np.sum(np.conj(steepest_direction) * steepest_direction_old)) #TODO: Should it really be real? Or abs? self.beta[iteration] = np.max( (delta_new - delta_mid) / delta_old , 0 ) # print("Orthogonality : {}".format(np.real(delta_mid/delta_old))) conjugate_direction = steepest_direction + self.beta[iteration] * conjugate_direction; # Update conjugate direction else: conjugate_direction = steepest_direction if iteration == 0: if self.limit_to_fov: delta_old = np.real(np.sum(np.conj(steepest_direction_SD)*steepest_direction_SD)) else: delta_old = np.real(np.sum(steepest_direction*np.conj(steepest_direction))) else: delta_old = delta_new if self.limit_to_fov: steepest_direction_old = steepest_direction_SD else: steepest_direction_old = steepest_direction ### Calculate step size ### ### Update object ### current_solution = current_solution + self.step_size * conjugate_direction relative_change = np.real(self.step_size * np.sqrt(np.sum(conjugate_direction*np.conj(conjugate_direction)) / np.sum(current_solution*np.conj(current_solution)))) print(F'Relative change: {relative_change:.4g}') print(F'Cost for iteration {iteration + 1}: {objective_cost:.4g}') field = np.fft.ifft2(current_solution) phase = np.angle(field) attenuation = np.abs(field) return phase, attenuation
[docs]class GaussNewton(PhaseRetrievalAlgorithm2D): """ Iteratively Regularized Gauss Newton Method [1] Parameters ---------- GaussNewton_iterations : int Number of Gauss-Newton iterations ConjugateGradient_iterations : float Number of Conjugate Gradient iterations per Gauss-Newton iterations threshold_CG : float Threshold for Conjugate Gradient method tau : float Step size for Gauss-Newton alpha_reduce_factor : float Reduction factor for Tikhonov regularization omega: float Penalization for positivity regularization omega_augment_factor : float Multiplicative factor for positivity regularization sobolev_exponent : float Exponent of Sobolev norm for regularization phys : int 0 : No physical constraints 1 : Positivity of attenuation and phase as a regularization References ---------- [1] Maretzke OSA 24(6) (2016) 6490-6506 """ def __init__(self, dataset=None, PSF=[], **kwargs): self.padding = 2 # TODO: Padding needs to be handled cleaner. At least a default is needed. super().__init__(dataset, **kwargs) self.GaussNewton_iterations = 10 self.ConjugateGradient_iterations = 10 self.threshold_CG = 1e-10 self.tau = 1/2 self.alpha_reduce_factor = 2/3 self.omega = 0.1 self.omega_augment_factor = 2 self.sobolev_exponent = 1/2 self.phys = 1 self.propagator = self.simple_propagator(self.pixel_size, self.Lambda, self.distance) self.retriever = CTF(dataset, **kwargs)
[docs] def Frechet_derivative(self, f,P,epsilon,Gx,ND): """ Frechet derivative of forward operator. Parameters ---------- f : complex np.array f = φ - iB parametrizes the phase shifts φ and attenuation B, position the derivative is to be computed P : complex np.array Fresnel propagator. epsilon : complex np.array Input of Frechet derivative Gx : real np.array Gramian matrix of regularization term ND : int Number of positions Returns ------- frechet : real np.array Frechet derivative with respect to f at epsilon """ grad = np.zeros_like(P) wave = np.exp(-1j * f) wave_FT = np.fft.fft2(wave) for d in range(ND): epsilon_d = epsilon P_d = P[d,:,:] t1 = np.fft.ifft2( wave_FT * P_d ) epsilon_wave_FT = np.fft.fft2(epsilon_d * wave) t2 = np.fft.ifft2(epsilon_wave_FT * P_d) grad[d] = np.imag(np.conj(t1)* t2) return grad
[docs] def adjoint_Frechet_derivative(self, f,P,epsilon,Gx,ND,support,quotient_beta_delta=[]): """ Frechet derivative of forward operator. Parameters ---------- f : complex np.array f = φ - iB parametrizes the phase shifts φ and attenuation B, position the derivative is to be computed P : complex np.array Fresnel propagator. epsilon : complex np.array Input of Frechet derivative Gx : real np.array Gramian matrix of regularization term ND : int Number of positions support : np.array Support of object. quotient_beta_delta : float Beta over Delta factor if homogeneous object Returns ------- adjoint_frechet : complex np.array Adjoint of the Frechet derivative with respect to f at epsilon """ adjoint_frechet = 0 wave = np.exp(-1j * f) wave_FT = np.fft.fft2(wave) for d in range(ND): epsilon_d = epsilon[d,:,:] P_d = P[d,:,:] epsilon_d = np.fft.ifft2((1/Gx) * np.fft.fft2(epsilon_d)) t1 = np.fft.ifft2(wave_FT * P_d) t1 = 1j * epsilon_d * t1 t2 = np.fft.ifft2( np.conj(P_d) * np.fft.fft2(t1) ) v = np.conj(wave) * t2 adjoint_frechet = adjoint_frechet + (1/ND) * v adjoint_frechet = np.where(support[d] == 0, 0, adjoint_frechet) if quotient_beta_delta != []: ### Homogeneous adjoint_frechet = np.real(adjoint_frechet) * (1 - 1.0j * quotient_beta_delta) return adjoint_frechet
[docs] def operator_to_inverse(self, f,P,epsilon,alpha_reg,gamma,Gx,ND,support): """ (Linear) Operator to inverse with conjugate gradient method Parameters ---------- f : complex np.array f = φ - iB parametrizes the phase shifts φ and attenuation B, position the derivative is to be computed P : complex np.array Fresnel propagator. epsilon : complex np.array Input of Frechet derivative alpha_reg : float Penalization for Thikonov regularization omega : float Penalization for positivity regularization Gx : real np.array Gramian matrix of regularization term ND : int Number of positions support : np.array Support of object. Returns ------- adjoint_gradient_inverse : complex np.array Operator to inverse with respect to f at epsilon """ phase_old = np.real(f) attenuation_old = -np.imag(f) phase_new = np.real(epsilon) attenuation_new = -np.imag(epsilon) if self.phys == 0: regularization_positivy = 0 if self.phys == 1: #### Eq. (9) regularization_positivy = (np.maximum(0,-np.sign(phase_old)) * phase_new - 1j * np.maximum(0,-np.sign(attenuation_old)) * attenuation_new) regularization_positivy = np.fft.ifft2((1/Gx) * np.fft.fft2(regularization_positivy)) gradient = self.Frechet_derivative(f,P,epsilon,Gx,ND) adjoint_gradient = self.adjoint_Frechet_derivative(f,P,gradient,Gx,ND,support) adjoint_gradient_inverse = adjoint_gradient + alpha_reg * epsilon + gamma * regularization_positivy return adjoint_gradient_inverse
def _algorithm(self, image, positions=False, support=None): # ============================================================================= # Initialisation of f = \varphi + iB # ============================================================================= phase_initial, attenuation_initial = self.retriever.reconstruct_image(image, pad=True) field_detector = np.zeros_like(image, dtype='complex_') gramian = (1 + self.fx**2 + self.fy**2)**self.sobolev_exponent f0 = phase_initial - 1.0j * attenuation_initial f = f0 if not support: support = np.ones_like(image) # ============================================================================= # Initialisation of alpha regularization # ============================================================================= F_prime_etoile = self.adjoint_Frechet_derivative(f0,self.propagator,image,gramian,self.ND,support) F_prime = self.Frechet_derivative(f0,self.propagator,F_prime_etoile,gramian,self.ND) numerator = np.sum(F_prime*np.conj(F_prime)) denominator = np.fft.ifft2(np.sqrt(gramian) * np.fft.fft2(F_prime_etoile)) denominator = np.sum(denominator * np.conj(denominator)) alpha_k = numerator/denominator ### initial alpha omega_k = self.omega ### initial omega print(f"Alpha initial : {alpha_k}") # ============================================================================= # Iterations of Gauss-Newton # ============================================================================= for ii in range(self.GaussNewton_iterations): print(f'========== Iteration GN {ii+1} ==========') attenuation = -np.imag(f) phase = np.real(f) wave_FT = np.fft.fft2(np.exp(-1j * f)) # Calulate wavefields at D for d in range(self.ND): field_detector[d] = np.fft.ifft2(wave_FT * self.propagator[d]) # Calculate intensities Icalc and errors h at current iteration intensity_calculated = np.abs(field_detector)**2 current_error = image - intensity_calculated adjoint_frechet = self.adjoint_Frechet_derivative(f,self.propagator,current_error,gramian,self.ND,support) if self.phys == 0: regularization_positivy = 0 if self.phys == 1: regularization_positivy = (np.maximum(0,-np.sign(phase)) * phase - 1j * np.maximum(0,-np.sign(attenuation)) * attenuation) regularization_positivy = np.fft.ifft2((1/gramian) * np.fft.fft2(regularization_positivy)) regularization_term = alpha_k * (f0-f) - omega_k * regularization_positivy # ============================================================================= # ============================================================================= # # Conjugate Gradient iterations # ============================================================================= # ============================================================================= b = adjoint_frechet + regularization_term # ============================================================================= # # initialisation CG # ============================================================================= x0 = np.zeros_like(f) x = np.zeros_like(f) #Ax0 Ax0 = self.operator_to_inverse(f,self.propagator,x0,alpha_k,omega_k,gramian,self.ND,support) # r0 = b - Ax0 r = b - Ax0 r_bar = np.conj(r) p = r p_bar = np.conj(p) alphaCG = 0 norm_r = np.abs(np.sum(r*r_bar)) # ============================================================================= # Iterations of CG # ============================================================================= for i in range(self.ConjugateGradient_iterations): Ap = self.operator_to_inverse(f,self.propagator,p,alpha_k,omega_k,gramian,self.ND,support) pAp = p_bar*Ap norm_Ap = np.sum(np.abs(pAp)) alphaCG = np.abs(norm_r)/np.abs(norm_Ap) x = x + alphaCG * p r = r - alphaCG * Ap r_bar = np.conj(r) norm_r_new = np.sum(r*r_bar) if np.sqrt(norm_r_new) < self.threshold_CG: pass beta = norm_r_new/norm_r p = r + beta * p p_bar = np.conj(p) norm_r = norm_r_new ###============================================================================= ###Back to Gauss-Newton iterations ###============================================================================= f = f + self.tau * x alpha_k = self.alpha_reduce_factor * alpha_k omega_k = omega_k * self.omega_augment_factor attenuation = np.imag(f) phase = -np.real(f) return phase, attenuation
[docs]class NLPDHGM(PhaseRetrievalAlgorithm2D): """ Exact and Linearised : NonLinear Primal Dual Hybrid Gradient Method [1] Parameters ---------- iterations : int Number of NL-PDHGM iterations sigma : float Step size for dual variables tau : float Step size for primal variables alpha_TGV, beta_TGV : float Penalizations for Second order Total Generalized Variation (TGV) for attenuation delta_TV: float Penalization for Total Variation (TV) for phase gamma: float \in [0,1] Over-relaxation parameters for primal variables omega: float Penalization for positivity regularization omega_augment_factor : float Multiplicative factor for positivity regularization phys : int 0 : No physical constraints 1 : Positivity of attenuation and phase as a constraint 2 : Positivity of attenuation and phase as a regularization References ---------- [1] Valkonen Inverse Problems 30 (2014) 055012 """ def __init__(self, dataset=None, PSF=[], **kwargs): super().__init__(dataset, **kwargs) self.iterations = 100 self.sigma = 0.35 self.tau = 0.35 self.alpha_TGV = 0.01 self.beta_TGV = 0.005 self.delta_TV = 0.01 self.gamma = 1 self.omega = 0.1 self.omega_augment_factor = 3/2 self.propagator = self.simple_propagator(self.pixel_size, self.Lambda, self.distance) self.phys = 2 self.retriever = CTFPurePhase(dataset, **kwargs)
[docs] def grad(self, M): """ Gradient operator. Parameters ---------- M : complex np.array Image whose gradient is to be computed. Returns ------- gradient : complex np.array Gradient of image M. """ nx = np.shape(M)[0] fx = M[np.hstack((np.arange(1,nx),[nx-1])),:] - M ny = np.shape(M)[1] fy = M[:,np.hstack((np.arange(1,ny),[ny-1]))] - M gradient = np.concatenate((fx[np.newaxis,:,:],fy[np.newaxis,:,:]), axis=0) return gradient
[docs] def div(self, P): """ Divergence operator. Parameters ---------- P : complex np.array Image whose divergence is to be computed. Returns ------- divergence : complex np.array Divergence of image P. """ Py = P[1,:,:] Px = P[0,:,:] nx = np.shape(Px)[0] fx = Px - Px[np.hstack(([0],np.arange(0,nx-1))),:] fx[0,:] = Px[0,:] # boundary fx[nx-1,:] = -Px[nx-2,:] ny = np.shape(Py)[1] fy = Py - Py[:,np.hstack(([0],np.arange(0,ny-1)))] fy[:,0] = Py[:,0] # boundary fy[:,ny-1] = -Py[:,ny-2] divergence = fx+fy return divergence
[docs] def Frechet_derivative(self, f,P,epsilon): """ Frechet derivative of forward operator. Parameters ---------- f : complex np.array f = φ - iB parametrizes the phase shifts φ and attenuation B, position the derivative is to be computed P : complex np.array Fresnel propagator. epsilon : complex np.array Input of Frechet derivative Returns ------- frechet : real np.array Frechet derivative with respect to f at epsilon """ wave = np.exp(-1j*f) t1 = epsilon * np.fft.ifft2( np.fft.fft2(wave) * P ) t2 = np.fft.ifft2( np.fft.fft2(wave) * P ) frechet = np.imag(-np.conj(t1) * t2 ) return frechet
[docs] def adjoint_Frechet_derivative(self, f,P,epsilon): """ Adjoint of the Frechet derivative of forward operator. Parameters ---------- f : complex np.array f = φ - iB parametrizes the phase shifts φ and attenuation B, position the derivative is to be computed P : complex np.array Fresnel propagator. epsilon : real np.array Input of Adjoint of Frechet derivative Returns ------- adjoint_frechet : complex np.array Adjoint of the Frechet derivative with respect to f at epsilon """ wave = np.exp(-1j*f) wave_FT = np.fft.fft2(wave) t1 = np.fft.ifft2(wave_FT * P) t1 = 1j * epsilon * t1 t2 = np.fft.ifft2( np.conj(P) * np.fft.fft2(t1) ) adjoint_frechet = np.conj(wave) * t2 return adjoint_frechet
def _algorithm(self, image, positions=False): # ============================================================================= # Initialisations of the primal & dual variables # ============================================================================= phase_initial, attenuation_initial = self.retriever.reconstruct_image(image, pad=True) field_detector = np.zeros_like(image, dtype='complex_') phase_initial = phase_initial * 0 attenuation_initial = attenuation_initial * 0 primal_1 = np.stack((phase_initial,attenuation_initial), axis = 0) ### primal variable (attenuation, phase) primal_2 = np.zeros_like(primal_1) ## primal variable v for TGV minimization primal = np.concatenate((primal_1,primal_2), axis=0) primal_relax = primal f_initial = (phase_initial - 1j * attenuation_initial) Grad = self.grad(f_initial) GradGrad = np.concatenate((self.grad(Grad[0]),self.grad(Grad[1]))) dual = np.concatenate((image,GradGrad,Grad,Grad)) * 0 ### variable dual for ii in range(0, self.iterations): print(f'========== Iteration NL-PDHGM {ii+1} ==========') f = primal[1] - 1j * primal[0] f_relax = primal_relax[1] - 1j * primal_relax[0] # ============================================================================= # ============================================================================= # # Computation of dual updates # ============================================================================= # ============================================================================= wave_FT = np.fft.fft2(np.exp(-1j * f_relax)) for d in range(self.ND): field_detector[d] = np.fft.ifft2(wave_FT * self.propagator[d]) # Calculate intensities intensity_calculated and errors h at current iteration intensity_calculated = np.abs(field_detector)**2 # # ############## Exact NL-PDHG dual[0:self.ND] = 2*(dual[0:self.ND] + self.sigma*(intensity_calculated - image) ) / ( self.sigma + 2 ) # # # ############## Linearised NL-PDHG # mean_adjoint = [] # ## Mean of all distances # for d in range(self.ND): # v = Frechet_derivative(f,propagator[d,:,:],f_relax-f) # mean_adjoint.append(v) # mean_adjoint = np.array(mean_adjoint) # tmp = intensity_calculated + mean_adjoint # dual[0:self.ND] = 2*(dual[0:self.ND] + self.sigma*(tmp-image) ) / ( self.sigma + 2 ) # ============================================================================= # Dual variables for regularization # ============================================================================= # DDv Dv1 = self.grad(primal_relax[2]) Dv2 = self.grad(primal_relax[3]) DDv = np.concatenate((Dv1,Dv2)) numerator = (dual[self.ND:self.ND+4] + self.sigma * DDv) denominator = np.maximum(self.alpha_TGV, np.abs(numerator)) + 1e-14 dual[self.ND:self.ND+4] = numerator/denominator * self.alpha_TGV # Df - v Df = self.grad(-np.imag(f_relax)) - primal_relax[2:] numerator = (dual[self.ND+4:self.ND+6] + self.sigma * Df) denominator = np.maximum(self.beta_TGV, np.abs(numerator)) + 1e-14 dual[self.ND+4:self.ND+6] = numerator/denominator * self.beta_TGV gradient = self.grad(np.real(f_relax)) numerator = (dual[self.ND+6:] + self.sigma * gradient) denominator = np.maximum(self.delta_TV, np.abs(numerator)) + 1e-14 dual[self.ND+6:] = numerator/denominator * self.delta_TV # ============================================================================= # Computation of primal update # ============================================================================= mean_adjoint = 0 dist = np.linspace(1,self.ND,self.ND) ## Average of all positions for d in range(len(dist)): v = self.adjoint_Frechet_derivative(f,self.propagator[d,:,:],dual[d,:,:]) mean_adjoint = mean_adjoint + (1/self.ND) * v ## Primal f primal_tmp = primal TGV_regularization = self.div(dual[self.ND+4:self.ND+6]) TV_regularization = self.div(dual[self.ND+6:]) primal[0] = primal[0] - self.tau * (-np.imag(mean_adjoint) - TGV_regularization) primal[1] = primal[1] - self.tau * (np.real(mean_adjoint) - TV_regularization) if self.phys == 1: ##### Positivity constraints primal[0] = np.where(primal[0] < 0, 0, primal[0]) primal[1] = np.where(primal[1] < 0, 0, primal[1]) if self.phys == 2: #### Positivity regularization primal[0] = primal[0] / (2 * self.tau * self.omega * np.maximum(0,-np.real(primal[0]))**2 + 1) primal[1] = primal[1] / (2 * self.tau * self.omega * np.maximum(0,-np.real(primal[1]))**2 + 1) self.omega = self.omega * self.omega_augment_factor ## Update auxiliary primal variable for TGV computation divergence1 = self.div(dual[self.ND:self.ND+2]) divergence2 = self.div(dual[self.ND+2:self.ND+4]) primal[2:] = primal[2:] + self.tau * (dual[self.ND+4:self.ND+6] + np.stack((divergence1,divergence2), axis=0)) # ============================================================================= # Relaxation # ============================================================================= primal_relax = primal + self.gamma*(primal - primal_tmp) attenuation = -np.real(primal[0]) phase = -np.real(primal[1]) return phase, attenuation
[docs]class GDTV(PhaseRetrievalAlgorithm2D): """ Gradient Descent with (smooth) total variation regularization Parameters ---------- iterations : int Number of GD iterations epsilon : float Smoothing factor of Total Variation (TV) regularization alpha : float Penalization for Total Variation (TV) regularization tau : float Step size self.omega: float Penalization for positivity regularization self.omega_augment_factor : float Multiplicative factor for positivity regularization References ---------- [1] """ def __init__(self, dataset=None, PSF=[], **kwargs): self.padding = 2 super().__init__(dataset, **kwargs) self.iterations = 100 self.epsilon = 1e-3 self.alpha = np.array([5e-2,1e-3]) self.tau = 1.9 / (1 + self.alpha* 8/self.epsilon) self.omega = 0.01 self.omega_augment_factor = 3/2 self.propagator = self.simple_propagator(self.pixel_size, self.Lambda, self.distance) self.retriever = CTFPurePhase(dataset, **kwargs) #self.retriever = TIEHOM(dataset, **kwargs)
[docs] def grad(self, M): """ Gradient operator. Parameters ---------- M : complex np.array Image whose gradient is to be computed. Returns ------- gradient : complex np.array Gradient of image M. """ nx = np.shape(M)[0] fx = M[np.hstack((np.arange(1,nx),[nx-1])),:] - M ny = np.shape(M)[1] fy = M[:,np.hstack((np.arange(1,ny),[ny-1]))] - M gradient = np.concatenate((fx[np.newaxis,:,:],fy[np.newaxis,:,:]), axis=0) return gradient
[docs] def div(self, P): """ Divergence operator. Parameters ---------- P : complex np.array Image whose divergence is to be computed. Returns ------- divergence : complex np.array Divergence of image P. """ Py = P[1,:,:] Px = P[0,:,:] nx = np.shape(Px)[0] fx = Px - Px[np.hstack(([0],np.arange(0,nx-1))),:] fx[0,:] = Px[0,:] # boundary fx[nx-1,:] = -Px[nx-2,:] ny = np.shape(Py)[1] fy = Py - Py[:,np.hstack(([0],np.arange(0,ny-1)))] fy[:,0] = Py[:,0] # boundary fy[:,ny-1] = -Py[:,ny-2] divergence = fx+fy return divergence
[docs] def adjoint_Frechet_derivative(self, f,P,epsilon): """ Adjoint of the Frechet derivative of forward operator. Parameters ---------- f : complex np.array f = φ - iB parametrizes the phase shifts φ and attenuation B, position the derivative is to be computed P : complex np.array Fresnel propagator. epsilon : real np.array Input of Adjoint of Frechet derivative Returns ------- adjoint_frechet : complex np.array Adjoint of the Frechet derivative with respect to f at epsilon """ wave = np.exp(-1j*f) wave_FT = np.fft.fft2(wave) t1 = np.fft.ifft2(wave_FT * P) t1 = 1j * epsilon * t1 t2 = np.fft.ifft2( np.conj(P) * np.fft.fft2(t1) ) adjoint_frechet = np.conj(wave) * t2 return adjoint_frechet
def _algorithm(self, image, positions=False): # ============================================================================= # Initialisations # ============================================================================= phase_initial, attenuation_initial = self.retriever.reconstruct_image(image, pad=True) phase_initial = phase_initial attenuation_initial = attenuation_initial f0 = phase_initial - 1j * attenuation_initial field_detector = np.zeros_like(image, dtype='complex_') f = f0 for ii in range(0, self.iterations): print(f'========== Iteration GD-TV {ii+1} ==========') wave_FT = np.fft.fft2(np.exp(-1j * f)) for d in range(self.ND): field_detector[d] = np.fft.ifft2(wave_FT * self.propagator[d]) # Calculate intensities intensity_calculated and errors h at current iteration intensity_calculated = np.abs(field_detector)**2 current_error = intensity_calculated - image # ============================================================================= # Computation of GD update # ============================================================================= #### Average of all distances mean_adjoint = 0 for d in range(self.ND): v = self.adjoint_Frechet_derivative(f,self.propagator[d,:,:],current_error[d]) mean_adjoint = mean_adjoint + (1/self.ND) * v ########### Smooth TV norm Gr = self.grad(f); d = np.sqrt(self.epsilon**2 + np.sum(Gr**2, axis=0)) dd = np.stack((d,d), axis=0) G = self.div(Gr / dd ) ########### Positivity regularization phase = np.real(f) attenuation = -np.imag(f) regularization_positivy = (np.maximum(0,-np.sign(phase)) * phase - 1j * np.maximum(0,-np.sign(attenuation)) * attenuation) ### explicit scheme = Unstable? # f = f - self.tau * (mean_adjoint - self.alpha * G + self.omega * regularization_positivy) ### semi-implicit scheme = Positivity term treated implicitly and TV norm treated explicitly ### Here same regularization for attenuation & phase : seems that they need different penalizations and step_size # f = 1/(1 - self.tau * self.omega * regularization_positivy) * (f - self.tau * (mean_adjoint - self.alpha * G)) ### Semi-implicit scheme with different penalizations (and step size) for attenuation & phase imag = -np.imag(1/(1 - self.tau[0] * self.omega * regularization_positivy) * (f - self.tau[0] * (mean_adjoint - self.alpha[0] * G))) real = np.real(1/(1 - self.tau[1] * self.omega * regularization_positivy) * (f - self.tau[1] * (mean_adjoint - self.alpha[1] * G))) f = real - 1.0j * imag self.omega = self.omega * self.omega_augment_factor attenuation = np.imag(f) phase = -np.real(f) return phase, attenuation
[docs]class PDHGM_CTF(PhaseRetrievalAlgorithm2D): """ Chambolle-Pock [1] aka Primal Dual Hybrid Gradient Method (PDHGM) CTF-linearization based Parameters ---------- iterations : int Number of PDHGM iterations sigma : float Step size for dual variables tau : float Step size for primal variables alpha, beta : float Penalization for Second order Total Generalized Variation (TGV) for attenuation delta: float Penalization for Total Variation (TV) for phase gamma: float \in [0,1] Over-relaxation parameters for primal variables omega: float Penalization for positivity regularization omega_augment_factor : float Multiplicative factor for positivity regularization phys : int 0 : No physical constraints 1 : Positivity of attenuation and phase as a constraint 2 : Positivity of attenuation and phase as a regularization References ---------- [1] Chambolle & Pock : Journal of Mathematical Imaging and Vision 40 (2011) 120–145 """ def __init__(self, dataset=None, PSF=[], **kwargs): self.iterations = 100 #Default 100 self.sigma = 0.35 self.tau = 0.35 self.beta = 0.05 self.delta = 0.01 self.gamma = 1 self.omega = 0.1 self.omega_augment_factor = 3/2 self.phys = 2 self.padding = 2 self.retriever = CTF(dataset, **kwargs) #self.retriever = TIEHOM(dataset, **kwargs) super().__init__(dataset, **kwargs) self.propagator = self.simple_propagator(self.pixel_size, self.Lambda, self.distance) self.alpha = 0.1
[docs] def grad(self, M): """ Gradient operator. Parameters ---------- M : complex np.array Image whose gradient is to be computed. Returns ------- gradient : complex np.array Gradient of image M. """ nx = np.shape(M)[0] fx = M[np.hstack((np.arange(1,nx),[nx-1])),:] - M ny = np.shape(M)[1] fy = M[:,np.hstack((np.arange(1,ny),[ny-1]))] - M gradient = np.concatenate((fx[np.newaxis,:,:],fy[np.newaxis,:,:]), axis=0) return gradient
[docs] def div(self, P): """ Divergence operator. Parameters ---------- P : complex np.array Image whose divergence is to be computed. Returns ------- divergence : complex np.array Divergence of image P. """ Py = P[1,:,:] Px = P[0,:,:] nx = np.shape(Px)[0] fx = Px - Px[np.hstack(([0],np.arange(0,nx-1))),:] fx[0,:] = Px[0,:] # boundary fx[nx-1,:] = -Px[nx-2,:] ny = np.shape(Py)[1] fy = Py - Py[:,np.hstack(([0],np.arange(0,ny-1)))] fy[:,0] = Py[:,0] # boundary fy[:,ny-1] = -Py[:,ny-2] divergence = fx+fy return divergence
def _algorithm(self, image, positions=False): # ============================================================================= # Initialisations of the primal & dual variables # ============================================================================= phase_initial, attenuation_initial = self.retriever.reconstruct_image(image, pad=True) primal_1 = np.stack((phase_initial,attenuation_initial)) ### primal variable (attenuation, phase) primal_2 = np.zeros_like(primal_1) ## primal variable v for TGV minimization primal = np.concatenate((primal_1,primal_2), axis=0) primal_relax = primal f_initial = phase_initial - 1j * attenuation_initial Grad = self.grad(f_initial) GradGrad = np.concatenate((self.grad(Grad[0]),self.grad(Grad[1]))) dual = np.concatenate((image,GradGrad,Grad,Grad)) * 0 ### dual variable for ii in range(0, self.iterations): print(f'========== Iteration PDHGM-CTF {ii+1} ==========') f_relax = primal_relax[1] - 1j * primal_relax[0] # ============================================================================= # # Computation of dual updates # ============================================================================= intensity_calculated = np.zeros_like(image) for distance in positions: intensity_calculated[distance] = 1 - 2 * np.real(np.fft.ifft2( self.sinchirp[distance] * np.fft.fft2(primal[1]) + self.coschirp[distance] * np.fft.fft2(primal[0]) )) dual[0:self.ND] = 2*(dual[0:self.ND] + self.sigma*(intensity_calculated - image) ) / ( self.sigma + 2 ) # ########## Dual variables for regularization ##### DDv Dv1 = self.grad(primal_relax[2]) Dv2 = self.grad(primal_relax[3]) DDv = np.concatenate((Dv1,Dv2)) numerator = (dual[self.ND:self.ND+4] + self.sigma * DDv) denominator = np.maximum(self.alpha, np.abs(numerator)) + 1e-14 dual[self.ND:self.ND+4] = numerator/denominator * self.alpha ##### Df - v Df = self.grad(-np.imag(f_relax)) - primal_relax[2:] numerator = (dual[self.ND+4:self.ND+6] + self.sigma * Df) denominator = np.maximum(self.beta, np.abs(numerator)) + 1e-14 dual[self.ND+4:self.ND+6] = numerator/denominator * self.beta ##### gradient gradient = self.grad(np.real(f_relax)) numerator = (dual[self.ND+6:] + self.sigma * gradient) denominator = np.maximum(self.delta, np.abs(numerator)) + 1e-14 dual[self.ND+6:] = numerator/denominator * self.delta # ============================================================================= # # Computation of primal updates # ============================================================================= Astar1, Astar2 = 0, 0 ## Average of all distances for d in range(self.ND): Astar1 = Astar1 - 2 * self.coschirp[d] * np.fft.fft2(dual[d]) Astar2 = Astar2 - 2 * self.sinchirp[d] * np.fft.fft2(dual[d]) Astar1 = np.real(np.fft.ifft2(Astar1)) Astar2 = np.real(np.fft.ifft2(Astar2)) ##### Primal f primal_tmp = primal TGV_regularization = self.div(dual[self.ND+4:self.ND+6]) TV_regularization = self.div(dual[self.ND+6:]) primal[0] = primal[0] - self.tau * (Astar1 - TGV_regularization) primal[1] = primal[1] - self.tau * (Astar2 - TV_regularization) if self.phys == 1: ##### Positivity constraints primal[0] = np.where(primal[0] < 0, 0, primal[0]) primal[1] = np.where(primal[1] < 0, 0, primal[1]) if self.phys == 2: #### Positivity regularization primal[0] = primal[0] / (2 * self.tau * self.omega * np.maximum(0,-np.real(primal[0]))**2 + 1) primal[1] = primal[1] / (2 * self.tau * self.omega * np.maximum(0,-np.real(primal[1]))**2 + 1) self.omega = self.omega * self.omega_augment_factor ##### Update auxiliary primal variable for TGV computation divergence1 = self.div(dual[self.ND:self.ND+2]) divergence2 = self.div(dual[self.ND+2:self.ND+4]) primal[2:] = primal[2:] + self.tau * (dual[self.ND+4:self.ND+6] + np.stack((divergence1,divergence2), axis=0)) # ============================================================================= # Relaxation # ============================================================================= primal_relax = primal + self.gamma*(primal - primal_tmp) attenuation = -np.real(primal[0]) phase = -np.real(primal[1]) return phase, attenuation
[docs]class ADMM_CTFhomo(PhaseRetrievalAlgorithm2D): """ Alternating Direction Method of Multipliers based on CTF (homogeneous) linearization [1] Parameters ---------- ADMM_iterations : int Number of ADMM-CTF iterations tau : float Penalty parameter of augmented Lagrangian alpha : float Penalty parameter of Total Variation (TV) regularization beta_over_delta : float Refractive index ratio between beta and delta phys : int 0 : No physical constraints 1 : Positivity of attenuation and phase as a constraint References ---------- [1] Villanueva-Perez Optics Letters 42(6) (2017) """ def __init__(self, dataset=None, PSF=[], **kwargs): super().__init__(dataset, **kwargs) self.ADMM_iterations = 50 self.tau = 5e-5 self.alpha = self.tau * 0.01 self.beta_over_delta = 0.25 self.phys = 0 self.retriever = CTFPurePhase(dataset, **kwargs)
[docs] def grad(self, M): """ Gradient operator. Parameters ---------- M : real np.array Image whose gradient is to be computed. Returns ------- gradient : real np.array Gradient of image M. """ nx = np.shape(M)[0] fx = M[np.hstack((np.arange(1,nx),[nx-1])),:] - M ny = np.shape(M)[1] fy = M[:,np.hstack((np.arange(1,ny),[ny-1]))] - M gradient = np.concatenate((fx[np.newaxis,:,:],fy[np.newaxis,:,:]), axis=0) return gradient
[docs] def grad_adj(self, P): """ Adjoint of gradient operator. Parameters ---------- P : real np.array Image whose divergence is to be computed. Returns ------- grad_adj : real np.array Adjoint gradient of image P. """ Py = P[1,:,:] Px = P[0,:,:] nx = np.shape(Px)[0] fx = Px - Px[np.hstack(([0],np.arange(0,nx-1))),:] fx[0,:] = Px[0,:] # boundary fx[nx-1,:] = -Px[nx-2,:] ny = np.shape(Py)[1] fy = Py - Py[:,np.hstack(([0],np.arange(0,ny-1)))] fy[:,0] = Py[:,0] # boundary fy[:,ny-1] = -Py[:,ny-2] grad_adj = -(fx+fy) return grad_adj
[docs] def shrinkage(self, u, kappa): """ Shrinkage operation Parameters ---------- u : real np.array Input image. kappa : float Returns ------- u : real np.array max(0, |u| - kappa) * sign(u) """ u = np.maximum(0, u - kappa) - np.maximum(0, -u - kappa) return u
[docs] def operator_ctf_adjoint(self, b, beta_over_delta, FPSF=[]): """ Compute adjoint of CTF operator assuming beta over delta Parameters ---------- b : real np.array Input image beta_over_delta : float Refractive index ratio between beta and delta Returns ------- numerator : real np.array Adjoint of CTF operator """ numerator = np.zeros((self.nfx, self.nfy)) for d in range(self.ND): bf = np.fft.fft2(b[d]-1) numerator = numerator + 2 * (self.sinchirp[d] + beta_over_delta * self.coschirp[d]) * bf if FPSF != []: numerator = numerator*FPSF numerator = np.real(np.fft.ifft2(numerator)) return numerator
[docs] def inv_block_toeplitz_ctf_betaoverdelta(self, b, tau, beta_over_delta, OTF=[]): """ Inverse of CTF operator assuming beta over delta Parameters ---------- u : real np.array Input image tau : float Penalty parameter of augmented Lagrangian beta_over_delta : float Refractive index ratio between beta and delta Returns ------- denominator : real np.array Inverse of CTF operator """ denominator = np.zeros((self.nfx, self.nfy)) for d in range(self.ND): denominator = denominator + (2 * (self.sinchirp[d] + beta_over_delta * self.coschirp[d]))**2 if OTF != []: denominator = denominator*OTF # Kernel gradient part # kg = np.array([[1, 0, - 1], [2, 0, -2], [1, 0, -1]]) # kx = kg.transpose() # kkx = signal.convolve2d(kx, kx) # ky = kx.transpose() # kky = signal.convolve2d(ky, ky) # kk = np.zeros(b.shape) # kk[0: kkx.shape[0], 0: kkx.shape[1]] = -kkx-kky # kk = np.fft.fft2(np.roll(kk, (-2, -2), axis=(0, 1))) # S = denominator + tau*kk + 1e-14 S = denominator + tau * (self.fx**2 + self.fy**2) + 1e-14 x = np.real(np.fft.ifft2(np.fft.fft2(b) / S)) return x
def _algorithm(self, image, positions=False): #### Initialisation does not change the final result #phase_initial, attenuation_initial = self.retriever.reconstruct_image(image, pad=True) #x = phase_initial x = np.zeros((self.nfx,self.nfy)) ## primal variable u = np.zeros((2,self.nfx,self.nfy)) ## dual variable m = u.copy() ## Lagrange multiplier for t in range(0,self.ADMM_iterations): print(f'========== Iteration ADMM-CTFhomo {t+1} ==========') ctf_adjoint = self.operator_ctf_adjoint(image, self.beta_over_delta) tmp = ctf_adjoint + self.grad_adj(self.tau * u - m) x = self.inv_block_toeplitz_ctf_betaoverdelta(tmp, self.tau, self.beta_over_delta) ## update primal if self.phys == 1: x[x > 0] = 0.0 u = self.grad(x) + m/self.tau u = self.shrinkage(u, self.alpha/self.tau) ## update dual m = m + self.tau * (self.grad(x) - u) ## update Lagrange multiplier phase = x attenuation = self.beta_over_delta * phase return phase, attenuation
# class Iterative(): # def __init__(self): # pass # def Reconstruct(self, dataset, iterations, parameter, options=''): # #initialise # print('Starting iterative reconstruction, {} iterations'.format(iterations)) # retriever = CTF(dataset) # retriever.alpha=parameter # propagator = Propagator.CTF(dataset) # #TODO: I guess these are the kind of things that should be configured rather # tomography = Tomography.PyHST() # if not 'no_reinitialisation' in options: # print('initialising') # #TODO: I guess tomo should be without filter? # retriever.Reconstruct(dataset) # tomography.Reconstruct(dataset, volume='phase') # tomography.Reconstruct(dataset, volume='retrieved_attenuation') # #TODO: should have a convergence criteron also # for iteration in range(iterations): # print('----- Iteration {} / {} -----'.format(iteration+1, iterations)) # tomography.ForwardProject(dataset, 'phase') # tomography.ForwardProject(dataset, 'retrieved_attenuation') # propagator.Propagate(dataset) # dataset.Difference() # retriever.ReconstructDifference(dataset) # tomography.Reconstruct(dataset,volume='update') # tomography.Reconstruct(dataset,volume='attenuation_update') # dataset.UpdatePhase() # dataset.UpdateAttenuation()