RTK  2.0.1
Reconstruction Toolkit
rtkSchlomka2008NegativeLogLikelihood.h
Go to the documentation of this file.
1 /*=========================================================================
2  *
3  * Copyright RTK Consortium
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  *=========================================================================*/
18 
19 #ifndef rtkSchlomka2008NegativeLogLikelihood_h
20 #define rtkSchlomka2008NegativeLogLikelihood_h
21 
23 #include "rtkMacro.h"
24 
25 #include <itkVectorImage.h>
27 #include <itkVariableSizeMatrix.h>
28 
29 namespace rtk
30 {
44 // We have to define the cost function first
46 {
47 public:
48  ITK_DISALLOW_COPY_AND_ASSIGN(Schlomka2008NegativeLogLikelihood);
49 
54  itkNewMacro( Self );
56 
60 
65 
66  // Constructor
68  {
70  }
71 
72  // Destructor
73  ~Schlomka2008NegativeLogLikelihood() override = default;
74 
75  void Initialize() override
76  {
77  // This method computes the combined m_IncidentSpectrumAndDetectorResponseProduct
78  // from m_DetectorResponse and m_IncidentSpectrum
79 
80  // In spectral CT, m_DetectorResponse has as many rows as the number of bins,
81  // and m_IncidentSpectrum has only one row (there is only one spectrum illuminating
82  // the object)
84  for (unsigned int i=0; i<m_DetectorResponse.rows(); i++)
85  for (unsigned int j=0; j<m_DetectorResponse.cols(); j++)
87  }
88 
89  // Not used with a simplex optimizer, but may be useful later
90  // for gradient based methods
91  void GetDerivative( const ParametersType & lineIntegrals,
92  DerivativeType & derivatives ) const override
93  {
94  // Set the size of the derivatives vector
95  derivatives.set_size(m_NumberOfMaterials);
96 
97  // Get some required data
98  vnl_vector<double> attenuationFactors;
99  attenuationFactors.set_size(this->m_NumberOfEnergies);
100  GetAttenuationFactors(lineIntegrals, attenuationFactors);
101  vnl_vector<double> lambdas = ForwardModel(lineIntegrals);
102 
103  // Compute the vector of 1 - m_b / lambda_b
104  vnl_vector<double> weights;
105  weights.set_size(m_NumberOfSpectralBins);
106  for (unsigned int i=0; i<m_NumberOfSpectralBins; i++)
107  weights[i] = 1 - (m_MeasuredData[i] / lambdas[i]);
108 
109  // Prepare intermediate variables
110  vnl_vector<double> intermediate_a;
111  vnl_vector<double> partial_derivative_a;
112 
113  for (unsigned int a=0; a<m_NumberOfMaterials; a++)
114  {
115  // Compute the partial derivatives of lambda_b with respect to the material line integrals
116  intermediate_a = element_product(-attenuationFactors, m_MaterialAttenuations.get_column(a));
117  partial_derivative_a = m_IncidentSpectrumAndDetectorResponseProduct * intermediate_a;
118 
119  // Multiply them together element-wise, then dot product with the weights
120  derivatives[a] = dot_product(partial_derivative_a,weights);
121  }
122  }
123 
124  // Main method
125  MeasureType GetValue( const ParametersType & parameters ) const override
126  {
127  // Forward model: compute the expected number of counts in each bin
128  vnl_vector<double> forward = ForwardModel(parameters);
129 
130  long double measure = 0;
131  // Compute the negative log likelihood from the lambdas
132  for (unsigned int i=0; i<m_NumberOfSpectralBins; i++)
133  measure += forward[i] - std::log((long double)forward[i]) * m_MeasuredData[i];
134  return measure;
135  }
136 
137  void ComputeFischerMatrix(const ParametersType & lineIntegrals) override
138  {
139  // Get some required data
140  vnl_vector<double> attenuationFactors;
141  attenuationFactors.set_size(this->m_NumberOfEnergies);
142  GetAttenuationFactors(lineIntegrals, attenuationFactors);
143  vnl_vector<double> lambdas = ForwardModel(lineIntegrals);
144 
145  // Compute the vector of m_b / lambda_b^2
146  vnl_vector<double> weights;
147  weights.set_size(m_NumberOfSpectralBins);
148  for (unsigned int i=0; i<m_NumberOfSpectralBins; i++)
149  weights[i] = m_MeasuredData[i] / (lambdas[i] * lambdas[i]);
150 
151  // Prepare intermediate variables
152  vnl_vector<double> intermediate_a;
153  vnl_vector<double> intermediate_a_prime;
154  vnl_vector<double> partial_derivative_a;
155  vnl_vector<double> partial_derivative_a_prime;
156 
157  // Compute the Fischer information matrix
159  for (unsigned int a=0; a<m_NumberOfMaterials; a++)
160  {
161  for (unsigned int a_prime=0; a_prime<m_NumberOfMaterials; a_prime++)
162  {
163  // Compute the partial derivatives of lambda_b with respect to the material line integrals
164  intermediate_a = element_product(-attenuationFactors, m_MaterialAttenuations.get_column(a));
165  intermediate_a_prime = element_product(-attenuationFactors, m_MaterialAttenuations.get_column(a_prime));
166 
167  partial_derivative_a = m_IncidentSpectrumAndDetectorResponseProduct * intermediate_a;
168  partial_derivative_a_prime = m_IncidentSpectrumAndDetectorResponseProduct * intermediate_a_prime;
169 
170  // Multiply them together element-wise, then dot product with the weights
171  partial_derivative_a_prime = element_product(partial_derivative_a, partial_derivative_a_prime);
172  m_Fischer[a][a_prime] = dot_product(partial_derivative_a_prime,weights);
173  }
174  }
175  }
176 
177 };
178 
179 }// namespace RTK
180 
181 #endif
void ComputeFischerMatrix(const ParametersType &lineIntegrals) override
void GetAttenuationFactors(const ParametersType &lineIntegrals, vnl_vector< double > &attenuationFactors) const
virtual vnl_vector< double > ForwardModel(const ParametersType &lineIntegrals) const
MeasureType GetValue(const ParametersType &parameters) const override
void GetDerivative(const ParametersType &lineIntegrals, DerivativeType &derivatives) const override
~Schlomka2008NegativeLogLikelihood() override=default