Source code for propagator

#This file is part of the PyPhase software.
#
#Copyright (c) Max Langer (2019) 
#
#max.langer@creatis.insa-lyon.fr
#
#This software is a computer program whose purpose is to allow development,
#implementation, and deployment of phase retrieval algorihtms.
#
#This software is governed by the CeCILL  license under French law and
#abiding by the rules of distribution of free software.  You can  use, 
#modify and/ or redistribute the software under the terms of the CeCILL
#license as circulated by CEA, CNRS and INRIA at the following URL
#"http://www.cecill.info". 
#
#As a counterpart to the access to the source code and  rights to copy,
#modify and redistribute granted by the license, users are provided only
#with a limited warranty  and the software's author,  the holder of the
#economic rights,  and the successive licensors  have only  limited
#liability. 
#
#In this respect, the user's attention is drawn to the risks associated
#with loading,  using,  modifying and/or developing or reproducing the
#software by the user in light of its specific status of free software,
#that may mean  that it is complicated to manipulate,  and  that  also
#therefore means  that it is reserved for developers  and  experienced
#professionals having in-depth computer knowledge. Users are therefore
#encouraged to load and test the software's suitability as regards their
#requirements in conditions enabling the security of their systems and/or 
#data to be ensured and,  more generally, to use and operate it in the 
#same conditions as regards security. 
#
#The fact that you are presently reading this means that you have had
#knowledge of the CeCILL license and that you accept its terms.

import numpy as np
from vendor.EdfFile import EdfFile #TODO: This should not be necessary here!
import scipy.ndimage as ndimage
from math import *
import pyphase.parallelizer as Parallelizer
import pyphase.utilities as Utilities

[docs]class Propagator: def __init__(self, *, dataset=None, shape=None, energy=None, pixel_size=None, distance=None, pad=2, oversampling=4): """Takes either a dataset object or necessary parameters as keyword arguments Parameters ---------- dataset : pyphase.Dataset, optional Dataset object for parameters shape : tuple of ints, optional Size of images (ny, nx) for creation of frequency variables etc pixel_size : float , optional In m distance : list of floats, optional Effective propagation distances in m energy : float, optional Effective energy in keV alpha : tuple of floats, optional Regularisation parameter pad : int, optional Padding factor for propagator oversampling : int, optional Oversampling factor of projections """ self.padding=pad self.oversampling=oversampling self.length_scale=10e-6 # where should this be... if dataset: self.pixel_size=dataset.pixel_size*1e-6 self.nx=dataset.nx self.ny=dataset.ny self.Lambda=dataset.Lambda self.distance=dataset.distance else: self.nx=shape[1] self.ny=shape[0] self.Lambda = 12.4e-10 / energy self.pixel_size=pixel_size self.distance=distance self.nfx = self.nx*pad self.nfy = self.ny*pad self.sample_frequency=self.length_scale/self.pixel_size self.pixel_size_oversampled=self.pixel_size/self.oversampling self.fs_oversampled=self.length_scale/self.pixel_size_oversampled self.nx_oversampled = self.nx*self.oversampling self.ny_oversampled = self.ny*self.oversampling self.x=np.linspace((-self.pixel_size_oversampled*self.nx_oversampled/2), self.pixel_size_oversampled*(self.nx_oversampled/2-1), self.nx_oversampled) self.y=np.linspace((-self.pixel_size_oversampled*self.ny_oversampled/2), self.pixel_size_oversampled*(self.ny_oversampled/2-1), self.ny_oversampled) self.xx, self.yy = np.meshgrid(self.x, self.y) self.f=np.linspace(-self.fs_oversampled/2, self.fs_oversampled/2-self.fs_oversampled/(self.nx_oversampled*self.padding), self.nx_oversampled*self.padding) self.g=np.linspace(-self.fs_oversampled/2, self.fs_oversampled/2-self.fs_oversampled/(self.ny_oversampled*self.padding), self.ny_oversampled*self.padding) self.ff, self.gg = np.meshgrid(self.f, self.g)
[docs] def propagate_projection(self, dataset, position_number=None, projection=None, oversampled=False): """ Propagate one projection. Arguments --------- dataset : pyphase.Dataset, optional Datset with projection data. position_number : int, optional Which position to propagate to projection : int, optional Which projection to propagate phase : ndarray, optional Phase of wave to propagate attenuation : ndarray, optional Amplitude of wave to propagate position : float Effective propagation distance oversampled : bool True if imput images are already oversampled """ #TOOD: Should check if phase and attenuation are calculated #TODO: Split in image/projection like phaseretrieval position = dataset.position[position_number] phase = dataset.get_image(projection=projection, image_type='generated_phase') #TODO: generated or not should be choise. Not yet implemented attenuation = dataset.get_image(projection=projection, image_type='generated_attenuation') ID = self.propagate_image(amplitude, phase, position_number=position_number, oversampled=oversampled) dataset.write_image(Id, 'intensity', projection, position_number) return Id
[docs] def propagate_image(self, amplitude, phase, position_number=None, oversampled=False): """Propagate a wavefront created from two images. """ if self.oversampling and not oversampled: phase = ndimage.zoom(phase, self.oversampling) attenuation = ndimage.zoom(amplitude, self.oversampling) wave = np.exp(-amplitude+1j*phase) #TODO: Decide form of projection. If attenuation in mu, no square? wave = Utilities.resize(wave, [self.ny*self.padding*self.oversampling, self.nx*self.padding*self.oversampling]) Id = self._propagate(wave, position_number) # TODO: THIS PADDING IS DANGEROUS NO?! VERIFY! Id = Utilities.resize(Id, [self.ny*self.oversampling, self.nx*self.oversampling]) Id = ndimage.zoom(Id,1/self.oversampling) return Id
[docs]class Fresnel(Propagator): """Propagator using Fresnel transform""" def _propagate(self, wave, position_number): wave = np.fft.fft2(wave) P=np.fft.ifftshift(np.exp(-1j*pi*self.Lambda*self.distance[position_number]*(self.ff**2+self.gg**2)/(self.length_scale**2))) Id=np.fft.ifft2(wave*P) Id=np.abs(Id)**2 return Id
[docs]class CTF(Propagator): """Propagates using the CTF. Legacy code to be aligned with Fresnel."""
[docs] def PropagateProjection(self, dataset, projection, distance): length_scale=10e-6 oversampling=1 padding=2 ps=(dataset.pixel_size/oversampling)*1e-6 fs=1/ps nx = dataset.nx*oversampling ny = dataset.ny*oversampling x=np.linspace((-ps*nx/2), ps*(nx/2-1), nx) y=np.linspace((-ps*ny/2), ps*(ny/2-1), ny) xx, yy = np.meshgrid(x, y) f=np.linspace(-fs/2, fs/2-fs/(nx*padding), nx*padding) g=np.linspace(-fs/2, fs/2-fs/(ny*padding), ny*padding) ff, gg = np.meshgrid(f, g) #TODO: should have getters and setters for all the images. How? one per type? one for all w arg? fname = dataset.phase_forward_prefix+'_'+str(projection).zfill(4)+'.edf' imEDF = EdfFile(fname) phase = imEDF.GetData(0) EDFHeader = imEDF.GetHeader(0) phase = ndimage.zoom(phase, oversampling) phase = np.pad(phase, ((ny//2, ny//2), (nx//2, nx//2)), 'edge') fname = dataset.retrieved_attenuation_forward_prefix+'_'+str(projection).zfill(4)+'.edf' imEDF = EdfFile(fname) attenuation = imEDF.GetData(0) attenuation = ndimage.zoom(attenuation, oversampling) attenuation = np.pad(attenuation, ((ny//2, ny//2), (nx//2, nx//2)), 'edge') #P = [0 for q in range(DS.nD)]holosim_PP_prop_4_0000.edf #Id = [0 for q in range(DS.nD)] # TODO: Creation of propagator should probably be in constructor? So that one can get it out FN = dataset.Fresnel_number[distance-1] coschirp = np.cos((pi*FN) * (self.fx**2) + (pi*FN) * (self.fy**2)) sinchirp = np.sin((pi*FN) * (self.fx**2) + (pi*FN) * (self.fy**2)) # for n in DS.distance_number: FId = 2 * coschirp * np.fft.fft2(attenuation) - 2 * sinchirp * np.fft.fft2(phase) Id = 1 + np.real(np.fft.ifft2(FId)) Id = Id[ny//2:-ny//2, nx//2:-nx//2] Id = ndimage.zoom(Id,1/oversampling) fname=dataset.propagated_prefix+'_'+str(distance)+'_'+str(projection).zfill(4)+'.edf' EDF = EdfFile(fname) # TODO: verify what should go into the header... EDF.WriteImage(EDFHeader, Id, 0, "Float")
pass