From 18b6d25f7e4f0943b3592f3bb4f6ca5ed9c285d3 Mon Sep 17 00:00:00 2001 From: "Daniel M. Pelt" Date: Fri, 19 Jun 2015 22:28:06 +0200 Subject: Add support for Python algorithm plugins --- samples/python/s018_plugin.py | 138 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 samples/python/s018_plugin.py (limited to 'samples/python') diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py new file mode 100644 index 0000000..6677930 --- /dev/null +++ b/samples/python/s018_plugin.py @@ -0,0 +1,138 @@ +#----------------------------------------------------------------------- +#Copyright 2015 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see . +# +#----------------------------------------------------------------------- + +import astra +import numpy as np +import six + +# Define the plugin class (has to subclass astra.plugin.base) +# Note that usually, these will be defined in a separate package/module +class SIRTPlugin(astra.plugin.base): + """Example of an ASTRA plugin class, implementing a simple 2D SIRT algorithm. + + Optional options: + + 'rel_factor': relaxation factor + """ + required_options=[] + optional_options=['rel_factor'] + + def initialize(self,cfg): + self.W = astra.OpTomo(cfg['ProjectorId']) + self.vid = cfg['ReconstructionDataId'] + self.sid = cfg['ProjectionDataId'] + try: + self.rel = cfg['option']['rel_factor'] + except KeyError: + self.rel = 1 + + def run(self, its): + v = astra.data2d.get_shared(self.vid) + s = astra.data2d.get_shared(self.sid) + W = self.W + for i in range(its): + v[:] += self.rel*(W.T*(s - (W*v).reshape(s.shape))).reshape(v.shape)/s.size + +if __name__=='__main__': + + vol_geom = astra.create_vol_geom(256, 256) + proj_geom = astra.create_proj_geom('parallel', 1.0, 384, np.linspace(0,np.pi,180,False)) + + # As before, create a sinogram from a phantom + import scipy.io + P = scipy.io.loadmat('phantom.mat')['phantom256'] + proj_id = astra.create_projector('cuda',proj_geom,vol_geom) + + # construct the OpTomo object + W = astra.OpTomo(proj_id) + + sinogram = W * P + sinogram = sinogram.reshape([180, 384]) + + # Register the plugin with ASTRA + # A default set of plugins to load can be defined in: + # - /etc/astra-toolbox/plugins.txt + # - [ASTRA_INSTALL_PATH]/python/astra/plugins.txt + # - [USER_HOME_PATH]/.astra-toolbox/plugins.txt + # - [ASTRA_PLUGIN_PATH environment variable]/plugins.txt + # In these files, create a separate line for each plugin with: + # [PLUGIN_ASTRA_NAME] [FULL_PLUGIN_CLASS] + # + # So in this case, it would be a line: + # SIRT-PLUGIN s018_plugin.SIRTPlugin + # + astra.plugin.register('SIRT-PLUGIN','s018_plugin.SIRTPlugin') + + # To get help on a registered plugin, use get_help + six.print_(astra.plugin.get_help('SIRT-PLUGIN')) + + # Create data structures + sid = astra.data2d.create('-sino', proj_geom, sinogram) + vid = astra.data2d.create('-vol', vol_geom) + + # Create config using plugin name + cfg = astra.astra_dict('SIRT-PLUGIN') + cfg['ProjectorId'] = proj_id + cfg['ProjectionDataId'] = sid + cfg['ReconstructionDataId'] = vid + + # Create algorithm object + alg_id = astra.algorithm.create(cfg) + + # Run algorithm for 100 iterations + astra.algorithm.run(alg_id, 100) + + # Get reconstruction + rec = astra.data2d.get(vid) + + # Options for the plugin go in cfg['option'] + cfg = astra.astra_dict('SIRT-PLUGIN') + cfg['ProjectorId'] = proj_id + cfg['ProjectionDataId'] = sid + cfg['ReconstructionDataId'] = vid + cfg['option'] = {} + cfg['option']['rel_factor'] = 1.5 + alg_id_rel = astra.algorithm.create(cfg) + astra.algorithm.run(alg_id_rel, 100) + rec_rel = astra.data2d.get(vid) + + # We can also use OpTomo to call the plugin + rec_op = W.reconstruct('SIRT-PLUGIN', sinogram, 100, extraOptions={'rel_factor':1.5}) + + import pylab as pl + pl.gray() + pl.figure(1) + pl.imshow(rec,vmin=0,vmax=1) + pl.figure(2) + pl.imshow(rec_rel,vmin=0,vmax=1) + pl.figure(3) + pl.imshow(rec_op,vmin=0,vmax=1) + pl.show() + + # Clean up. + astra.projector.delete(proj_id) + astra.algorithm.delete([alg_id, alg_id_rel]) + astra.data2d.delete([vid, sid]) -- cgit v1.2.3 From 11af4b554df9a8a5c31d9dcbc1ea849b32394ba3 Mon Sep 17 00:00:00 2001 From: "Daniel M. Pelt" Date: Wed, 24 Jun 2015 18:36:03 +0200 Subject: Better way of passing options to Python plugin using inspect --- python/astra/plugin.py | 24 ++++++++++++------------ samples/python/s018_plugin.py | 13 ++++--------- 2 files changed, 16 insertions(+), 21 deletions(-) (limited to 'samples/python') diff --git a/python/astra/plugin.py b/python/astra/plugin.py index ccdb2cb..891f6c9 100644 --- a/python/astra/plugin.py +++ b/python/astra/plugin.py @@ -26,22 +26,20 @@ from . import plugin_c as p from . import log +import inspect class base(object): def astra_init(self, cfg): try: - try: - req = self.required_options - except AttributeError: - log.warn("Plugin '" + self.__class__.__name__ + "' does not specify required options") - req = {} - - try: - opt = self.optional_options - except AttributeError: - log.warn("Plugin '" + self.__class__.__name__ + "' does not specify optional options") - opt = {} + args, varargs, varkw, defaults = inspect.getargspec(self.initialize) + nopt = len(defaults) + if nopt>0: + req = args[2:-nopt] + opt = args[-nopt:] + else: + req = args[2:] + opt = [] try: optDict = cfg['options'] @@ -60,7 +58,9 @@ class base(object): if not cfgKeys.issubset(reqKeys | optKeys): log.warn(self.__class__.__name__ + ": unused configuration option: " + str(list(cfgKeys.difference(reqKeys | optKeys)))) - self.initialize(cfg) + args = [optDict[k] for k in req] + kwargs = dict((k,optDict[k]) for k in opt if k in optDict) + self.initialize(cfg, *args, **kwargs) except Exception as e: log.error(str(e)) raise diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py index 6677930..90e09ac 100644 --- a/samples/python/s018_plugin.py +++ b/samples/python/s018_plugin.py @@ -33,21 +33,16 @@ import six class SIRTPlugin(astra.plugin.base): """Example of an ASTRA plugin class, implementing a simple 2D SIRT algorithm. - Optional options: + Options: - 'rel_factor': relaxation factor + 'rel_factor': relaxation factor (optional) """ - required_options=[] - optional_options=['rel_factor'] - def initialize(self,cfg): + def initialize(self,cfg, rel_factor = 1): self.W = astra.OpTomo(cfg['ProjectorId']) self.vid = cfg['ReconstructionDataId'] self.sid = cfg['ProjectionDataId'] - try: - self.rel = cfg['option']['rel_factor'] - except KeyError: - self.rel = 1 + self.rel = rel_factor def run(self, its): v = astra.data2d.get_shared(self.vid) -- cgit v1.2.3 From d91b51f6d58003de84a9d6dd8189fceba0e81a5a Mon Sep 17 00:00:00 2001 From: "Daniel M. Pelt" Date: Mon, 20 Jul 2015 14:07:21 +0200 Subject: Allow registering plugins without explicit name, and fix exception handling when running in Matlab --- include/astra/PluginAlgorithm.h | 3 ++ matlab/mex/astra_mex_plugin_c.cpp | 23 ++++------ python/astra/plugin.py | 71 ++++++++++++----------------- python/astra/plugin_c.pyx | 14 ++++-- samples/python/s018_plugin.py | 23 +++++----- src/PluginAlgorithm.cpp | 95 +++++++++++++++++++++++++++++++-------- 6 files changed, 138 insertions(+), 91 deletions(-) (limited to 'samples/python') diff --git a/include/astra/PluginAlgorithm.h b/include/astra/PluginAlgorithm.h index a82c579..b56228e 100644 --- a/include/astra/PluginAlgorithm.h +++ b/include/astra/PluginAlgorithm.h @@ -64,9 +64,12 @@ public: CPluginAlgorithm * getPlugin(std::string name); bool registerPlugin(std::string name, std::string className); + bool registerPlugin(std::string className); bool registerPluginClass(std::string name, PyObject * className); + bool registerPluginClass(PyObject * className); PyObject * getRegistered(); + std::map getRegisteredMap(); std::string getHelp(std::string name); diff --git a/matlab/mex/astra_mex_plugin_c.cpp b/matlab/mex/astra_mex_plugin_c.cpp index 2d9b9a0..177fcf4 100644 --- a/matlab/mex/astra_mex_plugin_c.cpp +++ b/matlab/mex/astra_mex_plugin_c.cpp @@ -37,9 +37,6 @@ $Id$ #include "astra/PluginAlgorithm.h" -#include "Python.h" -#include "bytesobject.h" - using namespace std; using namespace astra; @@ -52,29 +49,25 @@ using namespace astra; void astra_mex_plugin_get_registered(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr(); - PyObject *dict = fact->getRegistered(); - PyObject *key, *value; - Py_ssize_t pos = 0; - while (PyDict_Next(dict, &pos, &key, &value)) { - mexPrintf("%s: %s\n",PyBytes_AsString(key),PyBytes_AsString(value)); + std::map mp = fact->getRegisteredMap(); + for(std::map::iterator it=mp.begin();it!=mp.end();it++){ + mexPrintf("%s: %s\n",it->first.c_str(), it->second.c_str()); } - Py_DECREF(dict); } //----------------------------------------------------------------------------------------- -/** astra_mex_plugin('register', name, class_name); +/** astra_mex_plugin('register', class_name); * * Register plugin. */ void astra_mex_plugin_register(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { - if (3 <= nrhs) { - string name = mexToString(prhs[1]); - string class_name = mexToString(prhs[2]); + if (2 <= nrhs) { + string class_name = mexToString(prhs[1]); astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr(); - fact->registerPlugin(name, class_name); + fact->registerPlugin(class_name); }else{ - mexPrintf("astra_mex_plugin('register', name, class_name);\n"); + mexPrintf("astra_mex_plugin('register', class_name);\n"); } } diff --git a/python/astra/plugin.py b/python/astra/plugin.py index f8fc3bd..4b32e6e 100644 --- a/python/astra/plugin.py +++ b/python/astra/plugin.py @@ -32,60 +32,47 @@ import traceback class base(object): def astra_init(self, cfg): - try: - args, varargs, varkw, defaults = inspect.getargspec(self.initialize) - if not defaults is None: - nopt = len(defaults) - else: - nopt = 0 - if nopt>0: - req = args[2:-nopt] - opt = args[-nopt:] - else: - req = args[2:] - opt = [] + args, varargs, varkw, defaults = inspect.getargspec(self.initialize) + if not defaults is None: + nopt = len(defaults) + else: + nopt = 0 + if nopt>0: + req = args[2:-nopt] + opt = args[-nopt:] + else: + req = args[2:] + opt = [] - try: - optDict = cfg['options'] - except KeyError: - optDict = {} + try: + optDict = cfg['options'] + except KeyError: + optDict = {} - cfgKeys = set(optDict.keys()) - reqKeys = set(req) - optKeys = set(opt) + cfgKeys = set(optDict.keys()) + reqKeys = set(req) + optKeys = set(opt) - if not reqKeys.issubset(cfgKeys): - for key in reqKeys.difference(cfgKeys): - log.error("Required option '" + key + "' for plugin '" + self.__class__.__name__ + "' not specified") - raise ValueError("Missing required options") + if not reqKeys.issubset(cfgKeys): + for key in reqKeys.difference(cfgKeys): + log.error("Required option '" + key + "' for plugin '" + self.__class__.__name__ + "' not specified") + raise ValueError("Missing required options") - if not cfgKeys.issubset(reqKeys | optKeys): - log.warn(self.__class__.__name__ + ": unused configuration option: " + str(list(cfgKeys.difference(reqKeys | optKeys)))) + if not cfgKeys.issubset(reqKeys | optKeys): + log.warn(self.__class__.__name__ + ": unused configuration option: " + str(list(cfgKeys.difference(reqKeys | optKeys)))) - args = [optDict[k] for k in req] - kwargs = dict((k,optDict[k]) for k in opt if k in optDict) - self.initialize(cfg, *args, **kwargs) - except Exception: - log.error(traceback.format_exc().replace("%","%%")) - raise + args = [optDict[k] for k in req] + kwargs = dict((k,optDict[k]) for k in opt if k in optDict) + self.initialize(cfg, *args, **kwargs) - def astra_run(self, its): - try: - self.run(its) - except Exception: - log.error(traceback.format_exc().replace("%","%%")) - raise - -def register(name, className): +def register(className): """Register plugin with ASTRA. - :param name: Plugin name to register - :type name: :class:`str` :param className: Class name or class object to register :type className: :class:`str` or :class:`class` """ - p.register(name,className) + p.register(className) def get_registered(): """Get dictionary of registered plugins. diff --git a/python/astra/plugin_c.pyx b/python/astra/plugin_c.pyx index 91b3cd5..8d6816b 100644 --- a/python/astra/plugin_c.pyx +++ b/python/astra/plugin_c.pyx @@ -38,7 +38,9 @@ from . import utils cdef extern from "astra/PluginAlgorithm.h" namespace "astra": cdef cppclass CPluginAlgorithmFactory: + bool registerPlugin(string className) bool registerPlugin(string name, string className) + bool registerPluginClass(object className) bool registerPluginClass(string name, object className) object getRegistered() string getHelp(string name) @@ -46,11 +48,17 @@ cdef extern from "astra/PluginAlgorithm.h" namespace "astra": cdef extern from "astra/PluginAlgorithm.h" namespace "astra::CPluginAlgorithmFactory": cdef CPluginAlgorithmFactory* getSingletonPtr() -def register(name, className): +def register(className, name=None): if inspect.isclass(className): - fact.registerPluginClass(six.b(name), className) + if name==None: + fact.registerPluginClass(className) + else: + fact.registerPluginClass(six.b(name), className) else: - fact.registerPlugin(six.b(name), six.b(className)) + if name==None: + fact.registerPlugin(six.b(className)) + else: + fact.registerPlugin(six.b(name), six.b(className)) def get_registered(): return fact.getRegistered() diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py index 90e09ac..31cca95 100644 --- a/samples/python/s018_plugin.py +++ b/samples/python/s018_plugin.py @@ -38,6 +38,10 @@ class SIRTPlugin(astra.plugin.base): 'rel_factor': relaxation factor (optional) """ + # The astra_name variable defines the name to use to + # call the plugin from ASTRA + astra_name = "SIRT-PLUGIN" + def initialize(self,cfg, rel_factor = 1): self.W = astra.OpTomo(cfg['ProjectorId']) self.vid = cfg['ReconstructionDataId'] @@ -68,18 +72,13 @@ if __name__=='__main__': sinogram = sinogram.reshape([180, 384]) # Register the plugin with ASTRA - # A default set of plugins to load can be defined in: - # - /etc/astra-toolbox/plugins.txt - # - [ASTRA_INSTALL_PATH]/python/astra/plugins.txt - # - [USER_HOME_PATH]/.astra-toolbox/plugins.txt - # - [ASTRA_PLUGIN_PATH environment variable]/plugins.txt - # In these files, create a separate line for each plugin with: - # [PLUGIN_ASTRA_NAME] [FULL_PLUGIN_CLASS] - # - # So in this case, it would be a line: - # SIRT-PLUGIN s018_plugin.SIRTPlugin - # - astra.plugin.register('SIRT-PLUGIN','s018_plugin.SIRTPlugin') + # First we import the package that contains the plugin + import s018_plugin + # Then, we register the plugin class with ASTRA + astra.plugin.register(s018_plugin.SIRTPlugin) + + # Get a list of registered plugins + six.print_(astra.plugin.get_registered()) # To get help on a registered plugin, use get_help six.print_(astra.plugin.get_help('SIRT-PLUGIN')) diff --git a/src/PluginAlgorithm.cpp b/src/PluginAlgorithm.cpp index d6cf731..7f7ff61 100644 --- a/src/PluginAlgorithm.cpp +++ b/src/PluginAlgorithm.cpp @@ -100,7 +100,10 @@ bool CPluginAlgorithm::initialize(const Config& _cfg){ PyObject *cfgDict = XMLNode2dict(_cfg.self); PyObject *retVal = PyObject_CallMethod(instance, "astra_init", "O",cfgDict); Py_DECREF(cfgDict); - if(retVal==NULL) return false; + if(retVal==NULL){ + logPythonError(); + return false; + } m_bIsInitialized = true; Py_DECREF(retVal); return m_bIsInitialized; @@ -108,8 +111,11 @@ bool CPluginAlgorithm::initialize(const Config& _cfg){ void CPluginAlgorithm::run(int _iNrIterations){ if(instance==NULL) return; - PyObject *retVal = PyObject_CallMethod(instance, "astra_run", "i",_iNrIterations); - if(retVal==NULL) return; + PyObject *retVal = PyObject_CallMethod(instance, "run", "i",_iNrIterations); + if(retVal==NULL){ + logPythonError(); + return; + } Py_DECREF(retVal); } @@ -157,18 +163,6 @@ CPluginAlgorithmFactory::~CPluginAlgorithmFactory(){ if(six!=NULL) Py_DECREF(six); } -bool CPluginAlgorithmFactory::registerPlugin(std::string name, std::string className){ - PyObject *str = PyBytes_FromString(className.c_str()); - PyDict_SetItemString(pluginDict, name.c_str(), str); - Py_DECREF(str); - return true; -} - -bool CPluginAlgorithmFactory::registerPluginClass(std::string name, PyObject * className){ - PyDict_SetItemString(pluginDict, name.c_str(), className); - return true; -} - PyObject * getClassFromString(std::string str){ std::vector items; boost::split(items, str, boost::is_any_of(".")); @@ -190,6 +184,43 @@ PyObject * getClassFromString(std::string str){ return pyclass; } +bool CPluginAlgorithmFactory::registerPlugin(std::string name, std::string className){ + PyObject *str = PyBytes_FromString(className.c_str()); + PyDict_SetItemString(pluginDict, name.c_str(), str); + Py_DECREF(str); + return true; +} + +bool CPluginAlgorithmFactory::registerPlugin(std::string className){ + PyObject *pyclass = getClassFromString(className); + if(pyclass==NULL) return false; + bool ret = registerPluginClass(pyclass); + Py_DECREF(pyclass); + return ret; +} + +bool CPluginAlgorithmFactory::registerPluginClass(std::string name, PyObject * className){ + PyDict_SetItemString(pluginDict, name.c_str(), className); + return true; +} + +bool CPluginAlgorithmFactory::registerPluginClass(PyObject * className){ + PyObject *astra_name = PyObject_GetAttrString(className,"astra_name"); + if(astra_name==NULL){ + logPythonError(); + return false; + } + PyObject *retb = PyObject_CallMethod(six,"b","O",astra_name); + if(retb!=NULL){ + PyDict_SetItemString(pluginDict,PyBytes_AsString(retb),className); + Py_DECREF(retb); + }else{ + logPythonError(); + } + Py_DECREF(astra_name); + return true; +} + CPluginAlgorithm * CPluginAlgorithmFactory::getPlugin(std::string name){ PyObject *className = PyDict_GetItemString(pluginDict, name.c_str()); if(className==NULL) return NULL; @@ -212,12 +243,34 @@ PyObject * CPluginAlgorithmFactory::getRegistered(){ return pluginDict; } +std::map CPluginAlgorithmFactory::getRegisteredMap(){ + std::map ret; + PyObject *key, *value; + Py_ssize_t pos = 0; + while (PyDict_Next(pluginDict, &pos, &key, &value)) { + PyObject * keyb = PyObject_Bytes(key); + PyObject * valb = PyObject_Bytes(value); + ret[PyBytes_AsString(keyb)] = PyBytes_AsString(valb); + Py_DECREF(keyb); + Py_DECREF(valb); + } + return ret; +} + std::string CPluginAlgorithmFactory::getHelp(std::string name){ PyObject *className = PyDict_GetItemString(pluginDict, name.c_str()); - if(className==NULL) return ""; - std::string str = std::string(PyBytes_AsString(className)); + if(className==NULL){ + ASTRA_ERROR("Plugin %s not found!",name.c_str()); + return ""; + } std::string ret = ""; - PyObject *pyclass = getClassFromString(str); + PyObject *pyclass; + if(PyBytes_Check(className)){ + std::string str = std::string(PyBytes_AsString(className)); + pyclass = getClassFromString(str); + }else{ + pyclass = className; + } if(pyclass==NULL) return ""; if(inspect!=NULL && six!=NULL){ PyObject *retVal = PyObject_CallMethod(inspect,"getdoc","O",pyclass); @@ -228,9 +281,13 @@ std::string CPluginAlgorithmFactory::getHelp(std::string name){ ret = std::string(PyBytes_AsString(retb)); Py_DECREF(retb); } + }else{ + logPythonError(); } } - Py_DECREF(pyclass); + if(PyBytes_Check(className)){ + Py_DECREF(pyclass); + } return ret; } -- cgit v1.2.3 From e07449189a05e3bcdc8ad4a9fbb95c0751f567bb Mon Sep 17 00:00:00 2001 From: Willem Jan Palenstijn Date: Fri, 4 Dec 2015 15:15:16 +0100 Subject: Add sample for experimental composite geometry code --- python/astra/PyIncludes.pxd | 2 + python/astra/experimental.pyx | 84 ++++++++++++++++++++++++++++ samples/python/s018_experimental_multires.py | 84 ++++++++++++++++++++++++++++ 3 files changed, 170 insertions(+) create mode 100644 python/astra/experimental.pyx create mode 100644 samples/python/s018_experimental_multires.py (limited to 'samples/python') diff --git a/python/astra/PyIncludes.pxd b/python/astra/PyIncludes.pxd index 35dea5f..e9e2bdb 100644 --- a/python/astra/PyIncludes.pxd +++ b/python/astra/PyIncludes.pxd @@ -224,6 +224,7 @@ cdef extern from "astra/Float32VolumeData3DMemory.h" namespace "astra": int getRowCount() int getColCount() int getSliceCount() + bool isInitialized() @@ -255,6 +256,7 @@ cdef extern from "astra/Float32ProjectionData3DMemory.h" namespace "astra": int getDetectorColCount() int getDetectorRowCount() int getAngleCount() + bool isInitialized() cdef extern from "astra/Float32Data3D.h" namespace "astra": cdef cppclass CFloat32Data3D: diff --git a/python/astra/experimental.pyx b/python/astra/experimental.pyx new file mode 100644 index 0000000..da27504 --- /dev/null +++ b/python/astra/experimental.pyx @@ -0,0 +1,84 @@ +#----------------------------------------------------------------------- +# Copyright: 2010-2015, iMinds-Vision Lab, University of Antwerp +# 2014-2015, CWI, Amsterdam +# +# Contact: astra@uantwerpen.be +# Website: http://sf.net/projects/astra-toolbox +# +# This file is part of the ASTRA Toolbox. +# +# +# The ASTRA Toolbox is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# The ASTRA Toolbox is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with the ASTRA Toolbox. If not, see . +# +#----------------------------------------------------------------------- + +# distutils: language = c++ +# distutils: libraries = astra + +include "config.pxi" + +import six +from .PyIncludes cimport * +from libcpp.vector cimport vector + +cdef extern from "astra/CompositeGeometryManager.h" namespace "astra": + cdef cppclass CCompositeGeometryManager: + bool doFP(CProjector3D *, vector[CFloat32VolumeData3DMemory *], vector[CFloat32ProjectionData3DMemory *]) + bool doBP(CProjector3D *, vector[CFloat32VolumeData3DMemory *], vector[CFloat32ProjectionData3DMemory *]) + +cdef extern from *: + CFloat32VolumeData3DMemory * dynamic_cast_vol_mem "dynamic_cast" (CFloat32Data3D * ) except NULL + CFloat32ProjectionData3DMemory * dynamic_cast_proj_mem "dynamic_cast" (CFloat32Data3D * ) except NULL + +cimport PyProjector3DManager +from .PyProjector3DManager cimport CProjector3DManager +cimport PyData3DManager +from .PyData3DManager cimport CData3DManager + +cdef CProjector3DManager * manProj = PyProjector3DManager.getSingletonPtr() +cdef CData3DManager * man3d = PyData3DManager.getSingletonPtr() + +def do_composite(projector_id, vol_ids, proj_ids, t): + cdef vector[CFloat32VolumeData3DMemory *] vol + cdef CFloat32VolumeData3DMemory * pVolObject + cdef CFloat32ProjectionData3DMemory * pProjObject + for v in vol_ids: + pVolObject = dynamic_cast_vol_mem(man3d.get(v)) + if pVolObject == NULL: + raise Exception("Data object not found") + if not pVolObject.isInitialized(): + raise Exception("Data object not initialized properly") + vol.push_back(pVolObject) + cdef vector[CFloat32ProjectionData3DMemory *] proj + for v in proj_ids: + pProjObject = dynamic_cast_proj_mem(man3d.get(v)) + if pProjObject == NULL: + raise Exception("Data object not found") + if not pProjObject.isInitialized(): + raise Exception("Data object not initialized properly") + proj.push_back(pProjObject) + cdef CCompositeGeometryManager m + cdef CProjector3D * projector = manProj.get(projector_id) # may be NULL + if t == "FP": + if not m.doFP(projector, vol, proj): + raise Exception("Failed to perform FP") + else: + if not m.doBP(projector, vol, proj): + raise Exception("Failed to perform BP") + +def do_composite_FP(projector_id, vol_ids, proj_ids): + do_composite(projector_id, vol_ids, proj_ids, "FP") + +def do_composite_BP(projector_id, vol_ids, proj_ids): + do_composite(projector_id, vol_ids, proj_ids, "BP") diff --git a/samples/python/s018_experimental_multires.py b/samples/python/s018_experimental_multires.py new file mode 100644 index 0000000..cf38e53 --- /dev/null +++ b/samples/python/s018_experimental_multires.py @@ -0,0 +1,84 @@ +#----------------------------------------------------------------------- +#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see . +# +#----------------------------------------------------------------------- + +import astra +import numpy as np +from astra.experimental import do_composite_FP + +astra.log.setOutputScreen(astra.log.STDERR, astra.log.DEBUG) + +# low res part (voxels of 4x4x4) +vol_geom1 = astra.create_vol_geom(32, 16, 32, -64, 0, -64, 64, -64, 64) + +# high res part (voxels of 1x1x1) +vol_geom2 = astra.create_vol_geom(128, 64, 128, 0, 64, -64, 64, -64, 64) + + +# Split the output in two parts as well, for demonstration purposes +angles1 = np.linspace(0, np.pi/2, 90, False) +angles2 = np.linspace(np.pi/2, np.pi, 90, False) +proj_geom1 = astra.create_proj_geom('parallel3d', 1.0, 1.0, 128, 192, angles1) +proj_geom2 = astra.create_proj_geom('parallel3d', 1.0, 1.0, 128, 192, angles2) + +# Create a simple hollow cube phantom +cube1 = np.zeros((32,32,16)) +cube1[4:28,4:28,4:16] = 1 + +cube2 = np.zeros((128,128,64)) +cube2[16:112,16:112,0:112] = 1 +cube2[33:97,33:97,4:28] = 0 + +vol1 = astra.data3d.create('-vol', vol_geom1, cube1) +vol2 = astra.data3d.create('-vol', vol_geom2, cube2) + +proj1 = astra.data3d.create('-proj3d', proj_geom1, 0) +proj2 = astra.data3d.create('-proj3d', proj_geom2, 0) + +# The actual geometries don't matter for this composite FP/BP case +projector = astra.create_projector('cuda3d', proj_geom1, vol_geom1) + +do_composite_FP(projector, [vol1, vol2], [proj1, proj2]) + +proj_data1 = astra.data3d.get(proj1) +proj_data2 = astra.data3d.get(proj2) + +# Display a single projection image +import pylab +pylab.gray() +pylab.figure(1) +pylab.imshow(proj_data1[:,0,:]) +pylab.figure(2) +pylab.imshow(proj_data2[:,0,:]) +pylab.show() + + +# Clean up. Note that GPU memory is tied up in the algorithm object, +# and main RAM in the data objects. +astra.data3d.delete(vol1) +astra.data3d.delete(vol2) +astra.data3d.delete(proj1) +astra.data3d.delete(proj2) +astra.projector3d.delete(projector) -- cgit v1.2.3 From f2227eaca7248b6b01a5776f0cb750cff4c0279a Mon Sep 17 00:00:00 2001 From: Willem Jan Palenstijn Date: Thu, 7 Jan 2016 13:29:37 +0100 Subject: Rename sample with conflicting filename --- samples/python/s018_experimental_multires.py | 84 ---------------------------- samples/python/s019_experimental_multires.py | 84 ++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 84 deletions(-) delete mode 100644 samples/python/s018_experimental_multires.py create mode 100644 samples/python/s019_experimental_multires.py (limited to 'samples/python') diff --git a/samples/python/s018_experimental_multires.py b/samples/python/s018_experimental_multires.py deleted file mode 100644 index cf38e53..0000000 --- a/samples/python/s018_experimental_multires.py +++ /dev/null @@ -1,84 +0,0 @@ -#----------------------------------------------------------------------- -#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam -# -#Author: Daniel M. Pelt -#Contact: D.M.Pelt@cwi.nl -#Website: http://dmpelt.github.io/pyastratoolbox/ -# -# -#This file is part of the Python interface to the -#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). -# -#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify -#it under the terms of the GNU General Public License as published by -#the Free Software Foundation, either version 3 of the License, or -#(at your option) any later version. -# -#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, -#but WITHOUT ANY WARRANTY; without even the implied warranty of -#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -#GNU General Public License for more details. -# -#You should have received a copy of the GNU General Public License -#along with the Python interface to the ASTRA Toolbox. If not, see . -# -#----------------------------------------------------------------------- - -import astra -import numpy as np -from astra.experimental import do_composite_FP - -astra.log.setOutputScreen(astra.log.STDERR, astra.log.DEBUG) - -# low res part (voxels of 4x4x4) -vol_geom1 = astra.create_vol_geom(32, 16, 32, -64, 0, -64, 64, -64, 64) - -# high res part (voxels of 1x1x1) -vol_geom2 = astra.create_vol_geom(128, 64, 128, 0, 64, -64, 64, -64, 64) - - -# Split the output in two parts as well, for demonstration purposes -angles1 = np.linspace(0, np.pi/2, 90, False) -angles2 = np.linspace(np.pi/2, np.pi, 90, False) -proj_geom1 = astra.create_proj_geom('parallel3d', 1.0, 1.0, 128, 192, angles1) -proj_geom2 = astra.create_proj_geom('parallel3d', 1.0, 1.0, 128, 192, angles2) - -# Create a simple hollow cube phantom -cube1 = np.zeros((32,32,16)) -cube1[4:28,4:28,4:16] = 1 - -cube2 = np.zeros((128,128,64)) -cube2[16:112,16:112,0:112] = 1 -cube2[33:97,33:97,4:28] = 0 - -vol1 = astra.data3d.create('-vol', vol_geom1, cube1) -vol2 = astra.data3d.create('-vol', vol_geom2, cube2) - -proj1 = astra.data3d.create('-proj3d', proj_geom1, 0) -proj2 = astra.data3d.create('-proj3d', proj_geom2, 0) - -# The actual geometries don't matter for this composite FP/BP case -projector = astra.create_projector('cuda3d', proj_geom1, vol_geom1) - -do_composite_FP(projector, [vol1, vol2], [proj1, proj2]) - -proj_data1 = astra.data3d.get(proj1) -proj_data2 = astra.data3d.get(proj2) - -# Display a single projection image -import pylab -pylab.gray() -pylab.figure(1) -pylab.imshow(proj_data1[:,0,:]) -pylab.figure(2) -pylab.imshow(proj_data2[:,0,:]) -pylab.show() - - -# Clean up. Note that GPU memory is tied up in the algorithm object, -# and main RAM in the data objects. -astra.data3d.delete(vol1) -astra.data3d.delete(vol2) -astra.data3d.delete(proj1) -astra.data3d.delete(proj2) -astra.projector3d.delete(projector) diff --git a/samples/python/s019_experimental_multires.py b/samples/python/s019_experimental_multires.py new file mode 100644 index 0000000..cf38e53 --- /dev/null +++ b/samples/python/s019_experimental_multires.py @@ -0,0 +1,84 @@ +#----------------------------------------------------------------------- +#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see . +# +#----------------------------------------------------------------------- + +import astra +import numpy as np +from astra.experimental import do_composite_FP + +astra.log.setOutputScreen(astra.log.STDERR, astra.log.DEBUG) + +# low res part (voxels of 4x4x4) +vol_geom1 = astra.create_vol_geom(32, 16, 32, -64, 0, -64, 64, -64, 64) + +# high res part (voxels of 1x1x1) +vol_geom2 = astra.create_vol_geom(128, 64, 128, 0, 64, -64, 64, -64, 64) + + +# Split the output in two parts as well, for demonstration purposes +angles1 = np.linspace(0, np.pi/2, 90, False) +angles2 = np.linspace(np.pi/2, np.pi, 90, False) +proj_geom1 = astra.create_proj_geom('parallel3d', 1.0, 1.0, 128, 192, angles1) +proj_geom2 = astra.create_proj_geom('parallel3d', 1.0, 1.0, 128, 192, angles2) + +# Create a simple hollow cube phantom +cube1 = np.zeros((32,32,16)) +cube1[4:28,4:28,4:16] = 1 + +cube2 = np.zeros((128,128,64)) +cube2[16:112,16:112,0:112] = 1 +cube2[33:97,33:97,4:28] = 0 + +vol1 = astra.data3d.create('-vol', vol_geom1, cube1) +vol2 = astra.data3d.create('-vol', vol_geom2, cube2) + +proj1 = astra.data3d.create('-proj3d', proj_geom1, 0) +proj2 = astra.data3d.create('-proj3d', proj_geom2, 0) + +# The actual geometries don't matter for this composite FP/BP case +projector = astra.create_projector('cuda3d', proj_geom1, vol_geom1) + +do_composite_FP(projector, [vol1, vol2], [proj1, proj2]) + +proj_data1 = astra.data3d.get(proj1) +proj_data2 = astra.data3d.get(proj2) + +# Display a single projection image +import pylab +pylab.gray() +pylab.figure(1) +pylab.imshow(proj_data1[:,0,:]) +pylab.figure(2) +pylab.imshow(proj_data2[:,0,:]) +pylab.show() + + +# Clean up. Note that GPU memory is tied up in the algorithm object, +# and main RAM in the data objects. +astra.data3d.delete(vol1) +astra.data3d.delete(vol2) +astra.data3d.delete(proj1) +astra.data3d.delete(proj2) +astra.projector3d.delete(projector) -- cgit v1.2.3 From b529ff854f0e1191108e31d6be294d31b50c666e Mon Sep 17 00:00:00 2001 From: Willem Jan Palenstijn Date: Thu, 7 Jan 2016 13:36:02 +0100 Subject: Add multi-GPU sample --- samples/matlab/s020_3d_multiGPU.m | 38 +++++++++++++++++++++++++ samples/python/s020_3d_multiGPU.py | 57 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+) create mode 100644 samples/matlab/s020_3d_multiGPU.m create mode 100644 samples/python/s020_3d_multiGPU.py (limited to 'samples/python') diff --git a/samples/matlab/s020_3d_multiGPU.m b/samples/matlab/s020_3d_multiGPU.m new file mode 100644 index 0000000..bade325 --- /dev/null +++ b/samples/matlab/s020_3d_multiGPU.m @@ -0,0 +1,38 @@ +% ----------------------------------------------------------------------- +% This file is part of the ASTRA Toolbox +% +% Copyright: 2010-2015, iMinds-Vision Lab, University of Antwerp +% 2014-2015, CWI, Amsterdam +% License: Open Source under GPLv3 +% Contact: astra@uantwerpen.be +% Website: http://sf.net/projects/astra-toolbox +% ----------------------------------------------------------------------- + + +% Set up multi-GPU usage. +% This only works for 3D GPU forward projection and back projection. +astra_mex('set_gpu_index', [0 1]); + +% Optionally, you can also restrict the amount of GPU memory ASTRA will use. +% The line commented below sets this to 1GB. +%astra_mex('set_gpu_index', [0 1], 'memory', 1024*1024*1024); + +vol_geom = astra_create_vol_geom(1024, 1024, 1024); + +angles = linspace2(0, pi, 1024); +proj_geom = astra_create_proj_geom('parallel3d', 1.0, 1.0, 1024, 1024, angles); + +% Create a simple hollow cube phantom +cube = zeros(1024,1024,1024); +cube(129:896,129:896,129:896) = 1; +cube(257:768,257:768,257:768) = 0; + +% Create projection data from this +[proj_id, proj_data] = astra_create_sino3d_cuda(cube, proj_geom, vol_geom); + +% Backproject projection data +[bproj_id, bproj_data] = astra_create_backprojection3d_cuda(proj_data, proj_geom, vol_geom); + +astra_mex_data3d('delete', proj_id); +astra_mex_data3d('delete', bproj_id); + diff --git a/samples/python/s020_3d_multiGPU.py b/samples/python/s020_3d_multiGPU.py new file mode 100644 index 0000000..d6799c4 --- /dev/null +++ b/samples/python/s020_3d_multiGPU.py @@ -0,0 +1,57 @@ +#----------------------------------------------------------------------- +#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see . +# +#----------------------------------------------------------------------- + +import astra +import numpy as np + +# Set up multi-GPU usage. +# This only works for 3D GPU forward projection and back projection. +astra.astra.set_gpu_index([0,1]) + +# Optionally, you can also restrict the amount of GPU memory ASTRA will use. +# The line commented below sets this to 1GB. +#astra.astra.set_gpu_index([0,1], memory=1024*1024*1024) + +vol_geom = astra.create_vol_geom(1024, 1024, 1024) + +angles = np.linspace(0, np.pi, 1024,False) +proj_geom = astra.create_proj_geom('parallel3d', 1.0, 1.0, 1024, 1024, angles) + +# Create a simple hollow cube phantom +cube = np.zeros((1024,1024,1024)) +cube[128:895,128:895,128:895] = 1 +cube[256:767,256:767,256:767] = 0 + +# Create projection data from this +proj_id, proj_data = astra.create_sino3d_gpu(cube, proj_geom, vol_geom) + +# Backproject projection data +bproj_id, bproj_data = astra.create_backprojection3d_gpu(proj_data, proj_geom, vol_geom) + +# Clean up. Note that GPU memory is tied up in the algorithm object, +# and main RAM in the data objects. +astra.data3d.delete(proj_id) +astra.data3d.delete(bproj_id) -- cgit v1.2.3 From 1e26f7602b6685c584fd4d857353f390622e3a34 Mon Sep 17 00:00:00 2001 From: "Daniel M. Pelt" Date: Mon, 25 Apr 2016 10:47:59 +0200 Subject: Change flatten to ravel in Python code --- python/astra/optomo.py | 4 ++-- samples/python/s009_projection_matrix.py | 2 +- samples/python/s015_fp_bp.py | 6 +++--- samples/python/s017_OpTomo.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) (limited to 'samples/python') diff --git a/python/astra/optomo.py b/python/astra/optomo.py index dd10713..5a92998 100644 --- a/python/astra/optomo.py +++ b/python/astra/optomo.py @@ -125,7 +125,7 @@ class OpTomo(scipy.sparse.linalg.LinearOperator): algorithm.delete(fp_id) self.data_mod.delete([vid,sid]) - return s.flatten() + return s.ravel() def rmatvec(self,s): """Implements the transpose operator. @@ -147,7 +147,7 @@ class OpTomo(scipy.sparse.linalg.LinearOperator): algorithm.delete(bp_id) self.data_mod.delete([vid,sid]) - return v.flatten() + return v.ravel() def __mul__(self,v): """Provides easy forward operator by *. diff --git a/samples/python/s009_projection_matrix.py b/samples/python/s009_projection_matrix.py index c4c4557..e20d58c 100644 --- a/samples/python/s009_projection_matrix.py +++ b/samples/python/s009_projection_matrix.py @@ -46,7 +46,7 @@ W = astra.matrix.get(matrix_id) # Manually use this projection matrix to do a projection: import scipy.io P = scipy.io.loadmat('phantom.mat')['phantom256'] -s = W.dot(P.flatten()) +s = W.dot(P.ravel()) s = np.reshape(s, (len(proj_geom['ProjectionAngles']),proj_geom['DetectorCount'])) import pylab diff --git a/samples/python/s015_fp_bp.py b/samples/python/s015_fp_bp.py index fa0bf86..ff0b30a 100644 --- a/samples/python/s015_fp_bp.py +++ b/samples/python/s015_fp_bp.py @@ -46,12 +46,12 @@ class astra_wrap(object): def matvec(self,v): sid, s = astra.create_sino(np.reshape(v,(vol_geom['GridRowCount'],vol_geom['GridColCount'])),self.proj_id) astra.data2d.delete(sid) - return s.flatten() + return s.ravel() def rmatvec(self,v): bid, b = astra.create_backprojection(np.reshape(v,(len(proj_geom['ProjectionAngles']),proj_geom['DetectorCount'],)),self.proj_id) astra.data2d.delete(bid) - return b.flatten() + return b.ravel() vol_geom = astra.create_vol_geom(256, 256) proj_geom = astra.create_proj_geom('parallel', 1.0, 384, np.linspace(0,np.pi,180,False)) @@ -65,7 +65,7 @@ proj_id = astra.create_projector('cuda',proj_geom,vol_geom) sinogram_id, sinogram = astra.create_sino(P, proj_id) # Reshape the sinogram into a vector -b = sinogram.flatten() +b = sinogram.ravel() # Call lsqr with ASTRA FP and BP import scipy.sparse.linalg diff --git a/samples/python/s017_OpTomo.py b/samples/python/s017_OpTomo.py index 967fa64..214e9a7 100644 --- a/samples/python/s017_OpTomo.py +++ b/samples/python/s017_OpTomo.py @@ -50,7 +50,7 @@ pylab.figure(2) pylab.imshow(sinogram) # Run the lsqr linear solver -output = scipy.sparse.linalg.lsqr(W, sinogram.flatten(), iter_lim=150) +output = scipy.sparse.linalg.lsqr(W, sinogram.ravel(), iter_lim=150) rec = output[0].reshape([256, 256]) pylab.figure(3) -- cgit v1.2.3 From ed717202a0c917958892e26322d6ea5173f7b32c Mon Sep 17 00:00:00 2001 From: Willem Jan Palenstijn Date: Mon, 25 Apr 2016 17:04:39 +0200 Subject: Use FP/BP out argument in sample plugin --- samples/python/s018_plugin.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) (limited to 'samples/python') diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py index 31cca95..85b5486 100644 --- a/samples/python/s018_plugin.py +++ b/samples/python/s018_plugin.py @@ -30,30 +30,38 @@ import six # Define the plugin class (has to subclass astra.plugin.base) # Note that usually, these will be defined in a separate package/module -class SIRTPlugin(astra.plugin.base): - """Example of an ASTRA plugin class, implementing a simple 2D SIRT algorithm. +class LandweberPlugin(astra.plugin.base): + """Example of an ASTRA plugin class, implementing a simple 2D Landweber algorithm. Options: - 'rel_factor': relaxation factor (optional) + 'Relaxation': relaxation factor (optional) """ # The astra_name variable defines the name to use to # call the plugin from ASTRA - astra_name = "SIRT-PLUGIN" + astra_name = "LANDWEBER-PLUGIN" - def initialize(self,cfg, rel_factor = 1): + def initialize(self,cfg, Relaxation = 1): self.W = astra.OpTomo(cfg['ProjectorId']) self.vid = cfg['ReconstructionDataId'] self.sid = cfg['ProjectionDataId'] - self.rel = rel_factor + self.rel = Relaxation def run(self, its): v = astra.data2d.get_shared(self.vid) s = astra.data2d.get_shared(self.sid) + tv = np.zeros(v.shape, dtype=np.float32) + ts = np.zeros(s.shape, dtype=np.float32) W = self.W for i in range(its): - v[:] += self.rel*(W.T*(s - (W*v).reshape(s.shape))).reshape(v.shape)/s.size + W.FP(v,out=ts) + ts -= s # ts = W*v - s + + W.BP(ts,out=tv) + tv *= self.rel / s.size + + v -= tv # v = v - rel * W'*(W*v-s) / s.size if __name__=='__main__': @@ -75,20 +83,20 @@ if __name__=='__main__': # First we import the package that contains the plugin import s018_plugin # Then, we register the plugin class with ASTRA - astra.plugin.register(s018_plugin.SIRTPlugin) + astra.plugin.register(s018_plugin.LandweberPlugin) # Get a list of registered plugins six.print_(astra.plugin.get_registered()) # To get help on a registered plugin, use get_help - six.print_(astra.plugin.get_help('SIRT-PLUGIN')) + six.print_(astra.plugin.get_help('LANDWEBER-PLUGIN')) # Create data structures sid = astra.data2d.create('-sino', proj_geom, sinogram) vid = astra.data2d.create('-vol', vol_geom) # Create config using plugin name - cfg = astra.astra_dict('SIRT-PLUGIN') + cfg = astra.astra_dict('LANDWEBER-PLUGIN') cfg['ProjectorId'] = proj_id cfg['ProjectionDataId'] = sid cfg['ReconstructionDataId'] = vid @@ -103,18 +111,18 @@ if __name__=='__main__': rec = astra.data2d.get(vid) # Options for the plugin go in cfg['option'] - cfg = astra.astra_dict('SIRT-PLUGIN') + cfg = astra.astra_dict('LANDWEBER-PLUGIN') cfg['ProjectorId'] = proj_id cfg['ProjectionDataId'] = sid cfg['ReconstructionDataId'] = vid cfg['option'] = {} - cfg['option']['rel_factor'] = 1.5 + cfg['option']['Relaxation'] = 1.5 alg_id_rel = astra.algorithm.create(cfg) astra.algorithm.run(alg_id_rel, 100) rec_rel = astra.data2d.get(vid) # We can also use OpTomo to call the plugin - rec_op = W.reconstruct('SIRT-PLUGIN', sinogram, 100, extraOptions={'rel_factor':1.5}) + rec_op = W.reconstruct('LANDWEBER-PLUGIN', sinogram, 100, extraOptions={'Relaxation':1.5}) import pylab as pl pl.gray() -- cgit v1.2.3