summaryrefslogtreecommitdiffstats
path: root/matlab/mex
diff options
context:
space:
mode:
Diffstat (limited to 'matlab/mex')
-rw-r--r--matlab/mex/astra_mex_c.cpp20
1 files changed, 20 insertions, 0 deletions
diff --git a/matlab/mex/astra_mex_c.cpp b/matlab/mex/astra_mex_c.cpp
index 0068664..816f1f3 100644
--- a/matlab/mex/astra_mex_c.cpp
+++ b/matlab/mex/astra_mex_c.cpp
@@ -36,6 +36,8 @@ $Id$
#include "astra/Globals.h"
+#include "../cuda/2d/darthelper.h"
+
using namespace std;
using namespace astra;
@@ -72,6 +74,22 @@ void astra_mex_use_cuda(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs
}
//-----------------------------------------------------------------------------------------
+/** set_gpu_index = astra_mex('set_gpu_index');
+ *
+ * Set active GPU
+ */
+void astra_mex_set_gpu_index(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
+{
+#ifdef ASTRA_CUDA
+ if (nrhs >= 2) {
+ bool ret = astraCUDA::setGPUIndex((int)mxGetScalar(prhs[1]));
+ if (!ret)
+ mexPrintf("Failed to set GPU %d\n", (int)mxGetScalar(prhs[1]));
+ }
+#endif
+}
+
+//-----------------------------------------------------------------------------------------
/** version_number = astra_mex('version');
*
* Fetch the version number of the toolbox.
@@ -117,6 +135,8 @@ void mexFunction(int nlhs, mxArray* plhs[],
astra_mex_use_cuda(nlhs, plhs, nrhs, prhs);
} else if (sMode == std::string("credits")) {
astra_mex_credits(nlhs, plhs, nrhs, prhs);
+ } else if (sMode == std::string("set_gpu_index")) {
+ astra_mex_set_gpu_index(nlhs, plhs, nrhs, prhs);
} else {
printHelp();
}