/ani/mrses

To get this branch, use:
bzr branch http://suren.me/webbzr/ani/mrses
1 by Suren A. Chilingaryan
Initial import
1
#include <stdio.h>
2
#include <string.h>
3
#include <sys/time.h>
4
#include <assert.h>
5
6
#include <cblas.h>
7
#include <lapack.h>
8
9
#include "mrses_impl.h"
10
#include "mrses_ppu.h"
11
12
#ifdef USE_FAST_RANDOM
13
#define FASTRND(var, max, add) \
14
    g_seed = (214013*g_seed+2531011); \
15
    var = ((int)((max)*((g_seed>>16)&0x7FFF) / (0x7FFF + 1.0))) + add;
16
#endif
17
18
#undef rnd
19
//#define rnd(i) ((int)((i) * ((random_r(&data->seed, &rndval)?0:rndval) / (RAND_MAX + 1.0))))
20
#define rnd(i) ((int)((i) * (rand_r(&data->seed) / (RAND_MAX + 1.0))))
21
22
23
static inline int mrses_ppu_real_run(
24
    MRSESContext mrses, MRSESDataType *result,
25
    int width, int width2, int nA, int nB,
26
    MRSESDataType *A, MRSESDataType *B, MRSESDataType *mean,
27
    MRSESDataType *C, MRSESDataType *Ca, MRSESDataType *Cb
28
) {
29
    int i, err;
30
    
31
    char hmode = 'L';
32
    
33
    MRSESDataType detAB, detC;
34
    MRSESDataType rmahal, rcorr;
35
36
/*
37
	// We can save twice by not computing symmetric parts
38
    //blas_gemm(CblasRowMajor, CblasNoTrans, CblasTrans, width, width, nA, 1, A, nA, A, nA, 0, Ca, width);
39
    //blas_gemm(CblasRowMajor, CblasNoTrans, CblasTrans, width, width, nB, 1, B, nB, B, nB, 0, Cb, width);
40
    //printf("%f %f\n", Ca[4], Ca[20]);
41
42
    //blas_syrk(CblasRowMajor, CblasUpper, CblasTrans, nA, width, 1.0, A, nA, 0, Ca, width);
43
    //blas_syrk(CblasRowMajor, CblasUpper, CblasTrans, nB, width, 1.0, B, nB, 0, Cb, width);
44
*/
45
46
    blas_syrk(CblasRowMajor, CblasUpper, CblasNoTrans, width, nA, 1.0, A, nA, 0, Ca, width);
47
    blas_syrk(CblasRowMajor, CblasUpper, CblasNoTrans, width, nB, 1.0, B, nB, 0, Cb, width);
48
49
    memcpy(C, Ca, width2 * sizeof(MRSESDataType));
50
    blas_axpy(width2, 1, Cb, 1, C, 1);
51
    blas_scal(width2, 0.5, C, 1);
52
53
//    PRINT_MATRIX("%+6.4f ", Ca, 5, 5, 5)
54
//    PRINT_MATRIX("%+6.4f ", A, nA, 5, 16)
55
56
    lapack_potrf(&hmode, &width, C, &width, &err);
57
    if (err) return 1;
58
59
    lapack_potrf(&hmode, &width, Ca, &width, &err);
60
    if (err) return 1;
61
62
    lapack_potrf(&hmode, &width, Cb, &width, &err);
63
    if (err) return 1;
64
    
65
66
    detC = C[0]; detAB = Ca[0] * Cb[0];
67
    for (i = width + 1; i < width2; i+= (width+1)) {
68
	detAB *= (Ca[i] * Cb[i]);
69
	detC *= C[i];
70
    }
71
72
	/* we just computing sqrt(detX) actually */
73
    rcorr = 2 * logf(detC * detC / detAB);
74
75
/*
76
    blas_trsm(
77
	CblasRowMajor, CblasRight, CblasUpper, CblasNoTrans, CblasNonUnit,
78
	1, width, 1, C, width, mean, width //1?
79
    );
80
*/
81
82
    blas_trsv(
83
	CblasRowMajor, CblasUpper, CblasTrans, CblasNonUnit,
84
	width, C, width, mean, 1
85
    );
86
87
88
    rmahal = blas_dot(width, mean, 1, mean, 1);
89
90
    switch (mrses->dist) {
91
	case BHATTACHARYYA:
92
	    *result = rmahal/8 + rcorr/4;
93
	break;
94
	case MAHALANOBIS:
95
	    *result = rmahal;
96
	break;
97
	case CORCOR:
98
	    *result = rcorr;
99
	break;
100
	default:
101
	    *result = 0;
102
    }
103
104
/*    
105
    if (!block)
106
    printf("Det: %f = %f %f (%e)\n", rmahal/8 + rcorr/4, rmahal, rcorr, detC*detC);
107
*/
108
109
    return 0;
110
}
111
112
struct MRSESTemporaryDataT {
113
    MRSESDataType *A;
114
    MRSESDataType *B;
115
    MRSESDataType *C;
116
    MRSESDataType *Ca;
117
    MRSESDataType *Cb;
118
    MRSESDataType *mean;
119
    MRSESDataType *mean_copy;
120
121
#ifndef USE_FAST_RANDOM
122
//    struct random_data seed;
123
    unsigned int seed;
124
#endif /* USE_FAST_RANDOM */
125
};
126
typedef struct MRSESTemporaryDataT MRSESTemporaryDataS;
127
typedef struct MRSESTemporaryDataT *MRSESTemporaryData;
128
129
static inline MRSESTemporaryData mrses_ppu_malloc(HWThread thr, MRSESContext mrses) {
130
    unsigned char *alloc;
131
    MRSESTemporaryData data;
132
    int nAB;
133
    int properties = mrses->properties;
134
    int width, width2;
135
    int pos, size;
136
137
#ifndef USE_FAST_RANDOM
138
# ifndef FIX_RANDOM
139
    struct timeval tv;
140
# endif /* FIX_RANDOM */
141
#endif /* USE_FAST_RANDOM */
142
    
143
    pos = calc_alloc(properties * sizeof(uint32_t), HW_ALIGN);
144
    if (thr->data) return (MRSESTemporaryData)(thr->data + pos);
145
146
    width = mrses->width;    
147
    width2 = width * width;
148
    nAB = max(mrses->nA, mrses->nB);
149
150
    size = (
151
	calc_alloc(properties * sizeof(uint32_t), HW_ALIGN) +
152
	calc_alloc(sizeof(MRSESTemporaryDataS), HW_ALIGN) +
153
	2 * calc_alloc(width * nAB * sizeof(MRSESDataType), HW_ALIGN) + 
154
	3 * calc_alloc(width * width * sizeof(MRSESDataType), HW_ALIGN) +
155
	2 * calc_alloc(width * sizeof(MRSESDataType), HW_ALIGN)
156
    );
157
158
    posix_memalign((void*)&alloc, HW_ALIGN, size);
159
    if (!alloc) return NULL;
160
161
    memset(alloc, 0, properties * sizeof(uint32_t));
162
    
163
    data = (MRSESTemporaryData)(alloc + pos);
164
    pos += calc_alloc(sizeof(MRSESTemporaryDataS), HW_ALIGN);
165
166
    data->A = (MRSESDataType*)(alloc + pos);
167
    pos += calc_alloc(width * nAB * sizeof(MRSESDataType), HW_ALIGN);
168
169
    data->B = (MRSESDataType*)(alloc + pos);
170
    pos += calc_alloc(width * nAB * sizeof(MRSESDataType), HW_ALIGN);
171
172
    data->C = (MRSESDataType*)(alloc + pos);
173
    pos += calc_alloc(width * width * sizeof(MRSESDataType), HW_ALIGN);
174
175
    data->Ca = (MRSESDataType*)(alloc + pos);
176
    pos += calc_alloc(width * width * sizeof(MRSESDataType), HW_ALIGN);
177
178
    data->Cb = (MRSESDataType*)(alloc + pos);
179
    pos += calc_alloc(width * width * sizeof(MRSESDataType), HW_ALIGN);
180
181
    data->mean = (MRSESDataType*)(alloc + pos);
182
    pos += calc_alloc(width * sizeof(MRSESDataType), HW_ALIGN);
183
184
    data->mean_copy = (MRSESDataType*)(alloc + pos);
185
186
#ifndef USE_FAST_RANDOM
187
# ifdef FIX_RANDOM
188
//    srandom_r(FIX_RANDOM, &data->seed);
189
    data->seed = FIX_RANDOM;
190
# else /* FIX_RANDOM */
191
    gettimeofday(&tv, NULL);
192
    data->seed = tv.tv_usec;
193
//    memset(&data->seed, 0, sizeof(struct random_data));
194
//    srandom_r(tv.tv_usec, &data->seed);
195
# endif /* FIX_RANDOM */
196
#endif /* USE_FAST_RANDOM */
197
198
    thr->data = (void*)alloc;
199
200
    return data;
201
}
202
203
204
int mrses_ppu_run(HWThread thr, void *hwctx __attribute__ ((unused)), int block, MRSESContext mrses) {
205
    int err;
206
    int i, idx;
207
208
    int width = mrses->width;
209
    int width2 = width * width;
210
    int alloc = mrses->alloc;
211
    int nA = mrses->nA;
212
    int nB = mrses->nB;
213
    
214
    MRSESIntType *index = mrses->index + block * width;
215
    MRSESDataType *result = mrses->result + block;
216
    
217
    MRSESDataType *Afull = mrses->A;
218
    MRSESDataType *Bfull = mrses->B;
219
    MRSESDataType *Mfull = mrses->mean;
220
221
    MRSESTemporaryData data;
222
223
    MRSESDataType *A;
224
    MRSESDataType *B;
225
    MRSESDataType *C;
226
    MRSESDataType *Ca;
227
    MRSESDataType *Cb;
228
    MRSESDataType *mean;
229
230
    data = mrses_ppu_malloc(thr, mrses);
231
    if (!data) return 1;
232
233
    A = data->A;
234
    B = data->B;
235
    Ca = data->Ca;
236
    Cb = data->Cb;
237
    C = data->C;
238
    mean = data->mean;
239
240
    for (i = 0; i < width; i++) {
241
	idx = index[i] - 1;
242
	    
243
	memcpy(A + i * nA, Afull + idx * alloc, nA * sizeof(MRSESDataType));
244
	memcpy(B + i * nB, Bfull + idx * alloc, nB * sizeof(MRSESDataType));
245
	mean[i] = Mfull[idx];
246
    }
247
    
248
    err = mrses_ppu_real_run(
249
	mrses, result,
250
	width, width2, nA, nB,
251
	A, B, mean, C, Ca, Cb
252
    );
253
254
    if (err) {
255
	MRSESDataType *result = mrses->result + block;
256
	*result = 0;	/* Just do octave computation instead */
257
    }
258
    
259
    return 0;
260
}
261
262
263
264
int mrses_ppu_iterate(HWThread thr, void *hwctx __attribute__ ((unused)), int block_group, MRSESContext mrses) {
265
    int err;
266
    int i, idx;
267
268
    int iterate_size = mrses->iterate_size;
269
    int block = block_group * iterate_size;
270
    int block_end = block + iterate_size;
271
    
272
    int width = mrses->width;
273
    int width2 = width * width;
274
    int alloc = mrses->alloc;
275
    int nA = mrses->nA;
276
    int nB = mrses->nB;
277
    int properties = mrses->properties;
278
    int iterations = mrses->iterations;
279
    
280
    unsigned int *hist;
281
282
    MRSESIntType *index;
283
    MRSESIntType *ires = mrses->ires; 
284
    MRSESIntType drp_gen, rpl_gen, rst_gen;
285
    MRSESDataType result = 0, cur_result = 0;
286
    
287
    MRSESDataType *Afull = mrses->A;
288
    MRSESDataType *Bfull = mrses->B;
289
    MRSESDataType *Mfull = mrses->mean;
290
291
    MRSESTemporaryData data;
292
293
    MRSESDataType *A;
294
    MRSESDataType *B;
295
    MRSESDataType *C;
296
    MRSESDataType *Ca;
297
    MRSESDataType *Cb;
298
    MRSESDataType *mean;
299
    MRSESDataType *mean_copy;
300
301
#ifdef TRACE_TIMINGS
302
    struct timeval tv1, tv2;
303
#endif /* TRACE_TIMINGS */
304
305
#ifdef USE_FAST_RANDOM
306
    unsigned int g_seed;
307
# ifndef TRACE_TIMINGS
308
    struct timeval tvseed;
309
# endif /* TRACE_TIMINGS */
310
#else /* USE_FAST_RANDOM */
311
//    int32_t rndval;
312
#endif /* USE_FAST_RANDOM */
313
314
#ifdef TRACE_TIMINGS
315
    gettimeofday(&tv1, NULL);
316
#endif /* TRACE_TIMINGS */
317
318
#ifdef USE_FAST_RANDOM
319
# ifdef TRACE_TIMINGS
320
    g_seed = tv1.tv_usec;
321
# else
322
    gettimeofday(&tvseed, NULL);
323
    g_seed = tvseed.tv_usec;
324
# endif /* TRACE_TIMINGS */
325
#endif /* USE_FAST_RANDOM */
326
327
328
    data = mrses_ppu_malloc(thr, mrses);
329
    if (!data) return 1;
330
331
    A = data->A;
332
    B = data->B;
333
    Ca = data->Ca;
334
    Cb = data->Cb;
335
    C = data->C;
336
    mean = data->mean;
337
    mean_copy = data->mean_copy;
338
    
339
    hist = (uint32_t*)thr->data;
340
341
342
    for (;block < block_end; ++block) { 
343
	index = mrses->index + block * properties;
344
	
345
	
346
	for (i = 0; i < width; ++i) {
347
	    idx = index[i];
348
349
	    memcpy(A + i * nA, Afull + idx * alloc, nA * sizeof(MRSESDataType));
350
	    memcpy(B + i * nB, Bfull + idx * alloc, nB * sizeof(MRSESDataType));
351
	    mean[i] = Mfull[idx];
352
	}
353
    
354
	memcpy(mean_copy, mean, width * sizeof(MRSESDataType));
355
356
	err = mrses_ppu_real_run(
357
	    mrses, &cur_result,
358
	    width, width2, nA, nB,
359
	    A, B, mean_copy, C, Ca, Cb
360
	);
361
362
	rst_gen = width;
363
	for (i = 0; i < iterations; i++) {
364
#ifdef USE_FAST_RANDOM
365
	    FASTRND(drp_gen, width, 0);
366
	    FASTRND(rpl_gen, properties - width, width);
367
#else /* USE_FAST_RANDOM */
368
	    drp_gen = rnd(width);
369
	    rpl_gen = rnd(properties - width) + width;
370
#endif /* USE_FAST_RANDOM */
371
	
372
	    if ((rst_gen < width)&&(rst_gen != drp_gen)) {
373
		idx = index[rst_gen];
374
	        memcpy(A + rst_gen * nA, Afull + idx * alloc, nA * sizeof(MRSESDataType));
375
		memcpy(B + rst_gen * nB, Bfull + idx * alloc, nB * sizeof(MRSESDataType));
376
	        rst_gen = width;
377
	    }	
378
379
	    idx = index[rpl_gen];
380
	    memcpy(A + drp_gen * nA, Afull + idx * alloc, nA * sizeof(MRSESDataType));
381
	    memcpy(B + drp_gen * nB, Bfull + idx * alloc, nB * sizeof(MRSESDataType));
382
	
383
	    memcpy(mean_copy, mean, width * sizeof(MRSESDataType));
384
	    mean_copy[drp_gen] = Mfull[idx];
385
	
386
	    err = mrses_ppu_real_run(
387
	        mrses, &result,
388
		width, width2, nA, nB,
389
		A, B, mean_copy, C, Ca, Cb
390
	    );
391
	
392
	    if (result < cur_result) {
393
		rst_gen = drp_gen;
394
	    } else {
395
//	    printf("%i=%i %i=%i %f %f\n", drp_gen, index[drp_gen], rpl_gen, index[rpl_gen], cur_result, result);
396
		cur_result = result;
397
	        SWAPij(index, drp_gen, rpl_gen);
398
399
		mean[drp_gen] = Mfull[idx];
400
	    }
401
	}
402
	for (i = 0; i < width; i++) {
403
	    hist[index[i]]++;
404
//	printf("%5i ", index[i] + 1);
405
	}
406
	if (ires) {
407
	    memcpy(ires + width * block, index, width * sizeof(MRSESIntType));
408
	} 
409
	
410
//    printf("\n\n\n");
411
    }
412
    
413
#ifdef TRACE_TIMINGS
414
    gettimeofday(&tv2, NULL);
415
    printf("Thead %p: %li ms\n", thr, (tv2.tv_sec - tv1.tv_sec)*1000+(tv2.tv_usec - tv1.tv_usec)/1000);
416
#endif /* TRACE_TIMINGS */
417
418
    return 0;
419
}