RTK  2.5.0
Reconstruction Toolkit
rtkProjectionsDecompositionNegativeLogLikelihood.h
Go to the documentation of this file.
1 /*=========================================================================
2  *
3  * Copyright RTK Consortium
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * https://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  *=========================================================================*/
18 
19 #ifndef rtkProjectionsDecompositionNegativeLogLikelihood_h
20 #define rtkProjectionsDecompositionNegativeLogLikelihood_h
21 
23 #include <itkVectorImage.h>
25 #include <itkVariableSizeMatrix.h>
26 #include "rtkMacro.h"
27 
28 namespace rtk
29 {
38 // We have to define the cost function first
40 {
41 public:
42  ITK_DISALLOW_COPY_AND_MOVE(ProjectionsDecompositionNegativeLogLikelihood);
43 
48  itkNewMacro(Self);
49 #ifdef itkOverrideGetNameOfClassMacro
50  itkOverrideGetNameOfClassMacro(ProjectionsDecompositionNegativeLogLikelihood);
51 #else
53 #endif
54 
55  // enum { SpaceDimension=m_NumberOfMaterials };
56 
60 
61  using DetectorResponseType = vnl_matrix<double>;
62  using MaterialAttenuationsType = vnl_matrix<double>;
63  using IncidentSpectrumType = vnl_matrix<float>;
67 
68  // Constructor
70  {
73  m_Initialized = false;
74  }
75 
76  // Destructor
78 
80  GetValue(const ParametersType & itkNotUsed(parameters)) const override
81  {
82  long double measure = 0;
83  return measure;
84  }
85  void
86  GetDerivative(const ParametersType & itkNotUsed(lineIntegrals),
87  DerivativeType & itkNotUsed(derivatives)) const override
88  {
89  itkExceptionMacro(<< "Not implemented");
90  }
91  virtual void
93  {}
94 
97  {
98  // Return the inverses of the diagonal components (i.e. the inverse variances, to be used directly in WLS
99  // reconstruction)
101  diag.SetSize(m_NumberOfMaterials);
102  diag.Fill(0);
103 
104  for (unsigned int mat = 0; mat < m_NumberOfMaterials; mat++)
105  diag[mat] = 1. / m_Fischer.GetInverse()[mat][mat];
106  return diag;
107  }
108 
111  {
112  // Return the whole Fischer information matrix
114  fischer.SetSize(m_NumberOfMaterials * m_NumberOfMaterials);
115  fischer.Fill(0);
116 
117  for (unsigned int i = 0; i < m_NumberOfMaterials; i++)
118  for (unsigned int j = 0; j < m_NumberOfMaterials; j++)
119  fischer[i * m_NumberOfMaterials + j] = m_Fischer[i][j];
120  return fischer;
121  }
122 
123  virtual void
124  ComputeFischerMatrix(const ParametersType & itkNotUsed(lineIntegrals))
125  {}
126 
127  unsigned int
128  GetNumberOfParameters() const override
129  {
130  return m_NumberOfMaterials;
131  }
132 
133  virtual vnl_vector<double>
134  ForwardModel(const ParametersType & lineIntegrals) const
135  {
136  vnl_vector<double> attenuationFactors;
137  attenuationFactors.set_size(m_NumberOfEnergies);
138  GetAttenuationFactors(lineIntegrals, attenuationFactors);
139 
140  // Apply detector response, getting the lambdas
141  return (m_IncidentSpectrumAndDetectorResponseProduct * attenuationFactors);
142  }
143 
144  void
145  GetAttenuationFactors(const ParametersType & lineIntegrals, vnl_vector<double> & attenuationFactors) const
146  {
147  // Apply attenuation at each energy
148  vnl_vector<double> vnlLineIntegrals;
149 
150  // Initialize the line integrals vnl vector
151  vnlLineIntegrals.set_size(m_NumberOfMaterials);
152  for (unsigned int m = 0; m < m_NumberOfMaterials; m++)
153  vnlLineIntegrals[m] = lineIntegrals[m];
154 
155  // Apply the material attenuations matrix
156  attenuationFactors = this->m_MaterialAttenuations * vnlLineIntegrals;
157 
158  // Compute the negative exponential
159  for (unsigned int energy = 0; energy < m_NumberOfEnergies; energy++)
160  {
161  attenuationFactors[energy] = std::exp(-attenuationFactors[energy]);
162  }
163  }
164 
167  {
169  initialGuess.SetSize(m_NumberOfMaterials);
170 
171  // Compute the mean attenuation in each bin, weighted by the input spectrum
172  // Needs to be done for each pixel, since the input spectrum is variable
173  MeanAttenuationInBinType MeanAttenuationInBin;
174  MeanAttenuationInBin.SetSize(this->m_NumberOfMaterials, this->m_NumberOfSpectralBins);
175  MeanAttenuationInBin.Fill(0);
176 
177  for (unsigned int mat = 0; mat < this->m_NumberOfMaterials; mat++)
178  {
179  for (unsigned int bin = 0; bin < m_NumberOfSpectralBins; bin++)
180  {
181  double accumulate = 0;
182  double accumulateWeights = 0;
183  for (int energy = m_Thresholds[bin] - 1;
184  (energy < m_Thresholds[bin + 1]) && (energy < (int)(this->m_MaterialAttenuations.rows()));
185  energy++)
186  {
187  accumulate += this->m_MaterialAttenuations[energy][mat] * this->m_IncidentSpectrum[0][energy];
188  accumulateWeights += this->m_IncidentSpectrum[0][energy];
189  }
190  MeanAttenuationInBin[mat][bin] = accumulate / accumulateWeights;
191  }
192  }
193 
194  for (unsigned int mat = 0; mat < m_NumberOfMaterials; mat++)
195  {
196  // Initialise to a very high value
197  initialGuess[mat] = 1e10;
198  for (unsigned int bin = 0; bin < m_NumberOfSpectralBins; bin++)
199  {
200  // Compute the length of current material required to obtain the attenuation
201  // observed in current bin. Keep only the minimum among all bins
202  double requiredLength = this->BinwiseLogTransform()[bin] / MeanAttenuationInBin[mat][bin];
203  if (initialGuess[mat] > requiredLength)
204  initialGuess[mat] = requiredLength;
205  }
206  }
207 
208  return initialGuess;
209  }
210 
213  {
214  itk::VariableLengthVector<double> logTransforms;
215  logTransforms.SetSize(m_NumberOfSpectralBins);
216 
217  vnl_vector<double> ones, nonAttenuated;
218  ones.set_size(m_NumberOfEnergies);
219  ones.fill(1.0);
220 
221  // The way m_IncidentSpectrumAndDetectorResponseProduct works is
222  // it is mutliplied by the vector of attenuations factors (here
223  // filled with ones, since we want the non-attenuated signal)
224  nonAttenuated = m_IncidentSpectrumAndDetectorResponseProduct * ones;
225 
226  for (unsigned int i = 0; i < m_MeasuredData.GetSize(); i++)
227  {
228  // Divide by the actually measured photon counts and apply log
229  if (m_MeasuredData[i] > 0)
230  logTransforms[i] = log(nonAttenuated[i] / m_MeasuredData[i]);
231  }
232 
233  return logTransforms;
234  }
235 
236  virtual vnl_vector<double>
237  GetVariances(const ParametersType & itkNotUsed(lineIntegrals)) const
238  {
239  vnl_vector<double> meaninglessResult;
240  meaninglessResult.set_size(m_NumberOfSpectralBins);
241  meaninglessResult.fill(0.);
242  return (meaninglessResult);
243  }
244 
245  itkSetMacro(MeasuredData, MeasuredDataType);
246  itkGetMacro(MeasuredData, MeasuredDataType);
247 
248  itkSetMacro(DetectorResponse, DetectorResponseType);
249  itkGetMacro(DetectorResponse, DetectorResponseType);
250 
251  itkSetMacro(MaterialAttenuations, MaterialAttenuationsType);
252  itkGetMacro(MaterialAttenuations, MaterialAttenuationsType);
253 
254  itkSetMacro(NumberOfEnergies, unsigned int);
255  itkGetMacro(NumberOfEnergies, unsigned int);
256 
257  itkSetMacro(NumberOfMaterials, unsigned int);
258  itkGetMacro(NumberOfMaterials, unsigned int);
259 
260  itkSetMacro(IncidentSpectrum, IncidentSpectrumType);
261  itkGetMacro(IncidentSpectrum, IncidentSpectrumType);
262 
263  itkSetMacro(NumberOfSpectralBins, unsigned int);
264  itkGetMacro(NumberOfSpectralBins, unsigned int);
265 
266  itkSetMacro(Thresholds, ThresholdsType);
267  itkGetMacro(Thresholds, ThresholdsType);
268 
269 protected:
276  unsigned int m_NumberOfEnergies;
277  unsigned int m_NumberOfMaterials;
281 };
282 
283 } // namespace rtk
284 
285 #endif
virtual vnl_vector< double > ForwardModel(const ParametersType &lineIntegrals) const
void GetAttenuationFactors(const ParametersType &lineIntegrals, vnl_vector< double > &attenuationFactors) const
void GetDerivative(const ParametersType &, DerivativeType &) const override
#define itkSetMacro(name, type)
Superclass::ParametersType ParametersType
Array< ParametersValueType > DerivativeType
virtual vnl_vector< double > GetVariances(const ParametersType &) const