13 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_DOUBLE_POLE_CART_HPP 14 #define MLPACK_METHODS_RL_ENVIRONMENT_DOUBLE_POLE_CART_HPP 49 State(
const arma::colvec& data) : data(data)
53 arma::colvec
Data()
const {
return data; }
55 arma::colvec&
Data() {
return data; }
68 double Angle(
const size_t i)
const {
return data[2 * i]; }
70 double&
Angle(
const size_t i) {
return data[2 * i]; }
78 const arma::colvec&
Encode()
const {
return data; }
103 static const size_t size = 2;
124 const double m1 = 0.1,
125 const double m2 = 0.01,
126 const double l1 = 0.5,
127 const double l2 = 0.05,
128 const double gravity = 9.8,
129 const double massCart = 1.0,
130 const double forceMag = 10.0,
131 const double tau = 0.02,
132 const double thetaThresholdRadians = 36 * 2 * 3.1416 / 360,
133 const double xThreshold = 2.4,
134 const double doneReward = 0.0) :
144 thetaThresholdRadians(thetaThresholdRadians),
145 xThreshold(xThreshold),
146 doneReward(doneReward),
166 arma::vec dydx(6, arma::fill::zeros);
170 Dsdt(state, action, dydx);
171 RK4(state, action, dydx, nextState);
177 if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
201 double totalForce = action.
action ? forceMag : -forceMag;
202 double totalMass = massCart;
205 double sinTheta1 = std::sin(state.
Angle(1));
206 double sinTheta2 = std::sin(state.
Angle(2));
207 double cosTheta1 = std::cos(state.
Angle(1));
208 double cosTheta2 = std::cos(state.
Angle(2));
211 totalForce += m1 * l1 * omega1 * omega1 * sinTheta1 + 0.375 * m1 * gravity *
212 std::sin(2 * state.
Angle(1));
213 totalForce += m2 * l2 * omega2 * omega2 * sinTheta1 + 0.375 * m2 * gravity *
214 std::sin(2 * state.
Angle(2));
217 totalMass += m1 * (0.25 + 0.75 * sinTheta1 * sinTheta1);
218 totalMass += m2 * (0.25 + 0.75 * sinTheta2 * sinTheta2);
221 double xAcc = totalForce / totalMass;
225 dydx[3] = -0.75 * (xAcc * cosTheta1 + gravity * sinTheta1) / l1;
226 dydx[5] = -0.75 * (xAcc * cosTheta2 + gravity * sinTheta2) / l2;
243 const double hh = tau * 0.5;
244 const double h6 = tau / 6;
249 yt = state.
Data() + (hh * dydx);
254 yt = state.
Data() + (hh * dyt);
260 yt = state.
Data() + (tau * dym);
267 nextState.
Data() = state.
Data() + h6 * (dydx + dyt + 2 * dym);
281 return Sample(state, action, nextState);
292 return State((arma::randu<arma::vec>(6) - 0.5) / 10.0);
303 if (maxSteps != 0 && stepsPerformed >= maxSteps)
305 Log::Info <<
"Episode terminated due to the maximum number of steps" 309 if (std::abs(state.
Position()) > xThreshold)
311 Log::Info <<
"Episode terminated due to cart crossing threshold";
314 if (std::abs(state.
Angle(1)) > thetaThresholdRadians ||
315 std::abs(state.
Angle(2)) > thetaThresholdRadians)
317 Log::Info <<
"Episode terminated due to pole falling";
360 double thetaThresholdRadians;
369 size_t stepsPerformed;
double Sample(const State &state, const Action &action)
Dynamics of Double Pole Cart.
State(const arma::colvec &data)
Construct a state instance from given data.
size_t MaxSteps() const
Get the maximum number of steps allowed.
double & Velocity()
Modify the velocity of the cart.
Linear algebra utility functions, generally performed on matrices or vectors.
DoublePoleCart(const size_t maxSteps=0, const double m1=0.1, const double m2=0.01, const double l1=0.5, const double l2=0.05, const double gravity=9.8, const double massCart=1.0, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=36 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=0.0)
Construct a Double Pole Cart instance using the given constants.
Implementation of Double Pole Cart Balancing task.
double & Angle(const size_t i)
Modify the angle of the $i^{th}$ pole.
void Dsdt(const State &state, const Action &action, arma::vec &dydx)
This is the ordinary differential equations required for estimation of next state through RK4 method...
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Position() const
Get the position of the cart.
double Angle(const size_t i) const
Get the angle of the $i^{th}$ pole.
size_t StepsPerformed() const
Get the number of steps performed.
size_t & MaxSteps()
Set the maximum number of steps allowed.
State InitialSample()
Initial state representation is randomly generated within [-0.05, 0.05].
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
void RK4(const State &state, const Action &action, arma::vec &dydx, State &nextState)
This function calls the RK4 iterative method to estimate the next state based on given ordinary diffe...
double & AngularVelocity(const size_t i)
Modify the angular velocity of the $i^{th}$ pole.
double AngularVelocity(const size_t i) const
Get the angular velocity of the $i^{th}$ pole.
bool IsTerminal(const State &state) const
This function checks if the car has reached the terminal state.
double Sample(const State &state, const Action &action, State &nextState)
Dynamics of Double Pole Cart instance.
arma::colvec & Data()
Modify the internal representation of the state.
const arma::colvec & Encode() const
Encode the state to a vector..
Implementation of action of Double Pole Cart.
Implementation of the state of Double Pole Cart.
State()
Construct a state instance.
double Velocity() const
Get the velocity of the cart.
static constexpr size_t dimension
Dimension of the encoded state.
double & Position()
Modify the position of the cart.
arma::colvec Data() const
Get the internal representation of the state.