From 233331b4a192c0149f58af1d4c89526260cd3a58 Mon Sep 17 00:00:00 2001 From: Willem Jan Palenstijn Date: Tue, 23 Jun 2015 12:18:47 +0200 Subject: Update sample --- samples/matlab/s010_supersampling.m | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) (limited to 'samples') diff --git a/samples/matlab/s010_supersampling.m b/samples/matlab/s010_supersampling.m index 80f6f56..148f6ad 100644 --- a/samples/matlab/s010_supersampling.m +++ b/samples/matlab/s010_supersampling.m @@ -12,23 +12,15 @@ vol_geom = astra_create_vol_geom(256, 256); proj_geom = astra_create_proj_geom('parallel', 3.0, 128, linspace2(0,pi,180)); P = phantom(256); -% Because the astra_create_sino_gpu wrapper does not have support for -% all possible algorithm options, we manually create a sinogram -phantom_id = astra_mex_data2d('create', '-vol', vol_geom, P); -sinogram_id = astra_mex_data2d('create', '-sino', proj_geom); -cfg = astra_struct('FP_CUDA'); -cfg.VolumeDataId = phantom_id; -cfg.ProjectionDataId = sinogram_id; +% We create a projector set up to use 3 rays per detector element +cfg_proj = astra_struct('cuda'); +cfg_proj.option.DetectorSuperSampling = 3; +cfg_proj.ProjectionGeometry = proj_geom; +cfg_proj.VolumeGeometry = vol_geom; +proj_id = astra_mex_projector('create', cfg_proj); -% Set up 3 rays per detector element -cfg.option.DetectorSuperSampling = 3; -alg_id = astra_mex_algorithm('create', cfg); -astra_mex_algorithm('run', alg_id); -astra_mex_algorithm('delete', alg_id); -astra_mex_data2d('delete', phantom_id); - -sinogram3 = astra_mex_data2d('get', sinogram_id); +[sinogram3 sinogram_id] = astra_create_sino(P, proj_id); figure(1); imshow(P, []); figure(2); imshow(sinogram3, []); @@ -39,14 +31,14 @@ rec_id = astra_mex_data2d('create', '-vol', vol_geom); cfg = astra_struct('SIRT_CUDA'); cfg.ReconstructionDataId = rec_id; cfg.ProjectionDataId = sinogram_id; -% Set up 3 rays per detector element -cfg.option.DetectorSuperSampling = 3; +cfg.ProjectorId = proj_id; + % There is also an option for supersampling during the backprojection step. % This should be used if your detector pixels are smaller than the voxels. % Set up 2 rays per image pixel dimension, for 4 rays total per image pixel. -% cfg.option.PixelSuperSampling = 2; +% cfg_proj.option.PixelSuperSampling = 2; alg_id = astra_mex_algorithm('create', cfg); -- cgit v1.2.3 From 18b6d25f7e4f0943b3592f3bb4f6ca5ed9c285d3 Mon Sep 17 00:00:00 2001 From: "Daniel M. Pelt" Date: Fri, 19 Jun 2015 22:28:06 +0200 Subject: Add support for Python algorithm plugins --- build/linux/Makefile.in | 16 +- include/astra/AstraObjectFactory.h | 13 ++ include/astra/PluginAlgorithm.h | 85 +++++++++++ matlab/mex/astra_mex_plugin_c.cpp | 139 ++++++++++++++++++ python/astra/__init__.py | 1 + python/astra/plugin.py | 95 ++++++++++++ python/astra/plugin_c.pyx | 59 ++++++++ python/astra/utils.pyx | 72 +-------- python/docSRC/index.rst | 1 + python/docSRC/plugins.rst | 8 + samples/python/s018_plugin.py | 138 +++++++++++++++++ src/PluginAlgorithm.cpp | 294 +++++++++++++++++++++++++++++++++++++ 12 files changed, 851 insertions(+), 70 deletions(-) create mode 100644 include/astra/PluginAlgorithm.h create mode 100644 matlab/mex/astra_mex_plugin_c.cpp create mode 100644 python/astra/plugin.py create mode 100644 python/astra/plugin_c.pyx create mode 100644 python/docSRC/plugins.rst create mode 100644 samples/python/s018_plugin.py create mode 100644 src/PluginAlgorithm.cpp (limited to 'samples') diff --git a/build/linux/Makefile.in b/build/linux/Makefile.in index 2d862f2..e209fa7 100644 --- a/build/linux/Makefile.in +++ b/build/linux/Makefile.in @@ -50,11 +50,17 @@ LDFLAGS+=-fopenmp endif ifeq ($(python),yes) -PYCPPFLAGS = ${CPPFLAGS} +PYTHON = @PYTHON@ +PYLIBDIR = $(shell $(PYTHON) -c 'from distutils.sysconfig import get_config_var; import six; six.print_(get_config_var("LIBDIR"))') +PYINCDIR = $(shell $(PYTHON) -c 'from distutils.sysconfig import get_python_inc; import six; six.print_(get_python_inc())') +PYLIBVER = `basename $(PYINCDIR)` +CPPFLAGS += -DASTRA_PYTHON -I$(PYINCDIR) +PYCPPFLAGS = $(CPPFLAGS) PYCPPFLAGS += -I../include -PYLDFLAGS = ${LDFLAGS} +PYLDFLAGS = $(LDFLAGS) PYLDFLAGS += -L../build/linux/.libs -PYTHON = @PYTHON@ +LIBS += -l$(PYLIBVER) +LDFLAGS += -L$(PYLIBDIR) endif BOOST_CPPFLAGS= @@ -234,6 +240,10 @@ MATLAB_MEX=\ matlab/mex/astra_mex_log_c.$(MEXSUFFIX) \ matlab/mex/astra_mex_data3d_c.$(MEXSUFFIX) +ifeq ($(python),yes) +ALL_OBJECTS+=src/PluginAlgorithm.lo +MATLAB_MEX+=matlab/mex/astra_mex_plugin_c.$(MEXSUFFIX) +endif OBJECT_DIRS = src/ tests/ cuda/2d/ cuda/3d/ matlab/mex/ ./ DEPDIRS = $(addsuffix $(DEPDIR),$(OBJECT_DIRS)) diff --git a/include/astra/AstraObjectFactory.h b/include/astra/AstraObjectFactory.h index 356acf9..325989e 100644 --- a/include/astra/AstraObjectFactory.h +++ b/include/astra/AstraObjectFactory.h @@ -40,6 +40,10 @@ $Id$ #include "AlgorithmTypelist.h" +#ifdef ASTRA_PYTHON +#include "PluginAlgorithm.h" +#endif + namespace astra { @@ -147,6 +151,15 @@ T* CAstraObjectFactory::create(const Config& _cfg) */ class _AstraExport CAlgorithmFactory : public CAstraObjectFactory {}; +#ifdef ASTRA_PYTHON +template <> +inline CAlgorithm* CAstraObjectFactory::findPlugin(std::string _sType) + { + CPluginAlgorithmFactory *fac = CPluginAlgorithmFactory::getSingletonPtr(); + return fac->getPlugin(_sType); + } +#endif + /** * Class used to create 2D projectors from a string or a config object */ diff --git a/include/astra/PluginAlgorithm.h b/include/astra/PluginAlgorithm.h new file mode 100644 index 0000000..7d6c64a --- /dev/null +++ b/include/astra/PluginAlgorithm.h @@ -0,0 +1,85 @@ +/* +----------------------------------------------------------------------- +Copyright: 2010-2015, iMinds-Vision Lab, University of Antwerp + 2014-2015, CWI, Amsterdam + +Contact: astra@uantwerpen.be +Website: http://sf.net/projects/astra-toolbox + +This file is part of the ASTRA Toolbox. + + +The ASTRA Toolbox is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +The ASTRA Toolbox is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with the ASTRA Toolbox. If not, see . + +----------------------------------------------------------------------- +$Id$ +*/ + +#ifndef _INC_ASTRA_PLUGINALGORITHM +#define _INC_ASTRA_PLUGINALGORITHM + +#ifdef ASTRA_PYTHON + +#include +#include "bytesobject.h" +#include "astra/Algorithm.h" +#include "astra/Singleton.h" +#include "astra/XMLDocument.h" +#include "astra/XMLNode.h" + +namespace astra { +class _AstraExport CPluginAlgorithm : public CAlgorithm { + +public: + + CPluginAlgorithm(PyObject* pyclass); + ~CPluginAlgorithm(); + + bool initialize(const Config& _cfg); + void run(int _iNrIterations); + +private: + PyObject * instance; + +}; + +class _AstraExport CPluginAlgorithmFactory : public Singleton { + +public: + + CPluginAlgorithmFactory(); + ~CPluginAlgorithmFactory(); + + CPluginAlgorithm * getPlugin(std::string name); + + bool registerPlugin(std::string name, std::string className); + bool registerPluginClass(std::string name, PyObject * className); + + PyObject * getRegistered(); + + std::string getHelp(std::string name); + +private: + PyObject * pluginDict; + PyObject *ospath, *inspect, *six, *astra; + std::vector getPluginPathList(); +}; + +PyObject* XMLNode2dict(XMLNode node); + +} + +#endif + +#endif \ No newline at end of file diff --git a/matlab/mex/astra_mex_plugin_c.cpp b/matlab/mex/astra_mex_plugin_c.cpp new file mode 100644 index 0000000..2d9b9a0 --- /dev/null +++ b/matlab/mex/astra_mex_plugin_c.cpp @@ -0,0 +1,139 @@ +/* +----------------------------------------------------------------------- +Copyright: 2010-2015, iMinds-Vision Lab, University of Antwerp + 2014-2015, CWI, Amsterdam + +Contact: astra@uantwerpen.be +Website: http://sf.net/projects/astra-toolbox + +This file is part of the ASTRA Toolbox. + + +The ASTRA Toolbox is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +The ASTRA Toolbox is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with the ASTRA Toolbox. If not, see . + +----------------------------------------------------------------------- +$Id$ +*/ + +/** \file astra_mex_plugin_c.cpp + * + * \brief Manages Python plugins. + */ + +#include +#include "mexHelpFunctions.h" +#include "mexInitFunctions.h" + +#include "astra/PluginAlgorithm.h" + +#include "Python.h" +#include "bytesobject.h" + +using namespace std; +using namespace astra; + + +//----------------------------------------------------------------------------------------- +/** astra_mex_plugin('get_registered'); + * + * Print registered plugins. + */ +void astra_mex_plugin_get_registered(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) +{ + astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr(); + PyObject *dict = fact->getRegistered(); + PyObject *key, *value; + Py_ssize_t pos = 0; + while (PyDict_Next(dict, &pos, &key, &value)) { + mexPrintf("%s: %s\n",PyBytes_AsString(key),PyBytes_AsString(value)); + } + Py_DECREF(dict); +} + +//----------------------------------------------------------------------------------------- +/** astra_mex_plugin('register', name, class_name); + * + * Register plugin. + */ +void astra_mex_plugin_register(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) +{ + if (3 <= nrhs) { + string name = mexToString(prhs[1]); + string class_name = mexToString(prhs[2]); + astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr(); + fact->registerPlugin(name, class_name); + }else{ + mexPrintf("astra_mex_plugin('register', name, class_name);\n"); + } +} + +//----------------------------------------------------------------------------------------- +/** astra_mex_plugin('get_help', name); + * + * Get help about plugin. + */ +void astra_mex_plugin_get_help(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) +{ + if (2 <= nrhs) { + string name = mexToString(prhs[1]); + astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr(); + mexPrintf((fact->getHelp(name)+"\n").c_str()); + }else{ + mexPrintf("astra_mex_plugin('get_help', name);\n"); + } +} + + +//----------------------------------------------------------------------------------------- + +static void printHelp() +{ + mexPrintf("Please specify a mode of operation.\n"); + mexPrintf(" Valid modes: register, get_registered, get_help\n"); +} + +//----------------------------------------------------------------------------------------- +/** + * ... = astra_mex(type,...); + */ +void mexFunction(int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) +{ + + // INPUT0: Mode + string sMode = ""; + if (1 <= nrhs) { + sMode = mexToString(prhs[0]); + } else { + printHelp(); + return; + } + + initASTRAMex(); + + // SWITCH (MODE) + if (sMode == std::string("get_registered")) { + astra_mex_plugin_get_registered(nlhs, plhs, nrhs, prhs); + }else if (sMode == std::string("get_help")) { + astra_mex_plugin_get_help(nlhs, plhs, nrhs, prhs); + }else if (sMode == std::string("register")) { + astra_mex_plugin_register(nlhs, plhs, nrhs, prhs); + } else { + printHelp(); + } + + return; +} + + diff --git a/python/astra/__init__.py b/python/astra/__init__.py index 6c15d30..10ed74d 100644 --- a/python/astra/__init__.py +++ b/python/astra/__init__.py @@ -34,6 +34,7 @@ from . import algorithm from . import projector from . import projector3d from . import matrix +from . import plugin from . import log from .optomo import OpTomo diff --git a/python/astra/plugin.py b/python/astra/plugin.py new file mode 100644 index 0000000..ccdb2cb --- /dev/null +++ b/python/astra/plugin.py @@ -0,0 +1,95 @@ +#----------------------------------------------------------------------- +#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see . +# +#----------------------------------------------------------------------- + +from . import plugin_c as p +from . import log + +class base(object): + + def astra_init(self, cfg): + try: + try: + req = self.required_options + except AttributeError: + log.warn("Plugin '" + self.__class__.__name__ + "' does not specify required options") + req = {} + + try: + opt = self.optional_options + except AttributeError: + log.warn("Plugin '" + self.__class__.__name__ + "' does not specify optional options") + opt = {} + + try: + optDict = cfg['options'] + except KeyError: + optDict = {} + + cfgKeys = set(optDict.keys()) + reqKeys = set(req) + optKeys = set(opt) + + if not reqKeys.issubset(cfgKeys): + for key in reqKeys.difference(cfgKeys): + log.error("Required option '" + key + "' for plugin '" + self.__class__.__name__ + "' not specified") + raise ValueError("Missing required options") + + if not cfgKeys.issubset(reqKeys | optKeys): + log.warn(self.__class__.__name__ + ": unused configuration option: " + str(list(cfgKeys.difference(reqKeys | optKeys)))) + + self.initialize(cfg) + except Exception as e: + log.error(str(e)) + raise + +def register(name, className): + """Register plugin with ASTRA. + + :param name: Plugin name to register + :type name: :class:`str` + :param className: Class name or class object to register + :type className: :class:`str` or :class:`class` + + """ + p.register(name,className) + +def get_registered(): + """Get dictionary of registered plugins. + + :returns: :class:`dict` -- Registered plugins. + + """ + return p.get_registered() + +def get_help(name): + """Get help for registered plugin. + + :param name: Plugin name to get help for + :type name: :class:`str` + :returns: :class:`str` -- Help string (docstring). + + """ + return p.get_help(name) \ No newline at end of file diff --git a/python/astra/plugin_c.pyx b/python/astra/plugin_c.pyx new file mode 100644 index 0000000..91b3cd5 --- /dev/null +++ b/python/astra/plugin_c.pyx @@ -0,0 +1,59 @@ +#----------------------------------------------------------------------- +#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see . +# +#----------------------------------------------------------------------- +# distutils: language = c++ +# distutils: libraries = astra + +import six +import inspect + +from libcpp.string cimport string +from libcpp cimport bool + +cdef CPluginAlgorithmFactory *fact = getSingletonPtr() + +from . import utils + +cdef extern from "astra/PluginAlgorithm.h" namespace "astra": + cdef cppclass CPluginAlgorithmFactory: + bool registerPlugin(string name, string className) + bool registerPluginClass(string name, object className) + object getRegistered() + string getHelp(string name) + +cdef extern from "astra/PluginAlgorithm.h" namespace "astra::CPluginAlgorithmFactory": + cdef CPluginAlgorithmFactory* getSingletonPtr() + +def register(name, className): + if inspect.isclass(className): + fact.registerPluginClass(six.b(name), className) + else: + fact.registerPlugin(six.b(name), six.b(className)) + +def get_registered(): + return fact.getRegistered() + +def get_help(name): + return utils.wrap_from_bytes(fact.getHelp(six.b(name))) diff --git a/python/astra/utils.pyx b/python/astra/utils.pyx index ddb37aa..3746b8e 100644 --- a/python/astra/utils.pyx +++ b/python/astra/utils.pyx @@ -30,7 +30,6 @@ cimport numpy as np import numpy as np import six from libcpp.string cimport string -from libcpp.list cimport list from libcpp.vector cimport vector from cython.operator cimport dereference as deref, preincrement as inc from cpython.version cimport PY_MAJOR_VERSION @@ -40,6 +39,9 @@ from .PyXMLDocument cimport XMLDocument from .PyXMLDocument cimport XMLNode from .PyIncludes cimport * +cdef extern from "astra/PluginAlgorithm.h" namespace "astra": + object XMLNode2dict(XMLNode) + cdef Config * dictToConfig(string rootname, dc): cdef Config * cfg = new Config() @@ -91,6 +93,8 @@ cdef void readDict(XMLNode root, _dc): dc = convert_item(_dc) for item in dc: val = dc[item] + if isinstance(val, list): + val = np.array(val,dtype=np.float64) if isinstance(val, np.ndarray): if val.size == 0: break @@ -142,69 +146,3 @@ cdef void readOptions(XMLNode node, dc): cdef configToDict(Config *cfg): return XMLNode2dict(cfg.self) -def castString3(input): - return input.decode('utf-8') - -def castString2(input): - return input - -if six.PY3: - castString = castString3 -else: - castString = castString2 - -def stringToPythonValue(inputIn): - input = castString(inputIn) - # matrix - if ';' in input: - row_strings = input.split(';') - col_strings = row_strings[0].split(',') - nRows = len(row_strings) - nCols = len(col_strings) - - out = np.empty((nRows,nCols)) - for ridx, row in enumerate(row_strings): - col_strings = row.split(',') - for cidx, col in enumerate(col_strings): - out[ridx,cidx] = float(col) - return out - - # vector - if ',' in input: - items = input.split(',') - out = np.empty(len(items)) - for idx,item in enumerate(items): - out[idx] = float(item) - return out - - try: - # integer - return int(input) - except ValueError: - try: - #float - return float(input) - except ValueError: - # string - return str(input) - - -cdef XMLNode2dict(XMLNode node): - cdef XMLNode subnode - cdef list[XMLNode] nodes - cdef list[XMLNode].iterator it - dct = {} - opts = {} - if node.hasAttribute(six.b('type')): - dct['type'] = castString(node.getAttribute(six.b('type'))) - nodes = node.getNodes() - it = nodes.begin() - while it != nodes.end(): - subnode = deref(it) - if castString(subnode.getName())=="Option": - opts[castString(subnode.getAttribute('key'))] = stringToPythonValue(subnode.getAttribute('value')) - else: - dct[castString(subnode.getName())] = stringToPythonValue(subnode.getContent()) - inc(it) - if len(opts)>0: dct['options'] = opts - return dct diff --git a/python/docSRC/index.rst b/python/docSRC/index.rst index b7cc6d6..dcc6590 100644 --- a/python/docSRC/index.rst +++ b/python/docSRC/index.rst @@ -19,6 +19,7 @@ Contents: creators functions operator + plugins matlab astra .. astra diff --git a/python/docSRC/plugins.rst b/python/docSRC/plugins.rst new file mode 100644 index 0000000..dc7c607 --- /dev/null +++ b/python/docSRC/plugins.rst @@ -0,0 +1,8 @@ +Plugins: the :mod:`plugin` module +========================================= + +.. automodule:: astra.plugin + :members: + :undoc-members: + :show-inheritance: + diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py new file mode 100644 index 0000000..6677930 --- /dev/null +++ b/samples/python/s018_plugin.py @@ -0,0 +1,138 @@ +#----------------------------------------------------------------------- +#Copyright 2015 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see . +# +#----------------------------------------------------------------------- + +import astra +import numpy as np +import six + +# Define the plugin class (has to subclass astra.plugin.base) +# Note that usually, these will be defined in a separate package/module +class SIRTPlugin(astra.plugin.base): + """Example of an ASTRA plugin class, implementing a simple 2D SIRT algorithm. + + Optional options: + + 'rel_factor': relaxation factor + """ + required_options=[] + optional_options=['rel_factor'] + + def initialize(self,cfg): + self.W = astra.OpTomo(cfg['ProjectorId']) + self.vid = cfg['ReconstructionDataId'] + self.sid = cfg['ProjectionDataId'] + try: + self.rel = cfg['option']['rel_factor'] + except KeyError: + self.rel = 1 + + def run(self, its): + v = astra.data2d.get_shared(self.vid) + s = astra.data2d.get_shared(self.sid) + W = self.W + for i in range(its): + v[:] += self.rel*(W.T*(s - (W*v).reshape(s.shape))).reshape(v.shape)/s.size + +if __name__=='__main__': + + vol_geom = astra.create_vol_geom(256, 256) + proj_geom = astra.create_proj_geom('parallel', 1.0, 384, np.linspace(0,np.pi,180,False)) + + # As before, create a sinogram from a phantom + import scipy.io + P = scipy.io.loadmat('phantom.mat')['phantom256'] + proj_id = astra.create_projector('cuda',proj_geom,vol_geom) + + # construct the OpTomo object + W = astra.OpTomo(proj_id) + + sinogram = W * P + sinogram = sinogram.reshape([180, 384]) + + # Register the plugin with ASTRA + # A default set of plugins to load can be defined in: + # - /etc/astra-toolbox/plugins.txt + # - [ASTRA_INSTALL_PATH]/python/astra/plugins.txt + # - [USER_HOME_PATH]/.astra-toolbox/plugins.txt + # - [ASTRA_PLUGIN_PATH environment variable]/plugins.txt + # In these files, create a separate line for each plugin with: + # [PLUGIN_ASTRA_NAME] [FULL_PLUGIN_CLASS] + # + # So in this case, it would be a line: + # SIRT-PLUGIN s018_plugin.SIRTPlugin + # + astra.plugin.register('SIRT-PLUGIN','s018_plugin.SIRTPlugin') + + # To get help on a registered plugin, use get_help + six.print_(astra.plugin.get_help('SIRT-PLUGIN')) + + # Create data structures + sid = astra.data2d.create('-sino', proj_geom, sinogram) + vid = astra.data2d.create('-vol', vol_geom) + + # Create config using plugin name + cfg = astra.astra_dict('SIRT-PLUGIN') + cfg['ProjectorId'] = proj_id + cfg['ProjectionDataId'] = sid + cfg['ReconstructionDataId'] = vid + + # Create algorithm object + alg_id = astra.algorithm.create(cfg) + + # Run algorithm for 100 iterations + astra.algorithm.run(alg_id, 100) + + # Get reconstruction + rec = astra.data2d.get(vid) + + # Options for the plugin go in cfg['option'] + cfg = astra.astra_dict('SIRT-PLUGIN') + cfg['ProjectorId'] = proj_id + cfg['ProjectionDataId'] = sid + cfg['ReconstructionDataId'] = vid + cfg['option'] = {} + cfg['option']['rel_factor'] = 1.5 + alg_id_rel = astra.algorithm.create(cfg) + astra.algorithm.run(alg_id_rel, 100) + rec_rel = astra.data2d.get(vid) + + # We can also use OpTomo to call the plugin + rec_op = W.reconstruct('SIRT-PLUGIN', sinogram, 100, extraOptions={'rel_factor':1.5}) + + import pylab as pl + pl.gray() + pl.figure(1) + pl.imshow(rec,vmin=0,vmax=1) + pl.figure(2) + pl.imshow(rec_rel,vmin=0,vmax=1) + pl.figure(3) + pl.imshow(rec_op,vmin=0,vmax=1) + pl.show() + + # Clean up. + astra.projector.delete(proj_id) + astra.algorithm.delete([alg_id, alg_id_rel]) + astra.data2d.delete([vid, sid]) diff --git a/src/PluginAlgorithm.cpp b/src/PluginAlgorithm.cpp new file mode 100644 index 0000000..df13f31 --- /dev/null +++ b/src/PluginAlgorithm.cpp @@ -0,0 +1,294 @@ +/* +----------------------------------------------------------------------- +Copyright: 2010-2015, iMinds-Vision Lab, University of Antwerp + 2014-2015, CWI, Amsterdam + +Contact: astra@uantwerpen.be +Website: http://sf.net/projects/astra-toolbox + +This file is part of the ASTRA Toolbox. + + +The ASTRA Toolbox is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +The ASTRA Toolbox is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with the ASTRA Toolbox. If not, see . + +----------------------------------------------------------------------- +$Id$ +*/ + +#ifdef ASTRA_PYTHON + +#include "astra/PluginAlgorithm.h" +#include +#include +#include +#include +#include +#include + +namespace astra { + +CPluginAlgorithm::CPluginAlgorithm(PyObject* pyclass){ + instance = PyObject_CallObject(pyclass, NULL); +} + +CPluginAlgorithm::~CPluginAlgorithm(){ + if(instance!=NULL){ + Py_DECREF(instance); + instance = NULL; + } +} + +bool CPluginAlgorithm::initialize(const Config& _cfg){ + if(instance==NULL) return false; + PyObject *cfgDict = XMLNode2dict(_cfg.self); + PyObject *retVal = PyObject_CallMethod(instance, "astra_init", "O",cfgDict); + Py_DECREF(cfgDict); + if(retVal==NULL) return false; + m_bIsInitialized = true; + Py_DECREF(retVal); + return m_bIsInitialized; +} + +void CPluginAlgorithm::run(int _iNrIterations){ + if(instance==NULL) return; + PyObject *retVal = PyObject_CallMethod(instance, "run", "i",_iNrIterations); + if(retVal==NULL) return; + Py_DECREF(retVal); +} + +const char ps = +#ifdef _WIN32 + '\\'; +#else + '/'; +#endif + +std::vector CPluginAlgorithmFactory::getPluginPathList(){ + std::vector list; + list.push_back("/etc/astra-toolbox"); + PyObject *ret, *retb; + ret = PyObject_CallMethod(inspect,"getfile","O",astra); + if(ret!=NULL){ + retb = PyObject_CallMethod(six,"b","O",ret); + Py_DECREF(ret); + if(retb!=NULL){ + std::string astra_inst (PyBytes_AsString(retb)); + Py_DECREF(retb); + ret = PyObject_CallMethod(ospath,"dirname","s",astra_inst.c_str()); + if(ret!=NULL){ + retb = PyObject_CallMethod(six,"b","O",ret); + Py_DECREF(ret); + if(retb!=NULL){ + list.push_back(std::string(PyBytes_AsString(retb))); + Py_DECREF(retb); + } + } + } + } + ret = PyObject_CallMethod(ospath,"expanduser","s","~"); + if(ret!=NULL){ + retb = PyObject_CallMethod(six,"b","O",ret); + Py_DECREF(ret); + if(retb!=NULL){ + list.push_back(std::string(PyBytes_AsString(retb)) + ps + ".astra-toolbox"); + Py_DECREF(retb); + } + } + const char *envval = getenv("ASTRA_PLUGIN_PATH"); + if(envval!=NULL){ + list.push_back(std::string(envval)); + } + return list; +} + +CPluginAlgorithmFactory::CPluginAlgorithmFactory(){ + Py_Initialize(); + pluginDict = PyDict_New(); + ospath = PyImport_ImportModule("os.path"); + inspect = PyImport_ImportModule("inspect"); + six = PyImport_ImportModule("six"); + astra = PyImport_ImportModule("astra"); + std::vector fls = getPluginPathList(); + std::vector items; + for(unsigned int i=0;i items; + boost::split(items, str, boost::is_any_of(".")); + PyObject *pyclass = PyImport_ImportModule(items[0].c_str()); + if(pyclass==NULL) return NULL; + PyObject *submod = pyclass; + for(unsigned int i=1;i= 3 +PyObject * pyStringFromString(std::string str){ + return PyUnicode_FromString(str.c_str()); +} +#else +PyObject * pyStringFromString(std::string str){ + return PyBytes_FromString(str.c_str()); +} +#endif + +PyObject* stringToPythonValue(std::string str){ + if(str.find(";")!=std::string::npos){ + std::vector rows, row; + boost::split(rows, str, boost::is_any_of(";")); + PyObject *mat = PyList_New(rows.size()); + for(unsigned int i=0; i(row[j]))); + } + PyList_SetItem(mat, i, rowlist); + } + return mat; + } + if(str.find(",")!=std::string::npos){ + std::vector vec; + boost::split(vec, str, boost::is_any_of(",")); + PyObject *veclist = PyList_New(vec.size()); + for(unsigned int i=0;i(vec[i]))); + } + return veclist; + } + try{ + return PyLong_FromLong(boost::lexical_cast(str)); + }catch(const boost::bad_lexical_cast &){ + try{ + return PyFloat_FromDouble(boost::lexical_cast(str)); + }catch(const boost::bad_lexical_cast &){ + return pyStringFromString(str); + } + } +} + +PyObject* XMLNode2dict(XMLNode node){ + PyObject *dct = PyDict_New(); + PyObject *opts = PyDict_New(); + if(node.hasAttribute("type")){ + PyObject *obj = pyStringFromString(node.getAttribute("type").c_str()); + PyDict_SetItemString(dct, "type", obj); + Py_DECREF(obj); + } + std::list nodes = node.getNodes(); + std::list::iterator it = nodes.begin(); + while(it!=nodes.end()){ + XMLNode subnode = *it; + if(subnode.getName()=="Option"){ + PyObject *obj = stringToPythonValue(subnode.getAttribute("value")); + PyDict_SetItemString(opts, subnode.getAttribute("key").c_str(), obj); + Py_DECREF(obj); + }else{ + PyObject *obj = stringToPythonValue(subnode.getContent()); + PyDict_SetItemString(dct, subnode.getName().c_str(), obj); + Py_DECREF(obj); + } + ++it; + } + PyDict_SetItemString(dct, "options", opts); + Py_DECREF(opts); + return dct; +} + +} +#endif \ No newline at end of file -- cgit v1.2.3 From 11af4b554df9a8a5c31d9dcbc1ea849b32394ba3 Mon Sep 17 00:00:00 2001 From: "Daniel M. Pelt" Date: Wed, 24 Jun 2015 18:36:03 +0200 Subject: Better way of passing options to Python plugin using inspect --- python/astra/plugin.py | 24 ++++++++++++------------ samples/python/s018_plugin.py | 13 ++++--------- 2 files changed, 16 insertions(+), 21 deletions(-) (limited to 'samples') diff --git a/python/astra/plugin.py b/python/astra/plugin.py index ccdb2cb..891f6c9 100644 --- a/python/astra/plugin.py +++ b/python/astra/plugin.py @@ -26,22 +26,20 @@ from . import plugin_c as p from . import log +import inspect class base(object): def astra_init(self, cfg): try: - try: - req = self.required_options - except AttributeError: - log.warn("Plugin '" + self.__class__.__name__ + "' does not specify required options") - req = {} - - try: - opt = self.optional_options - except AttributeError: - log.warn("Plugin '" + self.__class__.__name__ + "' does not specify optional options") - opt = {} + args, varargs, varkw, defaults = inspect.getargspec(self.initialize) + nopt = len(defaults) + if nopt>0: + req = args[2:-nopt] + opt = args[-nopt:] + else: + req = args[2:] + opt = [] try: optDict = cfg['options'] @@ -60,7 +58,9 @@ class base(object): if not cfgKeys.issubset(reqKeys | optKeys): log.warn(self.__class__.__name__ + ": unused configuration option: " + str(list(cfgKeys.difference(reqKeys | optKeys)))) - self.initialize(cfg) + args = [optDict[k] for k in req] + kwargs = dict((k,optDict[k]) for k in opt if k in optDict) + self.initialize(cfg, *args, **kwargs) except Exception as e: log.error(str(e)) raise diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py index 6677930..90e09ac 100644 --- a/samples/python/s018_plugin.py +++ b/samples/python/s018_plugin.py @@ -33,21 +33,16 @@ import six class SIRTPlugin(astra.plugin.base): """Example of an ASTRA plugin class, implementing a simple 2D SIRT algorithm. - Optional options: + Options: - 'rel_factor': relaxation factor + 'rel_factor': relaxation factor (optional) """ - required_options=[] - optional_options=['rel_factor'] - def initialize(self,cfg): + def initialize(self,cfg, rel_factor = 1): self.W = astra.OpTomo(cfg['ProjectorId']) self.vid = cfg['ReconstructionDataId'] self.sid = cfg['ProjectionDataId'] - try: - self.rel = cfg['option']['rel_factor'] - except KeyError: - self.rel = 1 + self.rel = rel_factor def run(self, its): v = astra.data2d.get_shared(self.vid) -- cgit v1.2.3 From d91b51f6d58003de84a9d6dd8189fceba0e81a5a Mon Sep 17 00:00:00 2001 From: "Daniel M. Pelt" Date: Mon, 20 Jul 2015 14:07:21 +0200 Subject: Allow registering plugins without explicit name, and fix exception handling when running in Matlab --- include/astra/PluginAlgorithm.h | 3 ++ matlab/mex/astra_mex_plugin_c.cpp | 23 ++++------ python/astra/plugin.py | 71 ++++++++++++----------------- python/astra/plugin_c.pyx | 14 ++++-- samples/python/s018_plugin.py | 23 +++++----- src/PluginAlgorithm.cpp | 95 +++++++++++++++++++++++++++++++-------- 6 files changed, 138 insertions(+), 91 deletions(-) (limited to 'samples') diff --git a/include/astra/PluginAlgorithm.h b/include/astra/PluginAlgorithm.h index a82c579..b56228e 100644 --- a/include/astra/PluginAlgorithm.h +++ b/include/astra/PluginAlgorithm.h @@ -64,9 +64,12 @@ public: CPluginAlgorithm * getPlugin(std::string name); bool registerPlugin(std::string name, std::string className); + bool registerPlugin(std::string className); bool registerPluginClass(std::string name, PyObject * className); + bool registerPluginClass(PyObject * className); PyObject * getRegistered(); + std::map getRegisteredMap(); std::string getHelp(std::string name); diff --git a/matlab/mex/astra_mex_plugin_c.cpp b/matlab/mex/astra_mex_plugin_c.cpp index 2d9b9a0..177fcf4 100644 --- a/matlab/mex/astra_mex_plugin_c.cpp +++ b/matlab/mex/astra_mex_plugin_c.cpp @@ -37,9 +37,6 @@ $Id$ #include "astra/PluginAlgorithm.h" -#include "Python.h" -#include "bytesobject.h" - using namespace std; using namespace astra; @@ -52,29 +49,25 @@ using namespace astra; void astra_mex_plugin_get_registered(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr(); - PyObject *dict = fact->getRegistered(); - PyObject *key, *value; - Py_ssize_t pos = 0; - while (PyDict_Next(dict, &pos, &key, &value)) { - mexPrintf("%s: %s\n",PyBytes_AsString(key),PyBytes_AsString(value)); + std::map mp = fact->getRegisteredMap(); + for(std::map::iterator it=mp.begin();it!=mp.end();it++){ + mexPrintf("%s: %s\n",it->first.c_str(), it->second.c_str()); } - Py_DECREF(dict); } //----------------------------------------------------------------------------------------- -/** astra_mex_plugin('register', name, class_name); +/** astra_mex_plugin('register', class_name); * * Register plugin. */ void astra_mex_plugin_register(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { - if (3 <= nrhs) { - string name = mexToString(prhs[1]); - string class_name = mexToString(prhs[2]); + if (2 <= nrhs) { + string class_name = mexToString(prhs[1]); astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr(); - fact->registerPlugin(name, class_name); + fact->registerPlugin(class_name); }else{ - mexPrintf("astra_mex_plugin('register', name, class_name);\n"); + mexPrintf("astra_mex_plugin('register', class_name);\n"); } } diff --git a/python/astra/plugin.py b/python/astra/plugin.py index f8fc3bd..4b32e6e 100644 --- a/python/astra/plugin.py +++ b/python/astra/plugin.py @@ -32,60 +32,47 @@ import traceback class base(object): def astra_init(self, cfg): - try: - args, varargs, varkw, defaults = inspect.getargspec(self.initialize) - if not defaults is None: - nopt = len(defaults) - else: - nopt = 0 - if nopt>0: - req = args[2:-nopt] - opt = args[-nopt:] - else: - req = args[2:] - opt = [] + args, varargs, varkw, defaults = inspect.getargspec(self.initialize) + if not defaults is None: + nopt = len(defaults) + else: + nopt = 0 + if nopt>0: + req = args[2:-nopt] + opt = args[-nopt:] + else: + req = args[2:] + opt = [] - try: - optDict = cfg['options'] - except KeyError: - optDict = {} + try: + optDict = cfg['options'] + except KeyError: + optDict = {} - cfgKeys = set(optDict.keys()) - reqKeys = set(req) - optKeys = set(opt) + cfgKeys = set(optDict.keys()) + reqKeys = set(req) + optKeys = set(opt) - if not reqKeys.issubset(cfgKeys): - for key in reqKeys.difference(cfgKeys): - log.error("Required option '" + key + "' for plugin '" + self.__class__.__name__ + "' not specified") - raise ValueError("Missing required options") + if not reqKeys.issubset(cfgKeys): + for key in reqKeys.difference(cfgKeys): + log.error("Required option '" + key + "' for plugin '" + self.__class__.__name__ + "' not specified") + raise ValueError("Missing required options") - if not cfgKeys.issubset(reqKeys | optKeys): - log.warn(self.__class__.__name__ + ": unused configuration option: " + str(list(cfgKeys.difference(reqKeys | optKeys)))) + if not cfgKeys.issubset(reqKeys | optKeys): + log.warn(self.__class__.__name__ + ": unused configuration option: " + str(list(cfgKeys.difference(reqKeys | optKeys)))) - args = [optDict[k] for k in req] - kwargs = dict((k,optDict[k]) for k in opt if k in optDict) - self.initialize(cfg, *args, **kwargs) - except Exception: - log.error(traceback.format_exc().replace("%","%%")) - raise + args = [optDict[k] for k in req] + kwargs = dict((k,optDict[k]) for k in opt if k in optDict) + self.initialize(cfg, *args, **kwargs) - def astra_run(self, its): - try: - self.run(its) - except Exception: - log.error(traceback.format_exc().replace("%","%%")) - raise - -def register(name, className): +def register(className): """Register plugin with ASTRA. - :param name: Plugin name to register - :type name: :class:`str` :param className: Class name or class object to register :type className: :class:`str` or :class:`class` """ - p.register(name,className) + p.register(className) def get_registered(): """Get dictionary of registered plugins. diff --git a/python/astra/plugin_c.pyx b/python/astra/plugin_c.pyx index 91b3cd5..8d6816b 100644 --- a/python/astra/plugin_c.pyx +++ b/python/astra/plugin_c.pyx @@ -38,7 +38,9 @@ from . import utils cdef extern from "astra/PluginAlgorithm.h" namespace "astra": cdef cppclass CPluginAlgorithmFactory: + bool registerPlugin(string className) bool registerPlugin(string name, string className) + bool registerPluginClass(object className) bool registerPluginClass(string name, object className) object getRegistered() string getHelp(string name) @@ -46,11 +48,17 @@ cdef extern from "astra/PluginAlgorithm.h" namespace "astra": cdef extern from "astra/PluginAlgorithm.h" namespace "astra::CPluginAlgorithmFactory": cdef CPluginAlgorithmFactory* getSingletonPtr() -def register(name, className): +def register(className, name=None): if inspect.isclass(className): - fact.registerPluginClass(six.b(name), className) + if name==None: + fact.registerPluginClass(className) + else: + fact.registerPluginClass(six.b(name), className) else: - fact.registerPlugin(six.b(name), six.b(className)) + if name==None: + fact.registerPlugin(six.b(className)) + else: + fact.registerPlugin(six.b(name), six.b(className)) def get_registered(): return fact.getRegistered() diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py index 90e09ac..31cca95 100644 --- a/samples/python/s018_plugin.py +++ b/samples/python/s018_plugin.py @@ -38,6 +38,10 @@ class SIRTPlugin(astra.plugin.base): 'rel_factor': relaxation factor (optional) """ + # The astra_name variable defines the name to use to + # call the plugin from ASTRA + astra_name = "SIRT-PLUGIN" + def initialize(self,cfg, rel_factor = 1): self.W = astra.OpTomo(cfg['ProjectorId']) self.vid = cfg['ReconstructionDataId'] @@ -68,18 +72,13 @@ if __name__=='__main__': sinogram = sinogram.reshape([180, 384]) # Register the plugin with ASTRA - # A default set of plugins to load can be defined in: - # - /etc/astra-toolbox/plugins.txt - # - [ASTRA_INSTALL_PATH]/python/astra/plugins.txt - # - [USER_HOME_PATH]/.astra-toolbox/plugins.txt - # - [ASTRA_PLUGIN_PATH environment variable]/plugins.txt - # In these files, create a separate line for each plugin with: - # [PLUGIN_ASTRA_NAME] [FULL_PLUGIN_CLASS] - # - # So in this case, it would be a line: - # SIRT-PLUGIN s018_plugin.SIRTPlugin - # - astra.plugin.register('SIRT-PLUGIN','s018_plugin.SIRTPlugin') + # First we import the package that contains the plugin + import s018_plugin + # Then, we register the plugin class with ASTRA + astra.plugin.register(s018_plugin.SIRTPlugin) + + # Get a list of registered plugins + six.print_(astra.plugin.get_registered()) # To get help on a registered plugin, use get_help six.print_(astra.plugin.get_help('SIRT-PLUGIN')) diff --git a/src/PluginAlgorithm.cpp b/src/PluginAlgorithm.cpp index d6cf731..7f7ff61 100644 --- a/src/PluginAlgorithm.cpp +++ b/src/PluginAlgorithm.cpp @@ -100,7 +100,10 @@ bool CPluginAlgorithm::initialize(const Config& _cfg){ PyObject *cfgDict = XMLNode2dict(_cfg.self); PyObject *retVal = PyObject_CallMethod(instance, "astra_init", "O",cfgDict); Py_DECREF(cfgDict); - if(retVal==NULL) return false; + if(retVal==NULL){ + logPythonError(); + return false; + } m_bIsInitialized = true; Py_DECREF(retVal); return m_bIsInitialized; @@ -108,8 +111,11 @@ bool CPluginAlgorithm::initialize(const Config& _cfg){ void CPluginAlgorithm::run(int _iNrIterations){ if(instance==NULL) return; - PyObject *retVal = PyObject_CallMethod(instance, "astra_run", "i",_iNrIterations); - if(retVal==NULL) return; + PyObject *retVal = PyObject_CallMethod(instance, "run", "i",_iNrIterations); + if(retVal==NULL){ + logPythonError(); + return; + } Py_DECREF(retVal); } @@ -157,18 +163,6 @@ CPluginAlgorithmFactory::~CPluginAlgorithmFactory(){ if(six!=NULL) Py_DECREF(six); } -bool CPluginAlgorithmFactory::registerPlugin(std::string name, std::string className){ - PyObject *str = PyBytes_FromString(className.c_str()); - PyDict_SetItemString(pluginDict, name.c_str(), str); - Py_DECREF(str); - return true; -} - -bool CPluginAlgorithmFactory::registerPluginClass(std::string name, PyObject * className){ - PyDict_SetItemString(pluginDict, name.c_str(), className); - return true; -} - PyObject * getClassFromString(std::string str){ std::vector items; boost::split(items, str, boost::is_any_of(".")); @@ -190,6 +184,43 @@ PyObject * getClassFromString(std::string str){ return pyclass; } +bool CPluginAlgorithmFactory::registerPlugin(std::string name, std::string className){ + PyObject *str = PyBytes_FromString(className.c_str()); + PyDict_SetItemString(pluginDict, name.c_str(), str); + Py_DECREF(str); + return true; +} + +bool CPluginAlgorithmFactory::registerPlugin(std::string className){ + PyObject *pyclass = getClassFromString(className); + if(pyclass==NULL) return false; + bool ret = registerPluginClass(pyclass); + Py_DECREF(pyclass); + return ret; +} + +bool CPluginAlgorithmFactory::registerPluginClass(std::string name, PyObject * className){ + PyDict_SetItemString(pluginDict, name.c_str(), className); + return true; +} + +bool CPluginAlgorithmFactory::registerPluginClass(PyObject * className){ + PyObject *astra_name = PyObject_GetAttrString(className,"astra_name"); + if(astra_name==NULL){ + logPythonError(); + return false; + } + PyObject *retb = PyObject_CallMethod(six,"b","O",astra_name); + if(retb!=NULL){ + PyDict_SetItemString(pluginDict,PyBytes_AsString(retb),className); + Py_DECREF(retb); + }else{ + logPythonError(); + } + Py_DECREF(astra_name); + return true; +} + CPluginAlgorithm * CPluginAlgorithmFactory::getPlugin(std::string name){ PyObject *className = PyDict_GetItemString(pluginDict, name.c_str()); if(className==NULL) return NULL; @@ -212,12 +243,34 @@ PyObject * CPluginAlgorithmFactory::getRegistered(){ return pluginDict; } +std::map CPluginAlgorithmFactory::getRegisteredMap(){ + std::map ret; + PyObject *key, *value; + Py_ssize_t pos = 0; + while (PyDict_Next(pluginDict, &pos, &key, &value)) { + PyObject * keyb = PyObject_Bytes(key); + PyObject * valb = PyObject_Bytes(value); + ret[PyBytes_AsString(keyb)] = PyBytes_AsString(valb); + Py_DECREF(keyb); + Py_DECREF(valb); + } + return ret; +} + std::string CPluginAlgorithmFactory::getHelp(std::string name){ PyObject *className = PyDict_GetItemString(pluginDict, name.c_str()); - if(className==NULL) return ""; - std::string str = std::string(PyBytes_AsString(className)); + if(className==NULL){ + ASTRA_ERROR("Plugin %s not found!",name.c_str()); + return ""; + } std::string ret = ""; - PyObject *pyclass = getClassFromString(str); + PyObject *pyclass; + if(PyBytes_Check(className)){ + std::string str = std::string(PyBytes_AsString(className)); + pyclass = getClassFromString(str); + }else{ + pyclass = className; + } if(pyclass==NULL) return ""; if(inspect!=NULL && six!=NULL){ PyObject *retVal = PyObject_CallMethod(inspect,"getdoc","O",pyclass); @@ -228,9 +281,13 @@ std::string CPluginAlgorithmFactory::getHelp(std::string name){ ret = std::string(PyBytes_AsString(retb)); Py_DECREF(retb); } + }else{ + logPythonError(); } } - Py_DECREF(pyclass); + if(PyBytes_Check(className)){ + Py_DECREF(pyclass); + } return ret; } -- cgit v1.2.3 From e07449189a05e3bcdc8ad4a9fbb95c0751f567bb Mon Sep 17 00:00:00 2001 From: Willem Jan Palenstijn Date: Fri, 4 Dec 2015 15:15:16 +0100 Subject: Add sample for experimental composite geometry code --- python/astra/PyIncludes.pxd | 2 + python/astra/experimental.pyx | 84 ++++++++++++++++++++++++++++ samples/python/s018_experimental_multires.py | 84 ++++++++++++++++++++++++++++ 3 files changed, 170 insertions(+) create mode 100644 python/astra/experimental.pyx create mode 100644 samples/python/s018_experimental_multires.py (limited to 'samples') diff --git a/python/astra/PyIncludes.pxd b/python/astra/PyIncludes.pxd index 35dea5f..e9e2bdb 100644 --- a/python/astra/PyIncludes.pxd +++ b/python/astra/PyIncludes.pxd @@ -224,6 +224,7 @@ cdef extern from "astra/Float32VolumeData3DMemory.h" namespace "astra": int getRowCount() int getColCount() int getSliceCount() + bool isInitialized() @@ -255,6 +256,7 @@ cdef extern from "astra/Float32ProjectionData3DMemory.h" namespace "astra": int getDetectorColCount() int getDetectorRowCount() int getAngleCount() + bool isInitialized() cdef extern from "astra/Float32Data3D.h" namespace "astra": cdef cppclass CFloat32Data3D: diff --git a/python/astra/experimental.pyx b/python/astra/experimental.pyx new file mode 100644 index 0000000..da27504 --- /dev/null +++ b/python/astra/experimental.pyx @@ -0,0 +1,84 @@ +#----------------------------------------------------------------------- +# Copyright: 2010-2015, iMinds-Vision Lab, University of Antwerp +# 2014-2015, CWI, Amsterdam +# +# Contact: astra@uantwerpen.be +# Website: http://sf.net/projects/astra-toolbox +# +# This file is part of the ASTRA Toolbox. +# +# +# The ASTRA Toolbox is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# The ASTRA Toolbox is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with the ASTRA Toolbox. If not, see . +# +#----------------------------------------------------------------------- + +# distutils: language = c++ +# distutils: libraries = astra + +include "config.pxi" + +import six +from .PyIncludes cimport * +from libcpp.vector cimport vector + +cdef extern from "astra/CompositeGeometryManager.h" namespace "astra": + cdef cppclass CCompositeGeometryManager: + bool doFP(CProjector3D *, vector[CFloat32VolumeData3DMemory *], vector[CFloat32ProjectionData3DMemory *]) + bool doBP(CProjector3D *, vector[CFloat32VolumeData3DMemory *], vector[CFloat32ProjectionData3DMemory *]) + +cdef extern from *: + CFloat32VolumeData3DMemory * dynamic_cast_vol_mem "dynamic_cast" (CFloat32Data3D * ) except NULL + CFloat32ProjectionData3DMemory * dynamic_cast_proj_mem "dynamic_cast" (CFloat32Data3D * ) except NULL + +cimport PyProjector3DManager +from .PyProjector3DManager cimport CProjector3DManager +cimport PyData3DManager +from .PyData3DManager cimport CData3DManager + +cdef CProjector3DManager * manProj = PyProjector3DManager.getSingletonPtr() +cdef CData3DManager * man3d = PyData3DManager.getSingletonPtr() + +def do_composite(projector_id, vol_ids, proj_ids, t): + cdef vector[CFloat32VolumeData3DMemory *] vol + cdef CFloat32VolumeData3DMemory * pVolObject + cdef CFloat32ProjectionData3DMemory * pProjObject + for v in vol_ids: + pVolObject = dynamic_cast_vol_mem(man3d.get(v)) + if pVolObject == NULL: + raise Exception("Data object not found") + if not pVolObject.isInitialized(): + raise Exception("Data object not initialized properly") + vol.push_back(pVolObject) + cdef vector[CFloat32ProjectionData3DMemory *] proj + for v in proj_ids: + pProjObject = dynamic_cast_proj_mem(man3d.get(v)) + if pProjObject == NULL: + raise Exception("Data object not found") + if not pProjObject.isInitialized(): + raise Exception("Data object not initialized properly") + proj.push_back(pProjObject) + cdef CCompositeGeometryManager m + cdef CProjector3D * projector = manProj.get(projector_id) # may be NULL + if t == "FP": + if not m.doFP(projector, vol, proj): + raise Exception("Failed to perform FP") + else: + if not m.doBP(projector, vol, proj): + raise Exception("Failed to perform BP") + +def do_composite_FP(projector_id, vol_ids, proj_ids): + do_composite(projector_id, vol_ids, proj_ids, "FP") + +def do_composite_BP(projector_id, vol_ids, proj_ids): + do_composite(projector_id, vol_ids, proj_ids, "BP") diff --git a/samples/python/s018_experimental_multires.py b/samples/python/s018_experimental_multires.py new file mode 100644 index 0000000..cf38e53 --- /dev/null +++ b/samples/python/s018_experimental_multires.py @@ -0,0 +1,84 @@ +#----------------------------------------------------------------------- +#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see . +# +#----------------------------------------------------------------------- + +import astra +import numpy as np +from astra.experimental import do_composite_FP + +astra.log.setOutputScreen(astra.log.STDERR, astra.log.DEBUG) + +# low res part (voxels of 4x4x4) +vol_geom1 = astra.create_vol_geom(32, 16, 32, -64, 0, -64, 64, -64, 64) + +# high res part (voxels of 1x1x1) +vol_geom2 = astra.create_vol_geom(128, 64, 128, 0, 64, -64, 64, -64, 64) + + +# Split the output in two parts as well, for demonstration purposes +angles1 = np.linspace(0, np.pi/2, 90, False) +angles2 = np.linspace(np.pi/2, np.pi, 90, False) +proj_geom1 = astra.create_proj_geom('parallel3d', 1.0, 1.0, 128, 192, angles1) +proj_geom2 = astra.create_proj_geom('parallel3d', 1.0, 1.0, 128, 192, angles2) + +# Create a simple hollow cube phantom +cube1 = np.zeros((32,32,16)) +cube1[4:28,4:28,4:16] = 1 + +cube2 = np.zeros((128,128,64)) +cube2[16:112,16:112,0:112] = 1 +cube2[33:97,33:97,4:28] = 0 + +vol1 = astra.data3d.create('-vol', vol_geom1, cube1) +vol2 = astra.data3d.create('-vol', vol_geom2, cube2) + +proj1 = astra.data3d.create('-proj3d', proj_geom1, 0) +proj2 = astra.data3d.create('-proj3d', proj_geom2, 0) + +# The actual geometries don't matter for this composite FP/BP case +projector = astra.create_projector('cuda3d', proj_geom1, vol_geom1) + +do_composite_FP(projector, [vol1, vol2], [proj1, proj2]) + +proj_data1 = astra.data3d.get(proj1) +proj_data2 = astra.data3d.get(proj2) + +# Display a single projection image +import pylab +pylab.gray() +pylab.figure(1) +pylab.imshow(proj_data1[:,0,:]) +pylab.figure(2) +pylab.imshow(proj_data2[:,0,:]) +pylab.show() + + +# Clean up. Note that GPU memory is tied up in the algorithm object, +# and main RAM in the data objects. +astra.data3d.delete(vol1) +astra.data3d.delete(vol2) +astra.data3d.delete(proj1) +astra.data3d.delete(proj2) +astra.projector3d.delete(projector) -- cgit v1.2.3 From f2227eaca7248b6b01a5776f0cb750cff4c0279a Mon Sep 17 00:00:00 2001 From: Willem Jan Palenstijn Date: Thu, 7 Jan 2016 13:29:37 +0100 Subject: Rename sample with conflicting filename --- samples/python/s018_experimental_multires.py | 84 ---------------------------- samples/python/s019_experimental_multires.py | 84 ++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 84 deletions(-) delete mode 100644 samples/python/s018_experimental_multires.py create mode 100644 samples/python/s019_experimental_multires.py (limited to 'samples') diff --git a/samples/python/s018_experimental_multires.py b/samples/python/s018_experimental_multires.py deleted file mode 100644 index cf38e53..0000000 --- a/samples/python/s018_experimental_multires.py +++ /dev/null @@ -1,84 +0,0 @@ -#----------------------------------------------------------------------- -#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam -# -#Author: Daniel M. Pelt -#Contact: D.M.Pelt@cwi.nl -#Website: http://dmpelt.github.io/pyastratoolbox/ -# -# -#This file is part of the Python interface to the -#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). -# -#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify -#it under the terms of the GNU General Public License as published by -#the Free Software Foundation, either version 3 of the License, or -#(at your option) any later version. -# -#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, -#but WITHOUT ANY WARRANTY; without even the implied warranty of -#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -#GNU General Public License for more details. -# -#You should have received a copy of the GNU General Public License -#along with the Python interface to the ASTRA Toolbox. If not, see . -# -#----------------------------------------------------------------------- - -import astra -import numpy as np -from astra.experimental import do_composite_FP - -astra.log.setOutputScreen(astra.log.STDERR, astra.log.DEBUG) - -# low res part (voxels of 4x4x4) -vol_geom1 = astra.create_vol_geom(32, 16, 32, -64, 0, -64, 64, -64, 64) - -# high res part (voxels of 1x1x1) -vol_geom2 = astra.create_vol_geom(128, 64, 128, 0, 64, -64, 64, -64, 64) - - -# Split the output in two parts as well, for demonstration purposes -angles1 = np.linspace(0, np.pi/2, 90, False) -angles2 = np.linspace(np.pi/2, np.pi, 90, False) -proj_geom1 = astra.create_proj_geom('parallel3d', 1.0, 1.0, 128, 192, angles1) -proj_geom2 = astra.create_proj_geom('parallel3d', 1.0, 1.0, 128, 192, angles2) - -# Create a simple hollow cube phantom -cube1 = np.zeros((32,32,16)) -cube1[4:28,4:28,4:16] = 1 - -cube2 = np.zeros((128,128,64)) -cube2[16:112,16:112,0:112] = 1 -cube2[33:97,33:97,4:28] = 0 - -vol1 = astra.data3d.create('-vol', vol_geom1, cube1) -vol2 = astra.data3d.create('-vol', vol_geom2, cube2) - -proj1 = astra.data3d.create('-proj3d', proj_geom1, 0) -proj2 = astra.data3d.create('-proj3d', proj_geom2, 0) - -# The actual geometries don't matter for this composite FP/BP case -projector = astra.create_projector('cuda3d', proj_geom1, vol_geom1) - -do_composite_FP(projector, [vol1, vol2], [proj1, proj2]) - -proj_data1 = astra.data3d.get(proj1) -proj_data2 = astra.data3d.get(proj2) - -# Display a single projection image -import pylab -pylab.gray() -pylab.figure(1) -pylab.imshow(proj_data1[:,0,:]) -pylab.figure(2) -pylab.imshow(proj_data2[:,0,:]) -pylab.show() - - -# Clean up. Note that GPU memory is tied up in the algorithm object, -# and main RAM in the data objects. -astra.data3d.delete(vol1) -astra.data3d.delete(vol2) -astra.data3d.delete(proj1) -astra.data3d.delete(proj2) -astra.projector3d.delete(projector) diff --git a/samples/python/s019_experimental_multires.py b/samples/python/s019_experimental_multires.py new file mode 100644 index 0000000..cf38e53 --- /dev/null +++ b/samples/python/s019_experimental_multires.py @@ -0,0 +1,84 @@ +#----------------------------------------------------------------------- +#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see . +# +#----------------------------------------------------------------------- + +import astra +import numpy as np +from astra.experimental import do_composite_FP + +astra.log.setOutputScreen(astra.log.STDERR, astra.log.DEBUG) + +# low res part (voxels of 4x4x4) +vol_geom1 = astra.create_vol_geom(32, 16, 32, -64, 0, -64, 64, -64, 64) + +# high res part (voxels of 1x1x1) +vol_geom2 = astra.create_vol_geom(128, 64, 128, 0, 64, -64, 64, -64, 64) + + +# Split the output in two parts as well, for demonstration purposes +angles1 = np.linspace(0, np.pi/2, 90, False) +angles2 = np.linspace(np.pi/2, np.pi, 90, False) +proj_geom1 = astra.create_proj_geom('parallel3d', 1.0, 1.0, 128, 192, angles1) +proj_geom2 = astra.create_proj_geom('parallel3d', 1.0, 1.0, 128, 192, angles2) + +# Create a simple hollow cube phantom +cube1 = np.zeros((32,32,16)) +cube1[4:28,4:28,4:16] = 1 + +cube2 = np.zeros((128,128,64)) +cube2[16:112,16:112,0:112] = 1 +cube2[33:97,33:97,4:28] = 0 + +vol1 = astra.data3d.create('-vol', vol_geom1, cube1) +vol2 = astra.data3d.create('-vol', vol_geom2, cube2) + +proj1 = astra.data3d.create('-proj3d', proj_geom1, 0) +proj2 = astra.data3d.create('-proj3d', proj_geom2, 0) + +# The actual geometries don't matter for this composite FP/BP case +projector = astra.create_projector('cuda3d', proj_geom1, vol_geom1) + +do_composite_FP(projector, [vol1, vol2], [proj1, proj2]) + +proj_data1 = astra.data3d.get(proj1) +proj_data2 = astra.data3d.get(proj2) + +# Display a single projection image +import pylab +pylab.gray() +pylab.figure(1) +pylab.imshow(proj_data1[:,0,:]) +pylab.figure(2) +pylab.imshow(proj_data2[:,0,:]) +pylab.show() + + +# Clean up. Note that GPU memory is tied up in the algorithm object, +# and main RAM in the data objects. +astra.data3d.delete(vol1) +astra.data3d.delete(vol2) +astra.data3d.delete(proj1) +astra.data3d.delete(proj2) +astra.projector3d.delete(projector) -- cgit v1.2.3 From b529ff854f0e1191108e31d6be294d31b50c666e Mon Sep 17 00:00:00 2001 From: Willem Jan Palenstijn Date: Thu, 7 Jan 2016 13:36:02 +0100 Subject: Add multi-GPU sample --- samples/matlab/s020_3d_multiGPU.m | 38 +++++++++++++++++++++++++ samples/python/s020_3d_multiGPU.py | 57 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+) create mode 100644 samples/matlab/s020_3d_multiGPU.m create mode 100644 samples/python/s020_3d_multiGPU.py (limited to 'samples') diff --git a/samples/matlab/s020_3d_multiGPU.m b/samples/matlab/s020_3d_multiGPU.m new file mode 100644 index 0000000..bade325 --- /dev/null +++ b/samples/matlab/s020_3d_multiGPU.m @@ -0,0 +1,38 @@ +% ----------------------------------------------------------------------- +% This file is part of the ASTRA Toolbox +% +% Copyright: 2010-2015, iMinds-Vision Lab, University of Antwerp +% 2014-2015, CWI, Amsterdam +% License: Open Source under GPLv3 +% Contact: astra@uantwerpen.be +% Website: http://sf.net/projects/astra-toolbox +% ----------------------------------------------------------------------- + + +% Set up multi-GPU usage. +% This only works for 3D GPU forward projection and back projection. +astra_mex('set_gpu_index', [0 1]); + +% Optionally, you can also restrict the amount of GPU memory ASTRA will use. +% The line commented below sets this to 1GB. +%astra_mex('set_gpu_index', [0 1], 'memory', 1024*1024*1024); + +vol_geom = astra_create_vol_geom(1024, 1024, 1024); + +angles = linspace2(0, pi, 1024); +proj_geom = astra_create_proj_geom('parallel3d', 1.0, 1.0, 1024, 1024, angles); + +% Create a simple hollow cube phantom +cube = zeros(1024,1024,1024); +cube(129:896,129:896,129:896) = 1; +cube(257:768,257:768,257:768) = 0; + +% Create projection data from this +[proj_id, proj_data] = astra_create_sino3d_cuda(cube, proj_geom, vol_geom); + +% Backproject projection data +[bproj_id, bproj_data] = astra_create_backprojection3d_cuda(proj_data, proj_geom, vol_geom); + +astra_mex_data3d('delete', proj_id); +astra_mex_data3d('delete', bproj_id); + diff --git a/samples/python/s020_3d_multiGPU.py b/samples/python/s020_3d_multiGPU.py new file mode 100644 index 0000000..d6799c4 --- /dev/null +++ b/samples/python/s020_3d_multiGPU.py @@ -0,0 +1,57 @@ +#----------------------------------------------------------------------- +#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see . +# +#----------------------------------------------------------------------- + +import astra +import numpy as np + +# Set up multi-GPU usage. +# This only works for 3D GPU forward projection and back projection. +astra.astra.set_gpu_index([0,1]) + +# Optionally, you can also restrict the amount of GPU memory ASTRA will use. +# The line commented below sets this to 1GB. +#astra.astra.set_gpu_index([0,1], memory=1024*1024*1024) + +vol_geom = astra.create_vol_geom(1024, 1024, 1024) + +angles = np.linspace(0, np.pi, 1024,False) +proj_geom = astra.create_proj_geom('parallel3d', 1.0, 1.0, 1024, 1024, angles) + +# Create a simple hollow cube phantom +cube = np.zeros((1024,1024,1024)) +cube[128:895,128:895,128:895] = 1 +cube[256:767,256:767,256:767] = 0 + +# Create projection data from this +proj_id, proj_data = astra.create_sino3d_gpu(cube, proj_geom, vol_geom) + +# Backproject projection data +bproj_id, bproj_data = astra.create_backprojection3d_gpu(proj_data, proj_geom, vol_geom) + +# Clean up. Note that GPU memory is tied up in the algorithm object, +# and main RAM in the data objects. +astra.data3d.delete(proj_id) +astra.data3d.delete(bproj_id) -- cgit v1.2.3 From 1e26f7602b6685c584fd4d857353f390622e3a34 Mon Sep 17 00:00:00 2001 From: "Daniel M. Pelt" Date: Mon, 25 Apr 2016 10:47:59 +0200 Subject: Change flatten to ravel in Python code --- python/astra/optomo.py | 4 ++-- samples/python/s009_projection_matrix.py | 2 +- samples/python/s015_fp_bp.py | 6 +++--- samples/python/s017_OpTomo.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) (limited to 'samples') diff --git a/python/astra/optomo.py b/python/astra/optomo.py index dd10713..5a92998 100644 --- a/python/astra/optomo.py +++ b/python/astra/optomo.py @@ -125,7 +125,7 @@ class OpTomo(scipy.sparse.linalg.LinearOperator): algorithm.delete(fp_id) self.data_mod.delete([vid,sid]) - return s.flatten() + return s.ravel() def rmatvec(self,s): """Implements the transpose operator. @@ -147,7 +147,7 @@ class OpTomo(scipy.sparse.linalg.LinearOperator): algorithm.delete(bp_id) self.data_mod.delete([vid,sid]) - return v.flatten() + return v.ravel() def __mul__(self,v): """Provides easy forward operator by *. diff --git a/samples/python/s009_projection_matrix.py b/samples/python/s009_projection_matrix.py index c4c4557..e20d58c 100644 --- a/samples/python/s009_projection_matrix.py +++ b/samples/python/s009_projection_matrix.py @@ -46,7 +46,7 @@ W = astra.matrix.get(matrix_id) # Manually use this projection matrix to do a projection: import scipy.io P = scipy.io.loadmat('phantom.mat')['phantom256'] -s = W.dot(P.flatten()) +s = W.dot(P.ravel()) s = np.reshape(s, (len(proj_geom['ProjectionAngles']),proj_geom['DetectorCount'])) import pylab diff --git a/samples/python/s015_fp_bp.py b/samples/python/s015_fp_bp.py index fa0bf86..ff0b30a 100644 --- a/samples/python/s015_fp_bp.py +++ b/samples/python/s015_fp_bp.py @@ -46,12 +46,12 @@ class astra_wrap(object): def matvec(self,v): sid, s = astra.create_sino(np.reshape(v,(vol_geom['GridRowCount'],vol_geom['GridColCount'])),self.proj_id) astra.data2d.delete(sid) - return s.flatten() + return s.ravel() def rmatvec(self,v): bid, b = astra.create_backprojection(np.reshape(v,(len(proj_geom['ProjectionAngles']),proj_geom['DetectorCount'],)),self.proj_id) astra.data2d.delete(bid) - return b.flatten() + return b.ravel() vol_geom = astra.create_vol_geom(256, 256) proj_geom = astra.create_proj_geom('parallel', 1.0, 384, np.linspace(0,np.pi,180,False)) @@ -65,7 +65,7 @@ proj_id = astra.create_projector('cuda',proj_geom,vol_geom) sinogram_id, sinogram = astra.create_sino(P, proj_id) # Reshape the sinogram into a vector -b = sinogram.flatten() +b = sinogram.ravel() # Call lsqr with ASTRA FP and BP import scipy.sparse.linalg diff --git a/samples/python/s017_OpTomo.py b/samples/python/s017_OpTomo.py index 967fa64..214e9a7 100644 --- a/samples/python/s017_OpTomo.py +++ b/samples/python/s017_OpTomo.py @@ -50,7 +50,7 @@ pylab.figure(2) pylab.imshow(sinogram) # Run the lsqr linear solver -output = scipy.sparse.linalg.lsqr(W, sinogram.flatten(), iter_lim=150) +output = scipy.sparse.linalg.lsqr(W, sinogram.ravel(), iter_lim=150) rec = output[0].reshape([256, 256]) pylab.figure(3) -- cgit v1.2.3 From ed717202a0c917958892e26322d6ea5173f7b32c Mon Sep 17 00:00:00 2001 From: Willem Jan Palenstijn Date: Mon, 25 Apr 2016 17:04:39 +0200 Subject: Use FP/BP out argument in sample plugin --- samples/python/s018_plugin.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) (limited to 'samples') diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py index 31cca95..85b5486 100644 --- a/samples/python/s018_plugin.py +++ b/samples/python/s018_plugin.py @@ -30,30 +30,38 @@ import six # Define the plugin class (has to subclass astra.plugin.base) # Note that usually, these will be defined in a separate package/module -class SIRTPlugin(astra.plugin.base): - """Example of an ASTRA plugin class, implementing a simple 2D SIRT algorithm. +class LandweberPlugin(astra.plugin.base): + """Example of an ASTRA plugin class, implementing a simple 2D Landweber algorithm. Options: - 'rel_factor': relaxation factor (optional) + 'Relaxation': relaxation factor (optional) """ # The astra_name variable defines the name to use to # call the plugin from ASTRA - astra_name = "SIRT-PLUGIN" + astra_name = "LANDWEBER-PLUGIN" - def initialize(self,cfg, rel_factor = 1): + def initialize(self,cfg, Relaxation = 1): self.W = astra.OpTomo(cfg['ProjectorId']) self.vid = cfg['ReconstructionDataId'] self.sid = cfg['ProjectionDataId'] - self.rel = rel_factor + self.rel = Relaxation def run(self, its): v = astra.data2d.get_shared(self.vid) s = astra.data2d.get_shared(self.sid) + tv = np.zeros(v.shape, dtype=np.float32) + ts = np.zeros(s.shape, dtype=np.float32) W = self.W for i in range(its): - v[:] += self.rel*(W.T*(s - (W*v).reshape(s.shape))).reshape(v.shape)/s.size + W.FP(v,out=ts) + ts -= s # ts = W*v - s + + W.BP(ts,out=tv) + tv *= self.rel / s.size + + v -= tv # v = v - rel * W'*(W*v-s) / s.size if __name__=='__main__': @@ -75,20 +83,20 @@ if __name__=='__main__': # First we import the package that contains the plugin import s018_plugin # Then, we register the plugin class with ASTRA - astra.plugin.register(s018_plugin.SIRTPlugin) + astra.plugin.register(s018_plugin.LandweberPlugin) # Get a list of registered plugins six.print_(astra.plugin.get_registered()) # To get help on a registered plugin, use get_help - six.print_(astra.plugin.get_help('SIRT-PLUGIN')) + six.print_(astra.plugin.get_help('LANDWEBER-PLUGIN')) # Create data structures sid = astra.data2d.create('-sino', proj_geom, sinogram) vid = astra.data2d.create('-vol', vol_geom) # Create config using plugin name - cfg = astra.astra_dict('SIRT-PLUGIN') + cfg = astra.astra_dict('LANDWEBER-PLUGIN') cfg['ProjectorId'] = proj_id cfg['ProjectionDataId'] = sid cfg['ReconstructionDataId'] = vid @@ -103,18 +111,18 @@ if __name__=='__main__': rec = astra.data2d.get(vid) # Options for the plugin go in cfg['option'] - cfg = astra.astra_dict('SIRT-PLUGIN') + cfg = astra.astra_dict('LANDWEBER-PLUGIN') cfg['ProjectorId'] = proj_id cfg['ProjectionDataId'] = sid cfg['ReconstructionDataId'] = vid cfg['option'] = {} - cfg['option']['rel_factor'] = 1.5 + cfg['option']['Relaxation'] = 1.5 alg_id_rel = astra.algorithm.create(cfg) astra.algorithm.run(alg_id_rel, 100) rec_rel = astra.data2d.get(vid) # We can also use OpTomo to call the plugin - rec_op = W.reconstruct('SIRT-PLUGIN', sinogram, 100, extraOptions={'rel_factor':1.5}) + rec_op = W.reconstruct('LANDWEBER-PLUGIN', sinogram, 100, extraOptions={'Relaxation':1.5}) import pylab as pl pl.gray() -- cgit v1.2.3