RTK  2.4.1
Reconstruction Toolkit
rtkProjectionsDecompositionNegativeLogLikelihood.h
Go to the documentation of this file.
1 /*=========================================================================
2  *
3  * Copyright RTK Consortium
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * https://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  *=========================================================================*/
18 
19 #ifndef rtkProjectionsDecompositionNegativeLogLikelihood_h
20 #define rtkProjectionsDecompositionNegativeLogLikelihood_h
21 
23 #include <itkVectorImage.h>
25 #include <itkVariableSizeMatrix.h>
26 #include "rtkMacro.h"
27 
28 namespace rtk
29 {
38 // We have to define the cost function first
40 {
41 public:
42  ITK_DISALLOW_COPY_AND_MOVE(ProjectionsDecompositionNegativeLogLikelihood);
43 
48  itkNewMacro(Self);
50 
51  // enum { SpaceDimension=m_NumberOfMaterials };
52 
56 
57  using DetectorResponseType = vnl_matrix<double>;
58  using MaterialAttenuationsType = vnl_matrix<double>;
59  using IncidentSpectrumType = vnl_matrix<float>;
63 
64  // Constructor
66  {
69  m_Initialized = false;
70  }
71 
72  // Destructor
74 
76  GetValue(const ParametersType & itkNotUsed(parameters)) const override
77  {
78  long double measure = 0;
79  return measure;
80  }
81  void
82  GetDerivative(const ParametersType & itkNotUsed(lineIntegrals),
83  DerivativeType & itkNotUsed(derivatives)) const override
84  {
85  itkExceptionMacro(<< "Not implemented");
86  }
87  virtual void
89  {}
90 
93  {
94  // Return the inverses of the diagonal components (i.e. the inverse variances, to be used directly in WLS
95  // reconstruction)
97  diag.SetSize(m_NumberOfMaterials);
98  diag.Fill(0);
99 
100  for (unsigned int mat = 0; mat < m_NumberOfMaterials; mat++)
101  diag[mat] = 1. / m_Fischer.GetInverse()[mat][mat];
102  return diag;
103  }
104 
107  {
108  // Return the whole Fischer information matrix
110  fischer.SetSize(m_NumberOfMaterials * m_NumberOfMaterials);
111  fischer.Fill(0);
112 
113  for (unsigned int i = 0; i < m_NumberOfMaterials; i++)
114  for (unsigned int j = 0; j < m_NumberOfMaterials; j++)
115  fischer[i * m_NumberOfMaterials + j] = m_Fischer[i][j];
116  return fischer;
117  }
118 
119  virtual void
120  ComputeFischerMatrix(const ParametersType & itkNotUsed(lineIntegrals))
121  {}
122 
123  unsigned int
124  GetNumberOfParameters() const override
125  {
126  return m_NumberOfMaterials;
127  }
128 
129  virtual vnl_vector<double>
130  ForwardModel(const ParametersType & lineIntegrals) const
131  {
132  vnl_vector<double> attenuationFactors;
133  attenuationFactors.set_size(m_NumberOfEnergies);
134  GetAttenuationFactors(lineIntegrals, attenuationFactors);
135 
136  // Apply detector response, getting the lambdas
137  return (m_IncidentSpectrumAndDetectorResponseProduct * attenuationFactors);
138  }
139 
140  void
141  GetAttenuationFactors(const ParametersType & lineIntegrals, vnl_vector<double> & attenuationFactors) const
142  {
143  // Apply attenuation at each energy
144  vnl_vector<double> vnlLineIntegrals;
145 
146  // Initialize the line integrals vnl vector
147  vnlLineIntegrals.set_size(m_NumberOfMaterials);
148  for (unsigned int m = 0; m < m_NumberOfMaterials; m++)
149  vnlLineIntegrals[m] = lineIntegrals[m];
150 
151  // Apply the material attenuations matrix
152  attenuationFactors = this->m_MaterialAttenuations * vnlLineIntegrals;
153 
154  // Compute the negative exponential
155  for (unsigned int energy = 0; energy < m_NumberOfEnergies; energy++)
156  {
157  attenuationFactors[energy] = std::exp(-attenuationFactors[energy]);
158  }
159  }
160 
163  {
165  initialGuess.SetSize(m_NumberOfMaterials);
166 
167  // Compute the mean attenuation in each bin, weighted by the input spectrum
168  // Needs to be done for each pixel, since the input spectrum is variable
169  MeanAttenuationInBinType MeanAttenuationInBin;
170  MeanAttenuationInBin.SetSize(this->m_NumberOfMaterials, this->m_NumberOfSpectralBins);
171  MeanAttenuationInBin.Fill(0);
172 
173  for (unsigned int mat = 0; mat < this->m_NumberOfMaterials; mat++)
174  {
175  for (unsigned int bin = 0; bin < m_NumberOfSpectralBins; bin++)
176  {
177  double accumulate = 0;
178  double accumulateWeights = 0;
179  for (int energy = m_Thresholds[bin] - 1;
180  (energy < m_Thresholds[bin + 1]) && (energy < (int)(this->m_MaterialAttenuations.rows()));
181  energy++)
182  {
183  accumulate += this->m_MaterialAttenuations[energy][mat] * this->m_IncidentSpectrum[0][energy];
184  accumulateWeights += this->m_IncidentSpectrum[0][energy];
185  }
186  MeanAttenuationInBin[mat][bin] = accumulate / accumulateWeights;
187  }
188  }
189 
190  for (unsigned int mat = 0; mat < m_NumberOfMaterials; mat++)
191  {
192  // Initialise to a very high value
193  initialGuess[mat] = 1e10;
194  for (unsigned int bin = 0; bin < m_NumberOfSpectralBins; bin++)
195  {
196  // Compute the length of current material required to obtain the attenuation
197  // observed in current bin. Keep only the minimum among all bins
198  double requiredLength = this->BinwiseLogTransform()[bin] / MeanAttenuationInBin[mat][bin];
199  if (initialGuess[mat] > requiredLength)
200  initialGuess[mat] = requiredLength;
201  }
202  }
203 
204  return initialGuess;
205  }
206 
209  {
210  itk::VariableLengthVector<double> logTransforms;
211  logTransforms.SetSize(m_NumberOfSpectralBins);
212 
213  vnl_vector<double> ones, nonAttenuated;
214  ones.set_size(m_NumberOfEnergies);
215  ones.fill(1.0);
216 
217  // The way m_IncidentSpectrumAndDetectorResponseProduct works is
218  // it is mutliplied by the vector of attenuations factors (here
219  // filled with ones, since we want the non-attenuated signal)
220  nonAttenuated = m_IncidentSpectrumAndDetectorResponseProduct * ones;
221 
222  for (unsigned int i = 0; i < m_MeasuredData.GetSize(); i++)
223  {
224  // Divide by the actually measured photon counts and apply log
225  if (m_MeasuredData[i] > 0)
226  logTransforms[i] = log(nonAttenuated[i] / m_MeasuredData[i]);
227  }
228 
229  return logTransforms;
230  }
231 
232  virtual vnl_vector<double>
233  GetVariances(const ParametersType & itkNotUsed(lineIntegrals)) const
234  {
235  vnl_vector<double> meaninglessResult;
236  meaninglessResult.set_size(m_NumberOfSpectralBins);
237  meaninglessResult.fill(0.);
238  return (meaninglessResult);
239  }
240 
241  itkSetMacro(MeasuredData, MeasuredDataType);
242  itkGetMacro(MeasuredData, MeasuredDataType);
243 
244  itkSetMacro(DetectorResponse, DetectorResponseType);
245  itkGetMacro(DetectorResponse, DetectorResponseType);
246 
247  itkSetMacro(MaterialAttenuations, MaterialAttenuationsType);
248  itkGetMacro(MaterialAttenuations, MaterialAttenuationsType);
249 
250  itkSetMacro(NumberOfEnergies, unsigned int);
251  itkGetMacro(NumberOfEnergies, unsigned int);
252 
253  itkSetMacro(NumberOfMaterials, unsigned int);
254  itkGetMacro(NumberOfMaterials, unsigned int);
255 
256  itkSetMacro(IncidentSpectrum, IncidentSpectrumType);
257  itkGetMacro(IncidentSpectrum, IncidentSpectrumType);
258 
259  itkSetMacro(NumberOfSpectralBins, unsigned int);
260  itkGetMacro(NumberOfSpectralBins, unsigned int);
261 
262  itkSetMacro(Thresholds, ThresholdsType);
263  itkGetMacro(Thresholds, ThresholdsType);
264 
265 protected:
272  unsigned int m_NumberOfEnergies;
273  unsigned int m_NumberOfMaterials;
277 };
278 
279 } // namespace rtk
280 
281 #endif
virtual vnl_vector< double > ForwardModel(const ParametersType &lineIntegrals) const
void GetAttenuationFactors(const ParametersType &lineIntegrals, vnl_vector< double > &attenuationFactors) const
void GetDerivative(const ParametersType &, DerivativeType &) const override
#define itkSetMacro(name, type)
Superclass::ParametersType ParametersType
Array< ParametersValueType > DerivativeType
virtual vnl_vector< double > GetVariances(const ParametersType &) const