13 #ifndef MLPACK_METHODS_RL_SUMTREE_HPP 14 #define MLPACK_METHODS_RL_SUMTREE_HPP 46 SumTree(
const size_t capacity) : capacity(capacity)
48 element = std::vector<T>(2 * capacity);
57 void Set(
size_t idx,
const T value)
64 element[idx] = element[2 * idx] + element[2 * idx + 1];
75 void BatchUpdate(
const arma::ucolvec& indices,
const arma::Col<T>& data)
77 for (
size_t i = 0; i < indices.n_rows; ++i)
79 element[indices[i] + capacity] = data[i];
82 for (
size_t i = capacity - 1; i > 0; i--)
84 element[i] = element[2 * i] + element[2 * i + 1];
111 const size_t nodeStart,
112 const size_t nodeEnd)
114 if (start == nodeStart && end == nodeEnd)
116 return element[node];
118 size_t mid = (nodeStart + nodeEnd) / 2;
121 return SumHelper(start, end, 2 * node, nodeStart, mid);
125 if (mid + 1 <= start)
127 return SumHelper(start, end, 2 * node + 1, mid + 1 , nodeEnd);
131 return SumHelper(start, mid, 2 * node, nodeStart, mid) +
132 SumHelper(mid + 1, end, 2 * node + 1, mid + 1 , nodeEnd);
143 T
Sum(
const size_t start,
size_t end)
146 return SumHelper(start, end, 1, 0, capacity - 1);
154 return Sum(0, capacity);
166 while (idx < capacity)
168 if (element[2 * idx] > mass)
174 mass -= element[2 * idx];
178 return idx - capacity;
186 std::vector<T> element;
SumTree(const size_t capacity)
Construct an instance of SumTree class.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
T SumHelper(const size_t start, const size_t end, const size_t node, const size_t nodeStart, const size_t nodeEnd)
Help function for the sum function.
T Get(size_t idx)
Get the data array with idx.
void BatchUpdate(const arma::ucolvec &indices, const arma::Col< T > &data)
Update the data with batch rather loop over the indices with set method.
SumTree()
Default constructor.
size_t FindPrefixSum(T mass)
Find the highest index idx in the array such that sum(arr[0] + arr[1] + ...
T Sum(const size_t start, size_t end)
Calculate the sum of contiguous subsequence of the array.
T Sum()
Shortcut for calculating the sum of whole array.
void Set(size_t idx, const T value)
Set the data array with idx.
Implementation of SumTree.