Program Listing for File trx.tpp
↰ Return to documentation for file (include/trx/trx.tpp)
// Taken from: https://stackoverflow.com/a/25389481
#ifndef TRX_H
#define TRX_TPP_STANDALONE
#define TRX_TPP_OPEN_NAMESPACE
#include <trx/trx.h>
#undef TRX_TPP_STANDALONE
namespace trx {
#endif
using Eigen::Dynamic;
using Eigen::half;
using Eigen::Index;
using Eigen::Map;
using Eigen::Matrix;
using Eigen::RowMajor;
inline void mkdir_or_throw(const std::string &path) {
std::error_code ec;
trx::fs::create_directories(path, ec);
if (ec) {
throw TrxIOError("Could not create directory " + path);
}
}
inline json default_header() {
std::vector<std::vector<float>> affine(4, std::vector<float>(4, 0.0f));
for (int i = 0; i < 4; i++) {
affine[i][i] = 1.0f;
}
json::object obj;
obj["VOXEL_TO_RASMM"] = affine;
obj["DIMENSIONS"] = std::vector<uint16_t>{1, 1, 1};
obj["NB_VERTICES"] = 0;
obj["NB_STREAMLINES"] = 0;
return json(obj);
}
inline std::string folder_from_path(const std::string &elem_filename, const std::string &root) {
trx::fs::path elem_path(elem_filename);
trx::fs::path folder_path = elem_path.parent_path();
std::string folder;
if (!root.empty()) {
trx::fs::path rel_path = elem_path.lexically_relative(trx::fs::path(root));
std::string rel_str = rel_path.string();
if (!rel_str.empty() && rel_str.rfind("..", 0) != 0) {
folder = rel_path.parent_path().string();
} else {
folder = folder_path.string();
}
} else {
folder = folder_path.string();
}
if (folder == ".") {
folder.clear();
}
return folder;
}
template <typename T>
void materialize_matrix_map_and_unmap(MMappedMatrix<T> &mapped_matrix) {
const int rows = mapped_matrix._matrix.rows();
const int cols = mapped_matrix._matrix.cols();
const size_t n = static_cast<size_t>(rows) * static_cast<size_t>(cols);
mapped_matrix._matrix_owned.resize(n);
if (n > 0) {
std::copy_n(mapped_matrix._matrix.data(), n, mapped_matrix._matrix_owned.data());
trx::detail::remap(mapped_matrix._matrix, mapped_matrix._matrix_owned.data(), rows, cols);
} else {
trx::detail::remap(mapped_matrix._matrix, static_cast<T *>(nullptr), rows, cols);
}
mapped_matrix.mmap.unmap();
}
template <typename T>
void materialize_sequence_data_and_unmap(ArraySequence<T> &sequence) {
const int rows = sequence._data.rows();
const int cols = sequence._data.cols();
const size_t n = static_cast<size_t>(rows) * static_cast<size_t>(cols);
sequence._data_owned.resize(n);
if (n > 0) {
std::copy_n(sequence._data.data(), n, sequence._data_owned.data());
trx::detail::remap(sequence._data, sequence._data_owned.data(), rows, cols);
} else {
trx::detail::remap(sequence._data, static_cast<T *>(nullptr), rows, cols);
}
sequence.mmap_pos.unmap();
}
template <typename OutT, typename InT>
void copy_cast_buffer(const InT *src, size_t n, std::vector<OutT> &dst) {
dst.resize(n);
for (size_t i = 0; i < n; ++i) {
dst[i] = static_cast<OutT>(src[i]);
}
}
template <typename OutT>
void copy_cast_from_dtype_buffer(const void *src, size_t n, const std::string &dtype, std::vector<OutT> &dst) {
if (dtype == "uint8") {
copy_cast_buffer(reinterpret_cast<const uint8_t *>(src), n, dst); // NOLINT
return;
}
if (dtype == "uint16" || dtype == "ushort") {
copy_cast_buffer(reinterpret_cast<const uint16_t *>(src), n, dst); // NOLINT
return;
}
if (dtype == "uint32") {
copy_cast_buffer(reinterpret_cast<const uint32_t *>(src), n, dst); // NOLINT
return;
}
if (dtype == "uint64") {
copy_cast_buffer(reinterpret_cast<const uint64_t *>(src), n, dst); // NOLINT
return;
}
if (dtype == "int8") {
copy_cast_buffer(reinterpret_cast<const int8_t *>(src), n, dst); // NOLINT
return;
}
if (dtype == "int16") {
copy_cast_buffer(reinterpret_cast<const int16_t *>(src), n, dst); // NOLINT
return;
}
if (dtype == "int32") {
copy_cast_buffer(reinterpret_cast<const int32_t *>(src), n, dst); // NOLINT
return;
}
if (dtype == "int64") {
copy_cast_buffer(reinterpret_cast<const int64_t *>(src), n, dst); // NOLINT
return;
}
if (dtype == "float16") {
copy_cast_buffer(reinterpret_cast<const Eigen::half *>(src), n, dst); // NOLINT
return;
}
if (dtype == "float32") {
copy_cast_buffer(reinterpret_cast<const float *>(src), n, dst); // NOLINT
return;
}
if (dtype == "float64") {
copy_cast_buffer(reinterpret_cast<const double *>(src), n, dst); // NOLINT
return;
}
throw TrxDTypeError("Unsupported metadata dtype for conversion: " + dtype);
}
template <class Matrix> void write_binary(const std::string &filename, const Matrix &matrix) {
std::ofstream out(filename, std::ios::out | std::ios::binary | std::ios::trunc);
typename Matrix::Index rows = matrix.rows(), cols = matrix.cols();
// out.write((char *)(&rows), sizeof(typename Matrix::Index));
// out.write((char *)(&cols), sizeof(typename Matrix::Index));
const auto *data = reinterpret_cast<const char *>(matrix.data()); // check_syntax off
out.write(data, rows * cols * sizeof(typename Matrix::Scalar));
out.close();
}
template <class Matrix> void read_binary(const std::string &filename, Matrix &matrix) {
std::ifstream in(filename, std::ios::in | std::ios::binary);
typename Matrix::Index rows = 0, cols = 0;
auto *rows_ptr = reinterpret_cast<char *>(&rows); // check_syntax off
auto *cols_ptr = reinterpret_cast<char *>(&cols); // check_syntax off
in.read(rows_ptr, sizeof(typename Matrix::Index));
in.read(cols_ptr, sizeof(typename Matrix::Index));
matrix.resize(rows, cols);
auto *matrix_ptr = reinterpret_cast<char *>(matrix.data()); // check_syntax off
in.read(matrix_ptr, rows * cols * sizeof(typename Matrix::Scalar));
in.close();
}
template <typename DT>
void ediff1d(Matrix<DT, Dynamic, 1> &lengths, Matrix<DT, Dynamic, Dynamic> &tmp, uint32_t to_end) {
Map<Matrix<uint32_t, 1, Dynamic>> v(tmp.data(), tmp.size());
lengths.resize(v.size(), 1);
// TODO: figure out if there's a built in way to manage this
for (int i = 0; i < v.size() - 1; i++) {
lengths(i) = v(i + 1) - v(i);
}
lengths(v.size() - 1) = to_end;
}
template <typename DT>
// Caveat: if filename has an extension, it will be replaced by the generated dtype extension.
std::string _generate_filename_from_data(const Eigen::MatrixBase<DT> &arr, std::string filename) {
std::string base, ext;
base = filename; // get_base(SEPARATOR, filename);
ext = get_ext(filename);
if (ext.size() != 0) {
base = base.substr(0, base.length() - ext.length() - 1);
}
std::string dt = dtype_from_scalar<typename DT::Scalar>();
Eigen::Index n_cols = arr.cols();
std::string new_filename;
if (n_cols == 1) {
new_filename = base + "." + dt;
} else {
new_filename = base + "." + std::to_string(static_cast<long long>(n_cols)) + "." + dt;
}
return new_filename;
}
template <typename DT>
std::unique_ptr<TrxFile<DT>> TrxFile<DT>::make_empty_like() const {
auto empty = std::make_unique<TrxFile<DT>>();
empty->header = _json_set(this->header, "NB_VERTICES", 0);
empty->header = _json_set(empty->header, "NB_STREAMLINES", 0);
return empty;
}
template <typename DT>
TrxFile<DT>::TrxFile(int nb_vertices, int nb_streamlines, const TrxFile<DT> *init_as, std::string reference) {
std::vector<std::vector<float>> affine(4);
std::vector<uint16_t> dimensions(3);
// TODO: check if there's a more efficient way to do this with Eigen
if (init_as != nullptr) {
for (int i = 0; i < 4; i++) {
affine[i] = {0, 0, 0, 0};
for (int j = 0; j < 4; j++) {
affine[i][j] = static_cast<float>(init_as->header["VOXEL_TO_RASMM"][i][j].number_value());
}
}
for (int i = 0; i < 3; i++) {
dimensions[i] = static_cast<uint16_t>(init_as->header["DIMENSIONS"][i].int_value());
}
}
// TODO: add else if for get_reference_info
else {
// identity matrix
for (int i = 0; i < 4; i++) {
affine[i] = {0, 0, 0, 0};
affine[i][i] = 1;
}
dimensions = {1, 1, 1};
}
if (nb_vertices == 0 && nb_streamlines == 0) {
if (init_as != nullptr) {
// raise error here
throw TrxArgumentError("Can't use init_as without declaring nb_vertices and nb_streamlines");
}
// will remove as completely unecessary. using as placeholders
this->header = {};
this->streamlines.reset();
// TODO: maybe create a matrix to map to of specified DT. Do we need this??
// set default datatype to half
// default data is null so will not set data. User will need configure desired datatype
// this->streamlines = ArraySequence<half>();
this->_uncompressed_folder_handle = "";
nb_vertices = 0;
nb_streamlines = 0;
} else if (nb_vertices > 0 && nb_streamlines > 0) {
auto trx = _initialize_empty_trx<DT>(nb_streamlines, nb_vertices, init_as);
this->streamlines = std::move(trx->streamlines);
this->groups = std::move(trx->groups);
this->data_per_streamline = std::move(trx->data_per_streamline);
this->data_per_vertex = std::move(trx->data_per_vertex);
this->data_per_group = std::move(trx->data_per_group);
this->_uncompressed_folder_handle = std::move(trx->_uncompressed_folder_handle);
this->_owns_uncompressed_folder = trx->_owns_uncompressed_folder;
this->_copy_safe = trx->_copy_safe;
trx->_owns_uncompressed_folder = false;
trx->_uncompressed_folder_handle.clear();
} else {
throw TrxArgumentError("You must declare both NB_VERTICES AND NB_STREAMLINES");
}
json::object header_obj;
header_obj["VOXEL_TO_RASMM"] = affine;
header_obj["DIMENSIONS"] = dimensions;
header_obj["NB_VERTICES"] = nb_vertices;
header_obj["NB_STREAMLINES"] = nb_streamlines;
this->header = json(header_obj);
this->_copy_safe = true;
}
template <typename DT>
std::unique_ptr<TrxFile<DT>> _initialize_empty_trx(int nb_streamlines, int nb_vertices, const TrxFile<DT> *init_as) {
auto trx = std::make_unique<TrxFile<DT>>();
std::string tmp_dir = make_temp_dir("trx");
json header = json::object();
if (init_as != nullptr) {
header = init_as->header;
}
header = _json_set(header, "NB_VERTICES", nb_vertices);
header = _json_set(header, "NB_STREAMLINES", nb_streamlines);
std::string positions_dtype;
std::string offsets_dtype;
std::string lengths_dtype;
if (init_as != nullptr) {
header = _json_set(header, "VOXEL_TO_RASMM", init_as->header["VOXEL_TO_RASMM"]);
header = _json_set(header, "DIMENSIONS", init_as->header["DIMENSIONS"]);
positions_dtype = dtype_from_scalar<DT>();
offsets_dtype = dtype_from_scalar<uint64_t>();
lengths_dtype = dtype_from_scalar<uint32_t>();
} else {
positions_dtype = dtype_from_scalar<DT>();
offsets_dtype = dtype_from_scalar<uint64_t>();
lengths_dtype = dtype_from_scalar<uint32_t>();
}
std::string positions_filename(tmp_dir);
positions_filename += "/positions.3." + positions_dtype;
std::tuple<int, int> shape = std::make_tuple(nb_vertices, 3);
trx->streamlines = std::make_unique<ArraySequence<DT>>();
trx->streamlines->mmap_pos = trx::_create_memmap(positions_filename, shape, "w+", positions_dtype);
trx::detail::remap(trx->streamlines->_data, trx->streamlines->mmap_pos.data(), shape);
std::string offsets_filename(tmp_dir);
offsets_filename += "/offsets." + offsets_dtype;
std::tuple<int, int> shape_off = std::make_tuple(nb_streamlines + 1, 1);
trx->streamlines->mmap_off = trx::_create_memmap(offsets_filename, shape_off, "w+", offsets_dtype);
trx::detail::remap(trx->streamlines->_offsets, trx->streamlines->mmap_off.data(), shape_off);
trx->streamlines->_lengths.resize(nb_streamlines);
trx->streamlines->_lengths.setZero();
if (init_as != nullptr) {
std::string dpv_dirname;
std::string dps_dirname;
if (init_as->data_per_vertex.size() > 0) {
dpv_dirname = tmp_dir + "/dpv/";
mkdir_or_throw(dpv_dirname);
}
if (init_as->data_per_streamline.size() > 0) {
dps_dirname = tmp_dir + "/dps/";
mkdir_or_throw(dps_dirname);
}
for (auto const &x : init_as->data_per_vertex) {
int rows, cols;
std::string dpv_dtype = dtype_from_scalar<DT>();
Map<Matrix<DT, Dynamic, Dynamic, RowMajor>> tmp_as = init_as->data_per_vertex.find(x.first)->second->_data;
std::string dpv_filename;
if (tmp_as.rows() == 1) {
dpv_filename = dpv_dirname + x.first + "." + dpv_dtype;
rows = nb_vertices;
cols = 1;
} else {
rows = nb_vertices;
cols = tmp_as.cols();
dpv_filename = dpv_dirname + x.first + "." + std::to_string(cols) + "." + dpv_dtype;
}
std::tuple<int, int> dpv_shape = std::make_tuple(rows, cols);
trx->data_per_vertex[x.first] = std::make_unique<ArraySequence<DT>>();
trx->data_per_vertex[x.first]->mmap_pos = trx::_create_memmap(dpv_filename, dpv_shape, "w+", dpv_dtype);
trx::detail::remap(trx->data_per_vertex[x.first]->_data, trx->data_per_vertex[x.first]->mmap_pos.data(), rows,
cols);
trx::detail::remap(trx->data_per_vertex[x.first]->_offsets, trx->streamlines->_offsets.data(),
int(trx->streamlines->_offsets.rows()), int(trx->streamlines->_offsets.cols()));
trx->data_per_vertex[x.first]->_lengths = trx->streamlines->_lengths;
}
for (auto const &x : init_as->data_per_streamline) {
std::string dps_dtype = dtype_from_scalar<DT>();
int rows, cols;
Map<Matrix<DT, Dynamic, Dynamic>> tmp_as = init_as->data_per_streamline.find(x.first)->second->_matrix;
std::string dps_filename;
if (tmp_as.rows() == 1) {
dps_filename = dps_dirname + x.first + "." + dps_dtype;
rows = nb_streamlines;
} else {
cols = tmp_as.cols();
rows = nb_streamlines;
dps_filename = dps_dirname + x.first + "." + std::to_string(cols) + "." + dps_dtype;
}
std::tuple<int, int> dps_shape = std::make_tuple(rows, cols);
trx->data_per_streamline[x.first] = std::make_unique<trx::MMappedMatrix<DT>>();
trx->data_per_streamline[x.first]->mmap =
trx::_create_memmap(dps_filename, dps_shape, std::string("w+"), dps_dtype);
trx::detail::remap(trx->data_per_streamline[x.first]->_matrix, trx->data_per_streamline[x.first]->mmap.data(),
rows, cols);
}
}
trx->header = header;
trx->_uncompressed_folder_handle = tmp_dir;
trx->_owns_uncompressed_folder = true;
return trx;
}
template <typename DT>
std::unique_ptr<TrxFile<DT>>
TrxFile<DT>::_create_trx_from_pointer(json header,
std::map<std::string, std::tuple<long long, long long>> dict_pointer_size,
std::string root_zip,
std::string root) {
auto trx = std::make_unique<trx::TrxFile<DT>>();
trx->header = header;
trx->streamlines = std::make_unique<ArraySequence<DT>>();
std::string filename;
// Iterate in reverse so that "positions" and "offsets" (which sort after "dpv"/"dps"/"groups")
// are processed first, before DPS/DPV entries that depend on them being initialized.
for (auto x = dict_pointer_size.rbegin(); x != dict_pointer_size.rend(); ++x) {
std::string elem_filename = x->first;
if (root_zip.size() > 0) {
filename = root_zip;
} else {
filename = elem_filename;
}
std::string folder = folder_from_path(elem_filename, root);
auto [base, dim, ext] = trx::detail::_split_ext_with_dimensionality(elem_filename);
long long mem_adress = std::get<0>(x->second);
long long size = std::get<1>(x->second);
if (base == "positions" && (folder.empty() || folder == ".")) {
const auto nb_vertices = static_cast<int64_t>(trx->header["NB_VERTICES"].int_value());
const auto expected = nb_vertices * 3;
if (size != expected || dim != 3) {
throw TrxFormatError("Wrong data size/dimensionality: size=" + std::to_string(size) +
" expected=" + std::to_string(expected) + " dim=" + std::to_string(dim) +
" filename=" + elem_filename);
}
std::tuple<int, int> shape = std::make_tuple(static_cast<int>(trx->header["NB_VERTICES"].int_value()), 3);
trx->streamlines->mmap_pos =
trx::_create_memmap(filename, shape, "r+", ext, mem_adress);
trx::detail::remap(trx->streamlines->_data, trx->streamlines->mmap_pos.data(), shape);
}
else if (base == "offsets" && (folder.empty() || folder == ".")) {
const auto nb_streamlines = static_cast<int64_t>(trx->header["NB_STREAMLINES"].int_value());
const auto nb_vertices = static_cast<uint64_t>(trx->header["NB_VERTICES"].int_value());
const auto expected = nb_streamlines + 1;
const bool missing_sentinel = (size == nb_streamlines && dim == 1);
if ((size != expected && !missing_sentinel) || dim != 1) {
throw TrxFormatError("Wrong offsets size/dimensionality: size=" + std::to_string(size) +
" expected=" + std::to_string(expected) + " dim=" + std::to_string(dim) +
" filename=" + elem_filename);
}
const int nb_str = static_cast<int>(trx->header["NB_STREAMLINES"].int_value());
const int offsets_rows = missing_sentinel ? (nb_str + 1) : static_cast<int>(size);
std::tuple<int, int> shape = std::make_tuple(offsets_rows, 1);
trx->streamlines->mmap_off = trx::_create_memmap(filename, std::make_tuple(static_cast<int>(size), 1), "r+",
ext, mem_adress);
if (ext == "uint64") {
if (missing_sentinel) {
trx->streamlines->_offsets_owned.resize(static_cast<size_t>(offsets_rows));
auto *src = reinterpret_cast<const uint64_t *>(trx->streamlines->mmap_off.data()); // NOLINT
for (int i = 0; i < static_cast<int>(size); ++i) {
trx->streamlines->_offsets_owned[static_cast<size_t>(i)] = src[i];
}
trx->streamlines->_offsets_owned.back() = nb_vertices;
trx::detail::remap(trx->streamlines->_offsets, trx->streamlines->_offsets_owned.data(), shape);
} else {
trx::detail::remap(trx->streamlines->_offsets, trx->streamlines->mmap_off.data(), shape);
}
} else if (ext == "uint32") {
trx->streamlines->_offsets_owned.resize(static_cast<size_t>(offsets_rows));
auto *src = reinterpret_cast<const uint32_t *>(trx->streamlines->mmap_off.data()); // NOLINT
for (int i = 0; i < static_cast<int>(size); ++i) {
trx->streamlines->_offsets_owned[static_cast<size_t>(i)] = static_cast<uint64_t>(src[i]);
}
if (missing_sentinel) {
trx->streamlines->_offsets_owned.back() = nb_vertices;
}
trx::detail::remap(trx->streamlines->_offsets, trx->streamlines->_offsets_owned.data(), shape);
} else {
throw TrxDTypeError("Unsupported offsets datatype: " + ext);
}
Matrix<uint64_t, Dynamic, 1> offsets = trx->streamlines->_offsets;
trx->streamlines->_lengths =
trx::detail::_compute_lengths(offsets, static_cast<int>(trx->header["NB_VERTICES"].int_value()));
}
else if (folder == "dps") {
std::tuple<int, int> shape;
trx->data_per_streamline[base] = std::make_unique<MMappedMatrix<DT>>();
int nb_scalar = size / static_cast<int>(trx->header["NB_STREAMLINES"].int_value());
if (size % static_cast<int>(trx->header["NB_STREAMLINES"].int_value()) != 0 || nb_scalar != dim) {
throw TrxFormatError("Wrong dps size/dimensionality");
} else {
shape = std::make_tuple(static_cast<int>(trx->header["NB_STREAMLINES"].int_value()), nb_scalar);
}
trx->data_per_streamline[base]->mmap = trx::_create_memmap(filename, shape, "r+", ext, mem_adress);
const std::string expected_dtype = dtype_from_scalar<DT>();
if (ext == expected_dtype) {
trx::detail::remap(trx->data_per_streamline[base]->_matrix, trx->data_per_streamline[base]->mmap.data(), shape);
materialize_matrix_map_and_unmap(*trx->data_per_streamline[base]);
} else {
const size_t n = static_cast<size_t>(std::get<0>(shape)) * static_cast<size_t>(std::get<1>(shape));
copy_cast_from_dtype_buffer<DT>(trx->data_per_streamline[base]->mmap.data(), n, ext,
trx->data_per_streamline[base]->_matrix_owned);
trx::detail::remap(trx->data_per_streamline[base]->_matrix,
trx->data_per_streamline[base]->_matrix_owned.data(), shape);
trx->data_per_streamline[base]->mmap.unmap();
}
}
else if (folder == "dpv") {
std::tuple<int, int> shape;
trx->data_per_vertex[base] = std::make_unique<ArraySequence<DT>>();
int nb_scalar = size / static_cast<int>(trx->header["NB_VERTICES"].int_value());
if (size % static_cast<int>(trx->header["NB_VERTICES"].int_value()) != 0 || nb_scalar != dim) {
throw TrxFormatError("Wrong dpv size/dimensionality");
} else {
shape = std::make_tuple(static_cast<int>(trx->header["NB_VERTICES"].int_value()), nb_scalar);
}
trx->data_per_vertex[base]->mmap_pos = trx::_create_memmap(filename, shape, "r+", ext, mem_adress);
const std::string expected_dtype = dtype_from_scalar<DT>();
if (ext == expected_dtype) {
trx::detail::remap(trx->data_per_vertex[base]->_data, trx->data_per_vertex[base]->mmap_pos.data(), shape);
} else {
const size_t n = static_cast<size_t>(std::get<0>(shape)) * static_cast<size_t>(std::get<1>(shape));
copy_cast_from_dtype_buffer<DT>(trx->data_per_vertex[base]->mmap_pos.data(), n, ext,
trx->data_per_vertex[base]->_data_owned);
trx::detail::remap(trx->data_per_vertex[base]->_data, trx->data_per_vertex[base]->_data_owned.data(), shape);
trx->data_per_vertex[base]->mmap_pos.unmap();
}
trx::detail::remap(trx->data_per_vertex[base]->_offsets, trx->streamlines->_offsets.data(),
int(trx->streamlines->_offsets.rows()), int(trx->streamlines->_offsets.cols()));
trx->data_per_vertex[base]->_lengths = trx->streamlines->_lengths;
}
else if (folder.rfind("dpg", 0) == 0) {
std::tuple<int, int> shape;
if (size != dim) {
throw TrxFormatError("Wrong dpg size/dimensionality");
} else {
shape = std::make_tuple(1, static_cast<int>(size));
}
std::string data_name = path_basename(base);
std::string sub_folder = path_basename(folder);
trx->data_per_group[sub_folder][data_name] = std::make_unique<MMappedMatrix<DT>>();
trx->data_per_group[sub_folder][data_name]->mmap = trx::_create_memmap(filename, shape, "r+", ext, mem_adress);
const std::string expected_dtype = dtype_from_scalar<DT>();
if (ext == expected_dtype) {
trx::detail::remap(trx->data_per_group[sub_folder][data_name]->_matrix,
trx->data_per_group[sub_folder][data_name]->mmap.data(), shape);
materialize_matrix_map_and_unmap(*trx->data_per_group[sub_folder][data_name]);
} else {
const size_t n = static_cast<size_t>(std::get<0>(shape)) * static_cast<size_t>(std::get<1>(shape));
copy_cast_from_dtype_buffer<DT>(trx->data_per_group[sub_folder][data_name]->mmap.data(), n, ext,
trx->data_per_group[sub_folder][data_name]->_matrix_owned);
trx::detail::remap(trx->data_per_group[sub_folder][data_name]->_matrix,
trx->data_per_group[sub_folder][data_name]->_matrix_owned.data(), shape);
trx->data_per_group[sub_folder][data_name]->mmap.unmap();
}
}
else if (folder == "groups") {
std::tuple<int, int> shape;
if (dim != 1) {
throw TrxFormatError("Wrong group dimensionality");
} else {
shape = std::make_tuple(static_cast<int>(size), 1);
}
trx->groups[base] = nullptr;
typename TrxFile<DT>::GroupBackingInfo info;
info.filename = filename;
info.rows = std::get<0>(shape);
info.cols = std::get<1>(shape);
info.dtype = ext;
info.mem_offset = mem_adress;
trx->group_backing_info_[base] = std::move(info);
} else {
throw TrxFormatError("Entry is not part of a valid TRX structure: " + elem_filename);
}
}
if (trx->streamlines->_data.size() == 0 || trx->streamlines->_offsets.size() == 0) {
throw TrxFormatError("Missing essential data.");
}
return trx;
}
template <typename DT> std::unique_ptr<TrxFile<DT>> TrxFile<DT>::deepcopy() {
if (!this->streamlines || this->streamlines->_data.size() == 0 || this->streamlines->_offsets.size() == 0) {
auto empty_copy = std::make_unique<trx::TrxFile<DT>>();
empty_copy->header = this->header;
return empty_copy;
}
// Determine effective counts (handle sliced/non-copy-safe data)
json tmp_header = this->header;
int nb_streamlines, nb_vertices;
if (!this->_copy_safe) {
nb_streamlines = static_cast<int>(this->num_streamlines());
nb_vertices = static_cast<int>(this->streamlines->_data.size() / 3);
tmp_header = _json_set(tmp_header, "NB_STREAMLINES", nb_streamlines);
tmp_header = _json_set(tmp_header, "NB_VERTICES", nb_vertices);
} else {
nb_streamlines = tmp_header["NB_STREAMLINES"].int_value();
nb_vertices = tmp_header["NB_VERTICES"].int_value();
}
// Allocate a fresh TrxFile with memory-mapped storage
auto copy = _initialize_empty_trx<DT>(nb_streamlines, nb_vertices, this);
// Copy header
copy->header = tmp_header;
// Copy positions
copy->streamlines->_data = this->streamlines->_data;
// Copy offsets
copy->streamlines->_offsets = this->streamlines->_offsets;
// Ensure sentinel is correct
if (copy->streamlines->_offsets.size() > 0) {
copy->streamlines->_offsets(copy->streamlines->_offsets.size() - 1) = static_cast<uint64_t>(nb_vertices);
}
// Copy lengths
copy->streamlines->_lengths = this->streamlines->_lengths;
// Copy DPS
for (auto const &kv : this->data_per_streamline) {
auto it = copy->data_per_streamline.find(kv.first);
if (it != copy->data_per_streamline.end()) {
it->second->_matrix = kv.second->_matrix;
}
}
// Copy DPV
for (auto const &kv : this->data_per_vertex) {
auto it = copy->data_per_vertex.find(kv.first);
if (it != copy->data_per_vertex.end()) {
it->second->_data = kv.second->_data;
// _offsets is already correctly bound to copy->streamlines->_offsets by _initialize_empty_trx
it->second->_lengths = kv.second->_lengths;
}
}
// Copy groups (not covered by _initialize_empty_trx)
std::string tmp_dir = copy->_uncompressed_folder_handle;
if (!this->groups.empty()) {
this->ensure_all_groups_loaded();
std::string groups_dirname = tmp_dir + SEPARATOR + "groups" + SEPARATOR;
{
std::error_code ec;
trx::fs::create_directories(groups_dirname, ec);
}
for (auto const &kv : this->groups) {
std::string group_dtype = dtype_from_scalar<uint32_t>();
int rows = static_cast<int>(kv.second->_matrix.rows());
int cols = static_cast<int>(kv.second->_matrix.cols());
std::string group_filename = groups_dirname + kv.first;
group_filename = _generate_filename_from_data(kv.second->_matrix, group_filename);
std::tuple<int, int> group_shape = std::make_tuple(rows, cols);
copy->groups[kv.first] = std::make_unique<MMappedMatrix<uint32_t>>();
copy->groups[kv.first]->mmap = _create_memmap(group_filename, group_shape, "w+", group_dtype);
trx::detail::remap(copy->groups[kv.first]->_matrix, copy->groups[kv.first]->mmap.data(), rows, cols);
copy->groups[kv.first]->_matrix = kv.second->_matrix;
}
}
// Copy DPG (not covered by _initialize_empty_trx)
for (auto const &group_kv : this->data_per_group) {
std::string dpg_dirname = tmp_dir + SEPARATOR + "dpg" + SEPARATOR;
std::string dpg_subdirname = dpg_dirname + group_kv.first;
{
std::error_code ec;
trx::fs::create_directories(dpg_subdirname, ec);
}
for (auto const &field : group_kv.second) {
std::string dpg_dtype = dtype_from_scalar<DT>();
int rows = static_cast<int>(field.second->_matrix.rows());
int cols = static_cast<int>(field.second->_matrix.cols());
std::string dpg_filename = dpg_subdirname + SEPARATOR + field.first;
dpg_filename = _generate_filename_from_data(field.second->_matrix, dpg_filename);
std::tuple<int, int> dpg_shape = std::make_tuple(rows, cols);
copy->data_per_group[group_kv.first][field.first] = std::make_unique<MMappedMatrix<DT>>();
copy->data_per_group[group_kv.first][field.first]->mmap =
_create_memmap(dpg_filename, dpg_shape, "w+", dpg_dtype);
trx::detail::remap(copy->data_per_group[group_kv.first][field.first]->_matrix,
copy->data_per_group[group_kv.first][field.first]->mmap.data(), rows, cols);
copy->data_per_group[group_kv.first][field.first]->_matrix = field.second->_matrix;
}
}
return copy;
}
/// Compute the used range in a preallocated TrxFile by finding the last non-zero length.
/// Returns (nb_streamlines_used, nb_vertices_used).
template <typename DT> std::tuple<int, int> TrxFile<DT>::_get_real_len() {
if (this->streamlines->_lengths.size() == 0)
return std::make_tuple(0, 0);
int last_elem_pos = trx::detail::_dichotomic_search(this->streamlines->_lengths);
if (last_elem_pos != -1) {
int strs_end = last_elem_pos + 1;
int pts_end = this->streamlines->_lengths(Eigen::seq(0, last_elem_pos), 0).sum();
return std::make_tuple(strs_end, pts_end);
}
return std::make_tuple(0, 0);
}
template <typename DT>
std::tuple<int, int>
TrxFile<DT>::_copy_fixed_arrays_from(TrxFile<DT> *trx, int strs_start, int pts_start, int nb_strs_to_copy) {
int curr_strs_len, curr_pts_len;
if (nb_strs_to_copy == -1) {
std::tuple<int, int> curr = this->_get_real_len();
curr_strs_len = std::get<0>(curr);
curr_pts_len = std::get<1>(curr);
} else {
curr_strs_len = nb_strs_to_copy;
curr_pts_len = trx->streamlines->_lengths(Eigen::seq(0, curr_strs_len - 1)).sum();
}
if (pts_start == -1) {
pts_start = 0;
}
if (strs_start == -1) {
strs_start = 0;
}
int strs_end = strs_start + curr_strs_len;
int pts_end = pts_start + curr_pts_len;
if (curr_pts_len == 0)
return std::make_tuple(strs_start, pts_start);
this->streamlines->_data.block(pts_start, 0, curr_pts_len, this->streamlines->_data.cols()) =
trx->streamlines->_data.block(0, 0, curr_pts_len, trx->streamlines->_data.cols());
this->streamlines->_offsets.block(strs_start, 0, curr_strs_len + 1, 1) =
(trx->streamlines->_offsets.block(0, 0, curr_strs_len + 1, 1).array() + pts_start).matrix();
this->streamlines->_lengths.block(strs_start, 0, curr_strs_len, 1) =
trx->streamlines->_lengths.block(0, 0, curr_strs_len, 1);
for (auto const &x : this->data_per_vertex) {
this->data_per_vertex[x.first]->_data.block(
pts_start, 0, curr_pts_len, this->data_per_vertex[x.first]->_data.cols()) =
trx->data_per_vertex[x.first]->_data.block(0, 0, curr_pts_len, trx->data_per_vertex[x.first]->_data.cols());
trx::detail::remap(this->data_per_vertex[x.first]->_offsets, trx->data_per_vertex[x.first]->_offsets.data(),
static_cast<int>(trx->data_per_vertex[x.first]->_offsets.rows()),
static_cast<int>(trx->data_per_vertex[x.first]->_offsets.cols()));
this->data_per_vertex[x.first]->_lengths = trx->data_per_vertex[x.first]->_lengths;
}
for (auto const &x : this->data_per_streamline) {
this->data_per_streamline[x.first]->_matrix.block(
strs_start, 0, curr_strs_len, this->data_per_streamline[x.first]->_matrix.cols()) =
trx->data_per_streamline[x.first]->_matrix.block(
0, 0, curr_strs_len, trx->data_per_streamline[x.first]->_matrix.cols());
}
return std::make_tuple(strs_end, pts_end);
}
template <typename DT> void TrxFile<DT>::close() {
this->streamlines.reset();
this->groups.clear();
this->group_backing_info_.clear();
this->data_per_streamline.clear();
this->data_per_vertex.clear();
this->data_per_group.clear();
this->_cleanup_temporary_directory();
this->_uncompressed_folder_handle.clear();
this->_owns_uncompressed_folder = false;
this->_copy_safe = true;
std::vector<std::vector<float>> affine(4, std::vector<float>(4, 0.0f));
for (int i = 0; i < 4; i++) {
affine[i][i] = 1.0f;
}
std::vector<uint16_t> dimensions{1, 1, 1};
json::object header_obj;
header_obj["VOXEL_TO_RASMM"] = affine;
header_obj["DIMENSIONS"] = dimensions;
header_obj["NB_VERTICES"] = 0;
header_obj["NB_STREAMLINES"] = 0;
this->header = json(header_obj);
}
template <typename DT>
TrxFile<DT>::~TrxFile() {
// Release mmap-backed members before deleting temporary backing directory.
this->streamlines.reset();
this->groups.clear();
this->group_backing_info_.clear();
this->data_per_streamline.clear();
this->data_per_vertex.clear();
this->data_per_group.clear();
this->_cleanup_temporary_directory();
}
template <typename DT>
// Caveat: cleanup is best-effort; filesystem errors are ignored.
void TrxFile<DT>::_cleanup_temporary_directory() {
if (this->_owns_uncompressed_folder && !this->_uncompressed_folder_handle.empty()) {
if (rm_dir(this->_uncompressed_folder_handle) != 0) {
}
this->_uncompressed_folder_handle.clear();
this->_owns_uncompressed_folder = false;
}
}
template <typename DT>
// Caveats: downsizing vertices is not supported; reducing streamlines truncates data; same-size
// resize is a no-op.
void TrxFile<DT>::resize(int nb_streamlines, int nb_vertices, bool delete_dpg) {
if (!this->_copy_safe) {
throw TrxArgumentError("Cannot resize a sliced dataset.");
}
std::tuple<int, int> sp_end = this->_get_real_len();
int strs_end = std::get<0>(sp_end);
int ptrs_end = std::get<1>(sp_end);
if (nb_streamlines != -1 && nb_streamlines < strs_end) {
strs_end = nb_streamlines;
}
if (nb_vertices == -1) {
ptrs_end = this->streamlines->_lengths.sum();
nb_vertices = ptrs_end;
} else if (nb_vertices < ptrs_end) {
return;
}
if (nb_streamlines == -1) {
nb_streamlines = strs_end;
}
if (nb_streamlines == this->header["NB_STREAMLINES"].int_value() &&
nb_vertices == this->header["NB_VERTICES"].int_value()) {
return;
}
auto trx = _initialize_empty_trx(nb_streamlines, nb_vertices, this);
if (nb_streamlines < this->header["NB_STREAMLINES"].int_value())
trx->_copy_fixed_arrays_from(this, -1, -1, nb_streamlines);
else {
trx->_copy_fixed_arrays_from(this);
}
std::string tmp_dir = trx->_uncompressed_folder_handle;
if (this->groups.size() > 0) {
this->ensure_all_groups_loaded();
std::string group_dir = tmp_dir + SEPARATOR + "groups" + SEPARATOR;
mkdir_or_throw(group_dir);
for (auto const &x : this->groups) {
std::string group_dtype = dtype_from_scalar<uint32_t>();
std::string group_name = group_dir + x.first + "." + group_dtype;
int ori_length = this->groups[x.first]->_matrix.size();
std::vector<int> keep_rows;
std::vector<int> keep_cols = {0};
// Slicing
for (int i = 0; i < x.second->_matrix.rows(); ++i) {
for (int j = 0; j < x.second->_matrix.cols(); ++j) {
if (static_cast<int>(x.second->_matrix(i, j)) < strs_end) {
keep_rows.push_back(i);
}
}
}
// std::cout << "Cols " << keep_rows.at(1) << std::endl;
Matrix<uint32_t, Dynamic, Dynamic> tmp = this->groups[x.first]->_matrix(keep_rows, keep_cols);
std::tuple<int, int> group_shape = std::make_tuple(tmp.size(), 1);
trx->groups[x.first] = std::make_unique<MMappedMatrix<uint32_t>>();
trx->groups[x.first]->mmap = trx::_create_memmap(group_name, group_shape, "w+", group_dtype);
trx::detail::remap(trx->groups[x.first]->_matrix, trx->groups[x.first]->mmap.data(), group_shape);
// update values
for (int i = 0; i < trx->groups[x.first]->_matrix.rows(); ++i) {
for (int j = 0; j < trx->groups[x.first]->_matrix.cols(); ++j) {
trx->groups[x.first]->_matrix(i, j) = tmp(i, j);
}
}
}
}
if (delete_dpg) {
this->close();
return;
}
if (this->data_per_group.size() > 0) {
std::string dpg_dir = tmp_dir + SEPARATOR + "dpg" + SEPARATOR;
mkdir_or_throw(dpg_dir);
for (auto const &x : this->data_per_group) {
std::string dpg_subdir = dpg_dir + x.first;
mkdir_or_throw(dpg_subdir);
if (trx->data_per_group.find(x.first) == trx->data_per_group.end()) {
trx->data_per_group.emplace(x.first, std::map<std::string, std::unique_ptr<MMappedMatrix<DT>>>{});
} else {
trx->data_per_group[x.first].clear();
}
for (auto const &y : this->data_per_group[x.first]) {
std::string dpg_dtype = dtype_from_scalar<DT>();
std::string dpg_filename = dpg_subdir + SEPARATOR + y.first;
dpg_filename = _generate_filename_from_data(this->data_per_group[x.first][y.first]->_matrix, dpg_filename);
std::tuple<int, int> dpg_shape = std::make_tuple(this->data_per_group[x.first][y.first]->_matrix.rows(),
this->data_per_group[x.first][y.first]->_matrix.cols());
if (trx->data_per_group[x.first].find(y.first) == trx->data_per_group[x.first].end()) {
trx->data_per_group[x.first][y.first] = std::make_unique<MMappedMatrix<DT>>();
}
trx->data_per_group[x.first][y.first]->mmap = _create_memmap(dpg_filename, dpg_shape, "w+", dpg_dtype);
trx::detail::remap(trx->data_per_group[x.first][y.first]->_matrix,
trx->data_per_group[x.first][y.first]->mmap.data(), dpg_shape);
// update values
for (int i = 0; i < trx->data_per_group[x.first][y.first]->_matrix.rows(); ++i) {
for (int j = 0; j < trx->data_per_group[x.first][y.first]->_matrix.cols(); ++j) {
trx->data_per_group[x.first][y.first]->_matrix(i, j) =
this->data_per_group[x.first][y.first]->_matrix(i, j);
}
}
}
}
this->close();
}
}
template <typename DT> std::unique_ptr<TrxFile<DT>> TrxFile<DT>::load_from_zip(const std::string &filename) {
std::string temp_dir = extract_trx_archive(filename);
auto trx = TrxFile<DT>::load_from_directory(temp_dir);
trx->_uncompressed_folder_handle = temp_dir;
trx->_owns_uncompressed_folder = true;
return trx;
}
template <typename DT> std::unique_ptr<TrxFile<DT>> TrxFile<DT>::load_from_directory(const std::string &path) {
std::string directory = path;
{
std::error_code ec;
trx::fs::path resolved = trx::fs::weakly_canonical(trx::fs::path(path), ec);
if (!ec) {
directory = resolved.string();
}
}
std::string header_name = directory + SEPARATOR + "header.json";
// TODO: add check to verify that it's open
std::ifstream header_file;
for (int attempt = 0; attempt < 5; ++attempt) {
header_file.open(header_name);
if (header_file.is_open()) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(50));
}
if (!header_file.is_open()) {
std::error_code ec;
const bool exists = trx::fs::exists(directory, ec);
const int open_err = errno;
std::string detail = "Failed to open header.json at: " + header_name;
detail += " exists=" + std::string(exists ? "true" : "false");
detail += " errno=" + std::to_string(open_err) + " msg=" + std::string(std::strerror(open_err));
if (exists) {
std::vector<std::string> files;
for (const auto &entry : trx::fs::directory_iterator(directory, ec)) {
if (ec) {
break;
}
files.push_back(entry.path().filename().string());
}
if (!files.empty()) {
std::sort(files.begin(), files.end());
detail += " files=[";
for (size_t i = 0; i < files.size(); ++i) {
if (i > 0) {
detail += ",";
}
detail += files[i];
}
detail += "]";
}
}
throw TrxIOError(detail);
}
std::string jstream((std::istreambuf_iterator<char>(header_file)), std::istreambuf_iterator<char>());
header_file.close();
std::string err;
json header = json::parse(jstream, err);
if (!err.empty()) {
throw TrxIOError("Failed to parse header.json: " + err);
}
std::map<std::string, std::tuple<long long, long long>> files_pointer_size;
populate_fps(directory, files_pointer_size);
auto trx = TrxFile<DT>::_create_trx_from_pointer(header, files_pointer_size, "", directory);
trx->_uncompressed_folder_handle = directory;
trx->_owns_uncompressed_folder = false;
return trx;
}
template <typename DT> std::unique_ptr<TrxFile<DT>> TrxFile<DT>::load(const std::string &path) {
trx::fs::path input(path);
if (!trx::fs::exists(input)) {
throw TrxIOError("Input path does not exist: " + path);
}
std::error_code ec;
if (trx::fs::is_directory(input, ec) && !ec) {
return TrxFile<DT>::load_from_directory(path);
}
return TrxFile<DT>::load_from_zip(path);
}
template <typename DT> std::unique_ptr<TrxFile<DT>> load(const std::string &path) {
return TrxFile<DT>::load(path);
}
inline std::unique_ptr<TrxFile<float>>
load_float32_positions(const std::string &path, const LoadFloat32Options &options) {
const TrxScalarType dtype = detect_positions_scalar_type(path, TrxScalarType::Float32);
if (dtype == TrxScalarType::Float32) {
return load<float>(path);
}
auto out = load<float>(path);
if (!out || !out->streamlines) {
return out;
}
const Eigen::Index n_vertices = static_cast<Eigen::Index>(out->num_vertices());
if (n_vertices <= 0) {
return out;
}
const Eigen::Index chunk_rows = static_cast<Eigen::Index>(std::max<size_t>(1, options.chunk_rows));
with_trx_reader(path, [&](auto &reader, TrxScalarType) -> int {
const auto &src = reader->streamlines->_data;
if (src.rows() != n_vertices || src.cols() != 3) {
throw TrxFormatError("load_float32_positions(): unexpected positions shape during conversion");
}
for (Eigen::Index row0 = 0; row0 < n_vertices; row0 += chunk_rows) {
const Eigen::Index row1 = std::min<Eigen::Index>(n_vertices, row0 + chunk_rows);
for (Eigen::Index r = row0; r < row1; ++r) {
out->streamlines->_data(r, 0) = static_cast<float>(src(r, 0));
out->streamlines->_data(r, 1) = static_cast<float>(src(r, 1));
out->streamlines->_data(r, 2) = static_cast<float>(src(r, 2));
}
}
return 0;
});
return out;
}
template <typename DT> TrxReader<DT>::TrxReader(const std::string &path) { trx_ = TrxFile<DT>::load(path); }
template <typename DT> TrxReader<DT>::TrxReader(TrxReader &&other) noexcept : trx_(std::move(other.trx_)) {}
template <typename DT> TrxReader<DT> &TrxReader<DT>::operator=(TrxReader &&other) noexcept {
if (this != &other) {
trx_ = std::move(other.trx_);
}
return *this;
}
template <typename Fn>
auto with_trx_reader(const std::string &path, Fn &&fn)
-> decltype(fn(std::declval<TrxReader<float> &>(), TrxScalarType::Float32)) {
const TrxScalarType dtype = detect_positions_scalar_type(path, TrxScalarType::Float32);
switch (dtype) {
case TrxScalarType::Float16: {
TrxReader<Eigen::half> reader(path);
return fn(reader, dtype);
}
case TrxScalarType::Float64: {
TrxReader<double> reader(path);
return fn(reader, dtype);
}
case TrxScalarType::Float32:
default: {
TrxReader<float> reader(path);
return fn(reader, dtype);
}
}
}
template <typename DT> void TrxFile<DT>::save(const std::string &filename, TrxCompression compression) {
TrxSaveOptions options;
options.compression = compression;
save(filename, options);
}
template <typename DT> void TrxFile<DT>::normalize_for_save() {
if (!this->streamlines) {
throw TrxFormatError("Cannot normalize TRX without streamline data");
}
if (this->streamlines->_offsets.size() == 0) {
throw TrxFormatError("Cannot normalize TRX without offsets data");
}
const size_t offsets_count = static_cast<size_t>(this->streamlines->_offsets.size());
if (offsets_count < 1) {
throw TrxFormatError("Invalid offsets array");
}
const size_t total_streamlines = offsets_count - 1;
const uint64_t data_rows = static_cast<uint64_t>(this->streamlines->_data.rows());
size_t used_streamlines = total_streamlines;
for (size_t i = 1; i < offsets_count; ++i) {
const uint64_t prev = static_cast<uint64_t>(this->streamlines->_offsets(static_cast<Eigen::Index>(i - 1)));
const uint64_t curr = static_cast<uint64_t>(this->streamlines->_offsets(static_cast<Eigen::Index>(i)));
if (curr < prev || curr > data_rows) {
used_streamlines = i - 1;
break;
}
}
const uint64_t used_vertices =
static_cast<uint64_t>(this->streamlines->_offsets(static_cast<Eigen::Index>(used_streamlines)));
if (used_vertices > data_rows) {
throw TrxFormatError("TRX offsets exceed positions row count");
}
if (used_vertices > static_cast<uint64_t>(std::numeric_limits<int>::max()) ||
used_streamlines > static_cast<size_t>(std::numeric_limits<int>::max())) {
throw TrxFormatError("TRX normalize_for_save exceeds supported int range");
}
if (used_streamlines < total_streamlines || used_vertices < data_rows) {
this->resize(static_cast<int>(used_streamlines), static_cast<int>(used_vertices));
}
const size_t normalized_streamlines = this->num_streamlines();
for (size_t i = 0; i < normalized_streamlines; ++i) {
const uint64_t curr = static_cast<uint64_t>(this->streamlines->_offsets(static_cast<Eigen::Index>(i)));
const uint64_t next = static_cast<uint64_t>(this->streamlines->_offsets(static_cast<Eigen::Index>(i + 1)));
if (next < curr) {
throw TrxFormatError("TRX offsets must be monotonically increasing");
}
const uint64_t diff = next - curr;
if (diff > static_cast<uint64_t>(std::numeric_limits<uint32_t>::max())) {
throw TrxFormatError("TRX streamline length exceeds uint32 range");
}
this->streamlines->_lengths(static_cast<Eigen::Index>(i)) = static_cast<uint32_t>(diff);
}
this->header = _json_set(this->header, "NB_STREAMLINES", static_cast<int>(normalized_streamlines));
this->header = _json_set(this->header, "NB_VERTICES", static_cast<int>(this->num_vertices()));
}
template <typename DT> void TrxFile<DT>::save(const std::string &filename, const TrxSaveOptions &options) {
std::string ext = get_ext(filename);
if (ext.size() > 0 && ext != "zip" && ext != "trx") {
throw TrxDTypeError("Unsupported extension: " + ext);
}
TrxFile<DT> *save_trx = this;
if (!save_trx->streamlines || save_trx->streamlines->_offsets.size() == 0) {
throw TrxFormatError("Cannot save TRX without offsets data");
}
if (save_trx->header["NB_STREAMLINES"].is_number()) {
const auto nb_streamlines = static_cast<size_t>(save_trx->header["NB_STREAMLINES"].int_value());
if (save_trx->streamlines->_offsets.size() != static_cast<Eigen::Index>(nb_streamlines + 1)) {
throw TrxFormatError("TRX offsets size does not match NB_STREAMLINES");
}
}
if (save_trx->header["NB_VERTICES"].is_number()) {
const auto nb_vertices = static_cast<uint64_t>(save_trx->header["NB_VERTICES"].int_value());
const auto last = static_cast<uint64_t>(save_trx->num_vertices());
if (last != nb_vertices) {
throw TrxFormatError("TRX offsets sentinel does not match NB_VERTICES");
}
}
for (Eigen::Index i = 1; i < save_trx->streamlines->_offsets.size(); ++i) {
if (save_trx->streamlines->_offsets(i) < save_trx->streamlines->_offsets(i - 1)) {
throw TrxFormatError("TRX offsets must be monotonically increasing");
}
}
if (save_trx->streamlines->_data.size() > 0) {
const auto last = static_cast<uint64_t>(save_trx->num_vertices());
if (last != static_cast<uint64_t>(save_trx->streamlines->_data.rows())) {
throw TrxFormatError("TRX positions row count does not match offsets sentinel");
}
}
std::string tmp_dir_name = save_trx->_uncompressed_folder_handle;
if (!tmp_dir_name.empty()) {
const std::string header_path = tmp_dir_name + SEPARATOR + "header.json";
std::ofstream out_json(header_path, std::ios::out | std::ios::trunc);
if (!out_json.is_open()) {
throw TrxIOError("Failed to write header.json to: " + header_path);
}
out_json << save_trx->header.dump() << std::endl;
out_json.close();
}
const bool write_archive = options.mode == TrxSaveMode::Archive ||
(options.mode == TrxSaveMode::Auto && ext.size() > 0 && (ext == "zip" || ext == "trx"));
if (write_archive) {
auto sync_unmap_seq = [&](auto &seq) {
if (!seq) {
return;
}
std::error_code ec;
seq->mmap_pos.sync(ec);
seq->mmap_off.sync(ec);
};
auto sync_unmap_mat = [&](auto &mat) {
if (!mat) {
return;
}
std::error_code ec;
mat->mmap.sync(ec);
};
sync_unmap_seq(save_trx->streamlines);
for (auto &kv : save_trx->groups) {
sync_unmap_mat(kv.second);
}
for (auto &kv : save_trx->data_per_streamline) {
sync_unmap_mat(kv.second);
}
for (auto &kv : save_trx->data_per_vertex) {
sync_unmap_seq(kv.second);
}
for (auto &group_kv : save_trx->data_per_group) {
for (auto &kv : group_kv.second) {
sync_unmap_mat(kv.second);
}
}
std::unordered_set<std::string> skip;
std::string converted_pos_path;
std::string converted_pos_entry;
detail::TempFileGuard tmp_pos_guard;
if (options.positions_dtype.has_value() && save_trx->streamlines) {
const TrxScalarType target = *options.positions_dtype;
const std::string cur_dtype = detect_positions_dtype(tmp_dir_name);
const std::string new_dtype_str = scalar_type_name(target);
if (!cur_dtype.empty() && cur_dtype != new_dtype_str) {
skip.insert("positions.3." + cur_dtype);
tmp_pos_guard.path = detail::make_unique_temp_path("trx_pos_convert");
{
auto src_any = load_any(tmp_dir_name);
write_positions_as_dtype(src_any, target, tmp_pos_guard.path);
}
converted_pos_path = tmp_pos_guard.path;
converted_pos_entry = "positions.3." + new_dtype_str;
}
}
write_trx_archive(filename, tmp_dir_name, options.compression,
converted_pos_path, converted_pos_entry,
skip.empty() ? nullptr : &skip);
// tmp_pos_guard destructor removes the temp file after archive is written.
} else {
std::error_code ec;
if (!trx::fs::exists(tmp_dir_name, ec) || !trx::fs::is_directory(tmp_dir_name, ec)) {
throw TrxIOError("Temporary TRX directory does not exist: " + tmp_dir_name);
}
if (trx::fs::exists(filename, ec) && trx::fs::is_directory(filename, ec)) {
if (!options.overwrite_existing) {
throw TrxIOError("Output directory already exists: " + filename);
}
if (rm_dir(filename) != 0) {
throw TrxIOError("Could not remove existing directory " + filename);
}
}
trx::fs::path dest_path(filename);
if (dest_path.has_parent_path()) {
mkdir_or_throw(dest_path.parent_path().string());
}
copy_dir(tmp_dir_name, filename);
ec.clear();
if (!trx::fs::exists(filename, ec) || !trx::fs::is_directory(filename, ec)) {
throw TrxIOError("Failed to create output directory: " + filename);
}
if (options.positions_dtype.has_value() && save_trx->streamlines) {
const TrxScalarType target = *options.positions_dtype;
const std::string cur_dtype = detect_positions_dtype(filename);
const std::string new_dtype_str = scalar_type_name(target);
if (!cur_dtype.empty() && cur_dtype != new_dtype_str) {
const std::string old_pos = filename + SEPARATOR + "positions.3." + cur_dtype;
const std::string new_pos = filename + SEPARATOR + "positions.3." + new_dtype_str;
{
auto src_any = load_any(filename);
write_positions_as_dtype(src_any, target, new_pos);
}
std::error_code rm_ec;
trx::fs::remove(old_pos, rm_ec);
}
}
const trx::fs::path header_path = dest_path / "header.json";
if (!trx::fs::exists(header_path)) {
throw TrxFormatError("Missing header.json in output directory: " + header_path.string());
}
}
}
template <typename DT>
void TrxFile<DT>::add_dps_from_text(const std::string &name, const std::string &dtype, const std::string &path) {
std::ifstream input(path);
if (!input.is_open()) {
throw TrxIOError("Failed to open DPS text file: " + path);
}
std::vector<double> values;
double value = 0.0;
while (input >> value) {
values.push_back(value);
}
if (!input.eof() && input.fail()) {
throw TrxFormatError("Failed to parse DPS text file: " + path);
}
add_dps_from_vector(name, dtype, values);
}
template <typename DT>
template <typename T>
void TrxFile<DT>::add_dps_from_vector(const std::string &name, const std::string &dtype, const std::vector<T> &values) {
if (name.empty()) {
throw TrxArgumentError("DPS name cannot be empty");
}
std::string dtype_norm = dtype;
std::transform(dtype_norm.begin(), dtype_norm.end(), dtype_norm.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (!trx::detail::_is_dtype_valid(dtype_norm)) {
throw TrxDTypeError("Unsupported DPS dtype: " + dtype);
}
if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") {
throw TrxDTypeError("Unsupported DPS dtype: " + dtype);
}
if (this->_uncompressed_folder_handle.empty()) {
throw TrxIOError("TRX file has no backing directory to store DPS data");
}
size_t nb_streamlines = 0;
if (this->streamlines) {
nb_streamlines = static_cast<size_t>(this->streamlines->_lengths.size());
} else if (this->header["NB_STREAMLINES"].is_number()) {
nb_streamlines = static_cast<size_t>(this->header["NB_STREAMLINES"].int_value());
}
if (values.size() != nb_streamlines) {
throw TrxFormatError("DPS values (" + std::to_string(values.size()) + ") do not match number of streamlines (" +
std::to_string(nb_streamlines) + ")");
}
std::string dps_dirname = this->_uncompressed_folder_handle + SEPARATOR + "dps" + SEPARATOR;
mkdir_or_throw(dps_dirname);
std::string dps_filename = dps_dirname + name + "." + dtype_norm;
{
std::error_code ec;
if (trx::fs::exists(dps_filename, ec)) {
trx::fs::remove(dps_filename, ec);
}
}
auto existing = this->data_per_streamline.find(name);
if (existing != this->data_per_streamline.end()) {
this->data_per_streamline.erase(existing);
}
const int rows = static_cast<int>(nb_streamlines);
const int cols = 1;
std::tuple<int, int> shape = std::make_tuple(rows, cols);
const size_t n = static_cast<size_t>(rows * cols);
auto matrix = std::make_unique<trx::MMappedMatrix<DT>>();
matrix->mmap = trx::_create_memmap(dps_filename, shape, "w+", dtype_norm);
const std::string expected_dtype = dtype_from_scalar<DT>();
if (dtype_norm == expected_dtype) {
// On-disk dtype matches DT: memory-map directly and write as DT.
trx::detail::remap(matrix->_matrix, matrix->mmap.data(), rows, cols);
for (int i = 0; i < rows; ++i) {
matrix->_matrix(i, 0) = static_cast<DT>(values[static_cast<size_t>(i)]);
}
} else {
// On-disk dtype differs from DT (e.g. float32 DPS inside a float16 TrxFile).
// Write values as dtype_norm bytes into the mmap, then cross-cast into owned
// DT memory so the in-memory matrix uses the correct element size.
if (dtype_norm == "float16") {
auto *ptr = reinterpret_cast<Eigen::half *>(matrix->mmap.data());
for (size_t i = 0; i < n; ++i) ptr[i] = static_cast<Eigen::half>(values[i]);
} else if (dtype_norm == "float32") {
auto *ptr = reinterpret_cast<float *>(matrix->mmap.data());
for (size_t i = 0; i < n; ++i) ptr[i] = static_cast<float>(values[i]);
} else {
auto *ptr = reinterpret_cast<double *>(matrix->mmap.data());
for (size_t i = 0; i < n; ++i) ptr[i] = static_cast<double>(values[i]);
}
copy_cast_from_dtype_buffer<DT>(matrix->mmap.data(), n, dtype_norm, matrix->_matrix_owned);
trx::detail::remap(matrix->_matrix, matrix->_matrix_owned.data(), shape);
matrix->mmap.unmap();
}
this->data_per_streamline[name] = std::move(matrix);
}
template <typename DT>
template <typename T>
void TrxFile<DT>::add_dpv_from_vector(const std::string &name, const std::string &dtype, const std::vector<T> &values) {
if (name.empty()) {
throw TrxArgumentError("DPV name cannot be empty");
}
std::string dtype_norm = dtype;
std::transform(dtype_norm.begin(), dtype_norm.end(), dtype_norm.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (!trx::detail::_is_dtype_valid(dtype_norm)) {
throw TrxDTypeError("Unsupported DPV dtype: " + dtype);
}
if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") {
throw TrxDTypeError("Unsupported DPV dtype: " + dtype);
}
if (this->_uncompressed_folder_handle.empty()) {
throw TrxIOError("TRX file has no backing directory to store DPV data");
}
size_t nb_vertices = 0;
if (this->streamlines) {
nb_vertices = static_cast<size_t>(this->streamlines->_data.rows());
} else if (this->header["NB_VERTICES"].is_number()) {
nb_vertices = static_cast<size_t>(this->header["NB_VERTICES"].int_value());
}
if (values.size() != nb_vertices) {
throw TrxFormatError("DPV values (" + std::to_string(values.size()) + ") do not match number of vertices (" +
std::to_string(nb_vertices) + ")");
}
std::string dpv_dirname = this->_uncompressed_folder_handle + SEPARATOR + "dpv" + SEPARATOR;
mkdir_or_throw(dpv_dirname);
std::string dpv_filename = dpv_dirname + name + "." + dtype_norm;
{
std::error_code ec;
if (trx::fs::exists(dpv_filename, ec)) {
trx::fs::remove(dpv_filename, ec);
}
}
auto existing = this->data_per_vertex.find(name);
if (existing != this->data_per_vertex.end()) {
this->data_per_vertex.erase(existing);
}
const int rows = static_cast<int>(nb_vertices);
const int cols = 1;
std::tuple<int, int> shape = std::make_tuple(rows, cols);
const size_t n = static_cast<size_t>(rows * cols);
auto seq = std::make_unique<trx::ArraySequence<DT>>();
seq->mmap_pos = trx::_create_memmap(dpv_filename, shape, "w+", dtype_norm);
const std::string expected_dtype = dtype_from_scalar<DT>();
if (dtype_norm == expected_dtype) {
// On-disk dtype matches DT: memory-map directly and write as DT.
trx::detail::remap(seq->_data, seq->mmap_pos.data(), rows, cols);
for (int i = 0; i < rows; ++i) {
seq->_data(i, 0) = static_cast<DT>(values[static_cast<size_t>(i)]);
}
} else {
// On-disk dtype differs from DT (e.g. float32 DPV inside a float16 TrxFile).
// Write values as dtype_norm bytes into the mmap, then cross-cast into owned
// DT memory so the in-memory matrix uses the correct element size.
if (dtype_norm == "float16") {
auto *ptr = reinterpret_cast<Eigen::half *>(seq->mmap_pos.data());
for (size_t i = 0; i < n; ++i) ptr[i] = static_cast<Eigen::half>(values[i]);
} else if (dtype_norm == "float32") {
auto *ptr = reinterpret_cast<float *>(seq->mmap_pos.data());
for (size_t i = 0; i < n; ++i) ptr[i] = static_cast<float>(values[i]);
} else {
auto *ptr = reinterpret_cast<double *>(seq->mmap_pos.data());
for (size_t i = 0; i < n; ++i) ptr[i] = static_cast<double>(values[i]);
}
copy_cast_from_dtype_buffer<DT>(seq->mmap_pos.data(), n, dtype_norm, seq->_data_owned);
trx::detail::remap(seq->_data, seq->_data_owned.data(), shape);
seq->mmap_pos.unmap();
}
if (this->streamlines && this->streamlines->_offsets.size() > 0) {
trx::detail::remap(seq->_offsets, this->streamlines->_offsets.data(), int(this->streamlines->_offsets.rows()),
int(this->streamlines->_offsets.cols()));
seq->_lengths = this->streamlines->_lengths;
}
this->data_per_vertex[name] = std::move(seq);
}
template <typename DT>
void TrxFile<DT>::add_group_from_indices(const std::string &name, const std::vector<uint32_t> &indices) {
if (name.empty()) {
throw TrxArgumentError("Group name cannot be empty");
}
if (this->_uncompressed_folder_handle.empty()) {
throw TrxIOError("TRX file has no backing directory to store groups");
}
size_t nb_streamlines = 0;
if (this->streamlines) {
nb_streamlines = static_cast<size_t>(this->streamlines->_lengths.size());
} else if (this->header["NB_STREAMLINES"].is_number()) {
nb_streamlines = static_cast<size_t>(this->header["NB_STREAMLINES"].int_value());
}
for (const auto idx : indices) {
if (idx >= nb_streamlines) {
throw TrxArgumentError("Group index out of range: " + std::to_string(idx));
}
}
std::string groups_dirname = this->_uncompressed_folder_handle + SEPARATOR + "groups" + SEPARATOR;
mkdir_or_throw(groups_dirname);
std::string group_filename = groups_dirname + name + ".uint32";
{
std::error_code ec;
if (trx::fs::exists(group_filename, ec)) {
trx::fs::remove(group_filename, ec);
}
}
auto existing = this->groups.find(name);
if (existing != this->groups.end()) {
this->groups.erase(existing);
}
auto backing = this->group_backing_info_.find(name);
if (backing != this->group_backing_info_.end()) {
this->group_backing_info_.erase(backing);
}
const int rows = static_cast<int>(indices.size());
const int cols = 1;
std::tuple<int, int> shape = std::make_tuple(rows, cols);
auto group = std::make_unique<trx::MMappedMatrix<uint32_t>>();
group->mmap = trx::_create_memmap(group_filename, shape, "w+", "uint32");
trx::detail::remap(group->_matrix, group->mmap.data(), shape);
for (int i = 0; i < rows; ++i) {
group->_matrix(i, 0) = indices[static_cast<size_t>(i)];
}
this->groups[name] = std::move(group);
}
template <typename DT>
void TrxFile<DT>::set_voxel_to_rasmm(const Eigen::Matrix4f &affine) {
std::vector<std::vector<float>> matrix(4, std::vector<float>(4, 0.0f));
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
matrix[static_cast<size_t>(i)][static_cast<size_t>(j)] = affine(i, j);
}
}
this->header = _json_set(this->header, "VOXEL_TO_RASMM", matrix);
}
inline TrxStream::TrxStream(std::string positions_dtype) : positions_dtype_(std::move(positions_dtype)) {
std::transform(positions_dtype_.begin(), positions_dtype_.end(), positions_dtype_.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (positions_dtype_ != "float32" && positions_dtype_ != "float16") {
throw TrxArgumentError("TrxStream only supports float16/float32 positions for now");
}
tmp_dir_ = make_temp_dir("trx_proto");
positions_path_ = tmp_dir_ + SEPARATOR + "positions.tmp";
ensure_positions_stream();
}
inline TrxStream::~TrxStream() { cleanup_tmp(); }
inline TrxStream &TrxStream::set_metadata_mode(MetadataMode mode) {
if (finalized_) {
throw TrxArgumentError("Cannot adjust metadata mode after finalize");
}
metadata_mode_ = mode;
return *this;
}
inline TrxStream &TrxStream::set_metadata_buffer_max_bytes(std::size_t max_bytes) {
if (finalized_) {
throw TrxArgumentError("Cannot adjust metadata buffer after finalize");
}
metadata_buffer_max_bytes_ = max_bytes;
return *this;
}
inline void TrxStream::ensure_positions_stream() {
if (!positions_out_.is_open()) {
positions_out_.open(positions_path_, std::ios::binary | std::ios::out | std::ios::trunc);
if (!positions_out_.is_open()) {
throw TrxIOError("Failed to open TrxStream temp positions file: " + positions_path_);
}
}
}
inline void TrxStream::ensure_metadata_dir(const std::string &subdir) {
if (tmp_dir_.empty()) {
throw TrxIOError("TrxStream temp directory not initialized");
}
mkdir_or_throw(tmp_dir_ + SEPARATOR + subdir + SEPARATOR);
}
inline void TrxStream::flush_positions_buffer() {
if (positions_dtype_ == "float16") {
if (positions_buffer_half_.empty()) {
return;
}
ensure_positions_stream();
const size_t byte_count = positions_buffer_half_.size() * sizeof(half);
positions_out_.write(reinterpret_cast<const char *>(positions_buffer_half_.data()),
static_cast<std::streamsize>(byte_count));
if (!positions_out_) {
throw TrxIOError("Failed to write TrxStream positions buffer");
}
positions_buffer_half_.clear();
return;
}
if (positions_buffer_float_.empty()) {
return;
}
ensure_positions_stream();
const size_t byte_count = positions_buffer_float_.size() * sizeof(float);
positions_out_.write(reinterpret_cast<const char *>(positions_buffer_float_.data()),
static_cast<std::streamsize>(byte_count));
if (!positions_out_) {
throw TrxIOError("Failed to write TrxStream positions buffer");
}
positions_buffer_float_.clear();
}
inline void TrxStream::cleanup_tmp() {
positions_buffer_float_.clear();
positions_buffer_half_.clear();
if (positions_out_.is_open()) {
positions_out_.close();
}
if (!tmp_dir_.empty()) {
rm_dir(tmp_dir_);
tmp_dir_.clear();
}
}
inline void TrxStream::push_streamline(const float *xyz, size_t point_count) {
if (finalized_) {
throw TrxArgumentError("TrxStream already finalized");
}
if (point_count == 0) {
lengths_.push_back(0);
return;
}
if (positions_buffer_max_entries_ == 0) {
ensure_positions_stream();
if (positions_dtype_ == "float16") {
std::vector<half> tmp;
tmp.reserve(point_count * 3);
for (size_t i = 0; i < point_count * 3; ++i) {
tmp.push_back(static_cast<half>(xyz[i]));
}
const size_t byte_count = tmp.size() * sizeof(half);
positions_out_.write(reinterpret_cast<const char *>(tmp.data()), static_cast<std::streamsize>(byte_count));
if (!positions_out_) {
throw TrxIOError("Failed to write TrxStream positions");
}
} else {
const size_t byte_count = point_count * 3 * sizeof(float);
positions_out_.write(reinterpret_cast<const char *>(xyz), static_cast<std::streamsize>(byte_count));
if (!positions_out_) {
throw TrxIOError("Failed to write TrxStream positions");
}
}
} else {
const size_t floats_count = point_count * 3;
if (positions_dtype_ == "float16") {
positions_buffer_half_.reserve(positions_buffer_half_.size() + floats_count);
for (size_t i = 0; i < floats_count; ++i) {
positions_buffer_half_.push_back(static_cast<half>(xyz[i]));
}
if (positions_buffer_half_.size() >= positions_buffer_max_entries_) {
flush_positions_buffer();
}
} else {
positions_buffer_float_.insert(positions_buffer_float_.end(), xyz, xyz + floats_count);
if (positions_buffer_float_.size() >= positions_buffer_max_entries_) {
flush_positions_buffer();
}
}
}
total_vertices_ += point_count;
lengths_.push_back(static_cast<uint32_t>(point_count));
}
inline void TrxStream::push_streamline(const std::vector<float> &xyz_flat) {
if (xyz_flat.size() % 3 != 0) {
throw TrxArgumentError("TrxStream streamline buffer must be a multiple of 3");
}
push_streamline(xyz_flat.data(), xyz_flat.size() / 3);
}
inline void TrxStream::push_streamline(const std::vector<std::array<float, 3>> &points) {
if (points.empty()) {
push_streamline(static_cast<const float *>(nullptr), 0);
return;
}
std::vector<float> xyz_flat;
xyz_flat.reserve(points.size() * 3);
for (const auto &point : points) {
xyz_flat.push_back(point[0]);
xyz_flat.push_back(point[1]);
xyz_flat.push_back(point[2]);
}
push_streamline(xyz_flat);
}
inline TrxStream &TrxStream::set_voxel_to_rasmm(const Eigen::Matrix4f &affine) {
std::vector<std::vector<float>> matrix(4, std::vector<float>(4, 0.0f));
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
matrix[static_cast<size_t>(i)][static_cast<size_t>(j)] = affine(i, j);
}
}
header = _json_set(header, "VOXEL_TO_RASMM", matrix);
return *this;
}
inline TrxStream &TrxStream::set_dimensions(const std::array<uint16_t, 3> &dims) {
header = _json_set(header, "DIMENSIONS", std::vector<uint16_t>{dims[0], dims[1], dims[2]});
return *this;
}
template <typename T>
inline void
TrxStream::push_dps_from_vector(const std::string &name, const std::string &dtype, const std::vector<T> &values) {
if (name.empty()) {
throw TrxArgumentError("DPS name cannot be empty");
}
std::string dtype_norm = dtype;
std::transform(dtype_norm.begin(), dtype_norm.end(), dtype_norm.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (!trx::detail::_is_dtype_valid(dtype_norm)) {
throw TrxDTypeError("Unsupported DPS dtype: " + dtype);
}
if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") {
throw TrxDTypeError("Unsupported DPS dtype: " + dtype);
}
if (metadata_mode_ == MetadataMode::OnDisk) {
ensure_metadata_dir("dps");
const std::string filename = tmp_dir_ + SEPARATOR + "dps" + SEPARATOR + name + "." + dtype_norm;
std::ofstream out(filename, std::ios::binary | std::ios::out | std::ios::trunc);
if (!out.is_open()) {
throw TrxIOError("Failed to open DPS file: " + filename);
}
if (dtype_norm == "float16") {
const size_t chunk_elems = std::max<std::size_t>(1, metadata_buffer_max_bytes_ / sizeof(half));
std::vector<half> tmp;
tmp.reserve(chunk_elems);
size_t offset = 0;
while (offset < values.size()) {
const size_t count = std::min(chunk_elems, values.size() - offset);
tmp.clear();
for (size_t i = 0; i < count; ++i) {
tmp.push_back(static_cast<half>(values[offset + i]));
}
out.write(reinterpret_cast<const char *>(tmp.data()), static_cast<std::streamsize>(count * sizeof(half)));
offset += count;
}
} else if (dtype_norm == "float32") {
const size_t chunk_elems = std::max<std::size_t>(1, metadata_buffer_max_bytes_ / sizeof(float));
std::vector<float> tmp;
tmp.reserve(chunk_elems);
size_t offset = 0;
while (offset < values.size()) {
const size_t count = std::min(chunk_elems, values.size() - offset);
tmp.clear();
for (size_t i = 0; i < count; ++i) {
tmp.push_back(static_cast<float>(values[offset + i]));
}
out.write(reinterpret_cast<const char *>(tmp.data()), static_cast<std::streamsize>(count * sizeof(float)));
offset += count;
}
} else {
const size_t chunk_elems = std::max<std::size_t>(1, metadata_buffer_max_bytes_ / sizeof(double));
std::vector<double> tmp;
tmp.reserve(chunk_elems);
size_t offset = 0;
while (offset < values.size()) {
const size_t count = std::min(chunk_elems, values.size() - offset);
tmp.clear();
for (size_t i = 0; i < count; ++i) {
tmp.push_back(static_cast<double>(values[offset + i]));
}
out.write(reinterpret_cast<const char *>(tmp.data()), static_cast<std::streamsize>(count * sizeof(double)));
offset += count;
}
}
out.close();
metadata_files_.push_back({std::string("dps") + SEPARATOR + name + "." + dtype_norm, filename});
} else {
FieldValues field;
field.dtype = dtype_norm;
field.values.reserve(values.size());
for (const auto &v : values) {
field.values.push_back(static_cast<double>(v));
}
dps_[name] = std::move(field);
}
}
template <typename T>
inline void
TrxStream::push_dpv_from_vector(const std::string &name, const std::string &dtype, const std::vector<T> &values) {
if (name.empty()) {
throw TrxArgumentError("DPV name cannot be empty");
}
std::string dtype_norm = dtype;
std::transform(dtype_norm.begin(), dtype_norm.end(), dtype_norm.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (!trx::detail::_is_dtype_valid(dtype_norm)) {
throw TrxDTypeError("Unsupported DPV dtype: " + dtype);
}
if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") {
throw TrxDTypeError("Unsupported DPV dtype: " + dtype);
}
if (metadata_mode_ == MetadataMode::OnDisk) {
ensure_metadata_dir("dpv");
const std::string filename = tmp_dir_ + SEPARATOR + "dpv" + SEPARATOR + name + "." + dtype_norm;
std::ofstream out(filename, std::ios::binary | std::ios::out | std::ios::trunc);
if (!out.is_open()) {
throw TrxIOError("Failed to open DPV file: " + filename);
}
if (dtype_norm == "float16") {
const size_t chunk_elems = std::max<std::size_t>(1, metadata_buffer_max_bytes_ / sizeof(half));
std::vector<half> tmp;
tmp.reserve(chunk_elems);
size_t offset = 0;
while (offset < values.size()) {
const size_t count = std::min(chunk_elems, values.size() - offset);
tmp.clear();
for (size_t i = 0; i < count; ++i) {
tmp.push_back(static_cast<half>(values[offset + i]));
}
out.write(reinterpret_cast<const char *>(tmp.data()), static_cast<std::streamsize>(count * sizeof(half)));
offset += count;
}
} else if (dtype_norm == "float32") {
const size_t chunk_elems = std::max<std::size_t>(1, metadata_buffer_max_bytes_ / sizeof(float));
std::vector<float> tmp;
tmp.reserve(chunk_elems);
size_t offset = 0;
while (offset < values.size()) {
const size_t count = std::min(chunk_elems, values.size() - offset);
tmp.clear();
for (size_t i = 0; i < count; ++i) {
tmp.push_back(static_cast<float>(values[offset + i]));
}
out.write(reinterpret_cast<const char *>(tmp.data()), static_cast<std::streamsize>(count * sizeof(float)));
offset += count;
}
} else {
const size_t chunk_elems = std::max<std::size_t>(1, metadata_buffer_max_bytes_ / sizeof(double));
std::vector<double> tmp;
tmp.reserve(chunk_elems);
size_t offset = 0;
while (offset < values.size()) {
const size_t count = std::min(chunk_elems, values.size() - offset);
tmp.clear();
for (size_t i = 0; i < count; ++i) {
tmp.push_back(static_cast<double>(values[offset + i]));
}
out.write(reinterpret_cast<const char *>(tmp.data()), static_cast<std::streamsize>(count * sizeof(double)));
offset += count;
}
}
out.close();
metadata_files_.push_back({std::string("dpv") + SEPARATOR + name + "." + dtype_norm, filename});
} else {
FieldValues field;
field.dtype = dtype_norm;
field.values.reserve(values.size());
for (const auto &v : values) {
field.values.push_back(static_cast<double>(v));
}
dpv_[name] = std::move(field);
}
}
inline void TrxStream::set_positions_buffer_max_bytes(std::size_t max_bytes) {
if (finalized_) {
throw TrxArgumentError("Cannot adjust buffer after finalize");
}
if (max_bytes == 0) {
positions_buffer_max_entries_ = 0;
positions_buffer_float_.clear();
positions_buffer_half_.clear();
return;
}
const std::size_t element_size = positions_dtype_ == "float16" ? sizeof(half) : sizeof(float);
const std::size_t entries = max_bytes / element_size;
const std::size_t aligned = (entries / 3) * 3;
positions_buffer_max_entries_ = aligned;
if (positions_buffer_max_entries_ == 0) {
positions_buffer_float_.clear();
positions_buffer_half_.clear();
}
}
inline void TrxStream::push_group_from_indices(const std::string &name, const std::vector<uint32_t> &indices) {
if (name.empty()) {
throw TrxArgumentError("Group name cannot be empty");
}
if (metadata_mode_ == MetadataMode::OnDisk) {
ensure_metadata_dir("groups");
const std::string filename = tmp_dir_ + SEPARATOR + "groups" + SEPARATOR + name + ".uint32";
std::ofstream out(filename, std::ios::binary | std::ios::out | std::ios::trunc);
if (!out.is_open()) {
throw TrxIOError("Failed to open group file: " + filename);
}
const size_t chunk_elems = std::max<std::size_t>(1, metadata_buffer_max_bytes_ / sizeof(uint32_t));
size_t offset = 0;
while (offset < indices.size()) {
const size_t count = std::min(chunk_elems, indices.size() - offset);
out.write(reinterpret_cast<const char *>(indices.data() + offset),
static_cast<std::streamsize>(count * sizeof(uint32_t)));
offset += count;
}
out.close();
metadata_files_.push_back({std::string("groups") + SEPARATOR + name + ".uint32", filename});
} else {
groups_[name] = indices;
}
}
template <typename DT> void TrxStream::finalize(const std::string &filename, TrxCompression compression) {
if (finalized_) {
throw TrxArgumentError("TrxStream already finalized");
}
finalized_ = true;
flush_positions_buffer();
if (positions_out_.is_open()) {
positions_out_.flush();
positions_out_.close();
}
const size_t nb_streamlines = lengths_.size();
const size_t nb_vertices = total_vertices_;
TrxFile<DT> trx(static_cast<int>(nb_vertices), static_cast<int>(nb_streamlines));
json header_out = header;
header_out = _json_set(header_out, "NB_VERTICES", static_cast<int>(nb_vertices));
header_out = _json_set(header_out, "NB_STREAMLINES", static_cast<int>(nb_streamlines));
trx.header = header_out;
auto &positions = trx.streamlines->_data;
auto &offsets = trx.streamlines->_offsets;
auto &lengths = trx.streamlines->_lengths;
offsets(0, 0) = 0;
for (size_t i = 0; i < nb_streamlines; ++i) {
lengths(static_cast<Eigen::Index>(i)) = static_cast<uint32_t>(lengths_[i]);
offsets(static_cast<Eigen::Index>(i + 1), 0) = offsets(static_cast<Eigen::Index>(i), 0) + lengths_[i];
}
std::ifstream in(positions_path_, std::ios::binary);
if (!in.is_open()) {
throw TrxIOError("Failed to open TrxStream temp positions file for read: " + positions_path_);
}
for (size_t i = 0; i < nb_vertices; ++i) {
if (positions_dtype_ == "float16") {
half xyz[3];
in.read(reinterpret_cast<char *>(xyz), sizeof(xyz));
if (!in) {
throw TrxIOError("Failed to read TrxStream positions");
}
positions(static_cast<Eigen::Index>(i), 0) = static_cast<DT>(xyz[0]);
positions(static_cast<Eigen::Index>(i), 1) = static_cast<DT>(xyz[1]);
positions(static_cast<Eigen::Index>(i), 2) = static_cast<DT>(xyz[2]);
} else {
float xyz[3];
in.read(reinterpret_cast<char *>(xyz), sizeof(xyz));
if (!in) {
throw TrxIOError("Failed to read TrxStream positions");
}
positions(static_cast<Eigen::Index>(i), 0) = static_cast<DT>(xyz[0]);
positions(static_cast<Eigen::Index>(i), 1) = static_cast<DT>(xyz[1]);
positions(static_cast<Eigen::Index>(i), 2) = static_cast<DT>(xyz[2]);
}
}
for (const auto &kv : dps_) {
trx.add_dps_from_vector(kv.first, kv.second.dtype, kv.second.values);
}
for (const auto &kv : dpv_) {
trx.add_dpv_from_vector(kv.first, kv.second.dtype, kv.second.values);
}
for (const auto &kv : groups_) {
trx.add_group_from_indices(kv.first, kv.second);
}
if (metadata_mode_ == MetadataMode::OnDisk) {
for (const auto &meta : metadata_files_) {
const std::string dest = trx._uncompressed_folder_handle + SEPARATOR + meta.relative_path;
const trx::fs::path dest_path(dest);
if (dest_path.has_parent_path()) {
std::error_code parent_ec;
trx::fs::create_directories(dest_path.parent_path(), parent_ec);
}
std::error_code copy_ec;
trx::fs::copy_file(meta.absolute_path, dest, trx::fs::copy_options::overwrite_existing, copy_ec);
if (copy_ec) {
throw TrxIOError("Failed to copy metadata file: " + meta.absolute_path + " -> " + dest);
}
}
}
trx.save(filename, compression);
trx.close();
cleanup_tmp();
}
inline void TrxStream::finalize(const std::string &filename,
TrxScalarType output_dtype,
TrxCompression compression) {
switch (output_dtype) {
case TrxScalarType::Float16:
finalize<half>(filename, compression);
break;
case TrxScalarType::Float64:
finalize<double>(filename, compression);
break;
case TrxScalarType::Float32:
default:
finalize<float>(filename, compression);
break;
}
}
inline void TrxStream::finalize(const std::string &filename, const TrxSaveOptions &options) {
if (options.mode == TrxSaveMode::Directory) {
if (finalized_) {
throw TrxArgumentError("TrxStream already finalized");
}
if (options.overwrite_existing) {
finalize_directory(filename);
} else {
finalize_directory_persistent(filename);
}
return;
}
TrxScalarType out_type = TrxScalarType::Float32;
if (positions_dtype_ == "float16") {
out_type = TrxScalarType::Float16;
} else if (positions_dtype_ == "float64") {
out_type = TrxScalarType::Float64;
}
finalize(filename, out_type, options.compression);
}
inline void TrxStream::finalize_directory_impl(const std::string &directory, bool remove_existing) {
if (finalized_) {
throw TrxArgumentError("TrxStream already finalized");
}
finalized_ = true;
flush_positions_buffer();
if (positions_out_.is_open()) {
positions_out_.flush();
positions_out_.close();
}
const size_t nb_streamlines = lengths_.size();
const size_t nb_vertices = total_vertices_;
std::error_code ec;
if (remove_existing && trx::fs::exists(directory, ec)) {
trx::fs::remove_all(directory, ec);
ec.clear();
}
// Create directory if it doesn't exist
if (!trx::fs::exists(directory, ec)) {
mkdir_or_throw(directory);
}
ec.clear();
json header_out = header;
header_out = _json_set(header_out, "NB_VERTICES", static_cast<int>(nb_vertices));
header_out = _json_set(header_out, "NB_STREAMLINES", static_cast<int>(nb_streamlines));
const std::string header_path = directory + SEPARATOR + "header.json";
std::ofstream out_header(header_path, std::ios::out | std::ios::trunc);
if (!out_header.is_open()) {
throw TrxIOError("Failed to write header.json to: " + header_path);
}
out_header << header_out.dump() << std::endl;
out_header.close();
const std::string positions_name = "positions.3." + positions_dtype_;
const std::string positions_dst = directory + SEPARATOR + positions_name;
trx::fs::rename(positions_path_, positions_dst, ec);
if (ec) {
ec.clear();
trx::fs::copy_file(positions_path_, positions_dst, trx::fs::copy_options::overwrite_existing, ec);
if (ec) {
throw TrxIOError("Failed to copy positions file to: " + positions_dst);
}
}
const std::string offsets_dst = directory + SEPARATOR + "offsets.uint64";
std::ofstream offsets_out(offsets_dst, std::ios::binary | std::ios::out | std::ios::trunc);
if (!offsets_out.is_open()) {
throw TrxIOError("Failed to open offsets file for write: " + offsets_dst);
}
uint64_t offset = 0;
offsets_out.write(reinterpret_cast<const char *>(&offset), sizeof(offset));
for (const auto length : lengths_) {
offset += static_cast<uint64_t>(length);
offsets_out.write(reinterpret_cast<const char *>(&offset), sizeof(offset));
}
offsets_out.flush();
offsets_out.close();
auto write_field_values = [&](const std::string &path, const FieldValues &values) {
std::ofstream out(path, std::ios::binary | std::ios::out | std::ios::trunc);
if (!out.is_open()) {
throw TrxIOError("Failed to open metadata file: " + path);
}
const size_t count = values.values.size();
if (values.dtype == "float16") {
const size_t chunk = std::max<std::size_t>(1, metadata_buffer_max_bytes_ / sizeof(half));
std::vector<half> tmp;
tmp.reserve(chunk);
size_t idx = 0;
while (idx < count) {
const size_t n = std::min(chunk, count - idx);
tmp.clear();
for (size_t i = 0; i < n; ++i) {
tmp.push_back(static_cast<half>(values.values[idx + i]));
}
out.write(reinterpret_cast<const char *>(tmp.data()), static_cast<std::streamsize>(n * sizeof(half)));
idx += n;
}
} else if (values.dtype == "float32") {
const size_t chunk = std::max<std::size_t>(1, metadata_buffer_max_bytes_ / sizeof(float));
std::vector<float> tmp;
tmp.reserve(chunk);
size_t idx = 0;
while (idx < count) {
const size_t n = std::min(chunk, count - idx);
tmp.clear();
for (size_t i = 0; i < n; ++i) {
tmp.push_back(static_cast<float>(values.values[idx + i]));
}
out.write(reinterpret_cast<const char *>(tmp.data()), static_cast<std::streamsize>(n * sizeof(float)));
idx += n;
}
} else if (values.dtype == "float64") {
const size_t chunk = std::max<std::size_t>(1, metadata_buffer_max_bytes_ / sizeof(double));
std::vector<double> tmp;
tmp.reserve(chunk);
size_t idx = 0;
while (idx < count) {
const size_t n = std::min(chunk, count - idx);
tmp.clear();
for (size_t i = 0; i < n; ++i) {
tmp.push_back(values.values[idx + i]);
}
out.write(reinterpret_cast<const char *>(tmp.data()), static_cast<std::streamsize>(n * sizeof(double)));
idx += n;
}
} else {
throw TrxDTypeError("Unsupported metadata dtype: " + values.dtype);
}
out.close();
};
if (metadata_mode_ == MetadataMode::OnDisk) {
for (const auto &meta : metadata_files_) {
const std::string dest = directory + SEPARATOR + meta.relative_path;
const trx::fs::path dest_path(dest);
if (dest_path.has_parent_path()) {
std::error_code parent_ec;
trx::fs::create_directories(dest_path.parent_path(), parent_ec);
}
std::error_code copy_ec;
trx::fs::copy_file(meta.absolute_path, dest, trx::fs::copy_options::overwrite_existing, copy_ec);
if (copy_ec) {
throw TrxIOError("Failed to copy metadata file: " + meta.absolute_path + " -> " + dest);
}
}
} else {
if (!dps_.empty()) {
mkdir_or_throw(directory + SEPARATOR + "dps");
for (const auto &kv : dps_) {
const std::string path = directory + SEPARATOR + "dps" + SEPARATOR + kv.first + "." + kv.second.dtype;
write_field_values(path, kv.second);
}
}
if (!dpv_.empty()) {
mkdir_or_throw(directory + SEPARATOR + "dpv");
for (const auto &kv : dpv_) {
const std::string path = directory + SEPARATOR + "dpv" + SEPARATOR + kv.first + "." + kv.second.dtype;
write_field_values(path, kv.second);
}
}
if (!groups_.empty()) {
mkdir_or_throw(directory + SEPARATOR + "groups");
for (const auto &kv : groups_) {
const std::string path = directory + SEPARATOR + "groups" + SEPARATOR + kv.first + ".uint32";
std::ofstream out(path, std::ios::binary | std::ios::out | std::ios::trunc);
if (!out.is_open()) {
throw TrxIOError("Failed to open group file: " + path);
}
if (!kv.second.empty()) {
out.write(reinterpret_cast<const char *>(kv.second.data()),
static_cast<std::streamsize>(kv.second.size() * sizeof(uint32_t)));
}
out.close();
}
}
}
cleanup_tmp();
}
inline void TrxStream::finalize_directory(const std::string &directory) {
finalize_directory_impl(directory, true);
}
inline void TrxStream::finalize_directory_persistent(const std::string &directory) {
finalize_directory_impl(directory, false);
}
template <typename DT>
void TrxFile<DT>::add_dpv_from_tsf(const std::string &name, const std::string &dtype, const std::string &path) {
if (name.empty()) {
throw TrxArgumentError("DPV name cannot be empty");
}
std::string dtype_norm = dtype;
std::transform(dtype_norm.begin(), dtype_norm.end(), dtype_norm.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (!trx::detail::_is_dtype_valid(dtype_norm)) {
throw TrxDTypeError("Unsupported DPV dtype: " + dtype);
}
if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") {
throw TrxDTypeError("Unsupported DPV dtype for TSF input: " + dtype);
}
if (!this->streamlines) {
throw TrxFormatError("TRX file has no streamlines to attach DPV data");
}
if (this->_uncompressed_folder_handle.empty()) {
throw TrxIOError("TRX file has no backing directory to store DPV data");
}
const auto &lengths = this->streamlines->_lengths;
const size_t nb_streamlines = static_cast<size_t>(lengths.size());
const size_t nb_vertices = static_cast<size_t>(this->streamlines->_data.rows());
std::ifstream input(path);
if (!input.is_open()) {
throw TrxIOError("Failed to open TSF file: " + path);
}
auto trim = [](std::string note) {
const auto is_space = [](unsigned char ch) { return std::isspace(ch) != 0; };
note.erase(note.begin(),
std::find_if(note.begin(), note.end(), [is_space](unsigned char ch) { return !is_space(ch); }));
note.erase(std::find_if(note.rbegin(), note.rend(), [is_space](unsigned char ch) { return !is_space(ch); }).base(),
note.end());
return note;
};
std::streampos start_pos = input.tellg();
std::string line;
bool binary_mode = false;
size_t data_offset = 0;
std::string datatype;
if (std::getline(input, line)) {
const std::string first_line = trim(line);
if (first_line == "mrtrix track scalars") {
bool found_end = false;
while (std::getline(input, line)) {
const std::string trimmed = trim(line);
if (trimmed == "END") {
found_end = true;
break;
}
const auto pos = trimmed.find(':');
if (pos == std::string::npos) {
continue;
}
const std::string key = trim(trimmed.substr(0, pos));
const std::string value = trim(trimmed.substr(pos + 1));
if (key == "datatype") {
datatype = value;
} else if (key == "file") {
std::istringstream iss(value);
std::string dot;
iss >> dot >> data_offset;
if (!iss.fail()) {
binary_mode = true;
}
}
}
if (!found_end) {
throw TrxFormatError("Failed to parse TSF header: missing END");
}
} else {
input.clear();
input.seekg(start_pos);
}
} else {
throw TrxFormatError("Failed to parse TSF file: " + path);
}
std::vector<double> values;
values.reserve(nb_vertices);
size_t streamline_index = 0;
uint32_t expected_vertices = nb_streamlines > 0 ? lengths(0) : 0;
uint32_t current_vertices = 0;
if (binary_mode) {
if (datatype != "Float32LE" && datatype != "Float32BE" && datatype != "Float64LE" && datatype != "Float64BE") {
throw TrxDTypeError("Unsupported TSF datatype: " + datatype);
}
auto is_little_endian = []() {
const uint16_t value = 1;
return *reinterpret_cast<const uint8_t *>(&value) == 1;
};
const bool little_endian = is_little_endian();
const bool data_little_endian = datatype.find("LE") != std::string::npos;
input.clear();
input.seekg(static_cast<std::streamoff>(data_offset));
while (input.good()) {
double value = 0.0;
if (datatype == "Float32LE" || datatype == "Float32BE") {
uint32_t raw = 0;
input.read(reinterpret_cast<char *>(&raw), sizeof(raw));
if (!input) {
break;
}
if (little_endian != data_little_endian) {
raw = (raw >> 24) | ((raw >> 8) & 0x0000FF00) | ((raw << 8) & 0x00FF0000) | (raw << 24);
}
float v = 0.0f;
std::memcpy(&v, &raw, sizeof(v));
value = static_cast<double>(v);
} else {
uint64_t raw = 0;
input.read(reinterpret_cast<char *>(&raw), sizeof(raw));
if (!input) {
break;
}
if (little_endian != data_little_endian) {
raw = ((raw & 0x00000000000000FFULL) << 56) | ((raw & 0x000000000000FF00ULL) << 40) |
((raw & 0x0000000000FF0000ULL) << 24) | ((raw & 0x00000000FF000000ULL) << 8) |
((raw & 0x000000FF00000000ULL) >> 8) | ((raw & 0x0000FF0000000000ULL) >> 24) |
((raw & 0x00FF000000000000ULL) >> 40) | ((raw & 0xFF00000000000000ULL) >> 56);
}
double v = 0.0;
std::memcpy(&v, &raw, sizeof(v));
value = v;
}
if (std::isinf(value)) {
break;
}
if (std::isnan(value)) {
if (current_vertices != expected_vertices) {
throw TrxFormatError("TSF streamline length does not match TRX streamlines");
}
if (streamline_index + 1 < nb_streamlines) {
++streamline_index;
expected_vertices = lengths(streamline_index);
current_vertices = 0;
}
continue;
}
values.push_back(value);
++current_vertices;
}
} else {
std::string token;
while (input >> token) {
std::string token_norm = token;
std::transform(token_norm.begin(), token_norm.end(), token_norm.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (token_norm.rfind("nan", 0) == 0) {
if (current_vertices != expected_vertices) {
throw TrxFormatError("TSF streamline length does not match TRX streamlines");
}
if (streamline_index + 1 < nb_streamlines) {
++streamline_index;
expected_vertices = lengths(streamline_index);
current_vertices = 0;
}
continue;
}
if (token_norm.rfind("inf", 0) == 0) {
break;
}
double value = 0.0;
try {
size_t idx = 0;
value = std::stod(token, &idx);
if (idx != token.size()) {
throw TrxArgumentError("invalid token");
}
} catch (const std::exception &) {
throw TrxFormatError("Failed to parse TSF file: " + path);
}
if (std::isinf(value)) {
break;
}
if (std::isnan(value)) {
if (current_vertices != expected_vertices) {
throw TrxFormatError("TSF streamline length does not match TRX streamlines");
}
if (streamline_index + 1 < nb_streamlines) {
++streamline_index;
expected_vertices = lengths(streamline_index);
current_vertices = 0;
}
continue;
}
values.push_back(value);
++current_vertices;
}
if (!input.eof() && input.fail()) {
throw TrxFormatError("Failed to parse TSF file: " + path);
}
}
if (nb_streamlines > 0) {
if (streamline_index != nb_streamlines - 1 || current_vertices != expected_vertices) {
throw TrxFormatError("TSF streamline count does not match TRX streamlines");
}
}
if (values.size() != nb_vertices) {
throw TrxFormatError("TSF values (" + std::to_string(values.size()) + ") do not match number of vertices (" +
std::to_string(nb_vertices) + ")");
}
std::string dpv_dirname = this->_uncompressed_folder_handle + SEPARATOR + "dpv" + SEPARATOR;
mkdir_or_throw(dpv_dirname);
std::string dpv_filename = dpv_dirname + name + "." + dtype_norm;
{
std::error_code ec;
if (trx::fs::exists(dpv_filename, ec)) {
trx::fs::remove(dpv_filename, ec);
}
}
auto existing = this->data_per_vertex.find(name);
if (existing != this->data_per_vertex.end()) {
this->data_per_vertex.erase(existing);
}
const int rows = static_cast<int>(nb_vertices);
const int cols = 1;
std::tuple<int, int> shape = std::make_tuple(rows, cols);
auto seq = std::make_unique<trx::ArraySequence<DT>>();
seq->mmap_pos = trx::_create_memmap(dpv_filename, shape, "w+", dtype_norm);
trx::detail::remap(seq->_data, seq->mmap_pos.data(), rows, cols);
for (int i = 0; i < rows; ++i) {
seq->_data(i, 0) = static_cast<DT>(values[static_cast<size_t>(i)]);
}
trx::detail::remap(seq->_offsets, this->streamlines->_offsets.data(),
static_cast<int>(this->streamlines->_offsets.rows()),
static_cast<int>(this->streamlines->_offsets.cols()));
seq->_lengths = this->streamlines->_lengths;
this->data_per_vertex[name] = std::move(seq);
}
template <typename DT>
void TrxFile<DT>::export_dpv_to_tsf(const std::string &name,
const std::string &path,
const std::string ×tamp,
const std::string &dtype) const {
if (name.empty()) {
throw TrxArgumentError("DPV name cannot be empty");
}
if (timestamp.empty()) {
throw TrxArgumentError("TSF timestamp cannot be empty");
}
std::string dtype_norm = dtype;
std::transform(dtype_norm.begin(), dtype_norm.end(), dtype_norm.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (!trx::detail::_is_dtype_valid(dtype_norm)) {
throw TrxDTypeError("Unsupported TSF dtype: " + dtype);
}
if (dtype_norm != "float32" && dtype_norm != "float64") {
throw TrxDTypeError("Unsupported TSF dtype for output: " + dtype);
}
if (!this->streamlines) {
throw TrxFormatError("TRX file has no streamlines to export DPV data");
}
const auto dpv_it = this->data_per_vertex.find(name);
if (dpv_it == this->data_per_vertex.end()) {
throw TrxFormatError("DPV entry not found: " + name);
}
const auto *seq = dpv_it->second.get();
if (!seq) {
throw TrxFormatError("DPV entry is null: " + name);
}
if (seq->_data.cols() != 1) {
throw TrxFormatError("DPV must be 1D to export as TSF: " + name);
}
const auto &lengths = this->streamlines->_lengths;
const size_t nb_streamlines = static_cast<size_t>(lengths.size());
const size_t nb_vertices = static_cast<size_t>(seq->_data.rows());
if (nb_vertices != static_cast<size_t>(this->streamlines->_data.rows())) {
throw TrxFormatError("DPV vertex count does not match streamlines data");
}
const auto is_little_endian = []() {
const uint16_t value = 1;
return *reinterpret_cast<const uint8_t *>(&value) == 1;
};
const bool little_endian = is_little_endian();
const std::string dtype_spec = dtype_norm == "float64" ? (little_endian ? "Float64LE" : "Float64BE")
: (little_endian ? "Float32LE" : "Float32BE");
auto build_header = [&](size_t data_offset) {
std::ostringstream header;
header << "mrtrix track scalars\n";
header << "timestamp: " << timestamp << "\n";
header << "datatype: " << dtype_spec << "\n";
header << "file: . " << data_offset << "\n";
header << "count: " << nb_streamlines << "\n";
header << "total_count: " << nb_streamlines << "\n";
header << "END\n";
return header.str();
};
size_t data_offset = 0;
for (int i = 0; i < 4; ++i) {
const std::string header = build_header(data_offset);
size_t padded = header.size();
const size_t pad = (4 - (padded % 4)) % 4;
padded += pad;
if (padded == data_offset) {
break;
}
data_offset = padded;
}
const std::string header = build_header(data_offset);
std::ofstream out(path, std::ios::binary | std::ios::trunc);
if (!out.is_open()) {
throw TrxIOError("Failed to open TSF file for writing: " + path);
}
out.write(header.data(), static_cast<std::streamsize>(header.size()));
const size_t pad = (4 - (header.size() % 4)) % 4;
if (pad > 0) {
const std::array<char, 4> zeros{0, 0, 0, 0};
out.write(zeros.data(), static_cast<std::streamsize>(pad));
}
const auto write_value = [&](double value) {
if (dtype_norm == "float64") {
const double cast = value;
out.write(reinterpret_cast<const char *>(&cast), sizeof(cast));
} else {
const float cast = static_cast<float>(value);
out.write(reinterpret_cast<const char *>(&cast), sizeof(cast));
}
};
const size_t total_vertices = static_cast<size_t>(seq->_data.rows());
size_t offset = 0;
for (size_t s = 0; s < nb_streamlines; ++s) {
const uint32_t len = lengths(static_cast<Eigen::Index>(s));
if (offset > total_vertices) {
throw TrxFormatError("DPV length metadata exceeds vertex count");
}
if (len > std::numeric_limits<size_t>::max() - offset) {
throw TrxFormatError("DPV length metadata exceeds vertex count");
}
if (offset + static_cast<size_t>(len) > total_vertices) {
throw TrxFormatError("DPV length metadata exceeds vertex count");
}
offset += static_cast<size_t>(len);
}
offset = 0;
for (size_t s = 0; s < nb_streamlines; ++s) {
const uint32_t len = lengths(static_cast<Eigen::Index>(s));
for (uint32_t i = 0; i < len; ++i) {
const size_t idx = offset + static_cast<size_t>(i);
if (idx > static_cast<size_t>(std::numeric_limits<Eigen::Index>::max())) {
throw TrxFormatError("DPV length metadata exceeds vertex count");
}
write_value(static_cast<double>(seq->_data(static_cast<Eigen::Index>(idx), 0)));
}
offset += static_cast<size_t>(len);
if (s + 1 < nb_streamlines) {
write_value(std::numeric_limits<double>::quiet_NaN());
}
}
write_value(std::numeric_limits<double>::infinity());
if (!out.good()) {
throw TrxIOError("Failed to write TSF file: " + path);
}
}
template <typename DT> std::ostream &operator<<(std::ostream &out, const TrxFile<DT> &trx) {
out << "Header (header.json):\n";
out << trx.header.dump();
return out;
}
template <typename DT>
std::vector<std::array<Eigen::half, 6>> TrxFile<DT>::build_streamline_aabbs() const {
std::vector<std::array<Eigen::half, 6>> aabbs;
if (!this->streamlines) {
return aabbs;
}
std::vector<uint64_t> offsets;
if (this->streamlines->_offsets.size() > 0) {
offsets.resize(static_cast<size_t>(this->streamlines->_offsets.size()));
for (Eigen::Index i = 0; i < this->streamlines->_offsets.size(); ++i) {
offsets[static_cast<size_t>(i)] = this->streamlines->_offsets(i, 0);
}
} else if (this->streamlines->_lengths.size() > 0) {
const size_t nb_streamlines = static_cast<size_t>(this->streamlines->_lengths.size());
offsets.resize(nb_streamlines + 1);
offsets[0] = 0;
for (size_t i = 0; i < nb_streamlines; ++i) {
offsets[i + 1] = offsets[i] + static_cast<uint64_t>(this->streamlines->_lengths(static_cast<Eigen::Index>(i)));
}
} else {
return aabbs;
}
const size_t nb_streamlines = offsets.size() > 0 ? offsets.size() - 1 : 0;
aabbs.resize(nb_streamlines);
const uint64_t total_vertices = static_cast<uint64_t>(this->streamlines->_data.rows());
for (size_t i = 0; i < nb_streamlines; ++i) {
const uint64_t start = offsets[i];
const uint64_t end = offsets[i + 1];
if (end < start) {
throw TrxFormatError("Offsets must be monotonically increasing in build_streamline_aabbs");
}
if (end > total_vertices) {
throw TrxFormatError("Offsets exceed positions row count in build_streamline_aabbs");
}
if (end <= start) {
aabbs[i] = {Eigen::half(0), Eigen::half(0), Eigen::half(0),
Eigen::half(0), Eigen::half(0), Eigen::half(0)};
continue;
}
float min_x = std::numeric_limits<float>::infinity();
float min_y = std::numeric_limits<float>::infinity();
float min_z = std::numeric_limits<float>::infinity();
float max_x = -std::numeric_limits<float>::infinity();
float max_y = -std::numeric_limits<float>::infinity();
float max_z = -std::numeric_limits<float>::infinity();
for (uint64_t p = start; p < end; ++p) {
const float x = static_cast<float>(this->streamlines->_data(static_cast<Eigen::Index>(p), 0));
const float y = static_cast<float>(this->streamlines->_data(static_cast<Eigen::Index>(p), 1));
const float z = static_cast<float>(this->streamlines->_data(static_cast<Eigen::Index>(p), 2));
min_x = std::min(min_x, x);
min_y = std::min(min_y, y);
min_z = std::min(min_z, z);
max_x = std::max(max_x, x);
max_y = std::max(max_y, y);
max_z = std::max(max_z, z);
}
aabbs[i] = {static_cast<Eigen::half>(min_x), static_cast<Eigen::half>(min_y), static_cast<Eigen::half>(min_z),
static_cast<Eigen::half>(max_x), static_cast<Eigen::half>(max_y), static_cast<Eigen::half>(max_z)};
}
this->aabb_cache_ = aabbs;
return aabbs;
}
template <typename DT>
const std::vector<std::array<Eigen::half, 6>> &TrxFile<DT>::get_or_build_streamline_aabbs() const {
if (this->aabb_cache_.empty()) {
this->build_streamline_aabbs();
}
return this->aabb_cache_;
}
template <typename DT>
std::unique_ptr<TrxFile<DT>> TrxFile<DT>::query_aabb(
const std::array<float, 3> &min_corner,
const std::array<float, 3> &max_corner,
const std::vector<std::array<Eigen::half, 6>> *precomputed_aabbs,
bool build_cache_for_result,
size_t max_streamlines,
uint32_t rng_seed) const {
if (!this->streamlines) {
return this->make_empty_like();
}
const size_t nb_streamlines = this->num_streamlines();
if (nb_streamlines == 0) {
return this->make_empty_like();
}
std::vector<std::array<Eigen::half, 6>> aabbs_local;
const std::vector<std::array<Eigen::half, 6>> &aabbs = precomputed_aabbs
? *precomputed_aabbs
: (!this->aabb_cache_.empty() ? this->aabb_cache_ : (aabbs_local = this->build_streamline_aabbs()));
if (aabbs.size() != nb_streamlines) {
throw TrxArgumentError("AABB size does not match streamlines count");
}
const float min_x = min_corner[0];
const float min_y = min_corner[1];
const float min_z = min_corner[2];
const float max_x = max_corner[0];
const float max_y = max_corner[1];
const float max_z = max_corner[2];
std::vector<uint32_t> selected;
selected.reserve(nb_streamlines);
for (size_t i = 0; i < nb_streamlines; ++i) {
const auto &box = aabbs[i];
const float box_min_x = static_cast<float>(box[0]);
const float box_min_y = static_cast<float>(box[1]);
const float box_min_z = static_cast<float>(box[2]);
const float box_max_x = static_cast<float>(box[3]);
const float box_max_y = static_cast<float>(box[4]);
const float box_max_z = static_cast<float>(box[5]);
if (box_min_x <= max_x && box_max_x >= min_x &&
box_min_y <= max_y && box_max_y >= min_y &&
box_min_z <= max_z && box_max_z >= min_z) {
selected.push_back(static_cast<uint32_t>(i));
}
}
if (max_streamlines > 0 && selected.size() > max_streamlines) {
std::mt19937 rng(rng_seed);
std::shuffle(selected.begin(), selected.end(), rng);
selected.resize(max_streamlines);
// Re-sort by index for sequential memory access in subset_streamlines.
std::sort(selected.begin(), selected.end());
}
return this->subset_streamlines(selected, build_cache_for_result);
}
template <typename DT>
void TrxFile<DT>::invalidate_aabb_cache() const {
this->aabb_cache_.clear();
}
template <typename DT>
const MMappedMatrix<DT> *TrxFile<DT>::get_dps(const std::string &name) const {
auto it = this->data_per_streamline.find(name);
if (it == this->data_per_streamline.end()) {
return nullptr;
}
return it->second.get();
}
template <typename DT>
const ArraySequence<DT> *TrxFile<DT>::get_dpv(const std::string &name) const {
auto it = this->data_per_vertex.find(name);
if (it == this->data_per_vertex.end()) {
return nullptr;
}
return it->second.get();
}
template <typename DT>
const MMappedMatrix<uint32_t> *TrxFile<DT>::get_group_members(const std::string &name) const {
auto it = this->groups.find(name);
if (it == this->groups.end()) {
return nullptr;
}
if (!it->second) {
auto b = this->group_backing_info_.find(name);
if (b == this->group_backing_info_.end()) {
return nullptr;
}
const int rows = b->second.rows;
const int cols = b->second.cols;
std::tuple<int, int> shape = std::make_tuple(rows, cols);
it->second = std::make_unique<MMappedMatrix<uint32_t>>();
const size_t n = static_cast<size_t>(std::max(0, rows)) * static_cast<size_t>(std::max(0, cols));
it->second->_matrix_owned.resize(n);
if (n > 0) {
std::ifstream in(b->second.filename, std::ios::binary);
if (!in.is_open()) {
it->second.reset();
return nullptr;
}
if (b->second.mem_offset > 0) {
const std::streamoff byte_offset =
static_cast<std::streamoff>(b->second.mem_offset) * static_cast<std::streamoff>(sizeof(uint32_t));
in.seekg(byte_offset, std::ios::beg);
if (!in.good()) {
it->second.reset();
return nullptr;
}
}
in.read(reinterpret_cast<char *>(it->second->_matrix_owned.data()), static_cast<std::streamsize>(n * sizeof(uint32_t)));
if (!in) {
it->second.reset();
return nullptr;
}
trx::detail::remap(it->second->_matrix, it->second->_matrix_owned.data(), shape);
} else {
trx::detail::remap(it->second->_matrix, static_cast<uint32_t *>(nullptr), shape);
}
}
return it->second.get();
}
template <typename DT>
void TrxFile<DT>::ensure_all_groups_loaded() const {
std::vector<std::string> names;
names.reserve(this->groups.size());
for (const auto &kv : this->groups) {
names.push_back(kv.first);
}
for (const auto &name : names) {
static_cast<void>(this->get_group_members(name));
}
}
template <typename DT>
std::vector<std::array<DT, 3>> TrxFile<DT>::get_streamline(size_t streamline_index) const {
if (!this->streamlines || this->streamlines->_offsets.size() == 0) {
throw TrxFormatError("TRX streamlines are not available");
}
const size_t n_streamlines = this->num_streamlines();
if (streamline_index >= n_streamlines) {
throw std::out_of_range("Streamline index out of range");
}
const uint64_t start = static_cast<uint64_t>(this->streamlines->_offsets(static_cast<Eigen::Index>(streamline_index), 0));
const uint64_t end =
static_cast<uint64_t>(this->streamlines->_offsets(static_cast<Eigen::Index>(streamline_index + 1), 0));
std::vector<std::array<DT, 3>> points;
if (end <= start) {
return points;
}
points.reserve(static_cast<size_t>(end - start));
for (uint64_t i = start; i < end; ++i) {
points.push_back({this->streamlines->_data(static_cast<Eigen::Index>(i), 0),
this->streamlines->_data(static_cast<Eigen::Index>(i), 1),
this->streamlines->_data(static_cast<Eigen::Index>(i), 2)});
}
return points;
}
template <typename DT>
template <typename Fn>
void TrxFile<DT>::for_each_streamline(Fn &&fn) const {
if (!this->streamlines || this->streamlines->_offsets.size() == 0) {
return;
}
const size_t n_streamlines = this->num_streamlines();
for (size_t i = 0; i < n_streamlines; ++i) {
const uint64_t start = static_cast<uint64_t>(this->streamlines->_offsets(static_cast<Eigen::Index>(i), 0));
const uint64_t end = static_cast<uint64_t>(this->streamlines->_offsets(static_cast<Eigen::Index>(i + 1), 0));
fn(i, start, end - start);
}
}
template <typename DT>
template <typename T>
void TrxFile<DT>::add_dpg_from_vector(const std::string &group,
const std::string &name,
const std::string &dtype,
const std::vector<T> &values,
int rows,
int cols) {
if (group.empty()) {
throw TrxArgumentError("DPG group cannot be empty");
}
if (name.empty()) {
throw TrxArgumentError("DPG name cannot be empty");
}
std::string dtype_norm = dtype;
std::transform(dtype_norm.begin(), dtype_norm.end(), dtype_norm.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (!trx::detail::_is_dtype_valid(dtype_norm)) {
throw TrxDTypeError("Unsupported DPG dtype: " + dtype);
}
if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") {
throw TrxDTypeError("Unsupported DPG dtype: " + dtype);
}
if (this->_uncompressed_folder_handle.empty()) {
throw TrxIOError("TRX file has no backing directory to store DPG data");
}
if (rows <= 0) {
throw TrxArgumentError("DPG rows must be positive");
}
if (cols < 0) {
if (values.size() % static_cast<size_t>(rows) != 0) {
throw TrxArgumentError("DPG values size does not match rows");
}
cols = static_cast<int>(values.size() / static_cast<size_t>(rows));
}
if (cols <= 0) {
throw TrxArgumentError("DPG cols must be positive");
}
if (static_cast<size_t>(rows) * static_cast<size_t>(cols) != values.size()) {
throw TrxArgumentError("DPG values size does not match rows*cols");
}
std::string dpg_dir = this->_uncompressed_folder_handle + SEPARATOR + "dpg" + SEPARATOR;
mkdir_or_throw(dpg_dir);
std::string dpg_subdir = dpg_dir + group;
mkdir_or_throw(dpg_subdir);
std::string dpg_filename = dpg_subdir + SEPARATOR + name + "." + dtype_norm;
{
std::error_code ec;
if (trx::fs::exists(dpg_filename, ec)) {
trx::fs::remove(dpg_filename, ec);
}
}
auto &group_map = this->data_per_group[group];
group_map.erase(name);
std::tuple<int, int> shape = std::make_tuple(rows, cols);
group_map[name] = std::make_unique<MMappedMatrix<DT>>();
group_map[name]->mmap = _create_memmap(dpg_filename, shape, "w+", dtype_norm);
trx::detail::remap(group_map[name]->_matrix, group_map[name]->mmap.data(), rows, cols);
for (int i = 0; i < rows * cols; ++i) {
group_map[name]->_matrix(i) = static_cast<DT>(values[static_cast<size_t>(i)]);
}
}
template <typename DT>
template <typename Derived>
void TrxFile<DT>::add_dpg_from_matrix(const std::string &group,
const std::string &name,
const std::string &dtype,
const Eigen::MatrixBase<Derived> &matrix) {
if (matrix.size() == 0) {
throw TrxArgumentError("DPG matrix cannot be empty");
}
std::vector<typename Derived::Scalar> values;
values.reserve(static_cast<size_t>(matrix.size()));
for (Eigen::Index i = 0; i < matrix.rows(); ++i) {
for (Eigen::Index j = 0; j < matrix.cols(); ++j) {
values.push_back(matrix(i, j));
}
}
add_dpg_from_vector(group, name, dtype, values, static_cast<int>(matrix.rows()),
static_cast<int>(matrix.cols()));
}
template <typename DT>
const MMappedMatrix<DT> *TrxFile<DT>::get_dpg(const std::string &group, const std::string &name) const {
auto group_it = this->data_per_group.find(group);
if (group_it == this->data_per_group.end()) {
return nullptr;
}
auto field_it = group_it->second.find(name);
if (field_it == group_it->second.end()) {
return nullptr;
}
return field_it->second.get();
}
template <typename DT>
std::vector<std::string> TrxFile<DT>::list_dpg_groups() const {
std::vector<std::string> groups;
groups.reserve(this->data_per_group.size());
for (const auto &kv : this->data_per_group) {
groups.push_back(kv.first);
}
return groups;
}
template <typename DT>
std::vector<std::string> TrxFile<DT>::list_dpg_fields(const std::string &group) const {
std::vector<std::string> fields;
auto it = this->data_per_group.find(group);
if (it == this->data_per_group.end()) {
return fields;
}
fields.reserve(it->second.size());
for (const auto &kv : it->second) {
fields.push_back(kv.first);
}
return fields;
}
template <typename DT>
void TrxFile<DT>::remove_dpg(const std::string &group, const std::string &name) {
auto group_it = this->data_per_group.find(group);
if (group_it == this->data_per_group.end()) {
return;
}
group_it->second.erase(name);
if (group_it->second.empty()) {
this->data_per_group.erase(group_it);
}
}
template <typename DT>
void TrxFile<DT>::remove_dpg_group(const std::string &group) {
this->data_per_group.erase(group);
}
template <typename DT>
std::unique_ptr<TrxFile<DT>> TrxFile<DT>::subset_streamlines(const std::vector<uint32_t> &streamline_ids,
bool build_cache_for_result) const {
if (!this->streamlines) {
return this->make_empty_like();
}
std::vector<uint64_t> offsets;
if (this->streamlines->_offsets.size() > 0) {
offsets.resize(static_cast<size_t>(this->streamlines->_offsets.size()));
for (Eigen::Index i = 0; i < this->streamlines->_offsets.size(); ++i) {
offsets[static_cast<size_t>(i)] = this->streamlines->_offsets(i, 0);
}
} else if (this->streamlines->_lengths.size() > 0) {
const size_t nb_streamlines = static_cast<size_t>(this->streamlines->_lengths.size());
offsets.resize(nb_streamlines + 1);
offsets[0] = 0;
for (size_t i = 0; i < nb_streamlines; ++i) {
offsets[i + 1] = offsets[i] + static_cast<uint64_t>(this->streamlines->_lengths(static_cast<Eigen::Index>(i)));
}
} else {
return this->make_empty_like();
}
const size_t nb_streamlines = offsets.size() > 0 ? offsets.size() - 1 : 0;
if (nb_streamlines == 0) {
return this->make_empty_like();
}
std::vector<uint32_t> selected;
selected.reserve(streamline_ids.size());
std::vector<uint8_t> seen(nb_streamlines, 0);
for (uint32_t id : streamline_ids) {
if (id >= nb_streamlines) {
throw TrxArgumentError("Streamline id out of range");
}
if (!seen[id]) {
selected.push_back(id);
seen[id] = 1;
}
}
if (selected.empty()) {
return this->make_empty_like();
}
std::vector<int> old_to_new(nb_streamlines, -1);
size_t total_vertices = 0;
for (size_t i = 0; i < selected.size(); ++i) {
const uint32_t idx = selected[i];
old_to_new[idx] = static_cast<int>(i);
const uint64_t start = offsets[idx];
const uint64_t end = offsets[idx + 1];
total_vertices += static_cast<size_t>(end - start);
}
auto out = std::make_unique<TrxFile<DT>>(static_cast<int>(total_vertices),
static_cast<int>(selected.size()),
this);
out->header = _json_set(this->header, "NB_VERTICES", static_cast<int>(total_vertices));
out->header = _json_set(out->header, "NB_STREAMLINES", static_cast<int>(selected.size()));
auto &out_positions = out->streamlines->_data;
auto &out_offsets = out->streamlines->_offsets;
auto &out_lengths = out->streamlines->_lengths;
size_t cursor = 0;
out_offsets(0, 0) = 0;
for (size_t new_idx = 0; new_idx < selected.size(); ++new_idx) {
const uint32_t old_idx = selected[new_idx];
const uint64_t start = offsets[old_idx];
const uint64_t end = offsets[old_idx + 1];
const uint64_t len = end - start;
out_lengths(static_cast<Eigen::Index>(new_idx)) = static_cast<uint32_t>(len);
out_offsets(static_cast<Eigen::Index>(new_idx + 1), 0) =
out_offsets(static_cast<Eigen::Index>(new_idx), 0) + len;
if (len > 0) {
out_positions.block(static_cast<Eigen::Index>(cursor), 0,
static_cast<Eigen::Index>(len), 3) =
this->streamlines->_data.block(static_cast<Eigen::Index>(start), 0,
static_cast<Eigen::Index>(len), 3);
for (const auto &kv : this->data_per_vertex) {
const std::string &name = kv.first;
auto out_it = out->data_per_vertex.find(name);
if (out_it == out->data_per_vertex.end()) {
continue;
}
auto &out_dpv = out_it->second->_data;
auto &src_dpv = kv.second->_data;
const Eigen::Index cols = src_dpv.cols();
out_dpv.block(static_cast<Eigen::Index>(cursor), 0,
static_cast<Eigen::Index>(len), cols) =
src_dpv.block(static_cast<Eigen::Index>(start), 0,
static_cast<Eigen::Index>(len), cols);
}
}
for (const auto &kv : this->data_per_streamline) {
const std::string &name = kv.first;
auto out_it = out->data_per_streamline.find(name);
if (out_it == out->data_per_streamline.end()) {
continue;
}
out_it->second->_matrix.row(static_cast<Eigen::Index>(new_idx)) =
kv.second->_matrix.row(static_cast<Eigen::Index>(old_idx));
}
cursor += static_cast<size_t>(len);
}
for (const auto &kv : this->groups) {
const std::string &group_name = kv.first;
std::vector<uint32_t> indices;
const auto *group = this->get_group_members(group_name);
if (group == nullptr) {
continue;
}
const auto &matrix = group->_matrix;
indices.reserve(static_cast<size_t>(matrix.size()));
for (Eigen::Index r = 0; r < matrix.rows(); ++r) {
for (Eigen::Index c = 0; c < matrix.cols(); ++c) {
const uint32_t old_idx = matrix(r, c);
if (old_idx >= old_to_new.size()) {
continue;
}
const int new_idx = old_to_new[old_idx];
if (new_idx >= 0) {
indices.push_back(static_cast<uint32_t>(new_idx));
}
}
}
if (!indices.empty()) {
out->add_group_from_indices(group_name, indices);
}
}
if (!this->data_per_group.empty() && !out->groups.empty()) {
std::string dpg_dir = out->_uncompressed_folder_handle + SEPARATOR + "dpg" + SEPARATOR;
mkdir_or_throw(dpg_dir);
for (const auto &group_kv : out->groups) {
const std::string &group_name = group_kv.first;
auto src_group_it = this->data_per_group.find(group_name);
if (src_group_it == this->data_per_group.end()) {
continue;
}
std::string dpg_subdir = dpg_dir + group_name;
mkdir_or_throw(dpg_subdir);
if (out->data_per_group.find(group_name) == out->data_per_group.end()) {
out->data_per_group.emplace(group_name, std::map<std::string, std::unique_ptr<MMappedMatrix<DT>>>{});
} else {
out->data_per_group[group_name].clear();
}
for (const auto &field_kv : src_group_it->second) {
const std::string &field_name = field_kv.first;
std::string dpg_dtype = dtype_from_scalar<DT>();
std::string dpg_filename = dpg_subdir + SEPARATOR + field_name;
dpg_filename = _generate_filename_from_data(field_kv.second->_matrix, dpg_filename);
std::tuple<int, int> dpg_shape = std::make_tuple(field_kv.second->_matrix.rows(),
field_kv.second->_matrix.cols());
out->data_per_group[group_name][field_name] = std::make_unique<MMappedMatrix<DT>>();
out->data_per_group[group_name][field_name]->mmap =
_create_memmap(dpg_filename, dpg_shape, "w+", dpg_dtype);
trx::detail::remap(out->data_per_group[group_name][field_name]->_matrix,
out->data_per_group[group_name][field_name]->mmap.data(),
std::get<0>(dpg_shape), std::get<1>(dpg_shape));
for (int i = 0; i < out->data_per_group[group_name][field_name]->_matrix.rows(); ++i) {
for (int j = 0; j < out->data_per_group[group_name][field_name]->_matrix.cols(); ++j) {
out->data_per_group[group_name][field_name]->_matrix(i, j) =
field_kv.second->_matrix(i, j);
}
}
}
}
}
if (build_cache_for_result) {
out->build_streamline_aabbs();
}
return out;
}
#ifdef TRX_TPP_OPEN_NAMESPACE
} // namespace trx
#undef TRX_TPP_OPEN_NAMESPACE
#endif