diff options
| -rw-r--r-- | build/linux/Makefile.in | 16 | ||||
| -rw-r--r-- | include/astra/AstraObjectFactory.h | 55 | ||||
| -rw-r--r-- | include/astra/Globals.h | 2 | ||||
| -rw-r--r-- | include/astra/PluginAlgorithm.h | 90 | ||||
| -rw-r--r-- | matlab/mex/astra_mex_plugin_c.cpp | 132 | ||||
| -rw-r--r-- | matlab/mex/mexInitFunctions.cpp | 3 | ||||
| -rw-r--r-- | matlab/tools/astra_mex_plugin.m | 24 | ||||
| -rw-r--r-- | python/astra/PyIncludes.pxd | 2 | ||||
| -rw-r--r-- | python/astra/__init__.py | 1 | ||||
| -rw-r--r-- | python/astra/data2d_c.pyx | 27 | ||||
| -rw-r--r-- | python/astra/plugin.py | 121 | ||||
| -rw-r--r-- | python/astra/plugin_c.pyx | 67 | ||||
| -rw-r--r-- | python/astra/utils.pyx | 10 | ||||
| -rw-r--r-- | python/docSRC/index.rst | 1 | ||||
| -rw-r--r-- | python/docSRC/plugins.rst | 8 | ||||
| -rw-r--r-- | samples/python/s018_plugin.py | 132 | ||||
| -rw-r--r-- | src/Globals.cpp | 3 | ||||
| -rw-r--r-- | src/PluginAlgorithm.cpp | 401 | 
18 files changed, 1077 insertions, 18 deletions
| diff --git a/build/linux/Makefile.in b/build/linux/Makefile.in index abbebe2..bdffd4c 100644 --- a/build/linux/Makefile.in +++ b/build/linux/Makefile.in @@ -50,11 +50,17 @@ LDFLAGS+=-fopenmp  endif  ifeq ($(python),yes) -PYCPPFLAGS  = ${CPPFLAGS} +PYTHON      = @PYTHON@ +PYLIBDIR = $(shell $(PYTHON) -c 'from distutils.sysconfig import get_config_var; import six; six.print_(get_config_var("LIBDIR"))') +PYINCDIR = $(shell $(PYTHON) -c 'from distutils.sysconfig import get_python_inc; import six; six.print_(get_python_inc())') +PYLIBVER = `basename $(PYINCDIR)` +CPPFLAGS += -DASTRA_PYTHON -I$(PYINCDIR) +PYCPPFLAGS  = $(CPPFLAGS)  PYCPPFLAGS  += -I../include -PYLDFLAGS = ${LDFLAGS} +PYLDFLAGS = $(LDFLAGS)  PYLDFLAGS   += -L../build/linux/.libs -PYTHON      = @PYTHON@ +LIBS		+= -l$(PYLIBVER) +LDFLAGS += -L$(PYLIBDIR)  endif  BOOST_CPPFLAGS= @@ -235,6 +241,10 @@ MATLAB_MEX=\  	matlab/mex/astra_mex_data3d_c.$(MEXSUFFIX) \  	matlab/mex/astra_mex_direct_c.$(MEXSUFFIX) +ifeq ($(python),yes) +ALL_OBJECTS+=src/PluginAlgorithm.lo +MATLAB_MEX+=matlab/mex/astra_mex_plugin_c.$(MEXSUFFIX) +endif  OBJECT_DIRS = src/ tests/ cuda/2d/ cuda/3d/ matlab/mex/ ./  DEPDIRS = $(addsuffix $(DEPDIR),$(OBJECT_DIRS)) diff --git a/include/astra/AstraObjectFactory.h b/include/astra/AstraObjectFactory.h index 1ed4955..325989e 100644 --- a/include/astra/AstraObjectFactory.h +++ b/include/astra/AstraObjectFactory.h @@ -40,6 +40,10 @@ $Id$  #include "AlgorithmTypelist.h" +#ifdef ASTRA_PYTHON +#include "PluginAlgorithm.h" +#endif +  namespace astra { @@ -59,20 +63,27 @@ public:  	 */  	~CAstraObjectFactory(); -	/** Create, but don't initialize, a new projector object. +	/** Create, but don't initialize, a new object.  	 * -	 * @param _sType Type of the new projector. -	 * @return Pointer to a new, unitialized projector. +	 * @param _sType Type of the new object. +	 * @return Pointer to a new, uninitialized object.  	 */  	T* create(std::string _sType); -	/** Create and initialize a new projector object. +	/** Create and initialize a new object.  	 * -	 * @param _cfg Configuration object to create and initialize a new projector. +	 * @param _cfg Configuration object to create and initialize a new object.  	 * @return Pointer to a new, initialized projector.  	 */  	T* create(const Config& _cfg); +	/** Find a plugin. +	* +	* @param _sType Name of plugin to find. +	* @return Pointer to a new, uninitialized object, or NULL if not found. +	*/ +	T* findPlugin(std::string _sType); +  }; @@ -93,6 +104,15 @@ CAstraObjectFactory<T, TypeList>::~CAstraObjectFactory()  } + +//---------------------------------------------------------------------------------------- +// Hook for finding plugin in registered plugins. +template <typename T, typename TypeList> +T* CAstraObjectFactory<T, TypeList>::findPlugin(std::string _sType) +{ +	return NULL; +} +  //----------------------------------------------------------------------------------------  // Create   template <typename T, typename TypeList> @@ -101,6 +121,9 @@ T* CAstraObjectFactory<T, TypeList>::create(std::string _sType)  	functor_find<T> finder = functor_find<T>();  	finder.tofind = _sType;  	CreateObject<TypeList>::find(finder); +	if (finder.res == NULL) { +		finder.res = findPlugin(_sType); +	}  	return finder.res;  } @@ -109,14 +132,11 @@ T* CAstraObjectFactory<T, TypeList>::create(std::string _sType)  template <typename T, typename TypeList>  T* CAstraObjectFactory<T, TypeList>::create(const Config& _cfg)  { -	functor_find<T> finder = functor_find<T>(); -	finder.tofind = _cfg.self.getAttribute("type"); -	CreateObject<TypeList>::find(finder); -	if (finder.res == NULL) return NULL; -	if (finder.res->initialize(_cfg)) -		return finder.res; - -	delete finder.res; +	T* object = create(_cfg.self.getAttribute("type")); +	if (object == NULL) return NULL; +	if (object->initialize(_cfg)) +		return object; +	delete object;  	return NULL;  }  //---------------------------------------------------------------------------------------- @@ -131,6 +151,15 @@ T* CAstraObjectFactory<T, TypeList>::create(const Config& _cfg)  */  class _AstraExport CAlgorithmFactory : public CAstraObjectFactory<CAlgorithm, AlgorithmTypeList> {}; +#ifdef ASTRA_PYTHON +template <> +inline CAlgorithm* CAstraObjectFactory<CAlgorithm, AlgorithmTypeList>::findPlugin(std::string _sType) +	{ +		CPluginAlgorithmFactory *fac = CPluginAlgorithmFactory::getSingletonPtr(); +		return fac->getPlugin(_sType); +	} +#endif +  /**   * Class used to create 2D projectors from a string or a config object  */ diff --git a/include/astra/Globals.h b/include/astra/Globals.h index 4de07d1..f70c3a9 100644 --- a/include/astra/Globals.h +++ b/include/astra/Globals.h @@ -146,6 +146,8 @@ namespace astra {  	const float32 PIdiv2 = PI / 2;  	const float32 PIdiv4 = PI / 4;  	const float32 eps = 1e-7f; +	 +	extern _AstraExport bool running_in_matlab;  }  //---------------------------------------------------------------------------------------- diff --git a/include/astra/PluginAlgorithm.h b/include/astra/PluginAlgorithm.h new file mode 100644 index 0000000..667e813 --- /dev/null +++ b/include/astra/PluginAlgorithm.h @@ -0,0 +1,90 @@ +/* +----------------------------------------------------------------------- +Copyright: 2010-2015, iMinds-Vision Lab, University of Antwerp +           2014-2015, CWI, Amsterdam + +Contact: astra@uantwerpen.be +Website: http://sf.net/projects/astra-toolbox + +This file is part of the ASTRA Toolbox. + + +The ASTRA Toolbox is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +The ASTRA Toolbox is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>. + +----------------------------------------------------------------------- +$Id$ +*/ + +#ifndef _INC_ASTRA_PLUGINALGORITHM +#define _INC_ASTRA_PLUGINALGORITHM + +#ifdef ASTRA_PYTHON + +#include "astra/Algorithm.h" +#include "astra/Singleton.h" +#include "astra/XMLDocument.h" +#include "astra/XMLNode.h" + +// Slightly hackish forward declaration of PyObject +struct _object; +typedef _object PyObject; + + +namespace astra { +class _AstraExport CPluginAlgorithm : public CAlgorithm { + +public: + +    CPluginAlgorithm(PyObject* pyclass); +    ~CPluginAlgorithm(); + +    bool initialize(const Config& _cfg); +    void run(int _iNrIterations); + +private: +    PyObject * instance; + +}; + +class _AstraExport CPluginAlgorithmFactory : public Singleton<CPluginAlgorithmFactory> { + +public: + +    CPluginAlgorithmFactory(); +    ~CPluginAlgorithmFactory(); + +    CPluginAlgorithm * getPlugin(std::string name); + +    bool registerPlugin(std::string name, std::string className); +    bool registerPlugin(std::string className); +    bool registerPluginClass(std::string name, PyObject * className); +    bool registerPluginClass(PyObject * className); +     +    PyObject * getRegistered(); +    std::map<std::string, std::string> getRegisteredMap(); +     +    std::string getHelp(std::string name); + +private: +    PyObject * pluginDict; +    PyObject *inspect, *six; +}; + +PyObject* XMLNode2dict(XMLNode node); + +} + +#endif + +#endif diff --git a/matlab/mex/astra_mex_plugin_c.cpp b/matlab/mex/astra_mex_plugin_c.cpp new file mode 100644 index 0000000..177fcf4 --- /dev/null +++ b/matlab/mex/astra_mex_plugin_c.cpp @@ -0,0 +1,132 @@ +/* +----------------------------------------------------------------------- +Copyright: 2010-2015, iMinds-Vision Lab, University of Antwerp +           2014-2015, CWI, Amsterdam + +Contact: astra@uantwerpen.be +Website: http://sf.net/projects/astra-toolbox + +This file is part of the ASTRA Toolbox. + + +The ASTRA Toolbox is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +The ASTRA Toolbox is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>. + +----------------------------------------------------------------------- +$Id$ +*/ + +/** \file astra_mex_plugin_c.cpp + * + *  \brief Manages Python plugins. + */ + +#include <mex.h> +#include "mexHelpFunctions.h" +#include "mexInitFunctions.h" + +#include "astra/PluginAlgorithm.h" + +using namespace std; +using namespace astra; + + +//----------------------------------------------------------------------------------------- +/** astra_mex_plugin('get_registered'); + * + * Print registered plugins. + */ +void astra_mex_plugin_get_registered(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) +{ +    astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr(); +    std::map<std::string, std::string> mp = fact->getRegisteredMap(); +    for(std::map<std::string,std::string>::iterator it=mp.begin();it!=mp.end();it++){ +        mexPrintf("%s: %s\n",it->first.c_str(), it->second.c_str()); +    } +} + +//----------------------------------------------------------------------------------------- +/** astra_mex_plugin('register', class_name); + * + * Register plugin. + */ +void astra_mex_plugin_register(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) +{ +    if (2 <= nrhs) { +        string class_name = mexToString(prhs[1]); +        astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr(); +        fact->registerPlugin(class_name); +    }else{ +        mexPrintf("astra_mex_plugin('register', class_name);\n"); +    } +} + +//----------------------------------------------------------------------------------------- +/** astra_mex_plugin('get_help', name); + * + * Get help about plugin. + */ +void astra_mex_plugin_get_help(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) +{ +    if (2 <= nrhs) { +        string name = mexToString(prhs[1]); +        astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr(); +        mexPrintf((fact->getHelp(name)+"\n").c_str()); +    }else{ +        mexPrintf("astra_mex_plugin('get_help', name);\n"); +    } +} + + +//----------------------------------------------------------------------------------------- + +static void printHelp() +{ +	mexPrintf("Please specify a mode of operation.\n"); +	mexPrintf("   Valid modes: register, get_registered, get_help\n"); +} + +//----------------------------------------------------------------------------------------- +/** + * ... = astra_mex(type,...); + */ +void mexFunction(int nlhs, mxArray* plhs[], +				 int nrhs, const mxArray* prhs[]) +{ + +	// INPUT0: Mode +	string sMode = ""; +	if (1 <= nrhs) { +		sMode = mexToString(prhs[0]);	 +	} else { +		printHelp(); +		return; +	} + +	initASTRAMex(); + +	// SWITCH (MODE) +	if (sMode ==  std::string("get_registered")) {  +		astra_mex_plugin_get_registered(nlhs, plhs, nrhs, prhs);  +    }else if (sMode ==  std::string("get_help")) {  +        astra_mex_plugin_get_help(nlhs, plhs, nrhs, prhs);  +    }else if (sMode ==  std::string("register")) {  +		astra_mex_plugin_register(nlhs, plhs, nrhs, prhs);  +	} else { +		printHelp(); +	} + +	return; +} + + diff --git a/matlab/mex/mexInitFunctions.cpp b/matlab/mex/mexInitFunctions.cpp index 89a31a1..bd3df2c 100644 --- a/matlab/mex/mexInitFunctions.cpp +++ b/matlab/mex/mexInitFunctions.cpp @@ -17,6 +17,9 @@ void logCallBack(const char *msg, size_t len){   */  void initASTRAMex(){      if(mexIsInitialized) return; + +    astra::running_in_matlab=true; +      if(!astra::CLogger::setCallbackScreen(&logCallBack)){          mexErrMsgTxt("Error initializing mex functions.");      } diff --git a/matlab/tools/astra_mex_plugin.m b/matlab/tools/astra_mex_plugin.m new file mode 100644 index 0000000..4159365 --- /dev/null +++ b/matlab/tools/astra_mex_plugin.m @@ -0,0 +1,24 @@ +function [varargout] = astra_mex_plugin(varargin) +%------------------------------------------------------------------------ +% Reference page in Help browser +%    <a href="matlab:docsearch('astra_mex_plugin' )">astra_mex_plugin</a>. +%------------------------------------------------------------------------ +%------------------------------------------------------------------------ +% This file is part of the ASTRA Toolbox +%  +% Copyright: 2010-2015, iMinds-Vision Lab, University of Antwerp +%            2014-2015, CWI, Amsterdam +% License: Open Source under GPLv3 +% Contact: astra@uantwerpen.be +% Website: http://sf.net/projects/astra-toolbox +%------------------------------------------------------------------------ +% $Id$ +if nargout == 0 +    astra_mex_plugin_c(varargin{:}); +    if exist('ans','var') +        varargout{1} = ans; +    end +else +    varargout = cell(1,nargout); +    [varargout{:}] = astra_mex_plugin_c(varargin{:}); +end
\ No newline at end of file diff --git a/python/astra/PyIncludes.pxd b/python/astra/PyIncludes.pxd index 35dea5f..77346b0 100644 --- a/python/astra/PyIncludes.pxd +++ b/python/astra/PyIncludes.pxd @@ -62,6 +62,7 @@ cdef extern from "astra/VolumeGeometry2D.h" namespace "astra":  		float32 getWindowMaxX()  		float32 getWindowMaxY()  		Config* getConfiguration() +		bool isEqual(CVolumeGeometry2D*)  cdef extern from "astra/Float32Data2D.h" namespace "astra":  	cdef cppclass CFloat32CustomMemory: @@ -89,6 +90,7 @@ cdef extern from "astra/ProjectionGeometry2D.h" namespace "astra":  		float32 getProjectionAngle(int)  		float32 getDetectorWidth()  		Config* getConfiguration() +		bool isEqual(CProjectionGeometry2D*)  cdef extern from "astra/Float32Data2D.h" namespace "astra::CFloat32Data2D":  	cdef enum TWOEDataType "astra::CFloat32Data2D::EDataType": diff --git a/python/astra/__init__.py b/python/astra/__init__.py index 6c15d30..10ed74d 100644 --- a/python/astra/__init__.py +++ b/python/astra/__init__.py @@ -34,6 +34,7 @@ from . import algorithm  from . import projector  from . import projector3d  from . import matrix +from . import plugin  from . import log  from .optomo import OpTomo diff --git a/python/astra/data2d_c.pyx b/python/astra/data2d_c.pyx index 4919bf2..801fd8e 100644 --- a/python/astra/data2d_c.pyx +++ b/python/astra/data2d_c.pyx @@ -34,6 +34,9 @@ from cython cimport view  cimport PyData2DManager  from .PyData2DManager cimport CData2DManager +cimport PyProjector2DManager +from .PyProjector2DManager cimport CProjector2DManager +  cimport PyXMLDocument  from .PyXMLDocument cimport XMLDocument @@ -54,6 +57,8 @@ import operator  from six.moves import reduce  cdef CData2DManager * man2d = <CData2DManager * >PyData2DManager.getSingletonPtr() +cdef CProjector2DManager * manProj = <CProjector2DManager * >PyProjector2DManager.getSingletonPtr() +  cdef extern from "CFloat32CustomPython.h":      cdef cppclass CFloat32CustomPython: @@ -164,7 +169,6 @@ def store(i, data):      cdef CFloat32Data2D * pDataObject = getObject(i)      fillDataObject(pDataObject, data) -  def get_geometry(i):      cdef CFloat32Data2D * pDataObject = getObject(i)      cdef CFloat32ProjectionData2D * pDataObject2 @@ -179,6 +183,27 @@ def get_geometry(i):          raise Exception("Not a known data object")      return geom +cdef CProjector2D * getProjector(i) except NULL: +    cdef CProjector2D * proj = manProj.get(i) +    if proj == NULL: +        raise Exception("Projector not initialized.") +    if not proj.isInitialized(): +        raise Exception("Projector not initialized.") +    return proj + +def check_compatible(i, proj_id): +    cdef CProjector2D * proj = getProjector(proj_id) +    cdef CFloat32Data2D * pDataObject = getObject(i) +    cdef CFloat32ProjectionData2D * pDataObject2 +    cdef CFloat32VolumeData2D * pDataObject3 +    if pDataObject.getType() == TWOPROJECTION: +        pDataObject2 = <CFloat32ProjectionData2D * >pDataObject +        return pDataObject2.getGeometry().isEqual(proj.getProjectionGeometry()) +    elif pDataObject.getType() == TWOVOLUME: +        pDataObject3 = <CFloat32VolumeData2D * >pDataObject +        return pDataObject3.getGeometry().isEqual(proj.getVolumeGeometry()) +    else: +        raise Exception("Not a known data object")  def change_geometry(i, geom):      cdef Config *cfg diff --git a/python/astra/plugin.py b/python/astra/plugin.py new file mode 100644 index 0000000..3e3528d --- /dev/null +++ b/python/astra/plugin.py @@ -0,0 +1,121 @@ +#----------------------------------------------------------------------- +#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>. +# +#----------------------------------------------------------------------- + +from . import plugin_c as p +from . import log +from . import data2d +from . import data2d_c +from . import data3d +from . import projector +import inspect +import traceback + +class base(object): + +    def astra_init(self, cfg): +        args, varargs, varkw, defaults = inspect.getargspec(self.initialize) +        if not defaults is None: +            nopt = len(defaults) +        else: +            nopt = 0 +        if nopt>0: +            req = args[2:-nopt] +            opt = args[-nopt:] +        else: +            req = args[2:] +            opt = [] + +        try: +            optDict = cfg['options'] +        except KeyError: +            optDict = {} + +        cfgKeys = set(optDict.keys()) +        reqKeys = set(req) +        optKeys = set(opt) + +        if not reqKeys.issubset(cfgKeys): +            for key in reqKeys.difference(cfgKeys): +                log.error("Required option '" + key + "' for plugin '" + self.__class__.__name__ + "' not specified") +            raise ValueError("Missing required options") + +        if not cfgKeys.issubset(reqKeys | optKeys): +            log.warn(self.__class__.__name__ + ": unused configuration option: " + str(list(cfgKeys.difference(reqKeys | optKeys)))) + +        args = [optDict[k] for k in req] +        kwargs = dict((k,optDict[k]) for k in opt if k in optDict) +        self.initialize(cfg, *args, **kwargs) + +class ReconstructionAlgorithm2D(base): + +    def astra_init(self, cfg): +        self.pid = cfg['ProjectorId'] +        self.s = data2d.get_shared(cfg['ProjectionDataId']) +        self.v = data2d.get_shared(cfg['ReconstructionDataId']) +        self.vg = projector.volume_geometry(self.pid) +        self.pg = projector.projection_geometry(self.pid) +        if not data2d_c.check_compatible(cfg['ProjectionDataId'], self.pid): +            raise ValueError("Projection data and projector not compatible") +        if not data2d_c.check_compatible(cfg['ReconstructionDataId'], self.pid): +            raise ValueError("Reconstruction data and projector not compatible") +        super(ReconstructionAlgorithm2D,self).astra_init(cfg) + +class ReconstructionAlgorithm3D(base): + +    def astra_init(self, cfg): +        self.pid = cfg['ProjectorId'] +        self.s = data3d.get_shared(cfg['ProjectionDataId']) +        self.v = data3d.get_shared(cfg['ReconstructionDataId']) +        self.vg = data3d.get_geometry(cfg['ReconstructionDataId']) +        self.pg = data3d.get_geometry(cfg['ProjectionDataId']) +        super(ReconstructionAlgorithm3D,self).astra_init(cfg) + +def register(className): +    """Register plugin with ASTRA. +     +    :param className: Class name or class object to register +    :type className: :class:`str` or :class:`class` +     +    """ +    p.register(className) + +def get_registered(): +    """Get dictionary of registered plugins. +     +    :returns: :class:`dict` -- Registered plugins. +     +    """ +    return p.get_registered() + +def get_help(name): +    """Get help for registered plugin. +     +    :param name: Plugin name to get help for +    :type name: :class:`str` +    :returns: :class:`str` -- Help string (docstring). +     +    """ +    return p.get_help(name)
\ No newline at end of file diff --git a/python/astra/plugin_c.pyx b/python/astra/plugin_c.pyx new file mode 100644 index 0000000..8d6816b --- /dev/null +++ b/python/astra/plugin_c.pyx @@ -0,0 +1,67 @@ +#----------------------------------------------------------------------- +#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>. +# +#----------------------------------------------------------------------- +# distutils: language = c++ +# distutils: libraries = astra + +import six +import inspect + +from libcpp.string cimport string +from libcpp cimport bool + +cdef CPluginAlgorithmFactory *fact = getSingletonPtr() + +from . import utils + +cdef extern from "astra/PluginAlgorithm.h" namespace "astra": +    cdef cppclass CPluginAlgorithmFactory: +        bool registerPlugin(string className) +        bool registerPlugin(string name, string className) +        bool registerPluginClass(object className) +        bool registerPluginClass(string name, object className) +        object getRegistered() +        string getHelp(string name) + +cdef extern from "astra/PluginAlgorithm.h" namespace "astra::CPluginAlgorithmFactory": +    cdef CPluginAlgorithmFactory* getSingletonPtr() + +def register(className, name=None): +    if inspect.isclass(className): +        if name==None: +            fact.registerPluginClass(className) +        else: +            fact.registerPluginClass(six.b(name), className) +    else: +        if name==None: +            fact.registerPlugin(six.b(className)) +        else: +            fact.registerPlugin(six.b(name), six.b(className)) + +def get_registered(): +    return fact.getRegistered() + +def get_help(name): +    return utils.wrap_from_bytes(fact.getHelp(six.b(name))) diff --git a/python/astra/utils.pyx b/python/astra/utils.pyx index 260c308..07727ce 100644 --- a/python/astra/utils.pyx +++ b/python/astra/utils.pyx @@ -29,9 +29,13 @@  cimport numpy as np  import numpy as np  import six +if six.PY3: +    import builtins +else: +    import __builtin__  from libcpp.string cimport string -from libcpp.list cimport list  from libcpp.vector cimport vector +from libcpp.list cimport list  from cython.operator cimport dereference as deref, preincrement as inc  from cpython.version cimport PY_MAJOR_VERSION @@ -91,6 +95,8 @@ cdef void readDict(XMLNode root, _dc):      dc = convert_item(_dc)      for item in dc:          val = dc[item] +        if isinstance(val, __builtins__.list) or isinstance(val, tuple): +            val = np.array(val,dtype=np.float64)          if isinstance(val, np.ndarray):              if val.size == 0:                  break @@ -125,6 +131,8 @@ cdef void readOptions(XMLNode node, dc):          val = dc[item]          if node.hasOption(item):              raise Exception('Duplicate Option: %s' % item) +        if isinstance(val, __builtins__.list) or isinstance(val, tuple): +            val = np.array(val,dtype=np.float64)          if isinstance(val, np.ndarray):              if val.size == 0:                  break diff --git a/python/docSRC/index.rst b/python/docSRC/index.rst index b7cc6d6..dcc6590 100644 --- a/python/docSRC/index.rst +++ b/python/docSRC/index.rst @@ -19,6 +19,7 @@ Contents:     creators     functions     operator +   plugins     matlab     astra  .. astra diff --git a/python/docSRC/plugins.rst b/python/docSRC/plugins.rst new file mode 100644 index 0000000..dc7c607 --- /dev/null +++ b/python/docSRC/plugins.rst @@ -0,0 +1,8 @@ +Plugins: the :mod:`plugin` module +========================================= + +.. automodule:: astra.plugin +    :members: +    :undoc-members: +    :show-inheritance: + diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py new file mode 100644 index 0000000..31cca95 --- /dev/null +++ b/samples/python/s018_plugin.py @@ -0,0 +1,132 @@ +#----------------------------------------------------------------------- +#Copyright 2015 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>. +# +#----------------------------------------------------------------------- + +import astra +import numpy as np +import six + +# Define the plugin class (has to subclass astra.plugin.base) +# Note that usually, these will be defined in a separate package/module +class SIRTPlugin(astra.plugin.base): +    """Example of an ASTRA plugin class, implementing a simple 2D SIRT algorithm. + +    Options: + +    'rel_factor': relaxation factor (optional) +    """ + +    # The astra_name variable defines the name to use to +    # call the plugin from ASTRA +    astra_name = "SIRT-PLUGIN" + +    def initialize(self,cfg, rel_factor = 1): +        self.W = astra.OpTomo(cfg['ProjectorId']) +        self.vid = cfg['ReconstructionDataId'] +        self.sid = cfg['ProjectionDataId'] +        self.rel = rel_factor + +    def run(self, its): +        v = astra.data2d.get_shared(self.vid) +        s = astra.data2d.get_shared(self.sid) +        W = self.W +        for i in range(its): +            v[:] += self.rel*(W.T*(s - (W*v).reshape(s.shape))).reshape(v.shape)/s.size + +if __name__=='__main__': + +    vol_geom = astra.create_vol_geom(256, 256) +    proj_geom = astra.create_proj_geom('parallel', 1.0, 384, np.linspace(0,np.pi,180,False)) + +    # As before, create a sinogram from a phantom +    import scipy.io +    P = scipy.io.loadmat('phantom.mat')['phantom256'] +    proj_id = astra.create_projector('cuda',proj_geom,vol_geom) + +    # construct the OpTomo object +    W = astra.OpTomo(proj_id) + +    sinogram = W * P +    sinogram = sinogram.reshape([180, 384]) + +    # Register the plugin with ASTRA +    # First we import the package that contains the plugin +    import s018_plugin +    # Then, we register the plugin class with ASTRA +    astra.plugin.register(s018_plugin.SIRTPlugin) + +    # Get a list of registered plugins +    six.print_(astra.plugin.get_registered()) + +    # To get help on a registered plugin, use get_help +    six.print_(astra.plugin.get_help('SIRT-PLUGIN')) + +    # Create data structures +    sid = astra.data2d.create('-sino', proj_geom, sinogram) +    vid = astra.data2d.create('-vol', vol_geom) + +    # Create config using plugin name +    cfg = astra.astra_dict('SIRT-PLUGIN') +    cfg['ProjectorId'] = proj_id +    cfg['ProjectionDataId'] = sid +    cfg['ReconstructionDataId'] = vid + +    # Create algorithm object +    alg_id = astra.algorithm.create(cfg) + +    # Run algorithm for 100 iterations +    astra.algorithm.run(alg_id, 100) + +    # Get reconstruction +    rec = astra.data2d.get(vid) + +    # Options for the plugin go in cfg['option'] +    cfg = astra.astra_dict('SIRT-PLUGIN') +    cfg['ProjectorId'] = proj_id +    cfg['ProjectionDataId'] = sid +    cfg['ReconstructionDataId'] = vid +    cfg['option'] = {} +    cfg['option']['rel_factor'] = 1.5 +    alg_id_rel = astra.algorithm.create(cfg) +    astra.algorithm.run(alg_id_rel, 100) +    rec_rel = astra.data2d.get(vid) + +    # We can also use OpTomo to call the plugin +    rec_op = W.reconstruct('SIRT-PLUGIN', sinogram, 100, extraOptions={'rel_factor':1.5}) + +    import pylab as pl +    pl.gray() +    pl.figure(1) +    pl.imshow(rec,vmin=0,vmax=1) +    pl.figure(2) +    pl.imshow(rec_rel,vmin=0,vmax=1) +    pl.figure(3) +    pl.imshow(rec_op,vmin=0,vmax=1) +    pl.show() + +    # Clean up. +    astra.projector.delete(proj_id) +    astra.algorithm.delete([alg_id, alg_id_rel]) +    astra.data2d.delete([vid, sid]) diff --git a/src/Globals.cpp b/src/Globals.cpp index 813f9c9..904a459 100644 --- a/src/Globals.cpp +++ b/src/Globals.cpp @@ -28,5 +28,8 @@ $Id$  #include "astra/Globals.h" +namespace astra{ +    bool running_in_matlab=false; +}  // nothing to see here :) diff --git a/src/PluginAlgorithm.cpp b/src/PluginAlgorithm.cpp new file mode 100644 index 0000000..8f7dfc5 --- /dev/null +++ b/src/PluginAlgorithm.cpp @@ -0,0 +1,401 @@ +/* +----------------------------------------------------------------------- +Copyright: 2010-2015, iMinds-Vision Lab, University of Antwerp +           2014-2015, CWI, Amsterdam + +Contact: astra@uantwerpen.be +Website: http://sf.net/projects/astra-toolbox + +This file is part of the ASTRA Toolbox. + + +The ASTRA Toolbox is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +The ASTRA Toolbox is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>. + +----------------------------------------------------------------------- +$Id$ +*/ + +#ifdef ASTRA_PYTHON + +#include "astra/PluginAlgorithm.h" +#include "astra/Logging.h" +#include <boost/algorithm/string.hpp> +#include <boost/algorithm/string/split.hpp> +#include <boost/lexical_cast.hpp> +#include <iostream> +#include <fstream> +#include <string> + +#include <Python.h> +#include "bytesobject.h" + +namespace astra { + + + +void logPythonError(){ +    if(PyErr_Occurred()){ +        PyObject *ptype, *pvalue, *ptraceback; +        PyErr_Fetch(&ptype, &pvalue, &ptraceback); +        PyErr_NormalizeException(&ptype, &pvalue, &ptraceback); +        PyObject *traceback = PyImport_ImportModule("traceback"); +        if(traceback!=NULL){ +            PyObject *exc; +            if(ptraceback==NULL){ +                exc = PyObject_CallMethod(traceback,"format_exception_only","OO",ptype, pvalue); +            }else{ +                exc = PyObject_CallMethod(traceback,"format_exception","OOO",ptype, pvalue, ptraceback); +            } +            if(exc!=NULL){ +                PyObject *six = PyImport_ImportModule("six"); +                if(six!=NULL){ +                    PyObject *iter = PyObject_GetIter(exc); +                    if(iter!=NULL){ +                        PyObject *line; +                        std::string errStr = ""; +                        while(line = PyIter_Next(iter)){ +                            PyObject *retb = PyObject_CallMethod(six,"b","O",line); +                            if(retb!=NULL){ +                                errStr += std::string(PyBytes_AsString(retb)); +                                Py_DECREF(retb); +                            } +                            Py_DECREF(line); +                        } +                        ASTRA_ERROR("%s",errStr.c_str()); +                        Py_DECREF(iter); +                    } +                    Py_DECREF(six); +                } +                Py_DECREF(exc); +            } +            Py_DECREF(traceback); +        } +        if(ptype!=NULL) Py_DECREF(ptype); +        if(pvalue!=NULL) Py_DECREF(pvalue); +        if(ptraceback!=NULL) Py_DECREF(ptraceback); +    } +} + + +CPluginAlgorithm::CPluginAlgorithm(PyObject* pyclass){ +    instance = PyObject_CallObject(pyclass, NULL); +    if(instance==NULL) logPythonError(); +} + +CPluginAlgorithm::~CPluginAlgorithm(){ +    if(instance!=NULL){ +        Py_DECREF(instance); +        instance = NULL; +    } +} + +bool CPluginAlgorithm::initialize(const Config& _cfg){ +    if(instance==NULL) return false; +    PyObject *cfgDict = XMLNode2dict(_cfg.self); +    PyObject *retVal = PyObject_CallMethod(instance, "astra_init", "O",cfgDict); +    Py_DECREF(cfgDict); +    if(retVal==NULL){ +        logPythonError(); +        return false; +    } +    m_bIsInitialized = true; +    Py_DECREF(retVal); +    return m_bIsInitialized; +} + +void CPluginAlgorithm::run(int _iNrIterations){ +    if(instance==NULL) return; +    PyGILState_STATE state = PyGILState_Ensure(); +    PyObject *retVal = PyObject_CallMethod(instance, "run", "i",_iNrIterations); +    if(retVal==NULL){ +        logPythonError(); +    }else{ +        Py_DECREF(retVal); +    } +    PyGILState_Release(state); +} + +void fixLapackLoading(){ +    // When running in Matlab, we need to force numpy +    // to use its internal lapack library instead of +    // Matlab's MKL library to avoid errors. To do this, +    // we set Python's dlopen flags to RTLD_NOW|RTLD_DEEPBIND +    // and import 'numpy.linalg.lapack_lite' here. We reset +    // Python's dlopen flags afterwards. +    PyObject *sys = PyImport_ImportModule("sys"); +    if(sys!=NULL){ +        PyObject *curFlags = PyObject_CallMethod(sys,"getdlopenflags",NULL); +        if(curFlags!=NULL){ +            PyObject *retVal = PyObject_CallMethod(sys, "setdlopenflags", "i",10); +            if(retVal!=NULL){ +                PyObject *lapack = PyImport_ImportModule("numpy.linalg.lapack_lite"); +                if(lapack!=NULL){ +                    Py_DECREF(lapack); +                } +                PyObject_CallMethod(sys, "setdlopenflags", "O",curFlags); +                Py_DECREF(retVal); +            } +            Py_DECREF(curFlags); +        } +        Py_DECREF(sys); +    } +} + +CPluginAlgorithmFactory::CPluginAlgorithmFactory(){ +    if(!Py_IsInitialized()){ +        Py_Initialize(); +        PyEval_InitThreads(); +    } +#ifndef _MSC_VER +    if(astra::running_in_matlab) fixLapackLoading(); +#endif +    pluginDict = PyDict_New(); +    inspect = PyImport_ImportModule("inspect"); +    six = PyImport_ImportModule("six"); +} + +CPluginAlgorithmFactory::~CPluginAlgorithmFactory(){ +    if(pluginDict!=NULL){ +        Py_DECREF(pluginDict); +    } +    if(inspect!=NULL) Py_DECREF(inspect); +    if(six!=NULL) Py_DECREF(six); +} + +PyObject * getClassFromString(std::string str){ +    std::vector<std::string> items; +    boost::split(items, str, boost::is_any_of(".")); +    PyObject *pyclass = PyImport_ImportModule(items[0].c_str()); +    if(pyclass==NULL){ +        logPythonError(); +        return NULL; +    } +    PyObject *submod = pyclass; +    for(unsigned int i=1;i<items.size();i++){ +        submod = PyObject_GetAttrString(submod,items[i].c_str()); +        Py_DECREF(pyclass); +        pyclass = submod; +        if(pyclass==NULL){ +            logPythonError(); +            return NULL; +        } +    } +    return pyclass; +} + +bool CPluginAlgorithmFactory::registerPlugin(std::string name, std::string className){ +    PyObject *str = PyBytes_FromString(className.c_str()); +    PyDict_SetItemString(pluginDict, name.c_str(), str); +    Py_DECREF(str); +    return true; +} + +bool CPluginAlgorithmFactory::registerPlugin(std::string className){ +    PyObject *pyclass = getClassFromString(className); +    if(pyclass==NULL) return false; +    bool ret = registerPluginClass(pyclass); +    Py_DECREF(pyclass); +    return ret; +} + +bool CPluginAlgorithmFactory::registerPluginClass(std::string name, PyObject * className){ +    PyDict_SetItemString(pluginDict, name.c_str(), className); +    return true; +} + +bool CPluginAlgorithmFactory::registerPluginClass(PyObject * className){ +    PyObject *astra_name = PyObject_GetAttrString(className,"astra_name"); +    if(astra_name==NULL){ +        logPythonError(); +        return false; +    } +    PyObject *retb = PyObject_CallMethod(six,"b","O",astra_name); +    if(retb!=NULL){ +        PyDict_SetItemString(pluginDict,PyBytes_AsString(retb),className); +        Py_DECREF(retb); +    }else{ +        logPythonError(); +    } +    Py_DECREF(astra_name); +    return true; +} + +CPluginAlgorithm * CPluginAlgorithmFactory::getPlugin(std::string name){ +    PyObject *className = PyDict_GetItemString(pluginDict, name.c_str()); +    if(className==NULL) return NULL; +    CPluginAlgorithm *alg = NULL; +    if(PyBytes_Check(className)){ +        std::string str = std::string(PyBytes_AsString(className)); +    	PyObject *pyclass = getClassFromString(str); +        if(pyclass!=NULL){ +            alg = new CPluginAlgorithm(pyclass); +            Py_DECREF(pyclass); +        } +    }else{ +        alg = new CPluginAlgorithm(className); +    } +    return alg; +} + +PyObject * CPluginAlgorithmFactory::getRegistered(){ +    Py_INCREF(pluginDict); +    return pluginDict; +} + +std::map<std::string, std::string> CPluginAlgorithmFactory::getRegisteredMap(){ +    std::map<std::string, std::string> ret; +    PyObject *key, *value; +    Py_ssize_t pos = 0; +    while (PyDict_Next(pluginDict, &pos, &key, &value)) { +        PyObject *keystr = PyObject_Str(key); +        if(keystr!=NULL){ +            PyObject *valstr = PyObject_Str(value); +            if(valstr!=NULL){ +                PyObject * keyb = PyObject_CallMethod(six,"b","O",keystr); +                if(keyb!=NULL){ +                    PyObject * valb = PyObject_CallMethod(six,"b","O",valstr); +                    if(valb!=NULL){ +                        ret[PyBytes_AsString(keyb)] = PyBytes_AsString(valb); +                        Py_DECREF(valb); +                    } +                    Py_DECREF(keyb); +                } +                Py_DECREF(valstr); +            } +            Py_DECREF(keystr); +        } +        logPythonError(); +    } +    return ret; +} + +std::string CPluginAlgorithmFactory::getHelp(std::string name){ +    PyObject *className = PyDict_GetItemString(pluginDict, name.c_str()); +    if(className==NULL){ +        ASTRA_ERROR("Plugin %s not found!",name.c_str()); +        PyErr_Clear(); +        return ""; +    } +    std::string ret = ""; +    PyObject *pyclass; +    if(PyBytes_Check(className)){ +        std::string str = std::string(PyBytes_AsString(className)); +        pyclass = getClassFromString(str); +    }else{ +        pyclass = className; +    } +    if(pyclass==NULL) return ""; +    if(inspect!=NULL && six!=NULL){ +        PyObject *retVal = PyObject_CallMethod(inspect,"getdoc","O",pyclass); +        if(retVal!=NULL){ +            if(retVal!=Py_None){ +                PyObject *retb = PyObject_CallMethod(six,"b","O",retVal); +                if(retb!=NULL){ +                    ret = std::string(PyBytes_AsString(retb)); +                    Py_DECREF(retb); +                } +            } +            Py_DECREF(retVal); +        }else{ +            logPythonError(); +        } +    } +    if(PyBytes_Check(className)){ +        Py_DECREF(pyclass); +    } +    return ret; +} + +DEFINE_SINGLETON(CPluginAlgorithmFactory); + +#if PY_MAJOR_VERSION >= 3 +PyObject * pyStringFromString(std::string str){ +    return PyUnicode_FromString(str.c_str()); +} +#else +PyObject * pyStringFromString(std::string str){ +    return PyBytes_FromString(str.c_str()); +} +#endif + +PyObject* stringToPythonValue(std::string str){ +    if(str.find(";")!=std::string::npos){ +        std::vector<std::string> rows, row; +        boost::split(rows, str, boost::is_any_of(";")); +        PyObject *mat = PyList_New(rows.size()); +        for(unsigned int i=0; i<rows.size(); i++){ +            boost::split(row, rows[i], boost::is_any_of(",")); +            PyObject *rowlist = PyList_New(row.size()); +            for(unsigned int j=0;j<row.size();j++){ +                PyList_SetItem(rowlist, j, PyFloat_FromDouble(boost::lexical_cast<double>(row[j]))); +            } +            PyList_SetItem(mat, i, rowlist); +        } +        return mat; +    } +    if(str.find(",")!=std::string::npos){ +        std::vector<std::string> vec; +        boost::split(vec, str, boost::is_any_of(",")); +        PyObject *veclist = PyList_New(vec.size()); +        for(unsigned int i=0;i<vec.size();i++){ +            PyList_SetItem(veclist, i, PyFloat_FromDouble(boost::lexical_cast<double>(vec[i]))); +        } +        return veclist; +    } +    try{ +        return PyLong_FromLong(boost::lexical_cast<long>(str)); +    }catch(const boost::bad_lexical_cast &){ +        try{ +            return PyFloat_FromDouble(boost::lexical_cast<double>(str)); +        }catch(const boost::bad_lexical_cast &){ +            return pyStringFromString(str); +        } +    } +} + +PyObject* XMLNode2dict(XMLNode node){ +    PyObject *dct = PyDict_New(); +    PyObject *opts = PyDict_New(); +    if(node.hasAttribute("type")){ +        PyObject *obj = pyStringFromString(node.getAttribute("type").c_str()); +        PyDict_SetItemString(dct, "type", obj); +        Py_DECREF(obj); +    } +    std::list<XMLNode> nodes = node.getNodes(); +    std::list<XMLNode>::iterator it = nodes.begin(); +    while(it!=nodes.end()){ +        XMLNode subnode = *it; +        if(subnode.getName()=="Option"){ +            PyObject *obj; +            if(subnode.hasAttribute("value")){ +                obj = stringToPythonValue(subnode.getAttribute("value")); +            }else{ +                obj = stringToPythonValue(subnode.getContent()); +            } +            PyDict_SetItemString(opts, subnode.getAttribute("key").c_str(), obj); +            Py_DECREF(obj); +        }else{ +            PyObject *obj = stringToPythonValue(subnode.getContent()); +            PyDict_SetItemString(dct, subnode.getName().c_str(), obj); +            Py_DECREF(obj); +        } +        ++it; +    } +    PyDict_SetItemString(dct, "options", opts); +    Py_DECREF(opts); +    return dct; +} + +} +#endif | 
