14 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_CONTINUOUS_DOUBLE_POLE_CART_HPP 15 #define MLPACK_METHODS_RL_ENVIRONMENT_CONTINUOUS_DOUBLE_POLE_CART_HPP 50 State(
const arma::colvec& data) : data(data)
54 arma::colvec
Data()
const {
return data; }
56 arma::colvec&
Data() {
return data; }
69 double Angle(
const size_t i)
const {
return data[2 * i]; }
71 double&
Angle(
const size_t i) {
return data[2 * i]; }
79 const arma::colvec&
Encode()
const {
return data; }
117 const double m2 = 0.01,
118 const double l1 = 0.5,
119 const double l2 = 0.05,
120 const double gravity = 9.8,
121 const double massCart = 1.0,
122 const double forceMag = 10.0,
123 const double tau = 0.02,
124 const double thetaThresholdRadians = 36 * 2 *
126 const double xThreshold = 2.4,
127 const double doneReward = 0.0,
128 const size_t maxSteps = 0) :
137 thetaThresholdRadians(thetaThresholdRadians),
138 xThreshold(xThreshold),
139 doneReward(doneReward),
160 arma::vec dydx(6, arma::fill::zeros);
164 Dsdt(state, action, dydx);
165 RK4(state, action, dydx, nextState);
171 if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
195 double totalForce = action.
action[0];
196 double totalMass = massCart;
199 double sinTheta1 = std::sin(state.
Angle(1));
200 double sinTheta2 = std::sin(state.
Angle(2));
201 double cosTheta1 = std::cos(state.
Angle(1));
202 double cosTheta2 = std::cos(state.
Angle(2));
205 totalForce += m1 * l1 * omega1 * omega1 * sinTheta1 + 0.375 * m1 * gravity *
206 std::sin(2 * state.
Angle(1));
207 totalForce += m2 * l2 * omega2 * omega2 * sinTheta1 + 0.375 * m2 * gravity *
208 std::sin(2 * state.
Angle(2));
211 totalMass += m1 * (0.25 + 0.75 * sinTheta1 * sinTheta1);
212 totalMass += m2 * (0.25 + 0.75 * sinTheta2 * sinTheta2);
215 double xAcc = totalForce / totalMass;
219 dydx[3] = -0.75 * (xAcc * cosTheta1 + gravity * sinTheta1) / l1;
220 dydx[5] = -0.75 * (xAcc * cosTheta2 + gravity * sinTheta2) / l2;
237 const double hh = tau * 0.5;
238 const double h6 = tau / 6;
243 yt = state.
Data() + (hh * dydx);
248 yt = state.
Data() + (hh * dyt);
254 yt = state.
Data() + (tau * dym);
261 nextState.
Data() = state.
Data() + h6 * (dydx + dyt + 2 * dym);
275 return Sample(state, action, nextState);
286 return State((arma::randu<arma::vec>(6) - 0.5) / 10.0);
297 if (maxSteps != 0 && stepsPerformed >= maxSteps)
299 Log::Info <<
"Episode terminated due to the maximum number of steps" 303 if (std::abs(state.
Position()) > xThreshold)
305 Log::Info <<
"Episode terminated due to cart crossing threshold";
308 if (std::abs(state.
Angle(1)) > thetaThresholdRadians ||
309 std::abs(state.
Angle(2)) > thetaThresholdRadians)
311 Log::Info <<
"Episode terminated due to pole falling";
351 double thetaThresholdRadians;
363 size_t stepsPerformed;
double & AngularVelocity(const size_t i)
Modify the angular velocity of the $i^{th}$ pole.
bool IsTerminal(const State &state) const
This function checks if the car has reached the terminal state.
size_t & MaxSteps()
Set the maximum number of steps allowed.
State(const arma::colvec &data)
Construct a state instance from given data.
size_t MaxSteps() const
Get the maximum number of steps allowed.
double & Velocity()
Modify the velocity of the cart.
Linear algebra utility functions, generally performed on matrices or vectors.
Implementation of action of Continuous Double Pole Cart.
The core includes that mlpack expects; standard C++ includes and Armadillo.
size_t StepsPerformed() const
Get the number of steps performed.
const arma::colvec & Encode() const
Encode the state to a vector..
Implementation of the state of Continuous Double Pole Cart.
static constexpr size_t dimension
Dimension of the encoded state.
State()
Construct a state instance.
double & Position()
Modify the position of the cart.
State InitialSample()
Initial state representation is randomly generated within [-0.05, 0.05].
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
void RK4(const State &state, const Action &action, arma::vec &dydx, State &nextState)
This function calls the RK4 iterative method to estimate the next state based on given ordinary diffe...
double Position() const
Get the position of the cart.
void Dsdt(const State &state, const Action &action, arma::vec &dydx)
This is the ordinary differential equations required for estimation of next state through RK4 method...
double Velocity() const
Get the velocity of the cart.
arma::colvec Data() const
Get the internal representation of the state.
double Sample(const State &state, const Action &action, State &nextState)
Dynamics of Continuous Double Pole Cart instance.
double & Angle(const size_t i)
Modify the angle of the $i^{th}$ pole.
double AngularVelocity(const size_t i) const
Get the angular velocity of the $i^{th}$ pole.
double Angle(const size_t i) const
Get the angle of the $i^{th}$ pole.
ContinuousDoublePoleCart(const double m1=0.1, const double m2=0.01, const double l1=0.5, const double l2=0.05, const double gravity=9.8, const double massCart=1.0, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=36 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=0.0, const size_t maxSteps=0)
Construct a Double Pole Cart instance using the given constants.
arma::colvec & Data()
Modify the internal representation of the state.
Implementation of Continuous Double Pole Cart Balancing task.
double Sample(const State &state, const Action &action)
Dynamics of Continuous Double Pole Cart.