Program Listing for File trx.h

Return to documentation for file (include/trx/trx.h)

#ifndef TRX_H // include guard
#define TRX_H

#include <Eigen/Core>
#include <algorithm>
#include <array>
#include <cctype>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <chrono>
#include <cerrno>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <functional>
#include <json11.hpp>
#include <limits>
#include <memory>
#include <optional>
#include <sstream>
#include <stdexcept>
#include <string_view>
#include <system_error>
#include <random>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <thread>


#include <mio/mmap.hpp>
#include <mio/shared_mmap.hpp>

#include <trx/detail/exceptions.h>


namespace trx {
namespace fs = std::filesystem;
}

using json = json11::Json;

namespace trx {
enum class TrxSaveMode { Auto, Archive, Directory };

enum class TrxCompression { None, Deflate };

enum class TrxScalarType { Float16, Float32, Float64 };

enum class ConnectivityMeasure {
  StreamlineCount,
  DpsSum
};

struct ConnectivityMatrixResult {
  std::vector<std::string> group_names;
  std::vector<uint64_t> streamline_count_upper;
  std::vector<double> value_upper;
};

struct TrxSaveOptions {
  TrxCompression compression = TrxCompression::None;
  TrxSaveMode mode = TrxSaveMode::Auto;
  size_t memory_limit_bytes = 0; // Reserved for future save-path tuning.
  bool overwrite_existing = true;
  std::optional<TrxScalarType> positions_dtype;
};

inline json::object _json_object(const json &value) {
  if (value.is_object()) {
    return value.object_items();
  }
  return json::object();
}

inline json _json_set(const json &value, const std::string &key, const json &field) {
  auto obj = _json_object(value);
  obj[key] = field;
  return json(obj);
}
inline std::string path_basename(const std::string &path) {
  if (path.empty())
    return "";
  size_t end = path.find_last_not_of("/\\");
  if (end == std::string::npos)
    return "";
  size_t start = path.find_last_of("/\\", end);
  if (start == std::string::npos)
    return path.substr(0, end + 1);
  return path.substr(start + 1, end - start);
}

inline std::string path_dirname(const std::string &path) {
  if (path.empty())
    return ".";
  size_t end = path.find_last_not_of("/\\");
  if (end == std::string::npos)
    return ".";
  size_t sep = path.find_last_of("/\\", end);
  if (sep == std::string::npos)
    return ".";
  if (sep == 0)
    return std::string(1, path[0]);
#if defined(_WIN32) || defined(_WIN64)
  if (sep == 2 && path.size() >= 3 && path[1] == ':')
    return path.substr(0, 3);
#endif
  return path.substr(0, sep);
}

inline std::string to_utf8_string(const trx::fs::path &path) {
#if defined(__cpp_lib_char8_t)
  const auto u8 = path.u8string();
  return std::string(reinterpret_cast<const char *>(u8.data()), u8.size());
#else
  return path.u8string();
#endif
}



template <typename T> struct DTypeName {
  static constexpr bool supported = false;
  static constexpr std::string_view value() { return ""; }
};

template <> struct DTypeName<Eigen::half> {
  static constexpr bool supported = true;
  static constexpr std::string_view value() { return "float16"; }
};

template <> struct DTypeName<float> {
  static constexpr bool supported = true;
  static constexpr std::string_view value() { return "float32"; }
};

template <> struct DTypeName<double> {
  static constexpr bool supported = true;
  static constexpr std::string_view value() { return "float64"; }
};

template <> struct DTypeName<int8_t> {
  static constexpr bool supported = true;
  static constexpr std::string_view value() { return "int8"; }
};

template <> struct DTypeName<int16_t> {
  static constexpr bool supported = true;
  static constexpr std::string_view value() { return "int16"; }
};

template <> struct DTypeName<int32_t> {
  static constexpr bool supported = true;
  static constexpr std::string_view value() { return "int32"; }
};

template <> struct DTypeName<int64_t> {
  static constexpr bool supported = true;
  static constexpr std::string_view value() { return "int64"; }
};

template <> struct DTypeName<uint8_t> {
  static constexpr bool supported = true;
  static constexpr std::string_view value() { return "uint8"; }
};

template <> struct DTypeName<uint16_t> {
  static constexpr bool supported = true;
  static constexpr std::string_view value() { return "uint16"; }
};

template <> struct DTypeName<uint32_t> {
  static constexpr bool supported = true;
  static constexpr std::string_view value() { return "uint32"; }
};

template <> struct DTypeName<uint64_t> {
  static constexpr bool supported = true;
  static constexpr std::string_view value() { return "uint64"; }
};

template <typename T> inline std::string dtype_from_scalar() {
  using CleanT = std::remove_cv_t<std::remove_reference_t<T>>;
  static_assert(DTypeName<CleanT>::supported, "Unsupported dtype for TRX scalar.");
  return std::string(DTypeName<CleanT>::value());
}

inline constexpr const char *SEPARATOR = "/";
inline const std::array<std::string_view, 12> dtypes = {"float16",
                                                        "uint8",
                                                        "uint16",
                                                        "ushort",
                                                        "uint32",
                                                        "uint64",
                                                        "int8",
                                                        "int16",
                                                        "int32",
                                                        "int64",
                                                        "float32",
                                                        "float64"};

template <typename DT> struct ArraySequence {
  // Public accessors
  auto &data() { return _data; }
  const auto &data() const { return _data; }
  auto &offsets() { return _offsets; }
  const auto &offsets() const { return _offsets; }
  auto &lengths() { return _lengths; }
  const auto &lengths() const { return _lengths; }

  Eigen::Map<Eigen::Matrix<DT, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> _data;
  Eigen::Map<Eigen::Matrix<uint64_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> _offsets;
  Eigen::Matrix<uint32_t, Eigen::Dynamic, 1> _lengths;
  std::vector<DT> _data_owned;
  std::vector<uint64_t> _offsets_owned;
  mio::shared_mmap_sink mmap_pos;
  mio::shared_mmap_sink mmap_off;

  ArraySequence() : _data(nullptr, 1, 1), _offsets(nullptr, 1, 1) {}
};

template <typename DT> struct MMappedMatrix {
  // Public accessor
  auto &matrix() { return _matrix; }
  const auto &matrix() const { return _matrix; }

  Eigen::Map<Eigen::Matrix<DT, Eigen::Dynamic, Eigen::Dynamic>> _matrix;
  std::vector<DT> _matrix_owned;
  mio::shared_mmap_sink mmap;

  MMappedMatrix() : _matrix(nullptr, 1, 1) {}
};

template <typename DT> class TrxFile {
  // Access specifier
public:
  struct GroupBackingInfo {
    std::string filename;
    int rows = 0;
    int cols = 0;
    std::string dtype;
    long long mem_offset = 0;
  };

  // Data Members
  json header;
  std::unique_ptr<ArraySequence<DT>> streamlines;

  mutable std::map<std::string, std::unique_ptr<MMappedMatrix<uint32_t>>> groups; // vector of indices

  // int or float --check python float precision (singletons)
  std::map<std::string, std::unique_ptr<MMappedMatrix<DT>>> data_per_streamline;
  std::map<std::string, std::unique_ptr<ArraySequence<DT>>> data_per_vertex;
  std::map<std::string, std::map<std::string, std::unique_ptr<MMappedMatrix<DT>>>> data_per_group;
  std::string _uncompressed_folder_handle;
  bool _copy_safe;
  bool _owns_uncompressed_folder = false;

  // Member Functions()
  // TrxFile(int nb_vertices = 0, int nb_streamlines = 0);
  TrxFile(int nb_vertices = 0,
          int nb_streamlines = 0,
          const TrxFile<DT> *init_as = nullptr,
          std::string reference = "");
  ~TrxFile();

  static std::unique_ptr<TrxFile<DT>>
  _create_trx_from_pointer(json header,
                           std::map<std::string, std::tuple<long long, long long>> dict_pointer_size,
                           std::string root_zip = "",
                           std::string root = "");

  template <typename> friend class TrxReader;
  template <typename U> friend std::unique_ptr<TrxFile<U>> load(const std::string &path);

  std::unique_ptr<TrxFile<DT>> deepcopy();

  void resize(int nb_streamlines = -1, int nb_vertices = -1, bool delete_dpg = false);

  void save(const std::string &filename, TrxCompression compression = TrxCompression::None);
  void save(const std::string &filename, const TrxSaveOptions &options);
  void normalize_for_save();

  void add_dps_from_text(const std::string &name, const std::string &dtype, const std::string &path);
  template <typename T>
  void add_dps_from_vector(const std::string &name, const std::string &dtype, const std::vector<T> &values);
  template <typename T>
  void add_dpv_from_vector(const std::string &name, const std::string &dtype, const std::vector<T> &values);
  void add_group_from_indices(const std::string &name, const std::vector<uint32_t> &indices);
  void set_voxel_to_rasmm(const Eigen::Matrix4f &affine);
  void add_dpv_from_tsf(const std::string &name, const std::string &dtype, const std::string &path);
  void export_dpv_to_tsf(const std::string &name,
                         const std::string &path,
                         const std::string &timestamp,
                         const std::string &dtype = "float32") const;

  void close();
  void _cleanup_temporary_directory();

  size_t num_vertices() const {
    if (streamlines && streamlines->_offsets.size() > 0) {
      const auto last = streamlines->_offsets(streamlines->_offsets.size() - 1);
      return static_cast<size_t>(last);
    }
    if (streamlines && streamlines->_data.size() > 0) {
      return static_cast<size_t>(streamlines->_data.rows());
    }
    if (header["NB_VERTICES"].is_number()) {
      return static_cast<size_t>(header["NB_VERTICES"].int_value());
    }
    return 0;
  }

  size_t num_streamlines() const {
    if (streamlines && streamlines->_offsets.size() > 0) {
      return static_cast<size_t>(streamlines->_offsets.size() - 1);
    }
    if (streamlines && streamlines->_lengths.size() > 0) {
      return static_cast<size_t>(streamlines->_lengths.size());
    }
    if (header["NB_STREAMLINES"].is_number()) {
      return static_cast<size_t>(header["NB_STREAMLINES"].int_value());
    }
    return 0;
  }

  std::unique_ptr<TrxFile<DT>> make_empty_like() const;

  std::vector<std::array<Eigen::half, 6>> build_streamline_aabbs() const;
  const std::vector<std::array<Eigen::half, 6>> &get_or_build_streamline_aabbs() const;
  void invalidate_aabb_cache() const;

  std::unique_ptr<TrxFile<DT>>
  query_aabb(const std::array<float, 3> &min_corner,
             const std::array<float, 3> &max_corner,
             const std::vector<std::array<Eigen::half, 6>> *precomputed_aabbs = nullptr,
             bool build_cache_for_result = false,
             size_t max_streamlines = 0,
             uint32_t rng_seed = 42) const;

  std::unique_ptr<TrxFile<DT>>
  subset_streamlines(const std::vector<uint32_t> &streamline_ids,
                     bool build_cache_for_result = false) const;

  const MMappedMatrix<DT> *get_dps(const std::string &name) const;
  const ArraySequence<DT> *get_dpv(const std::string &name) const;
  const MMappedMatrix<uint32_t> *get_group_members(const std::string &name) const;
  std::vector<std::array<DT, 3>> get_streamline(size_t streamline_index) const;
  template <typename Fn> void for_each_streamline(Fn &&fn) const;

  ConnectivityMatrixResult compute_group_connectivity(ConnectivityMeasure measure = ConnectivityMeasure::StreamlineCount,
                                                      const std::string &dps_field_name = "") const;

  template <typename T>
  void add_dpg_from_vector(const std::string &group,
                           const std::string &name,
                           const std::string &dtype,
                           const std::vector<T> &values,
                           int rows = 1,
                           int cols = -1);

  template <typename Derived>
  void add_dpg_from_matrix(const std::string &group,
                           const std::string &name,
                           const std::string &dtype,
                           const Eigen::MatrixBase<Derived> &matrix);

  const MMappedMatrix<DT> *get_dpg(const std::string &group, const std::string &name) const;

  std::vector<std::string> list_dpg_groups() const;

  std::vector<std::string> list_dpg_fields(const std::string &group) const;

  void remove_dpg(const std::string &group, const std::string &name);

  void remove_dpg_group(const std::string &group);

  static std::unique_ptr<TrxFile<DT>> load_from_zip(const std::string &path);

  static std::unique_ptr<TrxFile<DT>> load_from_directory(const std::string &path);

  static std::unique_ptr<TrxFile<DT>> load(const std::string &path);

  const std::string &uncompressed_folder_handle() const { return _uncompressed_folder_handle; }
  std::string &uncompressed_folder_handle() { return _uncompressed_folder_handle; }

  std::tuple<int, int>
  _copy_fixed_arrays_from(TrxFile<DT> *trx, int strs_start = 0, int pts_start = 0, int nb_strs_to_copy = -1);
  int len();

private:
  void ensure_all_groups_loaded() const;
  std::map<std::string, GroupBackingInfo> group_backing_info_;
  mutable std::vector<std::array<Eigen::half, 6>> aabb_cache_;
  std::tuple<int, int> _get_real_len();
};

namespace detail {
int _sizeof_dtype(const std::string &dtype);

struct TempFileGuard {
  std::string path;
  TempFileGuard() = default;
  TempFileGuard(const TempFileGuard &) = delete;
  TempFileGuard &operator=(const TempFileGuard &) = delete;
  ~TempFileGuard() {
    if (!path.empty()) {
      std::error_code ec;
      trx::fs::remove(path, ec);
    }
  }
};

inline std::string make_unique_temp_path(const std::string &prefix) {
  std::error_code ec;
  trx::fs::path tmp = trx::fs::temp_directory_path(ec);
  if (ec)
    tmp = trx::fs::path(".");
  thread_local std::mt19937_64 rng(std::random_device{}());
  std::uniform_int_distribution<uint64_t> dist;
  return (tmp / (prefix + "_" + std::to_string(dist(rng)) + ".bin")).string();
}

} // namespace detail

struct TypedArray {
  std::string dtype;
  int rows = 0;
  int cols = 0;
  mio::shared_mmap_sink mmap;
  std::vector<std::uint8_t> owned;

  bool empty() const { return rows == 0 || cols == 0 || (owned.empty() && mmap.data() == nullptr); }
  size_t size() const { return static_cast<size_t>(rows) * static_cast<size_t>(cols); }

  template <typename T> Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> as_matrix() {
    return Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(data_as<T>(), rows, cols);
  }

  template <typename T>
  Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> as_matrix() const {
    return Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
        data_as<T>(), rows, cols);
  }

  struct ByteView {
    const std::uint8_t *data = nullptr;
    size_t size = 0;
  };

  struct MutableByteView {
    std::uint8_t *data = nullptr;
    size_t size = 0;
  };

  ByteView to_bytes() const {
    if (empty()) {
      return {};
    }
    return {reinterpret_cast<const std::uint8_t *>(data()),
            static_cast<size_t>(detail::_sizeof_dtype(dtype)) * size()};
  }

  MutableByteView to_bytes_mutable() {
    if (empty()) {
      return {};
    }
    return {reinterpret_cast<std::uint8_t *>(data()), static_cast<size_t>(detail::_sizeof_dtype(dtype)) * size()};
  }

  void materialize_to_owned() {
    const auto bytes = to_bytes();
    if (bytes.size > 0 && bytes.data != nullptr) {
      owned.assign(bytes.data, bytes.data + bytes.size);
    } else {
      owned.clear();
    }
    mmap.unmap();
  }

private:
  const void *data() const {
    if (!owned.empty()) {
      return owned.data();
    }
    return mmap.data();
  }
  void *data() {
    if (!owned.empty()) {
      return owned.data();
    }
    return mmap.data();
  }

  template <typename T> T *data_as() {
    const std::string expected = dtype_from_scalar<T>();
    if (dtype != expected) {
      throw std::invalid_argument("TypedArray dtype mismatch: expected " + expected + " got " + dtype);
    }
    return reinterpret_cast<T *>(data());
  }

  template <typename T> const T *data_as() const {
    const std::string expected = dtype_from_scalar<T>();
    if (dtype != expected) {
      throw std::invalid_argument("TypedArray dtype mismatch: expected " + expected + " got " + dtype);
    }
    return reinterpret_cast<const T *>(data());
  }
};

enum class TrxScalarType;

class AnyTrxFile {
public:
  AnyTrxFile() = default;
  ~AnyTrxFile();

  AnyTrxFile(const AnyTrxFile &) = delete;
  AnyTrxFile &operator=(const AnyTrxFile &) = delete;
  AnyTrxFile(AnyTrxFile &&) noexcept = default;
  AnyTrxFile &operator=(AnyTrxFile &&) noexcept = default;

  json header;
  TypedArray positions;
  TypedArray offsets;
  std::vector<uint64_t> offsets_u64;
  std::vector<uint32_t> lengths;

  std::map<std::string, TypedArray> groups;
  std::map<std::string, TypedArray> data_per_streamline;
  std::map<std::string, TypedArray> data_per_vertex;
  std::map<std::string, std::map<std::string, TypedArray>> data_per_group;

  size_t num_vertices() const;
  size_t num_streamlines() const;
  void close();
  void save(const std::string &filename, TrxCompression compression = TrxCompression::None);
  void save(const std::string &filename, const TrxSaveOptions &options);

  const TypedArray *get_dps(const std::string &name) const;
  const TypedArray *get_dpv(const std::string &name) const;
  std::vector<std::array<double, 3>> get_streamline(size_t streamline_index) const;

  using PositionsChunkCallback =
      std::function<void(TrxScalarType dtype, const void *data, size_t point_offset, size_t point_count)>;
  using PositionsChunkMutableCallback =
      std::function<void(TrxScalarType dtype, void *data, size_t point_offset, size_t point_count)>;

  void for_each_positions_chunk(size_t chunk_bytes, const PositionsChunkCallback &fn) const;
  void for_each_positions_chunk_mutable(size_t chunk_bytes, const PositionsChunkMutableCallback &fn);

  static AnyTrxFile load(const std::string &path);
  static AnyTrxFile load_from_zip(const std::string &path);
  static AnyTrxFile load_from_directory(const std::string &path);

  const std::string &backing_directory() const { return _backing_directory; }
  std::string &backing_directory() { return _backing_directory; }

  const std::string &uncompressed_folder_handle() const { return _uncompressed_folder_handle; }
  std::string &uncompressed_folder_handle() { return _uncompressed_folder_handle; }

private:
  std::string _uncompressed_folder_handle;
  bool _owns_uncompressed_folder = false;
  std::string _backing_directory;

  static std::string _normalize_dtype(const std::string &dtype);
  static AnyTrxFile
  _create_from_pointer(json header,
                       const std::map<std::string, std::tuple<long long, long long>> &dict_pointer_size,
                       const std::string &root);
  void _cleanup_temporary_directory();
};

inline AnyTrxFile load_any(const std::string &path) { return AnyTrxFile::load(path); }

class TrxStream {
public:
  explicit TrxStream(std::string positions_dtype = "float32");
  ~TrxStream();

  TrxStream(const TrxStream &) = delete;
  TrxStream &operator=(const TrxStream &) = delete;
  TrxStream(TrxStream &&) noexcept = default;
  TrxStream &operator=(TrxStream &&) noexcept = default;

  void push_streamline(const float *xyz, size_t point_count);
  void push_streamline(const std::vector<float> &xyz_flat);
  void push_streamline(const std::vector<std::array<float, 3>> &points);

  void set_positions_buffer_max_bytes(std::size_t max_bytes);

  enum class MetadataMode { InMemory, OnDisk };

  TrxStream &set_metadata_mode(MetadataMode mode);

  TrxStream &set_metadata_buffer_max_bytes(std::size_t max_bytes);

  TrxStream &set_voxel_to_rasmm(const Eigen::Matrix4f &affine);

  TrxStream &set_dimensions(const std::array<uint16_t, 3> &dims);

  template <typename T>
  void push_dps_from_vector(const std::string &name, const std::string &dtype, const std::vector<T> &values);
  template <typename T>
  void push_dpv_from_vector(const std::string &name, const std::string &dtype, const std::vector<T> &values);
  void push_group_from_indices(const std::string &name, const std::vector<uint32_t> &indices);

  template <typename DT> void finalize(const std::string &filename, TrxCompression compression = TrxCompression::None);
  void finalize(const std::string &filename,
                TrxScalarType output_dtype,
                TrxCompression compression = TrxCompression::None);
  void finalize(const std::string &filename, const TrxSaveOptions &options);

  void finalize_directory(const std::string &directory);

  void finalize_directory_persistent(const std::string &directory);

  size_t num_streamlines() const { return lengths_.size(); }
  size_t num_vertices() const { return total_vertices_; }

  json header;

private:
  struct FieldValues {
    std::string dtype;
    std::vector<double> values;
  };

  struct MetadataFile {
    std::string relative_path;
    std::string absolute_path;
  };

  void ensure_positions_stream();
  void flush_positions_buffer();
  void cleanup_tmp();
  void ensure_metadata_dir(const std::string &subdir);
  void finalize_directory_impl(const std::string &directory, bool remove_existing);

  std::string positions_dtype_;
  std::string tmp_dir_;
  std::string positions_path_;
  std::ofstream positions_out_;
  std::vector<float> positions_buffer_float_;
  std::vector<Eigen::half> positions_buffer_half_;
  std::size_t positions_buffer_max_entries_ = 0;
  std::vector<uint32_t> lengths_;
  size_t total_vertices_ = 0;
  bool finalized_ = false;

  std::map<std::string, std::vector<uint32_t>> groups_;
  std::map<std::string, FieldValues> dps_;
  std::map<std::string, FieldValues> dpv_;
  MetadataMode metadata_mode_ = MetadataMode::InMemory;
  std::vector<MetadataFile> metadata_files_;
  std::size_t metadata_buffer_max_bytes_ = 8 * 1024 * 1024;
};

json assignHeader(const json &root);


template <typename DT> std::unique_ptr<TrxFile<DT>> load_from_zip(const std::string &path);

template <typename DT> std::unique_ptr<TrxFile<DT>> load_from_directory(const std::string &path);

std::string detect_positions_dtype(const std::string &path);

void write_positions_as_dtype(const AnyTrxFile &source,
                               TrxScalarType target_dtype,
                               const std::string &out_path,
                               size_t chunk_bytes = 64 * 1024 * 1024);

inline std::string scalar_type_name(TrxScalarType dtype) {
  switch (dtype) {
  case TrxScalarType::Float16:
    return "float16";
  case TrxScalarType::Float32:
    return "float32";
  case TrxScalarType::Float64:
    return "float64";
  default:
    return "float32";
  }
}

struct PositionsOutputInfo {
  std::string directory;
  std::string positions_path;
  std::string dtype;
  size_t points = 0;
};

struct PrepareOutputOptions {
  bool overwrite_existing = true;
};

PositionsOutputInfo prepare_positions_output(const AnyTrxFile &input,
                                             const std::string &output_directory,
                                             const PrepareOutputOptions &options = {});

struct MergeTrxShardsOptions {
  std::vector<std::string> shard_directories;
  std::string output_path;
  TrxCompression compression = TrxCompression::None;
  bool output_directory = false;
  bool overwrite_existing = true;
};

void merge_trx_shards(const MergeTrxShardsOptions &options);

TrxScalarType detect_positions_scalar_type(const std::string &path, TrxScalarType fallback = TrxScalarType::Float32);

struct LoadFloat32Options {
  // Number of rows copied per chunk when converting non-float32 positions.
  // Lower values reduce peak transient memory at the cost of more loop overhead.
  size_t chunk_rows = 1 << 20;
};

bool is_trx_directory(const std::string &path);

template <typename DT> std::unique_ptr<TrxFile<DT>> load(const std::string &path);

std::unique_ptr<TrxFile<float>> load_float32_positions(const std::string &path, const LoadFloat32Options &options = {});

template <typename DT> class TrxReader {
public:
  explicit TrxReader(const std::string &path);
  ~TrxReader() = default;

  TrxReader(const TrxReader &) = delete;
  TrxReader &operator=(const TrxReader &) = delete;
  TrxReader(TrxReader &&other) noexcept;
  TrxReader &operator=(TrxReader &&other) noexcept;

  TrxFile<DT> *get() const { return trx_.get(); }
  TrxFile<DT> &operator*() const { return *trx_; }
  TrxFile<DT> *operator->() const { return trx_.get(); }

private:
  std::unique_ptr<TrxFile<DT>> trx_;
};

template <typename Fn>
auto with_trx_reader(const std::string &path, Fn &&fn)
    -> decltype(fn(std::declval<TrxReader<float> &>(), TrxScalarType::Float32));

void get_reference_info(const std::string &reference,
                        const Eigen::MatrixXf &affine,
                        const Eigen::RowVectorXi &dimensions);

template <typename DT> std::ostream &operator<<(std::ostream &out, const TrxFile<DT> &trx);

template <typename DT>
inline ConnectivityMatrixResult
TrxFile<DT>::compute_group_connectivity(ConnectivityMeasure measure, const std::string &dps_field_name) const {
  const size_t num_streamlines_total = this->num_streamlines();
  ConnectivityMatrixResult result;
  result.group_names.reserve(this->groups.size());
  for (const auto &kv : this->groups) {
    result.group_names.push_back(kv.first);
  }

  const size_t G = result.group_names.size();
  const size_t packed_size = (G * (G + 1)) / 2;
  result.streamline_count_upper.assign(packed_size, 0);
  result.value_upper.assign(packed_size, 0.0);
  if (G == 0 || num_streamlines_total == 0) {
    return result;
  }

  if (measure == ConnectivityMeasure::DpsSum && dps_field_name.empty()) {
    throw TrxArgumentError("compute_group_connectivity: dps_field_name is required for DpsSum mode.");
  }

  const MMappedMatrix<DT> *dps = nullptr;
  if (measure == ConnectivityMeasure::DpsSum) {
    dps = this->get_dps(dps_field_name);
    if (dps == nullptr) {
      throw TrxFormatError("compute_group_connectivity: DPS field not found: " + dps_field_name);
    }
    if (static_cast<size_t>(dps->_matrix.rows()) != num_streamlines_total || dps->_matrix.cols() != 1) {
      throw TrxFormatError("compute_group_connectivity: DPS field must be 1D with length NB_STREAMLINES.");
    }
  }

  auto packed_index = [G](size_t i, size_t j) -> size_t {
    // Row-major upper triangle packing for i<=j.
    return i * G - (i * (i - 1)) / 2 + (j - i);
  };

  std::unordered_map<std::string, size_t> group_to_id;
  group_to_id.reserve(G);
  for (size_t gid = 0; gid < G; ++gid) {
    group_to_id[result.group_names[gid]] = gid;
  }

  std::vector<std::vector<uint32_t>> streamline_to_groups(num_streamlines_total);
  for (const auto &kv : this->groups) {
    const auto it_gid = group_to_id.find(kv.first);
    if (it_gid == group_to_id.end()) {
      continue;
    }
    const size_t gid = it_gid->second;
    const uint32_t *ids_ptr = nullptr;
    size_t n_ids = 0;
    std::vector<uint32_t> tmp_ids;

    if (kv.second != nullptr) {
      const auto &mat = kv.second->_matrix;
      ids_ptr = mat.data();
      n_ids = static_cast<size_t>(mat.size());
    } else {
      auto b = this->group_backing_info_.find(kv.first);
      if (b == this->group_backing_info_.end()) {
        continue;
      }
      const size_t expected_ids = static_cast<size_t>(std::max(0, b->second.rows)) * static_cast<size_t>(std::max(0, b->second.cols));
      tmp_ids.resize(expected_ids);
      if (expected_ids > 0) {
        std::ifstream in(b->second.filename, std::ios::binary);
        if (!in.is_open()) {
          continue;
        }
        if (b->second.mem_offset > 0) {
          const std::streamoff byte_offset =
              static_cast<std::streamoff>(b->second.mem_offset) * static_cast<std::streamoff>(sizeof(uint32_t));
          in.seekg(byte_offset, std::ios::beg);
          if (!in.good()) {
            continue;
          }
        }
        in.read(reinterpret_cast<char *>(tmp_ids.data()),
                static_cast<std::streamsize>(expected_ids * sizeof(uint32_t)));
        if (!in) {
          continue;
        }
      }
      ids_ptr = tmp_ids.data();
      n_ids = tmp_ids.size();
    }

    for (size_t n = 0; n < n_ids; ++n) {
      const uint32_t sid = ids_ptr[n];
      if (static_cast<size_t>(sid) < num_streamlines_total) {
        streamline_to_groups[static_cast<size_t>(sid)].push_back(static_cast<uint32_t>(gid));
      }
    }
  }

  for (size_t sid = 0; sid < num_streamlines_total; ++sid) {
    auto &memberships = streamline_to_groups[sid];
    if (memberships.empty()) {
      continue;
    }
    std::sort(memberships.begin(), memberships.end());
    memberships.erase(std::unique(memberships.begin(), memberships.end()), memberships.end());

    const double w =
        (measure == ConnectivityMeasure::DpsSum) ? static_cast<double>(dps->_matrix(static_cast<Eigen::Index>(sid), 0))
                                                 : 1.0;

    for (size_t a = 0; a < memberships.size(); ++a) {
      const size_t i = memberships[a];
      for (size_t b = a; b < memberships.size(); ++b) {
        const size_t j = memberships[b];
        const size_t idx = packed_index(i, j);
        result.streamline_count_upper[idx] += 1;
        result.value_upper[idx] += w;
      }
    }
  }

  if (measure == ConnectivityMeasure::StreamlineCount) {
    for (size_t idx = 0; idx < packed_size; ++idx) {
      result.value_upper[idx] = static_cast<double>(result.streamline_count_upper[idx]);
    }
  }

  return result;
}
// private:

void allocate_file(const std::string &path, std::size_t size);

// Known limitations: only row-major order supported; shape uses tuple (sufficient for 2D);
// dtype parameter is used only for byte-size computation.
mio::shared_mmap_sink _create_memmap(std::string filename,
                                     const std::tuple<int, int> &shape,
                                     const std::string &mode = "r",
                                     const std::string &dtype = "float32",
                                     long long offset = 0);

template <typename DT>
std::string _generate_filename_from_data(const Eigen::MatrixBase<DT> &arr, const std::string filename);

template <typename DT>
std::unique_ptr<TrxFile<DT>>
_initialize_empty_trx(int nb_streamlines, int nb_vertices, const TrxFile<DT> *init_as = nullptr);

template <typename DT>
void ediff1d(Eigen::Matrix<DT, Eigen::Dynamic, 1> &lengths,
             const Eigen::Matrix<DT, Eigen::Dynamic, Eigen::Dynamic> &tmp,
             uint32_t to_end);

std::string extract_trx_archive(const std::string &zip_path);

void write_trx_archive(const std::string &filename,
                        const std::string &source_dir,
                        TrxCompression compression,
                        const std::string &converted_positions_path = "",
                        const std::string &converted_positions_entry = "",
                        const std::unordered_set<std::string> *skip = nullptr);

std::string get_base(const std::string &delimiter, const std::string &str);
std::string get_ext(const std::string &str);
void populate_fps(const std::string &name, std::map<std::string, std::tuple<long long, long long>> &files_pointer_size);

void copy_dir(const std::string &src, const std::string &dst);
void copy_file(const std::string &src, const std::string &dst);
int rm_dir(const std::string &d);
std::string make_temp_dir(const std::string &prefix);

std::string rm_root(const std::string &root, const std::string &path);

void append_groups_to_zip(const std::string &path,
                          const std::map<std::string, std::vector<uint32_t>> &groups,
                          TrxCompression compression = TrxCompression::None, bool overwrite = true);

void append_groups_to_directory(const std::string &directory,
                                const std::map<std::string, std::vector<uint32_t>> &groups,
                                bool overwrite = true);

void append_dps_to_zip(const std::string &path, const std::map<std::string, TypedArray> &dps,
                       TrxCompression compression = TrxCompression::None, bool overwrite = true);

void append_dps_to_directory(const std::string &directory,
                              const std::map<std::string, TypedArray> &dps,
                              bool overwrite = true);

void append_dpv_to_zip(const std::string &path, const std::map<std::string, TypedArray> &dpv,
                       TrxCompression compression = TrxCompression::None, bool overwrite = true);

void append_dpv_to_directory(const std::string &directory,
                              const std::map<std::string, TypedArray> &dpv,
                              bool overwrite = true);

std::string format_groups_summary(const std::map<std::string, size_t> &groups, int prefix_depth = 0,
                                   const std::string &line_prefix = "");

#ifndef TRX_TPP_STANDALONE
#endif

} // namespace trx

#include <trx/detail/dtype_helpers.h>

namespace trx {
#ifndef TRX_TPP_STANDALONE
#include <trx/trx.tpp>
#endif
} // namespace trx

#endif /* TRX_H */