ChannelShuffle< InputDataType, OutputDataType > Class Template Reference

Definition and implementation of the Channel Shuffle Layer. More...

Public Member Functions

 ChannelShuffle ()
 Create the Channel Shuffle object. More...

 
 ChannelShuffle (const size_t inRowSize, const size_t inColSize, const size_t depth, const size_t groupCount)
 The constructor for the Channel Shuffle. More...

 
template
<
typename
eT
>
void Backward (const arma::Mat< eT > &, const arma::Mat< eT > &gradient, arma::Mat< eT > &output)
 Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backwards through f. More...

 
OutputDataType const & Delta () const
 Get the delta. More...

 
OutputDataType & Delta ()
 Modify the delta. More...

 
template
<
typename
eT
>
void Forward (const arma::Mat< eT > &input, arma::Mat< eT > &output)
 Forward pass through the layer. More...

 
size_t const & InColSize () const
 Get the column size of the input. More...

 
size_t & InColSize ()
 Modify the column size of the input. More...

 
size_t const & InDepth () const
 Get the depth of the input. More...

 
size_t & InDepth ()
 Modify the depth of the input. More...

 
size_t const & InGroupCount () const
 Get the number of groups the channels is divided into. More...

 
size_t & InGroupCount ()
 Modify the number of groups the channels is divided into. More...

 
size_t InputShape () const
 Get the shape of the input. More...

 
size_t const & InRowSize () const
 Get the row size of the input. More...

 
size_t & InRowSize ()
 Modify the row size of the input. More...

 
OutputDataType const & OutputParameter () const
 Get the output parameter. More...

 
OutputDataType & OutputParameter ()
 Modify the output parameter. More...

 
template
<
typename
Archive
>
void serialize (Archive &ar, const uint32_t)
 Serialize the layer. More...

 

Detailed Description


template
<
typename
InputDataType
=
arma::mat
,
typename
OutputDataType
=
arma::mat
>

class mlpack::ann::ChannelShuffle< InputDataType, OutputDataType >

Definition and implementation of the Channel Shuffle Layer.

Channel Shuffle divides the channels/units in a tensor into groups and rearrange while keeping the original tensor shape.

For more information, refer to the following paper,

@article{zhang2018shufflenet,
author = {Xiangyu Zhang, Xinyu Zhou, Mengxiao Lin, Jian Sun and
Megvii Inc},
title = {Shufflenet: An extremely efficient convolutional neural
network for mobile devices},
year = {2018},
url = {https://arxiv.org/pdf/1707.01083},
}
Template Parameters
InputDataTypeType of the input data (arma::colvec, arma::mat, arma::sp_mat or arma::cube).
OutputDataTypeType of the output data (arma::colvec, arma::mat, arma::sp_mat or arma::cube).

Definition at line 46 of file channel_shuffle.hpp.

Constructor & Destructor Documentation

◆ ChannelShuffle() [1/2]

Create the Channel Shuffle object.

◆ ChannelShuffle() [2/2]

ChannelShuffle ( const size_t  inRowSize,
const size_t  inColSize,
const size_t  depth,
const size_t  groupCount 
)

The constructor for the Channel Shuffle.

Parameters
inRowSizeNumber of input rows.
inColSizeNumber of input columns.
depthNumber of input slices.
groupNumber of groups for shuffling channels.

Member Function Documentation

◆ Backward()

void Backward ( const arma::Mat< eT > &  ,
const arma::Mat< eT > &  gradient,
arma::Mat< eT > &  output 
)

Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backwards through f.

Using the results from the feed forward pass. Since the layer does not have any learn-able parameters, we just have to down-sample the gradient to make its size compatible with the input size.

Parameters
*(input) The input matrix.
gradientThe computed backward gradient.
outputThe resulting down-sampled output.

◆ Delta() [1/2]

OutputDataType const& Delta ( ) const
inline

Get the delta.

Definition at line 96 of file channel_shuffle.hpp.

◆ Delta() [2/2]

OutputDataType& Delta ( )
inline

Modify the delta.

Definition at line 98 of file channel_shuffle.hpp.

◆ Forward()

void Forward ( const arma::Mat< eT > &  input,
arma::Mat< eT > &  output 
)

Forward pass through the layer.

Parameters
inputThe input matrix.
outputThe resulting interpolated output matrix.

◆ InColSize() [1/2]

size_t const& InColSize ( ) const
inline

Get the column size of the input.

Definition at line 106 of file channel_shuffle.hpp.

◆ InColSize() [2/2]

size_t& InColSize ( )
inline

Modify the column size of the input.

Definition at line 108 of file channel_shuffle.hpp.

◆ InDepth() [1/2]

size_t const& InDepth ( ) const
inline

Get the depth of the input.

Definition at line 111 of file channel_shuffle.hpp.

◆ InDepth() [2/2]

size_t& InDepth ( )
inline

Modify the depth of the input.

Definition at line 113 of file channel_shuffle.hpp.

◆ InGroupCount() [1/2]

size_t const& InGroupCount ( ) const
inline

Get the number of groups the channels is divided into.

Definition at line 116 of file channel_shuffle.hpp.

◆ InGroupCount() [2/2]

size_t& InGroupCount ( )
inline

Modify the number of groups the channels is divided into.

Definition at line 118 of file channel_shuffle.hpp.

◆ InputShape()

size_t InputShape ( ) const
inline

Get the shape of the input.

Definition at line 121 of file channel_shuffle.hpp.

References ChannelShuffle< InputDataType, OutputDataType >::serialize().

◆ InRowSize() [1/2]

size_t const& InRowSize ( ) const
inline

Get the row size of the input.

Definition at line 101 of file channel_shuffle.hpp.

◆ InRowSize() [2/2]

size_t& InRowSize ( )
inline

Modify the row size of the input.

Definition at line 103 of file channel_shuffle.hpp.

◆ OutputParameter() [1/2]

OutputDataType const& OutputParameter ( ) const
inline

Get the output parameter.

Definition at line 91 of file channel_shuffle.hpp.

◆ OutputParameter() [2/2]

OutputDataType& OutputParameter ( )
inline

Modify the output parameter.

Definition at line 93 of file channel_shuffle.hpp.

◆ serialize()

void serialize ( Archive &  ar,
const uint32_t   
)

The documentation for this class was generated from the following file: