/normxcorr/trunk

To get this branch, use:
bzr branch http://suren.me/webbzr/normxcorr/trunk

« back to all changes in this revision

Viewing changes to dict_hw/matlab/normxcorr_hw.cu

  • Committer: Suren A. Chilingaryan
  • Date: 2009-12-12 01:38:41 UTC
  • Revision ID: csa@dside.dyndns.org-20091212013841-feih3qa4i28x75j4
Provide stand-alone library

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#include <stdio.h>
 
2
#include <stdlib.h>
 
3
#include <sys/time.h>
 
4
 
 
5
#include <mex.h>
 
6
#include <dict_hw.h>
 
7
 
 
8
#include "normxcorr_hw_msg.h"
 
9
 
 
10
 
 
11
#define USE_UNDOCUMENTED
 
12
//#define VALIDATE_LSUM
 
13
//#define VALIDATE_PEAK
 
14
 
 
15
typedef enum {
 
16
    ACTION_SETUP = 1,
 
17
#ifdef VALIDATE_LSUM
 
18
    ACTION_COMPUTE_BASE_FRAGMENT = 2,
 
19
#endif /* VALIDAT_LSUM */
 
20
    ACTION_SET_BASE_POINTS = 3,
 
21
    ACTION_COMPUTE_BASE = 4,
 
22
#ifdef VALIDATE_PEAK
 
23
    ACTION_COMPUTE_FRAGMENT = 11,
 
24
    ACTION_GET_CORRECTIONS = 15,
 
25
#endif /* VALIDATE_PEAK */
 
26
    ACTION_SET_POINTS = 12,
 
27
    ACTION_COMPUTE = 13,
 
28
    ACTION_GET_POINTS = 14,
 
29
} TAction;
 
30
 
 
31
 
 
32
static DICTContext pstate = NULL;
 
33
static mxArray *coords = NULL;                  // Matlab array with current coordinates
 
34
 
 
35
#ifndef EXTERN_C
 
36
# ifdef __cplusplus
 
37
   #define EXTERN_C extern "C"
 
38
# else
 
39
   #define EXTERN_C extern
 
40
# endif
 
41
#endif
 
42
 
 
43
#ifdef USE_UNDOCUMENTED
 
44
EXTERN_C mxArray *mxCreateSharedDataCopy(const mxArray *pr);
 
45
#endif /* USE_UNDOCUMENTED */
 
46
 
 
47
 
 
48
static void selfClean() {
 
49
    if (pstate) {
 
50
        reportMessage("cleaning normxcorr_hw instance");
 
51
 
 
52
        dictDestroyContext(pstate);
 
53
        pstate = NULL;
 
54
    }
 
55
    
 
56
    if (coords) {
 
57
        mxDestroyArray(coords);
 
58
        coords = NULL;
 
59
    }
 
60
}
 
61
 
 
62
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
 
63
    int err;
 
64
 
 
65
    mxArray *idMatrix;
 
66
    int32_t *idPtr;
 
67
    int64_t *errPtr;
 
68
 
 
69
    static int32_t id = 0;
 
70
    static int ncp;
 
71
    static int width, height;
 
72
    static int corr_size;
 
73
    
 
74
    float *points;
 
75
    
 
76
 
 
77
    DICTContext ps;
 
78
    TAction action;
 
79
 
 
80
    int iprop;
 
81
    
 
82
    const mxArray *input;
 
83
    const mxArray *base;
 
84
 
 
85
#ifdef VALIDATE_LSUM
 
86
    const mxArray *lsum;
 
87
    const mxArray *denom;
 
88
#endif /* VALIDATE_LSUM */
 
89
 
 
90
    const mxArray *x, *y;
 
91
 
 
92
    if (!nrhs) {
 
93
        reportMessage("Initializing normxcorr_hw instance");
 
94
 
 
95
        if (nlhs != 1) {
 
96
            reportError("You should accept a single result from initialization call");
 
97
            return;
 
98
        }
 
99
        
 
100
        idMatrix = mxCreateNumericMatrix(1, 1, mxINT32_CLASS, mxREAL);
 
101
        if (!idMatrix) {
 
102
            reportError("Initialization is failed");
 
103
            return;
 
104
        }
 
105
        
 
106
        dictSetLogger(reportError, reportMessage);
 
107
 
 
108
        if (pstate) {
 
109
            dictDestroyContext(pstate);
 
110
            pstate = dictCreateContext();
 
111
        } else {
 
112
            id = dictDetectHardware();
 
113
            if (id > 0) {
 
114
                pstate = dictCreateContext();
 
115
            }
 
116
        }
 
117
        
 
118
        if (pstate) {
 
119
            mexAtExit(selfClean);
 
120
        } else if (id > 0) {
 
121
            reportError("Context initialization has failed");
 
122
            id = -1;
 
123
        }
 
124
 
 
125
        idPtr = (int32_t*)mxGetData(idMatrix);
 
126
        idPtr[0] = id;
 
127
        
 
128
        plhs[0] = idMatrix;
 
129
        return;
 
130
    } else {
 
131
        if (!pstate) {
 
132
            reportError("normxcorr_hw should be initialized first");
 
133
            return;
 
134
        }
 
135
    }
 
136
 
 
137
        // Clean request
 
138
    if (nrhs == 1) {
 
139
        selfClean();
 
140
        return;
 
141
    }
 
142
 
 
143
    ps = pstate;
 
144
 
 
145
    action = (TAction)int(mxGetScalar((mxArray*)prhs[1]));
 
146
 
 
147
//    reportMessage("Executing normxcorr_hw action: %u", action);
 
148
 
 
149
    switch (action) {
 
150
     case ACTION_COMPUTE:
 
151
        if (nrhs != 3) {
 
152
            reportError("Compute action expects 1 argument, but %i is passed", nrhs - 2);
 
153
            return;
 
154
        }
 
155
 
 
156
        input = prhs[2];
 
157
 
 
158
        if (mxGetNumberOfDimensions(input) != 2) {
 
159
            reportError("Invalid dimensionality of base matrix, 2D matrix is expected");
 
160
            return;
 
161
        }
 
162
        
 
163
        if (mxGetClassID(input) != mxUINT8_CLASS) {
 
164
            reportError("Invalid type of image data, should be 8bit integers");
 
165
            return;
 
166
        }
 
167
        
 
168
        if ((mxGetN(input) != width)||(mxGetM(input) != height)) {
 
169
            reportError("Invalid size of image (%ix%i), but (%ix%i) is expected", mxGetN(input), mxGetM(input), width, height);
 
170
            return;
 
171
        }
 
172
 
 
173
        dictLoadImage(ps, (unsigned char*)mxGetData(input));
 
174
         break;
 
175
     case ACTION_COMPUTE_BASE:
 
176
        if (nrhs != 3) {
 
177
            reportError("ComputeBase action expects 1 argument, but %i is passed", nrhs - 2);
 
178
            return;
 
179
        }
 
180
 
 
181
        base = prhs[2];
 
182
    
 
183
        if (mxGetNumberOfDimensions(base) != 2) {
 
184
            reportError("Invalid dimensionality of base matrix, 2D matrix is expected");
 
185
            return;
 
186
        }
 
187
 
 
188
        if (mxGetClassID(base) != mxUINT8_CLASS) {
 
189
            reportError("Invalid matrix. The data type (%s) is not supported", mxGetClassName(base));
 
190
            return;
 
191
        }
 
192
 
 
193
        width = mxGetN(base);
 
194
        height = mxGetM(base);
 
195
        
 
196
        dictLoadTemplateImage(ps, (unsigned char*)mxGetData(base), width, height);
 
197
     break;
 
198
     case ACTION_SET_BASE_POINTS:
 
199
        if (nrhs != 4) {
 
200
            reportError("SET_POINTS action expects two arrays with 'x' and 'y' coordinates of control points");
 
201
            return;
 
202
        }
 
203
 
 
204
        x = prhs[2];
 
205
        y = prhs[3];
 
206
        
 
207
        if (    (mxGetClassID(x) != mxSINGLE_CLASS)||
 
208
                (mxGetClassID(y) != mxSINGLE_CLASS)||
 
209
                (mxGetN(x)*mxGetM(x) != ncp)||
 
210
                (mxGetN(y)*mxGetM(y) != ncp)
 
211
        ) {
 
212
            reportError("Invalid control points are specified");
 
213
            return;
 
214
        }
 
215
        
 
216
        dictSetTemplatePoints(ps,  (float*)mxGetData(x),  (float*)mxGetData(y));
 
217
     break;
 
218
     case ACTION_SET_POINTS:
 
219
        if (nrhs != 4) {
 
220
            reportError("SET_POINTS action expects two arrays with 'x' and 'y' coordinates of control points");
 
221
            return;
 
222
        }
 
223
 
 
224
        x = prhs[2];
 
225
        y = prhs[3];
 
226
        
 
227
        if (    (mxGetClassID(x) != mxSINGLE_CLASS)||
 
228
                (mxGetClassID(y) != mxSINGLE_CLASS)||
 
229
                (mxGetN(x)*mxGetM(x) != ncp)||
 
230
                (mxGetN(y)*mxGetM(y) != ncp)
 
231
        ) {
 
232
            reportError("Invalid control points are specified");
 
233
            return;
 
234
        }
 
235
        
 
236
        dictSetCurrentPoints(ps,  (float*)mxGetData(x),  (float*)mxGetData(y));
 
237
     break;
 
238
     case ACTION_GET_POINTS:
 
239
        if (nrhs != 2) {
 
240
            reportError("GetPoints action do not expect any arguments");
 
241
            return;
 
242
        }
 
243
        if (nlhs != 1) {
 
244
            reportError("GetPoints action returns a single matrix");
 
245
            return;
 
246
        }
 
247
 
 
248
        if (!coords) {
 
249
            reportError("normxcorr is not properly initialized, the result matrix is not allocated");
 
250
            return;
 
251
        }
 
252
        
 
253
        dictCompute(ps);
 
254
        
 
255
#ifdef USE_UNDOCUMENTED
 
256
        plhs[0] = mxCreateSharedDataCopy(coords);
 
257
//    mxArray *mxCreateSharedDataCopy(const mxArray *pr);
 
258
//    bool mxUnshareArray(const mxArray *pr, const bool noDeepCopy);    // true if not successful
 
259
//    mxArray *mxUnreference(const mxArray *pr);
 
260
#else /* USE_UNDOCUMENTED */
 
261
        plhs[0] = mxDuplicateArray(coords);
 
262
#endif /* USE_UNDOCUMENTED */
 
263
     break;     
 
264
     case ACTION_SETUP:
 
265
        if (nrhs != 6) {
 
266
            reportError("SETUP action expects 'ncp', 'corrsize', 'precision', and 'optimization level' parameters");
 
267
            return;
 
268
        }
 
269
 
 
270
        ncp = (int)mxGetScalar(prhs[2]);
 
271
        corr_size = (int)mxGetScalar(prhs[3]);
 
272
        iprop = (int)mxGetScalar(prhs[4]);
 
273
        
 
274
        err = dictSetup(ps, ncp, corr_size, (int)mxGetScalar(prhs[4]), (iprop>3)?DICT_FLAGS_DEFAULT:DICT_FLAGS_FIXED_FFT_SIZE);
 
275
 
 
276
        if (!err) {
 
277
            if (coords) mxDestroyArray(coords);
 
278
            coords = mxCreateNumericMatrix(ncp, 2, mxSINGLE_CLASS, mxREAL);
 
279
            if (coords) mexMakeArrayPersistent(coords);
 
280
            else {
 
281
                reportError("Allocation of result matrix of size %u*float bytes is failed", ncp);
 
282
                err = DICT_ERROR_MALLOC;
 
283
            }
 
284
        }
 
285
        
 
286
        if (!err) {
 
287
            points = (float*)mxGetData(coords);
 
288
            dictSetPointsBuffer(ps, points, points + ncp);
 
289
        }
 
290
        
 
291
        //mexMakeMemoryPersistent(ps->coords);
 
292
        //mexLock() mexUnlock()
 
293
 
 
294
        if (nlhs == 1) {
 
295
            idMatrix = mxCreateNumericMatrix(1, 1, mxINT64_CLASS, mxREAL);
 
296
            if (idMatrix) {
 
297
                errPtr = (int64_t*)mxGetData(idMatrix);
 
298
                errPtr[0] = err;
 
299
                plhs[0] = idMatrix;
 
300
            } else {
 
301
                reportError("Initialization of result matrix is failed");
 
302
                return;
 
303
            }
 
304
        }
 
305
     break;
 
306
#ifdef VALIDATE_PEAK
 
307
     case ACTION_COMPUTE_FRAGMENT:
 
308
        if (nlhs > 0) {
 
309
            icp = (unsigned int)mxGetScalar(prhs[2]) - 1;
 
310
            idMatrix = mxCreateNumericMatrix(size, size, mxSINGLE_CLASS, mxREAL);
 
311
 
 
312
            dictProcessFragment(ps, icp, 1, prhs[3]);
 
313
            dictGetCorrelations(ps, icp, (float*)mxGetPr(idMatrix));
 
314
        
 
315
            plhs[0] = idMatrix;
 
316
        }
 
317
     break;
 
318
     case ACTION_GET_CORRECTIONS:
 
319
        if (nlhs > 0) {
 
320
            idMatrix = mxCreateNumericMatrix(ncp, 2, mxSINGLE_CLASS, mxREAL);
 
321
            float *points = (float*)mxGetData(idMatrix); 
 
322
        
 
323
            dictGetCorrections(ps, points, points + ncp);
 
324
        
 
325
            plhs[0] = idMatrix
 
326
        }
 
327
     break;
 
328
#endif /* VALIDATE_PEAK */
 
329
#ifdef VALIDATE_LSUM
 
330
     case ACTION_COMPUTE_BASE_FRAGMENT:
 
331
        if (nrhs != 4) {
 
332
            reportError("ComputeBaseFragment action expects 2 arguments, but %i is passed", nrhs - 2);
 
333
            return;
 
334
        }
 
335
 
 
336
        icp = (unsigned int)mxGetScalar(prhs[2]) - 1;
 
337
        if (icp >= ps->ncp) {
 
338
            reportError("The control point (%i) is out of range (0-%u)", icp, ps->ncp - 1);
 
339
            return;
 
340
        }
 
341
 
 
342
        base = prhs[3];
 
343
    
 
344
        if (mxGetNumberOfDimensions(base) != 2) {
 
345
            reportError("Invalid dimensionality of base matrix, 2D matrix is expected");
 
346
            return;
 
347
        }
 
348
 
 
349
        if (mxGetClassID(base) != mxUINT8_CLASS) {
 
350
            reportError("Invalid matrix. The data type (%s) is not supported", mxGetClassName(base));
 
351
            return;
 
352
        }
 
353
 
 
354
        fft_size = 6 * corr_size + 1;
 
355
        if (nlhs > 0) lsum = mxCreateNumericMatrix(fft_size, fft_size, mxSINGLE_CLASS, mxREAL);
 
356
        else lsum = NULL;
 
357
        if (nlhs > 1) denom = mxCreateNumericMatrix(fft_size, fft_size, mxSINGLE_CLASS, mxREAL);
 
358
        else denom = NULL;
 
359
        
 
360
        dictSetDimensions(ps, GetN(base), GetM(base));
 
361
        dictLoadTemplateFragment(ps, icp, 1, base);
 
362
        dictGetLocalSum(ps, icp, lsum, denom);
 
363
 
 
364
        if (nlhs > 0) {
 
365
            plhs[0] = lsum;
 
366
            if (nlhs > 1) plhs[1] = denom;
 
367
        }
 
368
     break;
 
369
#endif /* VALIDATE_LSUM */
 
370
 
 
371
     default:
 
372
        reportError("Unknown request %i", action);
 
373
    }
 
374
}