/ani/mrses

To get this branch, use:
bzr branch http://suren.me/webbzr/ani/mrses

« back to all changes in this revision

Viewing changes to cell/mrses_ppu.c

  • Committer: Suren A. Chilingaryan
  • Date: 2010-04-28 04:30:08 UTC
  • Revision ID: csa@dside.dyndns.org-20100428043008-vd9z0nso9axezvlp
Initial import

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#include <stdio.h>
 
2
#include <string.h>
 
3
#include <sys/time.h>
 
4
#include <assert.h>
 
5
 
 
6
#include <cblas.h>
 
7
#include <lapack.h>
 
8
 
 
9
#include "mrses_impl.h"
 
10
#include "mrses_ppu.h"
 
11
 
 
12
#ifdef USE_FAST_RANDOM
 
13
#define FASTRND(var, max, add) \
 
14
    g_seed = (214013*g_seed+2531011); \
 
15
    var = ((int)((max)*((g_seed>>16)&0x7FFF) / (0x7FFF + 1.0))) + add;
 
16
#endif
 
17
 
 
18
#undef rnd
 
19
//#define rnd(i) ((int)((i) * ((random_r(&data->seed, &rndval)?0:rndval) / (RAND_MAX + 1.0))))
 
20
#define rnd(i) ((int)((i) * (rand_r(&data->seed) / (RAND_MAX + 1.0))))
 
21
 
 
22
 
 
23
static inline int mrses_ppu_real_run(
 
24
    MRSESContext mrses, MRSESDataType *result,
 
25
    int width, int width2, int nA, int nB,
 
26
    MRSESDataType *A, MRSESDataType *B, MRSESDataType *mean,
 
27
    MRSESDataType *C, MRSESDataType *Ca, MRSESDataType *Cb
 
28
) {
 
29
    int i, err;
 
30
    
 
31
    char hmode = 'L';
 
32
    
 
33
    MRSESDataType detAB, detC;
 
34
    MRSESDataType rmahal, rcorr;
 
35
 
 
36
/*
 
37
        // We can save twice by not computing symmetric parts
 
38
    //blas_gemm(CblasRowMajor, CblasNoTrans, CblasTrans, width, width, nA, 1, A, nA, A, nA, 0, Ca, width);
 
39
    //blas_gemm(CblasRowMajor, CblasNoTrans, CblasTrans, width, width, nB, 1, B, nB, B, nB, 0, Cb, width);
 
40
    //printf("%f %f\n", Ca[4], Ca[20]);
 
41
 
 
42
    //blas_syrk(CblasRowMajor, CblasUpper, CblasTrans, nA, width, 1.0, A, nA, 0, Ca, width);
 
43
    //blas_syrk(CblasRowMajor, CblasUpper, CblasTrans, nB, width, 1.0, B, nB, 0, Cb, width);
 
44
*/
 
45
 
 
46
    blas_syrk(CblasRowMajor, CblasUpper, CblasNoTrans, width, nA, 1.0, A, nA, 0, Ca, width);
 
47
    blas_syrk(CblasRowMajor, CblasUpper, CblasNoTrans, width, nB, 1.0, B, nB, 0, Cb, width);
 
48
 
 
49
    memcpy(C, Ca, width2 * sizeof(MRSESDataType));
 
50
    blas_axpy(width2, 1, Cb, 1, C, 1);
 
51
    blas_scal(width2, 0.5, C, 1);
 
52
 
 
53
//    PRINT_MATRIX("%+6.4f ", Ca, 5, 5, 5)
 
54
//    PRINT_MATRIX("%+6.4f ", A, nA, 5, 16)
 
55
 
 
56
    lapack_potrf(&hmode, &width, C, &width, &err);
 
57
    if (err) return 1;
 
58
 
 
59
    lapack_potrf(&hmode, &width, Ca, &width, &err);
 
60
    if (err) return 1;
 
61
 
 
62
    lapack_potrf(&hmode, &width, Cb, &width, &err);
 
63
    if (err) return 1;
 
64
    
 
65
 
 
66
    detC = C[0]; detAB = Ca[0] * Cb[0];
 
67
    for (i = width + 1; i < width2; i+= (width+1)) {
 
68
        detAB *= (Ca[i] * Cb[i]);
 
69
        detC *= C[i];
 
70
    }
 
71
 
 
72
        /* we just computing sqrt(detX) actually */
 
73
    rcorr = 2 * logf(detC * detC / detAB);
 
74
 
 
75
/*
 
76
    blas_trsm(
 
77
        CblasRowMajor, CblasRight, CblasUpper, CblasNoTrans, CblasNonUnit,
 
78
        1, width, 1, C, width, mean, width //1?
 
79
    );
 
80
*/
 
81
 
 
82
    blas_trsv(
 
83
        CblasRowMajor, CblasUpper, CblasTrans, CblasNonUnit,
 
84
        width, C, width, mean, 1
 
85
    );
 
86
 
 
87
 
 
88
    rmahal = blas_dot(width, mean, 1, mean, 1);
 
89
 
 
90
    switch (mrses->dist) {
 
91
        case BHATTACHARYYA:
 
92
            *result = rmahal/8 + rcorr/4;
 
93
        break;
 
94
        case MAHALANOBIS:
 
95
            *result = rmahal;
 
96
        break;
 
97
        case CORCOR:
 
98
            *result = rcorr;
 
99
        break;
 
100
        default:
 
101
            *result = 0;
 
102
    }
 
103
 
 
104
/*    
 
105
    if (!block)
 
106
    printf("Det: %f = %f %f (%e)\n", rmahal/8 + rcorr/4, rmahal, rcorr, detC*detC);
 
107
*/
 
108
 
 
109
    return 0;
 
110
}
 
111
 
 
112
struct MRSESTemporaryDataT {
 
113
    MRSESDataType *A;
 
114
    MRSESDataType *B;
 
115
    MRSESDataType *C;
 
116
    MRSESDataType *Ca;
 
117
    MRSESDataType *Cb;
 
118
    MRSESDataType *mean;
 
119
    MRSESDataType *mean_copy;
 
120
 
 
121
#ifndef USE_FAST_RANDOM
 
122
//    struct random_data seed;
 
123
    unsigned int seed;
 
124
#endif /* USE_FAST_RANDOM */
 
125
};
 
126
typedef struct MRSESTemporaryDataT MRSESTemporaryDataS;
 
127
typedef struct MRSESTemporaryDataT *MRSESTemporaryData;
 
128
 
 
129
static inline MRSESTemporaryData mrses_ppu_malloc(HWThread thr, MRSESContext mrses) {
 
130
    unsigned char *alloc;
 
131
    MRSESTemporaryData data;
 
132
    int nAB;
 
133
    int properties = mrses->properties;
 
134
    int width, width2;
 
135
    int pos, size;
 
136
 
 
137
#ifndef USE_FAST_RANDOM
 
138
# ifndef FIX_RANDOM
 
139
    struct timeval tv;
 
140
# endif /* FIX_RANDOM */
 
141
#endif /* USE_FAST_RANDOM */
 
142
    
 
143
    pos = calc_alloc(properties * sizeof(uint32_t), HW_ALIGN);
 
144
    if (thr->data) return (MRSESTemporaryData)(thr->data + pos);
 
145
 
 
146
    width = mrses->width;    
 
147
    width2 = width * width;
 
148
    nAB = max(mrses->nA, mrses->nB);
 
149
 
 
150
    size = (
 
151
        calc_alloc(properties * sizeof(uint32_t), HW_ALIGN) +
 
152
        calc_alloc(sizeof(MRSESTemporaryDataS), HW_ALIGN) +
 
153
        2 * calc_alloc(width * nAB * sizeof(MRSESDataType), HW_ALIGN) + 
 
154
        3 * calc_alloc(width * width * sizeof(MRSESDataType), HW_ALIGN) +
 
155
        2 * calc_alloc(width * sizeof(MRSESDataType), HW_ALIGN)
 
156
    );
 
157
 
 
158
    posix_memalign((void*)&alloc, HW_ALIGN, size);
 
159
    if (!alloc) return NULL;
 
160
 
 
161
    memset(alloc, 0, properties * sizeof(uint32_t));
 
162
    
 
163
    data = (MRSESTemporaryData)(alloc + pos);
 
164
    pos += calc_alloc(sizeof(MRSESTemporaryDataS), HW_ALIGN);
 
165
 
 
166
    data->A = (MRSESDataType*)(alloc + pos);
 
167
    pos += calc_alloc(width * nAB * sizeof(MRSESDataType), HW_ALIGN);
 
168
 
 
169
    data->B = (MRSESDataType*)(alloc + pos);
 
170
    pos += calc_alloc(width * nAB * sizeof(MRSESDataType), HW_ALIGN);
 
171
 
 
172
    data->C = (MRSESDataType*)(alloc + pos);
 
173
    pos += calc_alloc(width * width * sizeof(MRSESDataType), HW_ALIGN);
 
174
 
 
175
    data->Ca = (MRSESDataType*)(alloc + pos);
 
176
    pos += calc_alloc(width * width * sizeof(MRSESDataType), HW_ALIGN);
 
177
 
 
178
    data->Cb = (MRSESDataType*)(alloc + pos);
 
179
    pos += calc_alloc(width * width * sizeof(MRSESDataType), HW_ALIGN);
 
180
 
 
181
    data->mean = (MRSESDataType*)(alloc + pos);
 
182
    pos += calc_alloc(width * sizeof(MRSESDataType), HW_ALIGN);
 
183
 
 
184
    data->mean_copy = (MRSESDataType*)(alloc + pos);
 
185
 
 
186
#ifndef USE_FAST_RANDOM
 
187
# ifdef FIX_RANDOM
 
188
//    srandom_r(FIX_RANDOM, &data->seed);
 
189
    data->seed = FIX_RANDOM;
 
190
# else /* FIX_RANDOM */
 
191
    gettimeofday(&tv, NULL);
 
192
    data->seed = tv.tv_usec;
 
193
//    memset(&data->seed, 0, sizeof(struct random_data));
 
194
//    srandom_r(tv.tv_usec, &data->seed);
 
195
# endif /* FIX_RANDOM */
 
196
#endif /* USE_FAST_RANDOM */
 
197
 
 
198
    thr->data = (void*)alloc;
 
199
 
 
200
    return data;
 
201
}
 
202
 
 
203
 
 
204
int mrses_ppu_run(HWThread thr, void *hwctx __attribute__ ((unused)), int block, MRSESContext mrses) {
 
205
    int err;
 
206
    int i, idx;
 
207
 
 
208
    int width = mrses->width;
 
209
    int width2 = width * width;
 
210
    int alloc = mrses->alloc;
 
211
    int nA = mrses->nA;
 
212
    int nB = mrses->nB;
 
213
    
 
214
    MRSESIntType *index = mrses->index + block * width;
 
215
    MRSESDataType *result = mrses->result + block;
 
216
    
 
217
    MRSESDataType *Afull = mrses->A;
 
218
    MRSESDataType *Bfull = mrses->B;
 
219
    MRSESDataType *Mfull = mrses->mean;
 
220
 
 
221
    MRSESTemporaryData data;
 
222
 
 
223
    MRSESDataType *A;
 
224
    MRSESDataType *B;
 
225
    MRSESDataType *C;
 
226
    MRSESDataType *Ca;
 
227
    MRSESDataType *Cb;
 
228
    MRSESDataType *mean;
 
229
 
 
230
    data = mrses_ppu_malloc(thr, mrses);
 
231
    if (!data) return 1;
 
232
 
 
233
    A = data->A;
 
234
    B = data->B;
 
235
    Ca = data->Ca;
 
236
    Cb = data->Cb;
 
237
    C = data->C;
 
238
    mean = data->mean;
 
239
 
 
240
    for (i = 0; i < width; i++) {
 
241
        idx = index[i] - 1;
 
242
            
 
243
        memcpy(A + i * nA, Afull + idx * alloc, nA * sizeof(MRSESDataType));
 
244
        memcpy(B + i * nB, Bfull + idx * alloc, nB * sizeof(MRSESDataType));
 
245
        mean[i] = Mfull[idx];
 
246
    }
 
247
    
 
248
    err = mrses_ppu_real_run(
 
249
        mrses, result,
 
250
        width, width2, nA, nB,
 
251
        A, B, mean, C, Ca, Cb
 
252
    );
 
253
 
 
254
    if (err) {
 
255
        MRSESDataType *result = mrses->result + block;
 
256
        *result = 0;    /* Just do octave computation instead */
 
257
    }
 
258
    
 
259
    return 0;
 
260
}
 
261
 
 
262
 
 
263
 
 
264
int mrses_ppu_iterate(HWThread thr, void *hwctx __attribute__ ((unused)), int block_group, MRSESContext mrses) {
 
265
    int err;
 
266
    int i, idx;
 
267
 
 
268
    int iterate_size = mrses->iterate_size;
 
269
    int block = block_group * iterate_size;
 
270
    int block_end = block + iterate_size;
 
271
    
 
272
    int width = mrses->width;
 
273
    int width2 = width * width;
 
274
    int alloc = mrses->alloc;
 
275
    int nA = mrses->nA;
 
276
    int nB = mrses->nB;
 
277
    int properties = mrses->properties;
 
278
    int iterations = mrses->iterations;
 
279
    
 
280
    unsigned int *hist;
 
281
 
 
282
    MRSESIntType *index;
 
283
    MRSESIntType *ires = mrses->ires; 
 
284
    MRSESIntType drp_gen, rpl_gen, rst_gen;
 
285
    MRSESDataType result = 0, cur_result = 0;
 
286
    
 
287
    MRSESDataType *Afull = mrses->A;
 
288
    MRSESDataType *Bfull = mrses->B;
 
289
    MRSESDataType *Mfull = mrses->mean;
 
290
 
 
291
    MRSESTemporaryData data;
 
292
 
 
293
    MRSESDataType *A;
 
294
    MRSESDataType *B;
 
295
    MRSESDataType *C;
 
296
    MRSESDataType *Ca;
 
297
    MRSESDataType *Cb;
 
298
    MRSESDataType *mean;
 
299
    MRSESDataType *mean_copy;
 
300
 
 
301
#ifdef TRACE_TIMINGS
 
302
    struct timeval tv1, tv2;
 
303
#endif /* TRACE_TIMINGS */
 
304
 
 
305
#ifdef USE_FAST_RANDOM
 
306
    unsigned int g_seed;
 
307
# ifndef TRACE_TIMINGS
 
308
    struct timeval tvseed;
 
309
# endif /* TRACE_TIMINGS */
 
310
#else /* USE_FAST_RANDOM */
 
311
//    int32_t rndval;
 
312
#endif /* USE_FAST_RANDOM */
 
313
 
 
314
#ifdef TRACE_TIMINGS
 
315
    gettimeofday(&tv1, NULL);
 
316
#endif /* TRACE_TIMINGS */
 
317
 
 
318
#ifdef USE_FAST_RANDOM
 
319
# ifdef TRACE_TIMINGS
 
320
    g_seed = tv1.tv_usec;
 
321
# else
 
322
    gettimeofday(&tvseed, NULL);
 
323
    g_seed = tvseed.tv_usec;
 
324
# endif /* TRACE_TIMINGS */
 
325
#endif /* USE_FAST_RANDOM */
 
326
 
 
327
 
 
328
    data = mrses_ppu_malloc(thr, mrses);
 
329
    if (!data) return 1;
 
330
 
 
331
    A = data->A;
 
332
    B = data->B;
 
333
    Ca = data->Ca;
 
334
    Cb = data->Cb;
 
335
    C = data->C;
 
336
    mean = data->mean;
 
337
    mean_copy = data->mean_copy;
 
338
    
 
339
    hist = (uint32_t*)thr->data;
 
340
 
 
341
 
 
342
    for (;block < block_end; ++block) { 
 
343
        index = mrses->index + block * properties;
 
344
        
 
345
        
 
346
        for (i = 0; i < width; ++i) {
 
347
            idx = index[i];
 
348
 
 
349
            memcpy(A + i * nA, Afull + idx * alloc, nA * sizeof(MRSESDataType));
 
350
            memcpy(B + i * nB, Bfull + idx * alloc, nB * sizeof(MRSESDataType));
 
351
            mean[i] = Mfull[idx];
 
352
        }
 
353
    
 
354
        memcpy(mean_copy, mean, width * sizeof(MRSESDataType));
 
355
 
 
356
        err = mrses_ppu_real_run(
 
357
            mrses, &cur_result,
 
358
            width, width2, nA, nB,
 
359
            A, B, mean_copy, C, Ca, Cb
 
360
        );
 
361
 
 
362
        rst_gen = width;
 
363
        for (i = 0; i < iterations; i++) {
 
364
#ifdef USE_FAST_RANDOM
 
365
            FASTRND(drp_gen, width, 0);
 
366
            FASTRND(rpl_gen, properties - width, width);
 
367
#else /* USE_FAST_RANDOM */
 
368
            drp_gen = rnd(width);
 
369
            rpl_gen = rnd(properties - width) + width;
 
370
#endif /* USE_FAST_RANDOM */
 
371
        
 
372
            if ((rst_gen < width)&&(rst_gen != drp_gen)) {
 
373
                idx = index[rst_gen];
 
374
                memcpy(A + rst_gen * nA, Afull + idx * alloc, nA * sizeof(MRSESDataType));
 
375
                memcpy(B + rst_gen * nB, Bfull + idx * alloc, nB * sizeof(MRSESDataType));
 
376
                rst_gen = width;
 
377
            }   
 
378
 
 
379
            idx = index[rpl_gen];
 
380
            memcpy(A + drp_gen * nA, Afull + idx * alloc, nA * sizeof(MRSESDataType));
 
381
            memcpy(B + drp_gen * nB, Bfull + idx * alloc, nB * sizeof(MRSESDataType));
 
382
        
 
383
            memcpy(mean_copy, mean, width * sizeof(MRSESDataType));
 
384
            mean_copy[drp_gen] = Mfull[idx];
 
385
        
 
386
            err = mrses_ppu_real_run(
 
387
                mrses, &result,
 
388
                width, width2, nA, nB,
 
389
                A, B, mean_copy, C, Ca, Cb
 
390
            );
 
391
        
 
392
            if (result < cur_result) {
 
393
                rst_gen = drp_gen;
 
394
            } else {
 
395
//          printf("%i=%i %i=%i %f %f\n", drp_gen, index[drp_gen], rpl_gen, index[rpl_gen], cur_result, result);
 
396
                cur_result = result;
 
397
                SWAPij(index, drp_gen, rpl_gen);
 
398
 
 
399
                mean[drp_gen] = Mfull[idx];
 
400
            }
 
401
        }
 
402
        for (i = 0; i < width; i++) {
 
403
            hist[index[i]]++;
 
404
//      printf("%5i ", index[i] + 1);
 
405
        }
 
406
        if (ires) {
 
407
            memcpy(ires + width * block, index, width * sizeof(MRSESIntType));
 
408
        } 
 
409
        
 
410
//    printf("\n\n\n");
 
411
    }
 
412
    
 
413
#ifdef TRACE_TIMINGS
 
414
    gettimeofday(&tv2, NULL);
 
415
    printf("Thead %p: %li ms\n", thr, (tv2.tv_sec - tv1.tv_sec)*1000+(tv2.tv_usec - tv1.tv_usec)/1000);
 
416
#endif /* TRACE_TIMINGS */
 
417
 
 
418
    return 0;
 
419
}