12 #ifndef __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
13 #define __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
44 double min = -DBL_MIN,
49 template<
typename MatType>
50 void Initialize(
const MatType& dataset,
const size_t rank)
52 const size_t n = dataset.n_rows;
53 const size_t m = dataset.n_cols;
68 template<
typename MatType>
80 arma::mat deltaW(n, r);
83 for(
size_t i = 0;i < n;i++)
85 for(
size_t j = 0;j < m;j++)
88 if((val = V(i, j)) != 0)
89 deltaW.row(i) += (val - arma::dot(W.row(i), H.col(j))) *
90 arma::trans(H.col(j));
92 if(
kw != 0) deltaW.row(i) -=
kw * W.row(i);
108 template<
typename MatType>
120 arma::mat deltaH(r, m);
123 for(
size_t j = 0;j < m;j++)
125 for(
size_t i = 0;i < n;i++)
128 if((val = V(i, j)) != 0)
129 deltaH.col(j) += (val - arma::dot(W.row(i), H.col(j))) *
130 arma::trans(W.row(i));
132 if(
kh != 0) deltaH.col(j) -=
kh * H.col(j);
158 inline void SVDBatchLearning::WUpdate<arma::sp_mat>(
const arma::sp_mat& V,
168 arma::mat deltaW(n, r);
171 for(arma::sp_mat::const_iterator it = V.begin();it != V.end();it++)
173 size_t row = it.row();
174 size_t col = it.col();
175 deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) *
176 arma::trans(H.col(col));
179 if(kw != 0)
for(
size_t i = 0; i < n; i++)
181 deltaW.row(i) -= kw * W.row(i);
189 inline void SVDBatchLearning::HUpdate<arma::sp_mat>(
const arma::sp_mat& V,
199 arma::mat deltaH(r, m);
202 for(arma::sp_mat::const_iterator it = V.begin();it != V.end();it++)
204 size_t row = it.row();
205 size_t col = it.col();
206 deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) *
207 arma::trans(W.row(row));
210 if(kh != 0)
for(
size_t j = 0; j < m; j++)
212 deltaH.col(j) -= kh * H.col(j);
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
Linear algebra utility functions, generally performed on matrices or vectors.
void Initialize(const MatType &dataset, const size_t rank)
SVDBatchLearning(double u=0.0002, double kw=0, double kh=0, double momentum=0.9, double min=-DBL_MIN, double max=DBL_MAX)
SVD Batch learning constructor.
This class implements SVD batch learning with momentum.
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.