# -*- coding: utf-8 -*-
#This file is part of the PyPhase software.
#
#Copyright (c) Max Langer (2019)
#
#max.langer@creatis.insa-lyon.fr
#
#This software is a computer program whose purpose is to allow development,
#implementation, and deployment of phase retrieval algorihtms.
#
#This software is governed by the CeCILL license under French law and
#abiding by the rules of distribution of free software. You can use,
#modify and/ or redistribute the software under the terms of the CeCILL
#license as circulated by CEA, CNRS and INRIA at the following URL
#"http://www.cecill.info".
#
#As a counterpart to the access to the source code and rights to copy,
#modify and redistribute granted by the license, users are provided only
#with a limited warranty and the software's author, the holder of the
#economic rights, and the successive licensors have only limited
#liability.
#
#In this respect, the user's attention is drawn to the risks associated
#with loading, using, modifying and/or developing or reproducing the
#software by the user in light of its specific status of free software,
#that may mean that it is complicated to manipulate, and that also
#therefore means that it is reserved for developers and experienced
#professionals having in-depth computer knowledge. Users are therefore
#encouraged to load and test the software's suitability as regards their
#requirements in conditions enabling the security of their systems and/or
#data to be ensured and, more generally, to use and operate it in the
#same conditions as regards security.
#
#The fact that you are presently reading this means that you have had
#knowledge of the CeCILL license and that you accept its terms.
import sys
import vendor.pyelastix as PyElastix
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
import scipy.ndimage as ndimage
# update() : Displays or updates a console progress bar
## Accepts a float between 0 and 1. Any int will be converted to a float.
## A value under 0 represents a 'halt'.
## A value at 1 or bigger represents 100%
[docs]def update(title, position, target):
barLength = 20 # Modify this to change the length of the progress bar
status = ""
if position < 0:
progress = 0
status = "Halt...\r\n"
if position >= target:
progress = 1
status = "Done...\r\n"
else:
progress = position / target
block = int(round(barLength*progress))
text = "\r" + title + ": [{}] {} / {} {:0.2%} {}".format( "#"*block +
"-"*(barLength-block), position, target, progress, status)
sys.stdout.write(text)
sys.stdout.flush()
[docs]def resize(image, shape, pad_type='edge'):
"""
Resizes an image by either cutting out the centre or padding.
Assumes images are stored along fist dimension.
"""
#TODO: special case of one dimension smaller and one larger not covered atm
adjust_axis = False #TODO: To handle 2 or 3 dimensions. can probably be refactored
if len(image.shape) == 2:
image=image[np.newaxis]
adjust_axis = True
if image.shape[1] < shape[0]:
pady=(shape[0]-image.shape[1])//2
padx=(shape[1]-image.shape[2])//2
image = np.pad(image, ((0, 0), (pady,shape[0]-image.shape[1]-pady), (padx, shape[1]-image.shape[2]-padx)), pad_type)
elif image.shape[1] > shape[0]:
cy = (image.shape[1]-shape[0])//2
cx = (image.shape[2]-shape[1])//2
image = image[:, cy:cy+shape[0], cx:cx+shape[1]]
else:
pass
if adjust_axis:
return image[0]
else:
return image
[docs]class RegistrationAlgorithm:
"""
Abstract class for registration algorithms
Properties
----------
"""
def __init__(self):
pass
[docs] def register(self, moving_image, stationary_image):
"""
Register moving_image to stationary_image.
Parameters
----------
moving_image : ndarray
The image to register.
stationary_image : ndarray
The image to register to.
Returns
-------
field : ndarray
The calculated deformation field.
transformed_moving_image : ndarray
The deformed moving image.
transform_parameters : array
The calculated transform parameters. Lenght varies with the number
of parameters in the chosen algorithm (number_of_parameters)
"""
transformed_moving_image, field, transform_parameters = PyElastix.register(np.ascontiguousarray(moving_image), np.ascontiguousarray(stationary_image), self.parameters, verbose=0)
return field, transformed_moving_image, transform_parameters
pass
[docs]class ElastixRigid(RegistrationAlgorithm):
"""
Rigid registration algorithm using Elastix.
Attributes
----------
parameters : Parameters
Elastix standard parameters
number_of_parameters : int, default=3
Number of parameters in the transform
"""
def __init__(self):
self.parameters = PyElastix.get_default_params(type='RIGID')
#self.params.Metric = 'NormalizedMutualInformation'
#self.params.Metric = 'AdvancedMeanSquares'
self.parameters.NumberOfResolutions = 6
self.parameters.MaximumNumberOfIterations = 500
self.number_of_parameters=3
[docs]class ElastixAffine(RegistrationAlgorithm):
"""
Affine registration algorithm using Elastix.
Attributes
----------
parameters : Parameters
Elastix standard parameters
number_of_parameters : int, default=6
Number of parameters in the transform
"""
def __init__(self):
self.parameters = PyElastix.get_default_params(type='AFFINE')
#self.params.Metric = 'NormalizedMutualInformation'
#self.params.Metric = 'AdvancedMeanSquares'
self.parameters.MaximumNumberOfIterations = 500
self.number_of_parameters=6
[docs]class ElastixSimilar(RegistrationAlgorithm):
"""
Similarity transform registration algorithm using Elastix
Attributes
----------
parameters : Parameters
Elastix standard parameters
number_of_parameters : int, default=4
Number of parameters in the transform
"""
def __init__(self):
self.parameters = PyElastix.get_default_params(type='SIMILAR')
# self.params.NumberOfResolutions = 8
# self.params.MaximumNumberOfIterations = 3000
self.parameters.NumberOfResolutions = 6
self.parameters.MaximumNumberOfIterations = 500
#self.params.SP_A = 1
#self.params.AutomaticParameterEstimation = True
#self.params.AutomaticTransformInitialization = False
#self.params.AutomaticScalesEstimation = False
self.number_of_parameters=4
#print(self.params.as_dict())
[docs]class ImageDisplayer:
"""
Wrapper class for image display.
"""
# Should be an abstract class, right?!
# Which should be inherited by the class for each displayer, right?!
def __init__(self):
pass
# def displayImage(self, DS, n, d, correct):
# pass
# TODO: one viewer class
[docs]class PyplotImageDisplayer(ImageDisplayer):
"""
Interface to Pyplot for image display.
Notes
-----
With the idea to make the choice of display package flexible.
"""
# TODO: rename Pyplot? Develop interface ImageJ?
[docs] def display(self, image, title='', vmin=None, vmax=None):
'''
Display an image
Parameters
----------
image : nparray
The image to be displayed.
title : str, optional
Title of figure.
vmin : optional
Lower limit of contrast range.
vmax : optional
Upper limit of contrast range.
'''
# TODO: should include all the different possibilities I presume. Should be kwargs, not positional
fig, ax = plt.subplots()
fig.suptitle(title)
im = ax.imshow(image, cmap='gray', vmin=vmin, vmax=vmax)
ax.set_axis_off()
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
fig.colorbar(im, cax=cax)
im.axes.figure.canvas.draw()
plt.show()
[docs] def display_stack(self, stack):
# TODO: merge into display
fig, ax = plt.subplots(1, 1)
viewer = StackViewer(stack, ax)
fig.canvas.mpl_connect('key_press_event', viewer._on_key)
plt.show()
[docs] def close_all(self):
plt.close('all')
# maybe should be displayShifts?
[docs]class StackViewer(object):
"""
Functionality to browse stacks.
"""
#TODO: Working?! Should be with an _ to not be imported (internal functionality?)
def __init__(self, X, ax):
self.ax = ax
self.ax.set_title('use left/right arrow to navigate images')
self.X = X
self.slices, rows, cols = X.shape
self.ind = 0
self.im = self.ax.imshow(self.X[self.ind, :, :], cmap='gray')
self._update()
def _on_key(self, event):
if event.key == 'right':
self.ind = (self.ind + 1) % self.slices
elif event.key == 'left':
self.ind = (self.ind - 1) % self.slices
self.update()
def _update(self):
self.im.set_data(self.X[:, :, self.ind])
self.ax.set_ylabel('image %s' % self.ind)
self.im.axes.figure.canvas.draw()
#