15 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP 16 #define MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP 47 State(
const arma::colvec& data) : data(data)
51 arma::colvec&
Data() {
return data; }
64 double Angle()
const {
return data[2]; }
66 double&
Angle() {
return data[2]; }
74 const arma::colvec&
Encode()
const {
return data; }
99 static const size_t size = 2;
118 const double gravity = 9.8,
119 const double massCart = 1.0,
120 const double massPole = 0.1,
121 const double length = 0.5,
122 const double forceMag = 10.0,
123 const double tau = 0.02,
124 const double thetaThresholdRadians = 12 * 2 * 3.1416 / 360,
125 const double xThreshold = 2.4,
126 const double doneReward = 1.0) :
131 totalMass(massCart + massPole),
133 poleMassLength(massPole * length),
136 thetaThresholdRadians(thetaThresholdRadians),
137 xThreshold(xThreshold),
138 doneReward(doneReward),
159 double force = action.
action ? forceMag : -forceMag;
160 double cosTheta = std::cos(state.
Angle());
161 double sinTheta = std::sin(state.
Angle());
164 double thetaAcc = (gravity * sinTheta - cosTheta * temp) /
165 (length * (4.0 / 3.0 - massPole * cosTheta * cosTheta / totalMass));
166 double xAcc = temp - poleMassLength * thetaAcc * cosTheta / totalMass;
178 if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
199 return Sample(state, action, nextState);
210 return State((arma::randu<arma::colvec>(4) - 0.5) / 10.0);
221 if (maxSteps != 0 && stepsPerformed >= maxSteps)
223 Log::Info <<
"Episode terminated due to the maximum number of steps" 227 else if (std::abs(state.
Position()) > xThreshold ||
228 std::abs(state.
Angle()) > thetaThresholdRadians)
230 Log::Info <<
"Episode terminated due to agent failing.";
264 double poleMassLength;
273 double thetaThresholdRadians;
282 size_t stepsPerformed;
double Velocity() const
Get the velocity.
State(const arma::colvec &data)
Construct a state instance from given data.
double AngularVelocity() const
Get the angular velocity.
double & Velocity()
Modify the velocity.
Linear algebra utility functions, generally performed on matrices or vectors.
double Sample(const State &state, const Action &action)
Dynamics of Cart Pole.
Implementation of action of Cart Pole.
State()
Construct a state instance.
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Sample(const State &state, const Action &action, State &nextState)
Dynamics of Cart Pole instance.
CartPole(const size_t maxSteps=200, const double gravity=9.8, const double massCart=1.0, const double massPole=0.1, const double length=0.5, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=12 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=1.0)
Construct a Cart Pole instance using the given constants.
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
bool IsTerminal(const State &state) const
This function checks if the cart has reached the terminal state.
double Position() const
Get the position.
Implementation of the state of Cart Pole.
size_t & MaxSteps()
Set the maximum number of steps allowed.
double Angle() const
Get the angle.
double & Angle()
Modify the angle.
double & Position()
Modify the position.
double & AngularVelocity()
Modify the angular velocity.
arma::colvec & Data()
Modify the internal representation of the state.
static constexpr size_t dimension
Dimension of the encoded state.
const arma::colvec & Encode() const
Encode the state to a column vector.
Implementation of Cart Pole task.
size_t StepsPerformed() const
Get the number of steps performed.
State InitialSample()
Initial state representation is randomly generated within [-0.05, 0.05].
size_t MaxSteps() const
Get the maximum number of steps allowed.