17#ifndef __registration_metric_evaluate_h__
18#define __registration_metric_evaluate_h__
28 namespace Registration
39 template <
class MetricType,
typename U =
void>
44 template <
class MetricType>
45 struct metric_requires_precompute<MetricType, typename Void2<typename MetricType::requires_precompute>::type> {
NOMEMALIGN
49 template <
class MetricType,
typename U =
void>
50 struct metric_requires_initialisation {
NOMEMALIGN
54 template <
class MetricType>
55 struct metric_requires_initialisation<MetricType, typename Void2<typename MetricType::requires_initialisation>::type> {
NOMEMALIGN
61 template <
class MetricType,
class ParamType>
65 using TransformParamType =
typename ParamType::TransformParamType;
68 template <
class U = MetricType>
71 template <
class U = MetricType>
72 Evaluate (
const MetricType& metric_, ParamType& parameters,
typename metric_requires_initialisation<U>::yes = 0) :
77 metric.init (parameters.im1_image, parameters.im2_image);
81 template <
class U = MetricType>
82 Evaluate (
const MetricType& metric_, ParamType& parameters,
typename metric_requires_initialisation<U>::no = 0) :
88 template <
class U = MetricType>
89 default_type operator() (
const Eigen::Matrix<default_type, Eigen::Dynamic, 1>& x, Eigen::Matrix<default_type, Eigen::Dynamic, 1>& gradient,
typename metric_requires_precompute<U>::yes = 0) {
90 Eigen::VectorXd overall_cost_function = Eigen::VectorXd::Zero(1,1);
93 params.transformation.set_parameter_vector(x);
96 DEBUG (
"Reorienting FODs...");
97 std::shared_ptr<Image<default_type> > im1_image_reoriented;
98 std::shared_ptr<Image<default_type> > im2_image_reoriented;
103 if (
params.mc_settings.size()) {
104 DEBUG (
"Tissue contrast specific FOD reorientation");
108 DEBUG (
"FOD reorientation");
114 params.set_im1_iterpolator (*im1_image_reoriented);
115 params.set_im2_iterpolator (*im2_image_reoriented);
127 DEBUG (
"Metric evaluate iteration: " +
str(
iteration++) +
", cost: " +
str(overall_cost_function.transpose()));
128 DEBUG (
" x: " +
str(x.transpose()));
129 DEBUG (
" gradient: " +
str(gradient.transpose()));
130 DEBUG (
" norm(gradient): " +
str(gradient.norm()));
132 return overall_cost_function(0);
170 template <
class U = MetricType>
171 default_type operator() (
const Eigen::Matrix<default_type, Eigen::Dynamic, 1>& x, Eigen::Matrix<default_type, Eigen::Dynamic, 1>& gradient,
typename metric_requires_precompute<U>::no = 0) {
172 Eigen::VectorXd overall_cost_function = Eigen::VectorXd::Zero(1,1);
174 params.transformation.set_parameter_vector(x);
177 DEBUG (
"Reorienting FODs...");
178 std::shared_ptr<Image<default_type> > im1_image_reoriented;
179 std::shared_ptr<Image<default_type> > im2_image_reoriented;
184 if (
params.mc_settings.size()) {
185 DEBUG (
"Tissue contrast specific FOD reorientation");
189 DEBUG (
"FOD reorientation");
195 params.set_im1_iterpolator (*im1_image_reoriented);
196 params.set_im2_iterpolator (*im2_image_reoriented);
200 if (
params.loop_density < 1.0) {
201 DEBUG (
"stochastic gradient descent, density: " +
str(
params.loop_density));
214 if (
params.robust_estimate_subset) {
215 assert(
params.robust_estimate_subset_from.size() == 3);
216 assert(
params.robust_estimate_subset_size.size() == 3);
220 for (
auto i =
Loop(0,3) (subset); i; ++i) {
229 DEBUG (
"Metric evaluate iteration: " +
str(
iteration++) +
", cost: " +
str(overall_cost_function.transpose()));
230 DEBUG (
" x: " +
str(x.transpose()));
231 DEBUG (
" gradient: " +
str(gradient.transpose()));
232 DEBUG (
" norm(gradient): " +
str(gradient.norm()));
234 return overall_cost_function(0);
238 return params.transformation.size();
246 params.transformation.get_parameter_vector(x);
250 void set_directions (
const Eigen::MatrixXd& dir) {
static Image scratch(const Header &template_header, const std::string &label="scratch image")
Eigen::MatrixXd directions
FORCE_INLINE LoopAlongAxes Loop()
thread_local Math::RNG rng
thread-local, but globally accessible RNG to vastly simplify multi-threading
double default_type
the default type used throughout MRtrix
std::string str(const T &value, int precision=0)
ThreadedLoopRunOuter< decltype(Loop(vector< size_t >()))> ThreadedLoop(const HeaderType &source, const vector< size_t > &outer_axes, const vector< size_t > &inner_axes)
Multi-threaded loop object.
std::remove_reference< Functor >::type & functor