RTK  2.4.1
Reconstruction Toolkit
rtkSchlomka2008NegativeLogLikelihood.h
Go to the documentation of this file.
1 /*=========================================================================
2  *
3  * Copyright RTK Consortium
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * https://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  *=========================================================================*/
18 
19 #ifndef rtkSchlomka2008NegativeLogLikelihood_h
20 #define rtkSchlomka2008NegativeLogLikelihood_h
21 
23 #include "rtkMacro.h"
24 
25 #include <itkVectorImage.h>
27 #include <itkVariableSizeMatrix.h>
28 
29 namespace rtk
30 {
44 // We have to define the cost function first
46 {
47 public:
48  ITK_DISALLOW_COPY_AND_MOVE(Schlomka2008NegativeLogLikelihood);
49 
54  itkNewMacro(Self);
56 
60 
65 
66  // Constructor
68 
69  // Destructor
70  ~Schlomka2008NegativeLogLikelihood() override = default;
71 
72  void
73  Initialize() override
74  {
75  // This method computes the combined m_IncidentSpectrumAndDetectorResponseProduct
76  // from m_DetectorResponse and m_IncidentSpectrum
77 
78  // In spectral CT, m_DetectorResponse has as many rows as the number of bins,
79  // and m_IncidentSpectrum has only one row (there is only one spectrum illuminating
80  // the object)
82  for (unsigned int i = 0; i < m_DetectorResponse.rows(); i++)
83  for (unsigned int j = 0; j < m_DetectorResponse.cols(); j++)
85  }
86 
87  // Not used with a simplex optimizer, but may be useful later
88  // for gradient based methods
89  void
90  GetDerivative(const ParametersType & lineIntegrals, DerivativeType & derivatives) const override
91  {
92  // Set the size of the derivatives vector
93  derivatives.set_size(m_NumberOfMaterials);
94 
95  // Get some required data
96  vnl_vector<double> attenuationFactors;
97  attenuationFactors.set_size(this->m_NumberOfEnergies);
98  GetAttenuationFactors(lineIntegrals, attenuationFactors);
99  vnl_vector<double> lambdas = ForwardModel(lineIntegrals);
100 
101  // Compute the vector of 1 - m_b / lambda_b
102  vnl_vector<double> weights;
103  weights.set_size(m_NumberOfSpectralBins);
104  for (unsigned int i = 0; i < m_NumberOfSpectralBins; i++)
105  weights[i] = 1 - (m_MeasuredData[i] / lambdas[i]);
106 
107  // Prepare intermediate variables
108  vnl_vector<double> intermediate_a;
109  vnl_vector<double> partial_derivative_a;
110 
111  for (unsigned int a = 0; a < m_NumberOfMaterials; a++)
112  {
113  // Compute the partial derivatives of lambda_b with respect to the material line integrals
114  intermediate_a = element_product(-attenuationFactors, m_MaterialAttenuations.get_column(a));
115  partial_derivative_a = m_IncidentSpectrumAndDetectorResponseProduct * intermediate_a;
116 
117  // Multiply them together element-wise, then dot product with the weights
118  derivatives[a] = dot_product(partial_derivative_a, weights);
119  }
120  }
121 
122  // Main method
124  GetValue(const ParametersType & parameters) const override
125  {
126  // Forward model: compute the expected number of counts in each bin
127  vnl_vector<double> forward = ForwardModel(parameters);
128 
129  long double measure = 0;
130  // Compute the negative log likelihood from the lambdas
131  for (unsigned int i = 0; i < m_NumberOfSpectralBins; i++)
132  measure += forward[i] - std::log((long double)forward[i]) * m_MeasuredData[i];
133  return measure;
134  }
135 
136  void
137  ComputeFischerMatrix(const ParametersType & lineIntegrals) override
138  {
139  // Get some required data
140  vnl_vector<double> attenuationFactors;
141  attenuationFactors.set_size(this->m_NumberOfEnergies);
142  GetAttenuationFactors(lineIntegrals, attenuationFactors);
143  vnl_vector<double> lambdas = ForwardModel(lineIntegrals);
144 
145  // Compute the vector of m_b / lambda_b^2
146  vnl_vector<double> weights;
147  weights.set_size(m_NumberOfSpectralBins);
148  for (unsigned int i = 0; i < m_NumberOfSpectralBins; i++)
149  weights[i] = m_MeasuredData[i] / (lambdas[i] * lambdas[i]);
150 
151  // Prepare intermediate variables
152  vnl_vector<double> intermediate_a;
153  vnl_vector<double> intermediate_a_prime;
154  vnl_vector<double> partial_derivative_a;
155  vnl_vector<double> partial_derivative_a_prime;
156 
157  // Compute the Fischer information matrix
159  for (unsigned int a = 0; a < m_NumberOfMaterials; a++)
160  {
161  for (unsigned int a_prime = 0; a_prime < m_NumberOfMaterials; a_prime++)
162  {
163  // Compute the partial derivatives of lambda_b with respect to the material line integrals
164  intermediate_a = element_product(-attenuationFactors, m_MaterialAttenuations.get_column(a));
165  intermediate_a_prime = element_product(-attenuationFactors, m_MaterialAttenuations.get_column(a_prime));
166 
167  partial_derivative_a = m_IncidentSpectrumAndDetectorResponseProduct * intermediate_a;
168  partial_derivative_a_prime = m_IncidentSpectrumAndDetectorResponseProduct * intermediate_a_prime;
169 
170  // Multiply them together element-wise, then dot product with the weights
171  partial_derivative_a_prime = element_product(partial_derivative_a, partial_derivative_a_prime);
172  m_Fischer[a][a_prime] = dot_product(partial_derivative_a_prime, weights);
173  }
174  }
175  }
176 };
177 
178 } // namespace rtk
179 
180 #endif
virtual vnl_vector< double > ForwardModel(const ParametersType &lineIntegrals) const
void ComputeFischerMatrix(const ParametersType &lineIntegrals) override
void GetAttenuationFactors(const ParametersType &lineIntegrals, vnl_vector< double > &attenuationFactors) const
void GetDerivative(const ParametersType &lineIntegrals, DerivativeType &derivatives) const override
MeasureType GetValue(const ParametersType &parameters) const override
~Schlomka2008NegativeLogLikelihood() override=default