17#ifndef __dwi_tractography_algorithms_tensor_prob_h__
18#define __dwi_tractography_algorithms_tensor_prob_h__
29 namespace Tractography
38 class Shared :
public Tensor_Det::Shared {
MEMALIGN(Shared)
41 Tensor_Det::Shared (diff_path, property_set) {
43 if (is_act() && act().backtrack())
44 throw Exception (
"Sorry, backtracking not currently enabled for TensorProb algorithm");
46 properties[
"method"] =
"TensorProb";
71 bool init()
override {
87 void truncate_track (
GeneratedTrack& tck,
const size_t length_to_revert_from,
const size_t revert_step)
override { assert (0); }
98 log_signal (H.rows()) { }
100 void operator() (
float* data) {
101 for (ssize_t i = 0; i < residuals.size(); ++i)
102 log_signal[i] = data[i] >
float (0.0) ? -std::log (data[i]) : float (0.0);
104 residuals = H * log_signal;
106 for (ssize_t i = 0; i < residuals.size(); ++i) {
107 residuals[i] = residuals[i] ? (data[i] - std::exp (-residuals[i])) :
float(0.0);
108 data[i] += uniform_int (
rng) ? residuals[i] : -residuals[i];
113 const Eigen::MatrixXf& H;
114 std::uniform_int_distribution<> uniform_int;
115 Eigen::VectorXf residuals, log_signal;
124 for (
size_t i = 0; i < 8; ++i)
125 raw_signals.push_back (Eigen::VectorXf (size(3)));
130 bool get (
const Eigen::Vector3f& pos, Eigen::VectorXf& data) {
131 if (!scanner (pos)) {
139 for (ssize_t z = 0; z < 2; ++z) {
140 index(2) = clamp (P[2]+z, size(2));
141 for (ssize_t y = 0; y < 2; ++y) {
142 index(1) = clamp (P[1]+y, size(1));
143 for (ssize_t x = 0; x < 2; ++x) {
144 index(0) = clamp (P[0]+x, size(0));
146 get_values (raw_signals[i]);
147 data += factors[i] * raw_signals[i];
154 return !std::isnan (data[0]);
thread_local Math::RNG rng
thread-local, but globally accessible RNG to vastly simplify multi-threading
constexpr default_type NaN