summaryrefslogtreecommitdiffstats
path: root/matlab/algorithms/DART/tools/DARToptimizerBoneStudy.m
blob: cfd4febf30a5bdd27d6936b953833af3fa254fd1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
%--------------------------------------------------------------------------
% This file is part of the ASTRA Toolbox
%
% Copyright: 2010-2021, imec Vision Lab, University of Antwerp
%            2014-2021, CWI, Amsterdam
% License: Open Source under GPLv3
% Contact: astra@astra-toolbox.com
% Website: http://www.astra-toolbox.com/
%--------------------------------------------------------------------------

classdef DARToptimizerBoneStudy < handle
	
	%----------------------------------------------------------------------
	properties (SetAccess=public,GetAccess=public)

		% optimization options
		max_evals = 100;
		tolerance = 0.1;
		display = 'off';
		
		% DART options
		DART_iterations = 50;
		
		D_base = [];
		
	end
	
	%----------------------------------------------------------------------
	properties (SetAccess=private,GetAccess=public)
		
		stats = struct();
		
	end	
	
	%----------------------------------------------------------------------
	methods (Access=public)

		%------------------------------------------------------------------
		% Constructor
		function this = DARToptimizerBoneStudy(D_base)

			this.D_base = D_base;
					
			this.stats.params = {};
			this.stats.values = [];
			this.stats.rmse = [];
			this.stats.f_250 = [];			
			this.stats.f_100 = [];			
			this.stats.w_250 = [];			
			this.stats.w_125 = [];						
			
		end	
	
		%------------------------------------------------------------------
		function opt_values = run(this, params, initial_values)
			
			if nargin < 3
				for i = 1:numel(params)
					initial_values(i) = eval(['this.D_base.' params{i} ';']);
				end
			end
			
			% fminsearch
			options = optimset('display', this.display, 'MaxFunEvals', this.max_evals, 'TolX', this.tolerance);
			opt_values = fminsearch(@optim_func, initial_values, options, this.D_base, params, this);

			% save to D_base
			for i = 1:numel(params)
				eval(sprintf('this.D_base.%s = %d;',params{i}, opt_values(i)));
			end
			
		end
		%------------------------------------------------------------------
	end
	
end
	
%--------------------------------------------------------------------------
function rmse = optim_func(values, D_base, params, Optim)

	% copy DART 
	D = D_base.deepcopy();
	
	% set parameters
	for i = 1:numel(params)
		eval(sprintf('D.%s = %d;',params{i}, values(i)));
		D.output.pre = [D.output.pre num2str(values(i)) '_'];
	end
	
	% evaluate
	if D.initialized == 0
		D.initialize();
	end
	rng('default');
	D.iterate(Optim.DART_iterations);

	% compute rmse
	ROI = load('roi.mat');
	[rmse, f_250, f_100, w_250, w_125] = compute_rmse(D.S, ROI);
	%projection = D.tomography.project(D.S);
	%proj_diff = sum((projection(:) - D.base.sinogram(:)).^2);
	
	% save
	Optim.stats.params{end+1} = params;
	Optim.stats.values(end+1,:) = values;
	Optim.stats.rmse(end+1) = rmse;
	Optim.stats.f_250(end+1) = f_250;
	Optim.stats.f_100(end+1) = f_100;
	Optim.stats.w_250(end+1) = w_250;	
	Optim.stats.w_125(end+1) = w_125;		
	
	disp([num2str(values) ': ' num2str(rmse)]);
	
end