From a5ee66a4aee472ab72d204783b5e3da4b4f65beb Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Tue, 31 Oct 2017 11:45:39 +0000
Subject: fixed setParameter

fixed setParameter
allows regularizer to output simply the image rather than list.
---
 src/Python/ccpi/imaging/Regularizer.py | 42 ++++++++++++++++++++++------------
 1 file changed, 27 insertions(+), 15 deletions(-)

(limited to 'src/Python/ccpi')

diff --git a/src/Python/ccpi/imaging/Regularizer.py b/src/Python/ccpi/imaging/Regularizer.py
index 8ab6c6a..23799d6 100644
--- a/src/Python/ccpi/imaging/Regularizer.py
+++ b/src/Python/ccpi/imaging/Regularizer.py
@@ -108,6 +108,8 @@ class Regularizer():
             
         else:
             raise Exception('Unknown regularizer algorithm')
+
+        self.acceptedInputKeywords = pars.keys()
             
         return pars
     # parsForAlgorithm
@@ -134,17 +136,24 @@ class Regularizer():
                 raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key))
     # setParameter
 	
-    def getParameter(self, **kwargs):
-        ret = {}
-        for key , value in kwargs.items():
-            if key in self.pars.keys():
-                ret[key] = self.pars[key]
+    def getParameter(self, key):
+        if type(key) is str:
+            if key in self.acceptedInputKeywords:
+                return self.pars[key]
+            else:
+                raise Exception('Unrecongnised parameter: {0} '.format(key) )
+        elif type(key) is list:
+            outpars = []
+            for k in key:
+                outpars.append(self.getParameter(k))
+            return outpars
         else:
-            raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key))
-    # setParameter
+            raise Exception('Unhandled input {0}' .format(str(type(key))))
+        # getParameter
 	
         
-    def __call__(self, input = None, regularization_parameter = None, **kwargs):
+    def __call__(self, input = None, regularization_parameter = None,
+                 output_all = False, **kwargs):
         '''Actual call for the regularizer. 
         
         One can either set the regularization parameters first and then call the
@@ -179,19 +188,19 @@ class Regularizer():
         input = self.pars['input']
         regularization_parameter = self.pars['regularization_parameter']
         if self.algorithm == Regularizer.Algorithm.SplitBregman_TV :
-            return self.algorithm(input, regularization_parameter,
+            ret = self.algorithm(input, regularization_parameter,
                               self.pars['number_of_iterations'],
                               self.pars['tolerance_constant'],
                               self.pars['TV_penalty'].value )    
         elif self.algorithm == Regularizer.Algorithm.FGP_TV :
-            return self.algorithm(input, regularization_parameter,
+            ret = self.algorithm(input, regularization_parameter,
                               self.pars['number_of_iterations'],
                               self.pars['tolerance_constant'],
                               self.pars['TV_penalty'].value )
         elif self.algorithm == Regularizer.Algorithm.LLT_model :
             #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
             # no default
-            return self.algorithm(input, 
+            ret = self.algorithm(input, 
                               regularization_parameter,
                               self.pars['time_step'] , 
                               self.pars['number_of_iterations'],
@@ -200,7 +209,7 @@ class Regularizer():
         elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul :
             #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
             # no default
-            return self.algorithm(input, regularization_parameter,
+            ret = self.algorithm(input, regularization_parameter,
                                   self.pars['searching_window_ratio'] , 
                                   self.pars['similarity_window_ratio'] , 
                                   self.pars['PB_filtering_parameter'])
@@ -208,7 +217,7 @@ class Regularizer():
             #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
             # no default
             if len(np.shape(input)) == 2:
-                return self.algorithm(input, regularization_parameter,
+                ret = self.algorithm(input, regularization_parameter,
                                   self.pars['first_order_term'] , 
                                   self.pars['second_order_term'] , 
                                   self.pars['number_of_iterations'])
@@ -227,11 +236,14 @@ class Regularizer():
                 output = [out3d]
                 for i in range(1,len(out)):
                     output.append(out[i])
-                return output
+                ret = output
                 
                 
             
-            
+        if output_all:
+            return ret
+        else:
+            return ret[0]
         
     # __call__
     
-- 
cgit v1.2.3