From 3c2815ec1d0ddd9d00a5c1f454fcecc060126623 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Tue, 17 Oct 2017 09:35:11 +0100
Subject: Added many methods

---
 src/Python/ccpi/fista/FISTAReconstructor.py | 184 ++++++++++++++++++++++++----
 1 file changed, 160 insertions(+), 24 deletions(-)

(limited to 'src/Python')

diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py
index cbd27da..8318ea6 100644
--- a/src/Python/ccpi/fista/FISTAReconstructor.py
+++ b/src/Python/ccpi/fista/FISTAReconstructor.py
@@ -78,19 +78,28 @@ class FISTAReconstructor():
         # handle parmeters:
         # obligatory parameters
         self.pars = dict()
-        self.pars['projector_geometry'] = projector_geometry
-        self.pars['output_geometry'] = output_geometry
-        self.pars['input_sinogram'] = input_sinogram
+        self.pars['projector_geometry'] = projector_geometry # proj_geom
+        self.pars['output_geometry'] = output_geometry       # vol_geom
+        self.pars['input_sinogram'] = input_sinogram         # sino
         detectors, nangles, sliceZ = numpy.shape(input_sinogram)
         self.pars['detectors'] = detectors
-        self.pars['number_og_angles'] = nangles
+        self.pars['number_of_angles'] = nangles
         self.pars['SlicesZ'] = sliceZ
 
         print (self.pars)
         # handle optional input parameters (at instantiation)
         
         # Accepted input keywords
-        kw = ('number_of_iterations', 
+        kw = (
+              # mandatory fields
+              'projector_geometry',
+              'output_geometry',
+              'input_sinogram',
+              'detectors',
+              'number_of_angles',
+              'SlicesZ',
+              # optional fields
+              'number_of_iterations', 
               'Lipschitz_constant' , 
               'ideal_image' ,
               'weights' , 
@@ -98,8 +107,9 @@ class FISTAReconstructor():
               'initialize' , 
               'regularizer' , 
               'ring_lambda_R_L1',
-              'ring_alpha')
-        self.acceptedInputKeywords = kw
+              'ring_alpha',
+              'subsets')
+        self.acceptedInputKeywords = list(kw)
         
         # handle keyworded parameters
         if kwargs is not None:
@@ -122,8 +132,7 @@ class FISTAReconstructor():
         if 'Lipschitz_constant' in kwargs.keys():
             self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant']
         else:
-            self.pars['Lipschitz_constant'] = \
-                            self.calculateLipschitzConstantWithPowerMethod()
+            self.pars['Lipschitz_constant'] = None
         
         if not 'ideal_image' in kwargs.keys():
             self.pars['ideal_image'] = None
@@ -143,31 +152,44 @@ class FISTAReconstructor():
                 self.pars['ring_lambda_R_L1'] = 0
             if not 'ring_alpha' in kwargs.keys():
                 self.pars['ring_alpha'] = 1
-        
+
+        if not 'subsets' in kwargs.keys():
+            self.pars['subsets'] = 0
+        else:
+            self.createOrderedSubsets()
+
+        if not 'initialize' in kwargs.keys():
+            self.pars['initialize'] = False
             
             
     def setParameter(self, **kwargs):
-        '''set named parameter for the regularization engine
+        '''set named parameter for the reconstructor engine
         
         raises Exception if the named parameter is not recognized
-        Typical usage is:
-            
-        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
-        reg.setParameter(input=u0)    
-        reg.setParameter(regularization_parameter=10.)
         
-        it can be also used as
-        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
-        reg.setParameter(input=u0 , regularization_parameter=10.)
         '''
-        
         for key , value in kwargs.items():
-            if key in self.acceptedInputKeywords.keys():
+            if key in self.acceptedInputKeywords:
                 self.pars[key] = value
             else:
-                raise Exception('Wrong parameter {0} for '.format(key) + 
-                                'Reconstruction algorithm')
+                raise Exception('Wrong parameter {0} for '.format(key) +
+                                'reconstructor')
     # setParameter
+
+    def getParameter(self, key):
+        if type(key) is str:
+            if key in self.acceptedInputKeywords:
+                return self.pars[key]
+            else:
+                raise Exception('Unrecongnised parameter: {0} '.format(key) )
+        elif type(key) is list:
+            outpars = []
+            for k in key:
+                outpars.append(self.getParameter(k))
+            return outpars
+        else:
+            raise Exception('Unhandled input {0}' .format(str(type(key))))
+            
     
     def calculateLipschitzConstantWithPowerMethod(self):
         ''' using Power method (PM) to establish L constant'''
@@ -289,5 +311,119 @@ class FISTAReconstructor():
         if regularizer is not None:
             self.pars['regularizer'] = regularizer
         
+
+    def initialize(self):
+        # convenience variable storage
+        proj_geom = self.pars['projector_geometry']
+        vol_geom = self.pars['output_geometry']
+        sino = self.pars['input_sinogram']
+        
+        # a 'warm start' with SIRT method
+        # Create a data object for the reconstruction
+        rec_id = astra.matlab.data3d('create', '-vol',
+                                    vol_geom);
+        
+        #sinogram_id = astra_mex_data3d('create', '-proj3d', proj_geom, sino);
+        sinogram_id = astra.matlab.data3d('create', '-proj3d',
+                                          proj_geom,
+                                          sino)
+
+        sirt_config = astra.astra_dict('SIRT3D_CUDA')
+        sirt_config['ReconstructionDataId' ] = rec_id
+        sirt_config['ProjectionDataId'] = sinogram_id
+
+        sirt = astra.algorithm.create(sirt_config)
+        astra.algorithm.run(sirt, iterations=35)
+        X = astra.matlab.data3d('get', rec_id)
+
+        # clean up memory
+        astra.matlab.data3d('delete', rec_id)
+        astra.matlab.data3d('delete', sinogram_id)
+        astra.algorithm.delete(sirt)
+
+        
+
+        return X
+
+    def createOrderedSubsets(self, subsets=None):
+        if subsets is None:
+            try:
+                subsets = self.getParameter('subsets')
+            except Exception():
+                subsets = 0
+            #return subsets
+
+        angles = self.getParameter('projector_geometry')['ProjectionAngles'] 
+            
+            
+            
+        
+            
+
+    def prepareForIteration(self):
+        self.residual_error = numpy.zeros((self.pars['number_of_iterations']))
+        self.objective = numpy.zeros((self.pars['number_of_iterations']))
+
+        #2D array (for 3D data) of sparse "ring" 
+        detectors, nangles, sliceZ  = numpy.shape(self.pars['input_sinogram'])
+        self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float)
+        # another ring variable
+        self.rx = self.r.copy()
+
+        self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram']))
+        
+        if self.getParameter('Lipschitz_constant') is None:
+            self.pars['Lipschitz_constant'] = \
+                            self.calculateLipschitzConstantWithPowerMethod()
+
+    # prepareForIteration
+
+    def iterate(self, Xin=None):
+        # convenience variable storage
+        proj_geom , vol_geom, sino , \
+                  SlicesZ = self.getParameter(['projector_geometry' ,
+                                                        'output_geometry',
+                                                        'input_sinogram',
+                                                        'SlicesZ'])
+                        
+        t = 1
+        if Xin is None:    
+            if self.getParameter('initialize'):
+                X = self.initialize()
+            else:
+                N = vol_geom['GridColCount']
+                X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float)
+        else:
+            X = Xin.copy()
+
+        X_t = X.copy()
+        
+        for i in range(self.getParameter('number_of_iterations')):
+            X_old = X.copy()
+            t_old = t
+            r_old = self.r.copy()
+            if self.pars['projector_geometry']['type'] == 'parallel' or \
+               self.pars['projector_geometry']['type'] == 'parallel3d':
+                # if the geometry is parallel use slice-by-slice
+                # projection-backprojection routine
+                #sino_updt = zeros(size(sino),'single');
+                sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float)
+                
+                #for kkk = 1:SlicesZ
+                #    [sino_id, sino_updt(:,:,kkk)] =
+                # astra_create_sino3d_cuda(X_t(:,:,kkk), proj_geomT, vol_geomT);
+                #    astra_mex_data3d('delete', sino_id);
+                for kkk in range(SlicesZ):
+                    sino_id, sino_updt[kkk] = \
+                             astra.creators.create_sino3d_gpu(
+                                 X_t[kkk], proj_geomT, vol_geomT)
+                    
+            else:
+                # for divergent 3D geometry (watch GPU memory overflow in
+                # Astra < 1.8
+                sino_id, y = astra.creators.create_sino3d_gpu(X_t, 
+                                                              proj_geom, 
+                                                              vol_geom)
     
-    
+
+            
-- 
cgit v1.2.3