17#ifndef __registration_metric_cross_correlation_h__
18#define __registration_metric_cross_correlation_h__
26 namespace Registration
35 using requires_precompute = int;
37 template <
class Params>
40 Eigen::Matrix<default_type, Eigen::Dynamic, 1>& gradient) {
44 assert (params.processed_mask.valid());
45 assert (params.processed_image.valid());
46 assert (!this->
weighted &&
"FIXME: set_weights not implemented for CrossCorrelationNoGradient metric");
48 assign_pos_of (iter, 0, 3).
to (params.processed_mask);
49 if (!params.processed_mask.value())
53 ++params.processed_image.index(3);
55 --params.processed_image.index(3);
57 return (mean1 - val1) * (val2 - mean2);
60 template <
class ParamType>
62 DEBUG (
"precomputing cross correlation data...");
64 using Im1Type =
decltype(parameters.im1_image);
65 using Im2Type =
decltype(parameters.im2_image);
66 using MidwayImageType =
decltype(parameters.midway_image);
67 using Im1MaskType =
decltype(parameters.im1_mask);
68 using Im2MaskType =
decltype(parameters.im2_mask);
69 using Im1ImageInterpolatorType =
typename ParamType::Im1InterpType;
70 using Im2ImageInterpolatorType =
typename ParamType::Im2InterpType;
71 using PImageType =
typename ParamType::ProcessedImageType;
74 using Im1MaskInterpolatorType =
typename ParamType::Mask1InterpolatorType;
75 using Im2MaskInterpolatorType =
typename ParamType::Mask2InterpolatorType;
77 assert (parameters.midway_image.ndim() == 3);
82 Header midway_header (parameters.midway_image);
85 parameters.processed_mask = Header::scratch (midway_header).template get_image<bool>();
87 auto cc_header = Header::scratch (parameters.midway_image);
89 cc_header.size(3) = 2;
91 parameters.processed_image = PImageType::scratch (cc_header);
93 auto loop =
ThreadedLoop (
"precomputing cross correlation data...", parameters.processed_image, 0, 3);
94 loop.run (CCNoGradientPrecomputeFunctor<
decltype(parameters.transformation),
100 Im1ImageInterpolatorType,
101 Im2ImageInterpolatorType,
102 Im1MaskInterpolatorType,
103 Im2MaskInterpolatorType> (
104 parameters.transformation,
105 parameters.im1_image,
106 parameters.im2_image,
107 parameters.midway_image,
112 overlap), parameters.processed_image, parameters.processed_mask);
123 DEBUG (
"Cross Correlation metric: zero overlap");
135 typename LinearTrafoType,
138 typename MidwayImageType,
141 typename Im1ImageInterpolatorType,
142 typename Im2ImageInterpolatorType,
143 typename Im1MaskInterpolatorType,
144 typename Im2MaskInterpolatorType
146 struct CCNoGradientPrecomputeFunctor {
MEMALIGN(CCNoGradientPrecomputeFunctor)
147 CCNoGradientPrecomputeFunctor (
148 const LinearTrafoType& transformation,
151 const MidwayImageType& midway,
157 trafo_half (transformation.get_transform_half()),
158 trafo_half_inverse (transformation.get_transform_half_inverse()),
166 global_cnt (overlap),
170 assert (in1.valid());
171 assert (in2.valid());
172 im1_image_interp.reset (
new Im1ImageInterpolatorType (in1));
173 im2_image_interp.reset (
new Im2ImageInterpolatorType (in2));
175 im1_mask_interp.reset (
new Im1MaskInterpolatorType (msk1));
177 im2_mask_interp.reset (
new Im2MaskInterpolatorType (msk2));
180 ~CCNoGradientPrecomputeFunctor () {
187 template <
typename ProcessedImageType,
typename MaskImageType>
188 void operator() (ProcessedImageType& pimage, MaskImageType& mask) {
189 assert(mask.index(0) == pimage.index(0));
190 assert(mask.index(1) == pimage.index(1));
191 assert(mask.index(2) == pimage.index(2));
192 assert(pimage.index(3) == 0);
196 pos1 = trafo_half * pos;
198 im1_mask_interp->scanner(pos1);
199 if (!(*im1_mask_interp))
201 if (im1_mask_interp->value() < 0.5)
205 pos2 = trafo_half_inverse * pos;
207 im2_mask_interp->scanner(pos2);
208 if (!(*im2_mask_interp))
210 if (im2_mask_interp->value() < 0.5)
214 im1_image_interp->scanner(pos1);
215 if (!(*im1_image_interp))
217 v1 = im1_image_interp->value();
221 im2_image_interp->scanner(pos2);
222 if (!(*im2_image_interp))
224 v2 = im2_image_interp->value();
240 const Eigen::Transform<default_type, 3, Eigen::AffineCompact> trafo_half, trafo_half_inverse;
250 Eigen::Vector3d vox, pos, pos1, pos2;
a dummy image to iterate over, useful for multi-threaded looping.
double default_type
the default type used throughout MRtrix
Eigen::Transform< default_type, 3, Eigen::AffineCompact > transform_type
the type for the affine transform of an image:
ThreadedLoopRunOuter< decltype(Loop(vector< size_t >()))> ThreadedLoop(const HeaderType &source, const vector< size_t > &outer_axes, const vector< size_t > &inner_axes)
Multi-threaded loop object.
T to(const std::string &string)