sumtree.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_SUMTREE_HPP
14 #define MLPACK_METHODS_RL_SUMTREE_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace rl {
20 
31 template<typename T>
32 class SumTree
33 {
34  public:
38  SumTree() : capacity(0)
39  { /* Nothing to do here. */ }
40 
46  SumTree(const size_t capacity) : capacity(capacity)
47  {
48  element = std::vector<T>(2 * capacity);
49  }
50 
57  void Set(size_t idx, const T value)
58  {
59  idx += capacity;
60  element[idx] = value;
61  idx /= 2;
62  while (idx >= 1)
63  {
64  element[idx] = element[2 * idx] + element[2 * idx + 1];
65  idx /= 2;
66  }
67  }
68 
75  void BatchUpdate(const arma::ucolvec& indices, const arma::Col<T>& data)
76  {
77  for (size_t i = 0; i < indices.n_rows; ++i)
78  {
79  element[indices[i] + capacity] = data[i];
80  }
81  // update the total tree with bottom-up technique.
82  for (size_t i = capacity - 1; i > 0; i--)
83  {
84  element[i] = element[2 * i] + element[2 * i + 1];
85  }
86  }
87 
93  T Get(size_t idx)
94  {
95  idx += capacity;
96  return element[idx];
97  }
98 
108  T SumHelper(const size_t start,
109  const size_t end,
110  const size_t node,
111  const size_t nodeStart,
112  const size_t nodeEnd)
113  {
114  if (start == nodeStart && end == nodeEnd)
115  {
116  return element[node];
117  }
118  size_t mid = (nodeStart + nodeEnd) / 2;
119  if (end <= mid)
120  {
121  return SumHelper(start, end, 2 * node, nodeStart, mid);
122  }
123  else
124  {
125  if (mid + 1 <= start)
126  {
127  return SumHelper(start, end, 2 * node + 1, mid + 1 , nodeEnd);
128  }
129  else
130  {
131  return SumHelper(start, mid, 2 * node, nodeStart, mid) +
132  SumHelper(mid + 1, end, 2 * node + 1, mid + 1 , nodeEnd);
133  }
134  }
135  }
136 
143  T Sum(const size_t start, size_t end)
144  {
145  end -= 1;
146  return SumHelper(start, end, 1, 0, capacity - 1);
147  }
148 
152  T Sum()
153  {
154  return Sum(0, capacity);
155  }
156 
163  size_t FindPrefixSum(T mass)
164  {
165  size_t idx = 1;
166  while (idx < capacity)
167  {
168  if (element[2 * idx] > mass)
169  {
170  idx = 2 * idx;
171  }
172  else
173  {
174  mass -= element[2 * idx];
175  idx = 2 * idx + 1;
176  }
177  }
178  return idx - capacity;
179  }
180 
181  private:
183  size_t capacity;
184 
186  std::vector<T> element;
187 };
188 
189 } // namespace rl
190 } // namespace mlpack
191 
192 #endif
SumTree(const size_t capacity)
Construct an instance of SumTree class.
Definition: sumtree.hpp:46
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
T SumHelper(const size_t start, const size_t end, const size_t node, const size_t nodeStart, const size_t nodeEnd)
Help function for the sum function.
Definition: sumtree.hpp:108
T Get(size_t idx)
Get the data array with idx.
Definition: sumtree.hpp:93
void BatchUpdate(const arma::ucolvec &indices, const arma::Col< T > &data)
Update the data with batch rather loop over the indices with set method.
Definition: sumtree.hpp:75
SumTree()
Default constructor.
Definition: sumtree.hpp:38
size_t FindPrefixSum(T mass)
Find the highest index idx in the array such that sum(arr[0] + arr[1] + ...
Definition: sumtree.hpp:163
T Sum(const size_t start, size_t end)
Calculate the sum of contiguous subsequence of the array.
Definition: sumtree.hpp:143
T Sum()
Shortcut for calculating the sum of whole array.
Definition: sumtree.hpp:152
void Set(size_t idx, const T value)
Set the data array with idx.
Definition: sumtree.hpp:57
Implementation of SumTree.
Definition: sumtree.hpp:32