From ed717202a0c917958892e26322d6ea5173f7b32c Mon Sep 17 00:00:00 2001
From: Willem Jan Palenstijn <Willem.Jan.Palenstijn@cwi.nl>
Date: Mon, 25 Apr 2016 17:04:39 +0200
Subject: Use FP/BP out argument in sample plugin

---
 samples/python/s018_plugin.py | 34 +++++++++++++++++++++-------------
 1 file changed, 21 insertions(+), 13 deletions(-)

(limited to 'samples/python')

diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py
index 31cca95..85b5486 100644
--- a/samples/python/s018_plugin.py
+++ b/samples/python/s018_plugin.py
@@ -30,30 +30,38 @@ import six
 
 # Define the plugin class (has to subclass astra.plugin.base)
 # Note that usually, these will be defined in a separate package/module
-class SIRTPlugin(astra.plugin.base):
-    """Example of an ASTRA plugin class, implementing a simple 2D SIRT algorithm.
+class LandweberPlugin(astra.plugin.base):
+    """Example of an ASTRA plugin class, implementing a simple 2D Landweber algorithm.
 
     Options:
 
-    'rel_factor': relaxation factor (optional)
+    'Relaxation': relaxation factor (optional)
     """
 
     # The astra_name variable defines the name to use to
     # call the plugin from ASTRA
-    astra_name = "SIRT-PLUGIN"
+    astra_name = "LANDWEBER-PLUGIN"
 
-    def initialize(self,cfg, rel_factor = 1):
+    def initialize(self,cfg, Relaxation = 1):
         self.W = astra.OpTomo(cfg['ProjectorId'])
         self.vid = cfg['ReconstructionDataId']
         self.sid = cfg['ProjectionDataId']
-        self.rel = rel_factor
+        self.rel = Relaxation
 
     def run(self, its):
         v = astra.data2d.get_shared(self.vid)
         s = astra.data2d.get_shared(self.sid)
+        tv = np.zeros(v.shape, dtype=np.float32)
+        ts = np.zeros(s.shape, dtype=np.float32)
         W = self.W
         for i in range(its):
-            v[:] += self.rel*(W.T*(s - (W*v).reshape(s.shape))).reshape(v.shape)/s.size
+            W.FP(v,out=ts)
+            ts -= s # ts = W*v - s
+
+            W.BP(ts,out=tv)
+            tv *= self.rel / s.size
+
+            v -= tv # v = v - rel * W'*(W*v-s) / s.size
 
 if __name__=='__main__':
 
@@ -75,20 +83,20 @@ if __name__=='__main__':
     # First we import the package that contains the plugin
     import s018_plugin
     # Then, we register the plugin class with ASTRA
-    astra.plugin.register(s018_plugin.SIRTPlugin)
+    astra.plugin.register(s018_plugin.LandweberPlugin)
 
     # Get a list of registered plugins
     six.print_(astra.plugin.get_registered())
 
     # To get help on a registered plugin, use get_help
-    six.print_(astra.plugin.get_help('SIRT-PLUGIN'))
+    six.print_(astra.plugin.get_help('LANDWEBER-PLUGIN'))
 
     # Create data structures
     sid = astra.data2d.create('-sino', proj_geom, sinogram)
     vid = astra.data2d.create('-vol', vol_geom)
 
     # Create config using plugin name
-    cfg = astra.astra_dict('SIRT-PLUGIN')
+    cfg = astra.astra_dict('LANDWEBER-PLUGIN')
     cfg['ProjectorId'] = proj_id
     cfg['ProjectionDataId'] = sid
     cfg['ReconstructionDataId'] = vid
@@ -103,18 +111,18 @@ if __name__=='__main__':
     rec = astra.data2d.get(vid)
 
     # Options for the plugin go in cfg['option']
-    cfg = astra.astra_dict('SIRT-PLUGIN')
+    cfg = astra.astra_dict('LANDWEBER-PLUGIN')
     cfg['ProjectorId'] = proj_id
     cfg['ProjectionDataId'] = sid
     cfg['ReconstructionDataId'] = vid
     cfg['option'] = {}
-    cfg['option']['rel_factor'] = 1.5
+    cfg['option']['Relaxation'] = 1.5
     alg_id_rel = astra.algorithm.create(cfg)
     astra.algorithm.run(alg_id_rel, 100)
     rec_rel = astra.data2d.get(vid)
 
     # We can also use OpTomo to call the plugin
-    rec_op = W.reconstruct('SIRT-PLUGIN', sinogram, 100, extraOptions={'rel_factor':1.5})
+    rec_op = W.reconstruct('LANDWEBER-PLUGIN', sinogram, 100, extraOptions={'Relaxation':1.5})
 
     import pylab as pl
     pl.gray()
-- 
cgit v1.2.3