/ani/mrses

To get this branch, use:
bzr branch http://suren.me/webbzr/ani/mrses

« back to all changes in this revision

Viewing changes to mrses_hw_debug.m

  • Committer: Suren A. Chilingaryan
  • Date: 2010-04-28 04:30:08 UTC
  • Revision ID: csa@dside.dyndns.org-20100428043008-vd9z0nso9axezvlp
Initial import

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
function res=mrses_hw_debug(A,B,k,Niter,Ncycle,distmod,block)
 
2
  if (nargin<7)
 
3
    block=256;
 
4
  end
 
5
  if (nargin<6)
 
6
    distmod=1;
 
7
  end
 
8
  if (nargin<5)
 
9
    Ncycle=1000;
 
10
  end
 
11
  if (nargin<4)
 
12
    Niter=500;
 
13
  end
 
14
  if (nargin<3)
 
15
    k=5;
 
16
  end
 
17
  if (nargin<2)
 
18
    error('As minimum two matrixes needed for MRSES');
 
19
  end
 
20
  if (nargin>6)
 
21
    error('Too much parameters');
 
22
  end
 
23
 
 
24
  sa=size(A);sb=size(B);
 
25
  if (sa(2)==sb(2))
 
26
    genes=sa(2);
 
27
  else
 
28
    error('Features dimension mismatch');
 
29
  end
 
30
 
 
31
  nA=sa(1); nB=sb(1);
 
32
 
 
33
  %optki=zeros(Ncycle,k);
 
34
  
 
35
  ctx = mrses_hw();
 
36
  mrses_hw(ctx, 1, k, block, single(A), single(B), distmod);
 
37
 
 
38
  mean_a = mean(A,1);
 
39
  mean_b = mean(B,1);
 
40
  mdiff = mean_a - mean_b;
 
41
 
 
42
  dA = (A - ones([nA,1]) * mean_a) / sqrt(nA);
 
43
  dB = (B - ones([nB,1]) * mean_b) / sqrt(nB);
 
44
 
 
45
  for icycle=0:block:(Ncycle-1)
 
46
    block_size = min(block, Ncycle - icycle);
 
47
    %   SELECT k GENES {ki} FOR TEST AND EXCLUDE THEM FROM ALL GENES {ke}
 
48
    
 
49
    ki=int16([]); ke=int16([]); 
 
50
    for i=1:block_size
 
51
        tt=randperm(genes);% randomizing genes
 
52
        
 
53
        ki(:,i)=tt(1:k);        % selecting first k
 
54
        ke(:,i)=tt(k+1:end);    % the rest unuzed
 
55
    end
 
56
    
 
57
    cur_dist_hw = mrses_hw(ctx, 10, block_size, ki);
 
58
    cur_dist = multi_bmc(dA, dB, mdiff, ki, distmod);
 
59
%    find(cur_dist ~= cur_dist_hw)
 
60
    dist_diff = abs(cur_dist_hw - cur_dist);
 
61
    allowed = abs(cur_dist)/100000;
 
62
    find ( dist_diff > allowed)
 
63
 
 
64
%    for i=1:block_size
 
65
%       check_dist(i) = bmc(A(:,ki(:,i)),B(:,ki(:,i)), distmod); 
 
66
%    end
 
67
%    find(cur_dist ~= check_dist)
 
68
 
 
69
    for iter=1:Niter
 
70
        xki=ceil(rand(1,block_size)*k);                 % selecting random gen from selected
 
71
        xke=ceil(rand(1,block_size)*(genes-k));         % selected random gen from non-selected
 
72
 
 
73
        idx_i = sub2ind(size(ki), xki, 1:block_size);
 
74
        idx_e = sub2ind(size(ke), xke, 1:block_size);
 
75
        
 
76
        t=ki(idx_i);
 
77
        ki(idx_i)=ke(idx_e);
 
78
        ke(idx_e)=t;
 
79
        
 
80
        dist = multi_bmc(dA, dB, mdiff, ki, distmod);
 
81
%        for i=1:block_size
 
82
%           check_dist(i)=bmc(A(:,ki(:,i)),B(:,ki(:,i)),distmod); % compute distance between A and B with currently selected genes
 
83
%       end
 
84
        %find(dist ~= check_dist)
 
85
%       dist_diff = abs(dist - check_dist);
 
86
%       allowed = abs(dist)/1000000;
 
87
%       find ( dist_diff > allowed)
 
88
 
 
89
        
 
90
        bad = find(dist < cur_dist);
 
91
        idx_i = idx_i(bad);
 
92
        idx_e = idx_e(bad);
 
93
        
 
94
        t=ki(idx_i);
 
95
        ki(idx_i)=ke(idx_e);
 
96
        ke(idx_e)=t;
 
97
        
 
98
        cur_dist = max(dist, cur_dist);
 
99
    end
 
100
    optki(:,(icycle+1):(icycle+block_size))=ki; % save finally selected genes
 
101
  end
 
102
  mrses_hw(ctx);
 
103
  
 
104
  optki=reshape(optki,1,[]);
 
105
  [n,g]=hist(optki,1:genes);
 
106
  H=[n./Ncycle;g];
 
107
  res=flipud(sortrows(H'));
 
108
 
 
109
 
 
110
% DISTANCE CALCULATOR
 
111
% ifs are taking to much time, do a separate functions for each distance
 
112
function dist=multi_bmc(dA, dB, mean_diff, ki, distmod)
 
113
    block_size = size(ki,2);
 
114
    
 
115
    %ki(:,1) = [5,7,9,11,13];
 
116
    for i=1:block_size
 
117
        x = dA(:, ki(:,i));
 
118
        %x(1:5,:)'
 
119
        c1 = x'*x;
 
120
        x = dB(:, ki(:,i));
 
121
        c2 = x'*x;
 
122
        
 
123
        c=(c1+c2)./2;
 
124
        
 
125
        [L,p] = chol(c);
 
126
        %detc = prod(diag(L,0))^2;
 
127
        %tmp = diag(L,0);
 
128
        %detc = tmp' * tmp;
 
129
        if p > 0
 
130
            detc = 0;
 
131
        else
 
132
            detc = det(L)^2;
 
133
        end
 
134
 
 
135
%       rcorr(i)=log((detc.*detc)./(det(c1).*det(c2)));
 
136
        rcorr(i)=2.*log(detc./sqrt(det(c1).*det(c2)));
 
137
%       rcorr(i)=2.*log(detc./sqrt(det(c1*c2)));
 
138
 
 
139
        if detc == 0
 
140
            mdiff = mean_diff(ki(:,i));
 
141
            rmahal(i)=(mdiff*pinv(c))*mdiff';
 
142
        else
 
143
%           rmahal(i)=(mdiff/c)*mdiff';
 
144
%           rmahal(i)=((mdiff/L)/L')*mdiff';
 
145
%           rmahal(i) = prod(mdiff/L)^2;
 
146
            tmp = mean_diff(ki(:,i))/L;
 
147
            rmahal(i) = tmp * tmp';
 
148
        end
 
149
    end
 
150
 
 
151
 
 
152
    if (distmod==1) 
 
153
        dist = rmahal./8 + rcorr./4;
 
154
    elseif (distmode==2) 
 
155
        dist = rmahal;
 
156
    else 
 
157
        dist = rcorr;
 
158
    end
 
159
 
 
160
function dist=bmc(x1,x2,distmod)
 
161
    c1=cov1(x1);
 
162
    c2=cov1(x2);
 
163
    c=(c1+c2)./2;
 
164
 
 
165
    if (distmod~=2)
 
166
        rcorr=2.*log(det(c)./sqrt(det(c1).*det(c2)));
 
167
    end
 
168
 
 
169
    if (distmod~=3)
 
170
        m1=mean(x1);
 
171
        m2=mean(x2);
 
172
        rmahal=((m2-m1)/c)*(m2-m1)';
 
173
    end
 
174
    
 
175
    if (distmod==1) 
 
176
        dist = rmahal./8+rcorr./4;
 
177
    elseif (distmode==2) 
 
178
        dist=rmahal;
 
179
    else 
 
180
        dist=rcorr;
 
181
    end
 
182
 
 
183
 
 
184
function c = cov1(x, m)
 
185
    [rows, cols] = size(x);
 
186
    %rows = size(x(:,1))
 
187
    
 
188
    if (nargin<2) 
 
189
        nX = x - ones([rows,1]) * mean(x);
 
190
    else
 
191
        nX = x - ones([rows,1]) * m';
 
192
    end
 
193
 
 
194
    c = nX' * nX / rows;
 
195
 
 
196
function c = cov2(x)
 
197
    c = x' * x;