/ani/mrses

To get this branch, use:
bzr branch http://suren.me/webbzr/ani/mrses

« back to all changes in this revision

Viewing changes to cell/atlas_potrf.c

  • Committer: Suren A. Chilingaryan
  • Date: 2010-04-28 04:30:08 UTC
  • Revision ID: csa@dside.dyndns.org-20100428043008-vd9z0nso9axezvlp
Initial import

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#include <stdio.h>
 
2
#include <math.h>
 
3
 
 
4
#include <cblas.h>
 
5
#include "mrses.h"
 
6
 
 
7
int atlas_spotrf2(const int N, MRSESDataType *A, const int lda)
 
8
{
 
9
   int j;
 
10
   MRSESDataType Ajj, *Ac=A, *An=A+lda;
 
11
 
 
12
   for (j=0; j != N; j++)
 
13
   {
 
14
      Ajj = Ac[j] - cblas_sdot(j, Ac, 1, Ac, 1);
 
15
      if (Ajj > 0)
 
16
      {
 
17
         Ac[j] = Ajj = sqrt(Ajj);
 
18
         if (j != N-1)
 
19
         {
 
20
            cblas_sgemv(CblasColMajor, CblasTrans, j, N-j-1, -1,
 
21
                       An, lda, Ac, 1, 1, An+j, lda);
 
22
            cblas_sscal(N-j-1, 1/Ajj, An+j, lda);
 
23
            Ac = An;
 
24
            An += lda;
 
25
         }
 
26
      }
 
27
      else
 
28
      {
 
29
         Ac[j] = Ajj;
 
30
         return(j+1);
 
31
      }
 
32
   }
 
33
   return(0);
 
34
}
 
35
 
 
36
 
 
37
#define TYPE MRSESDataType
 
38
#define ATL_rzero 0
 
39
#define ATL_rone 1
 
40
#define ATL_rnone -1
 
41
#define ONE ATL_rone
 
42
 
 
43
inline int ATL_potrfU_4(TYPE *A, const short int lda)
 
44
{
 
45
   TYPE *pA1=A+lda, *pA2=pA1+lda, *pA3=pA2+lda;
 
46
   TYPE L11 = *A, L21 = *pA1, L31 = *pA2, L41 = *pA3;
 
47
   TYPE L22 = pA1[1], L32 = pA2[1], L42 = pA3[1];
 
48
   TYPE L33 = pA2[2], L43 = pA3[2];
 
49
   TYPE L44 = pA3[3];
 
50
   int iret=0;
 
51
 
 
52
   if (L11 > ATL_rzero)
 
53
   {
 
54
      *A = L11 = sqrt(L11);
 
55
      L11 = ATL_rone / L11;
 
56
      L21 *= L11;
 
57
      L31 *= L11;
 
58
      L41 *= L11;
 
59
      *pA1 = L21; *pA2 = L31; *pA3 = L41;
 
60
      L22 -= L21*L21;
 
61
      if (L22 > ATL_rzero)
 
62
      {
 
63
         pA1[1] = L22 = sqrt(L22);
 
64
         L22 = ATL_rone / L22;
 
65
         L32 = (L32 - L31*L21) * L22;
 
66
         L42 = (L42 - L41*L21) * L22;
 
67
         L33 -= L31*L31 + L32*L32;
 
68
         pA2[1] = L32; pA3[1] = L42;
 
69
         if (L33 > ATL_rzero)
 
70
         {
 
71
            pA2[2] = L33 = sqrt(L33);
 
72
            L43 = (L43 - L41*L31 - L42*L32) / L33;
 
73
            L44 -= L41*L41 + L42*L42 + L43*L43;
 
74
            pA3[2] = L43;
 
75
            if (L44 > ATL_rzero)
 
76
            {
 
77
               pA3[3] = sqrt(L44);
 
78
               return(0);
 
79
            }
 
80
            else iret=4;
 
81
         }
 
82
         else iret=3;
 
83
      }
 
84
      else iret=2;
 
85
   }
 
86
   else iret=1;
 
87
   return(iret);
 
88
}
 
89
 
 
90
inline int ATL_potrfU_3(TYPE *A, const short int lda)
 
91
{
 
92
   TYPE *pA1=A+lda, *pA2=pA1+lda;
 
93
   register TYPE L11 = *A, L21 = *pA1, L31 = *pA2;
 
94
   register TYPE L22=pA1[1], L32=pA2[1];
 
95
   register TYPE L33=pA2[2];
 
96
   int iret=0;
 
97
 
 
98
   if (L11 > ATL_rzero)
 
99
   {
 
100
      *A = L11 = sqrt(L11);
 
101
      L11 = ATL_rone / L11;
 
102
      L21 *= L11;
 
103
      L31 *= L11;
 
104
      *pA1 = L21; *pA2 = L31;
 
105
      L22 -= L21*L21;
 
106
      if (L22 > ATL_rzero)
 
107
      {
 
108
         L22 = sqrt(L22);
 
109
         L32 = (L32 - L31*L21) / L22;
 
110
         L33 -= L31*L31 + L32*L32;
 
111
         pA1[1] = L22; pA2[1] = L32;
 
112
         if (L33 > ATL_rzero)
 
113
         {
 
114
            pA2[2] = sqrt(L33);
 
115
            return(0);
 
116
         }
 
117
         else iret=3;
 
118
      }
 
119
      else iret=2;
 
120
   }
 
121
   else iret=1;
 
122
   return(iret);
 
123
}
 
124
 
 
125
inline int ATL_potrfU_2(TYPE *A, const short int lda)
 
126
{
 
127
   TYPE *pA1 = A+lda;
 
128
   register TYPE L11=*A, L21=*pA1, L22 = pA1[1];
 
129
 
 
130
   if (L11 > ATL_rzero)
 
131
   {
 
132
      *A = L11 = sqrt(L11);
 
133
      *pA1 = L21 = L21 / L11;
 
134
      L22 -= L21*L21;
 
135
      if (L22 > ATL_rzero)
 
136
      {
 
137
         pA1[1] = sqrt(L22);
 
138
         return(0);
 
139
      }
 
140
      else return(2);
 
141
   }
 
142
   else return(1);
 
143
}
 
144
 
 
145
#define BASE float
 
146
#define INDEX short int
 
147
inline void
 
148
gsl_cblas_strsm_clut_1 (const short int M, const short int N, const float *A, float *B, const short int lda)
 
149
{
 
150
    register short int i, j, k;
 
151
    float Ajj, Bij;
 
152
 
 
153
    for (i = 0; i < N; i++) {
 
154
      for (j = 0; j < M; j++) {
 
155
          Ajj = A[lda * j + j];
 
156
          Bij = B[lda * i + j] / Ajj;
 
157
          
 
158
          B[lda * i + j] = Bij;
 
159
          for (k = j + 1; k < M; k++) {
 
160
            B[lda * i + k] -= A[k * lda + j] * Bij;
 
161
          }
 
162
      }
 
163
    }
 
164
}
 
165
 
 
166
inline void
 
167
gsl_cblas_ssyrk_cut_m11 (const short int N, const short int M, const float *A, float *C, const short int lda) {
 
168
  register short int i, j, k;
 
169
 
 
170
    for (i = 0; i < N; i++) {
 
171
      for (j = 0; j <= i; j++) {
 
172
        float temp = 0.0;
 
173
        for (k = 0; k < M; k++) {
 
174
          temp += A[i * lda + k] * A[j * lda + k];
 
175
        }
 
176
        C[i * lda + j] -= temp;
 
177
      }
 
178
    }
 
179
}
 
180
 
 
181
inline void gsl_spotrf_step(const short int M, const short int N, float *A, const short int lda) {
 
182
  register short int i, j, k;
 
183
 
 
184
  float *B = A + M*lda;
 
185
  float *C = B + M;
 
186
 
 
187
  for (i = 0; i < N; i++) {
 
188
      for (j = 0; j < M; j++) {
 
189
          float sum = 0;
 
190
          
 
191
          for (k = 0; k < j; k++) {
 
192
            sum += A[j * lda + k] *  B[i * lda + k];
 
193
          }
 
194
          
 
195
          B[i * lda + j] =  (B[i * lda + j] - sum) / A[j * lda + j];
 
196
      }
 
197
 
 
198
      for (j = 0; j <= i; j++) {
 
199
        float temp = 0.0;
 
200
 
 
201
        for (k = 0; k < M; k++) {
 
202
          temp += B[i * lda + k] * B[j * lda + k];
 
203
        }
 
204
 
 
205
        C[i * lda + j] -= temp;
 
206
      }
 
207
  }
 
208
}
 
209
 
 
210
 
 
211
int atlas_spotrf_u(const short int N, TYPE *A, const short int lda)
 
212
{
 
213
   TYPE *Ac, *An;
 
214
   short int Nleft, Nright, ierr;
 
215
 
 
216
  if (N > 4)
 
217
  {
 
218
      Nleft = N >> 1;
 
219
      Nright = N - Nleft;
 
220
      ierr = atlas_spotrf_u(Nleft, A, lda);
 
221
      if (!ierr)
 
222
      {
 
223
        Ac = A + lda * Nleft;
 
224
        An = Ac + Nleft; //SHIFT;
 
225
        gsl_spotrf_step(Nleft, Nright, A, lda);
 
226
 
 
227
/*
 
228
//       gsl_cblas_strsm_clut_1(Nleft, Nright, A, Ac, lda);
 
229
//       gsl_cblas_ssyrk_cut_m11(Nright, Nleft, Ac, An, lda);
 
230
 
 
231
 
 
232
//         cblas_strsm(CblasColMajor, CblasLeft, CblasUpper, CblasTrans,
 
233
//                    CblasNonUnit, Nleft, Nright, ONE, A, lda, Ac, lda);
 
234
 
 
235
 
 
236
 
 
237
//         cblas_ssyrk(CblasColMajor, CblasUpper, CblasTrans, Nright, Nleft,
 
238
//                  ATL_rnone, Ac, lda, ATL_rone, An, lda);
 
239
 
 
240
*/
 
241
         ierr = atlas_spotrf_u(Nright, An, lda);
 
242
         if (ierr) return(ierr+Nleft);
 
243
      }
 
244
      else return(ierr);
 
245
   }
 
246
      else if (N==4) return(ATL_potrfU_4(A, lda));
 
247
      else if (N==3) return(ATL_potrfU_3(A, lda));
 
248
      else if (N==2) return(ATL_potrfU_2(A, lda));
 
249
      else if (N==1)
 
250
      {
 
251
         if (*A > ATL_rzero) *A = sqrt(*A);
 
252
         else return(1);
 
253
      }
 
254
   return(0);
 
255
}