From 6589fa197d9f87f7a37f46943aa995d97f50bb46 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Mon, 7 Aug 2017 17:21:12 +0100
Subject: added TGV_PD, removed useless code

---
 src/Python/fista_module.cpp | 245 ++++++++++++++++++++++++++------------------
 1 file changed, 146 insertions(+), 99 deletions(-)

(limited to 'src/Python')

diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp
index c2d9352..eacda3d 100644
--- a/src/Python/fista_module.cpp
+++ b/src/Python/fista_module.cpp
@@ -30,6 +30,7 @@ limitations under the License.
 #include "FGP_TV_core.h"
 #include "LLT_model_core.h"
 #include "PatchBased_Regul_core.h"
+#include "TGV_PD_core.h"
 #include "utils.h"
 
 
@@ -103,101 +104,8 @@ If unsuccessful in a MEX file, the MEX file terminates and returns control to th
 enough free heap space to create the mxArray.
 */
 
-void mexErrMessageText(char* text) {
-	std::cerr << text << std::endl;
-}
-
-/*
-double mxGetScalar(const mxArray *pm);
-args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray.
-Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray.	In C, mxGetScalar returns a double.
-*/
-
-template<typename T>
-double mxGetScalar(const np::ndarray plh) {
-	return (double)bp::extract<T>(plh[0]);
-}
-
-
-
-template<typename T>
-T * mxGetData(const np::ndarray pm) {
-	//args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray.
-	//Returns: Pointer to the value of the first real(nonimaginary) element of the mxArray.In C, mxGetScalar returns a double.
-	/*Access the numpy array pointer:
-	char * get_data() const;
-	Returns:	Array�s raw data pointer as a char
-	Note:	This returns char so stride math works properly on it.User will have to reinterpret_cast it.
-	probably this would work.
-	A = reinterpret_cast<float *>(prhs[0]);
-	*/
-	return reinterpret_cast<T *>(prhs[0]);
-}
-
-template<typename T>
-np::ndarray zeros(int dims, int * dim_array, T el) {
-	bp::tuple shape;
-	if (dims == 3)
-		shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]);
-	else if (dims == 2)
-		shape = bp::make_tuple(dim_array[0], dim_array[1]);
-	np::dtype dtype = np::dtype::get_builtin<T>();
-	np::ndarray zz = np::zeros(shape, dtype);
-	return zz;
-}
 
 
-
-
-bp::list mexFunction(np::ndarray input) {
-	int number_of_dims = input.get_nd();
-	int dim_array[3];
-
-	dim_array[0] = input.shape(0);
-	dim_array[1] = input.shape(1);
-	if (number_of_dims == 2) {
-		dim_array[2] = -1;
-	}
-	else {
-		dim_array[2] = input.shape(2);
-	}
-
-	/**************************************************************************/
-	np::ndarray zz = zeros(3, dim_array, (int)0);
-	np::ndarray fzz = zeros(3, dim_array, (float)0);
-	/**************************************************************************/
-
-	int * A = reinterpret_cast<int *>(input.get_data());
-	int * B = reinterpret_cast<int *>(zz.get_data());
-	float * C = reinterpret_cast<float *>(fzz.get_data());
-
-	//Copy data and cast
-	for (int i = 0; i < dim_array[0]; i++) {
-		for (int j = 0; j < dim_array[1]; j++) {
-			for (int k = 0; k < dim_array[2]; k++) {
-				int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i;
-				int val = (*(A + index));
-				float fval = (float)val;
-				std::memcpy(B + index, &val, sizeof(int));
-				std::memcpy(C + index, &fval, sizeof(float));
-			}
-		}
-	}
-
-	bp::list result;
-
-	result.append<int>(number_of_dims);
-	result.append<int>(dim_array[0]);
-	result.append<int>(dim_array[1]);
-	result.append<int>(dim_array[2]);
-	result.append<np::ndarray>(zz);
-	result.append<np::ndarray>(fzz);
-
-	//result.append<bp::tuple>(tup);
-	return result;
-
-}
-
 bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) {
 	
 	// the result is in the following list
@@ -487,7 +395,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me
 		np::ndarray npP1_old = np::zeros(shape, dtype);
 		np::ndarray npP2_old = np::zeros(shape, dtype);
 		np::ndarray npR1     = np::zeros(shape, dtype);
-		np::ndarray npR2     = zeros(2, dim_array, (float)0);
+		np::ndarray npR2     = np::zeros(shape, dtype);
 
 		D      = reinterpret_cast<float *>(npD.get_data());
 		D_old  = reinterpret_cast<float *>(npD_old.get_data());
@@ -866,7 +774,7 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d
 }
 
 
-bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW,  double d_h, double d_lambda) {
+bp::list PatchBased_Regul(np::ndarray input, double d_lambda, int SearchW_real, int SimilW,  double d_h) {
 	// the result is in the following list
 	bp::list result;
 
@@ -899,6 +807,7 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW,  doub
 
 	///*Handling inputs*/
 	//A = (float *)mxGetData(prhs[0]);    /* the image to regularize/filter */
+	A = reinterpret_cast<float *>(input.get_data());
 	//SearchW_real = (int)mxGetScalar(prhs[1]); /* the searching window ratio */
 	//SimilW = (int)mxGetScalar(prhs[2]);  /* the similarity window ratio */
 	//h = (float)mxGetScalar(prhs[3]);  /* parameter for the PB filtering function */
@@ -907,6 +816,8 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW,  doub
 	//if (h <= 0) mexErrMsgTxt("Parmeter for the PB penalty function should be > 0");
 	//if (lambda <= 0) mexErrMsgTxt(" Regularization parmeter should be > 0");
 
+	lambda = (float)d_lambda;
+	h = (float)d_h;
 	SearchW = SearchW_real + 2 * SimilW;
 
 	/* SearchW_full = 2*SearchW + 1; */ /* the full searching window  size */
@@ -918,7 +829,6 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW,  doub
 	newsizeY = M + 2 * (padXY); /* the Y size of the padded array */
 	newsizeZ = Z + 2 * (padXY); /* the Z size of the padded array */
 	int N_dims[] = { newsizeX, newsizeY, newsizeZ };
-
 	/******************************2D case ****************************/
 	if (numdims == 2) {
 		///*Handling output*/
@@ -943,12 +853,13 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW,  doub
 		/*Perform padding of image A to the size of [newsizeX * newsizeY] */
 		switchpad_crop = 0; /*padding*/
 		pad_crop(A, Ap, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop);
-
+		
 		/* Do PB regularization with the padded array  */
 		PB_FUNC2D(Ap, Bp, newsizeY, newsizeX, padXY, SearchW, SimilW, (float)h, (float)lambda);
-
+		
 		switchpad_crop = 1; /*cropping*/
 		pad_crop(Bp, B, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop);
+		
 		result.append<np::ndarray>(npB);
 	}
 	else
@@ -983,6 +894,141 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW,  doub
 		result.append<np::ndarray>(npB);
 	} /*end else ndims*/
 
+	return result;
+}
+
+bp::list TGV_PD(np::ndarray input, double d_lambda, double d_alpha1, double d_alpha0, int iter) {
+	// the result is in the following list
+	bp::list result;
+	int number_of_dims, /*iter,*/ dimX, dimY, dimZ, ll;
+	//const int  *dim_array;
+	float *A, *U, *U_old, *P1, *P2, *P3, *Q1, *Q2, *Q3, *Q4, *Q5, *Q6, *Q7, *Q8, *Q9, *V1, *V1_old, *V2, *V2_old, *V3, *V3_old, lambda, L2, tau, sigma, alpha1, alpha0;
+
+	//number_of_dims = mxGetNumberOfDimensions(prhs[0]);
+	//dim_array = mxGetDimensions(prhs[0]);
+	number_of_dims = input.get_nd();
+	int dim_array[3];
+
+	dim_array[0] = input.shape(0);
+	dim_array[1] = input.shape(1);
+	if (number_of_dims == 2) {
+		dim_array[2] = -1;
+	}
+	else {
+		dim_array[2] = input.shape(2);
+	}
+	/*Handling Matlab input data*/
+	//A = (float *)mxGetData(prhs[0]); /*origanal noise image/volume*/
+	//if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); }
+	
+	A = reinterpret_cast<float *>(input.get_data());
+
+	//lambda = (float)mxGetScalar(prhs[1]); /*regularization parameter*/
+	//alpha1 = (float)mxGetScalar(prhs[2]); /*first-order term*/
+	//alpha0 = (float)mxGetScalar(prhs[3]); /*second-order term*/
+	//iter = (int)mxGetScalar(prhs[4]); /*iterations number*/
+	//if (nrhs != 5) mexErrMsgTxt("Five input parameters is reqired: Image(2D/3D), Regularization parameter, alpha1, alpha0, Iterations");
+	lambda = (float)d_lambda;
+	alpha1 = (float)d_alpha1;
+	alpha0 = (float)d_alpha0;
+
+	/*Handling Matlab output data*/
+	dimX = dim_array[0]; dimY = dim_array[1];
+
+	if (number_of_dims == 2) {
+		/*2D case*/
+		dimZ = 1;
+		bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]);
+		np::dtype dtype = np::dtype::get_builtin<float>();
+
+		np::ndarray npU = np::zeros(shape, dtype);
+		np::ndarray npP1 = np::zeros(shape, dtype);
+		np::ndarray npP2 = np::zeros(shape, dtype);
+		np::ndarray npQ1 = np::zeros(shape, dtype);
+		np::ndarray npQ2 = np::zeros(shape, dtype);
+		np::ndarray npQ3 = np::zeros(shape, dtype);
+		np::ndarray npV1 = np::zeros(shape, dtype);
+		np::ndarray npV1_old = np::zeros(shape, dtype);
+		np::ndarray npV2 = np::zeros(shape, dtype);
+		np::ndarray npV2_old = np::zeros(shape, dtype);
+		np::ndarray npU_old = np::zeros(shape, dtype);
+
+		U = reinterpret_cast<float *>(npU.get_data());
+		U_old = reinterpret_cast<float *>(npU_old.get_data());
+		P1 = reinterpret_cast<float *>(npP1.get_data());
+		P2 = reinterpret_cast<float *>(npP2.get_data());
+		Q1 = reinterpret_cast<float *>(npQ1.get_data());
+		Q2 = reinterpret_cast<float *>(npQ2.get_data());
+		Q3 = reinterpret_cast<float *>(npQ3.get_data());
+		V1 = reinterpret_cast<float *>(npV1.get_data());
+		V1_old = reinterpret_cast<float *>(npV1_old.get_data());
+		V2 = reinterpret_cast<float *>(npV2.get_data());
+		V2_old = reinterpret_cast<float *>(npV2_old.get_data());
+		//U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+
+		/*dual variables*/
+		/*P1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		P2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+
+		Q1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		Q2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		Q3 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+
+		U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+
+		V1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		V1_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		V2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+		V2_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/
+		/*printf("%i \n", i);*/
+		L2 = 12.0; /*Lipshitz constant*/
+		tau = 1.0 / pow(L2, 0.5);
+		sigma = 1.0 / pow(L2, 0.5);
+
+		/*Copy A to U*/
+		copyIm(A, U, dimX, dimY, dimZ);
+		/* Here primal-dual iterations begin for 2D */
+		for (ll = 0; ll < iter; ll++) {
+
+			/* Calculate Dual Variable P */
+			DualP_2D(U, V1, V2, P1, P2, dimX, dimY, dimZ, sigma);
+
+			/*Projection onto convex set for P*/
+			ProjP_2D(P1, P2, dimX, dimY, dimZ, alpha1);
+
+			/* Calculate Dual Variable Q */
+			DualQ_2D(V1, V2, Q1, Q2, Q3, dimX, dimY, dimZ, sigma);
+
+			/*Projection onto convex set for Q*/
+			ProjQ_2D(Q1, Q2, Q3, dimX, dimY, dimZ, alpha0);
+
+			/*saving U into U_old*/
+			copyIm(U, U_old, dimX, dimY, dimZ);
+
+			/*adjoint operation  -> divergence and projection of P*/
+			DivProjP_2D(U, A, P1, P2, dimX, dimY, dimZ, lambda, tau);
+
+			/*get updated solution U*/
+			newU(U, U_old, dimX, dimY, dimZ);
+
+			/*saving V into V_old*/
+			copyIm(V1, V1_old, dimX, dimY, dimZ);
+			copyIm(V2, V2_old, dimX, dimY, dimZ);
+
+			/* upd V*/
+			UpdV_2D(V1, V2, P1, P2, Q1, Q2, Q3, dimX, dimY, dimZ, tau);
+
+			/*get new V*/
+			newU(V1, V1_old, dimX, dimY, dimZ);
+			newU(V2, V2_old, dimX, dimY, dimZ);
+		} /*end of iterations*/
+	
+		result.append<np::ndarray>(npU);
+	}
+	
+
+	
+	
 	return result;
 }
 
@@ -997,8 +1043,9 @@ BOOST_PYTHON_MODULE(regularizers)
 	np::dtype dt1 = np::dtype::get_builtin<uint8_t>();
 	np::dtype dt2 = np::dtype::get_builtin<uint16_t>();
 
-	def("mexFunction", mexFunction);
 	def("SplitBregman_TV", SplitBregman_TV);
 	def("FGP_TV", FGP_TV);
 	def("LLT_model", LLT_model);
+	def("PatchBased_Regul", PatchBased_Regul);
+	def("TGV_PD", TGV_PD);
 }
\ No newline at end of file
-- 
cgit v1.2.3