mlpack  1.0.12
svd_batch_learning.hpp
Go to the documentation of this file.
1 
12 #ifndef __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
13 #define __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
14 
15 #include <mlpack/core.hpp>
16 
17 namespace mlpack {
18 namespace amf {
19 
30 {
31  public:
40  SVDBatchLearning(double u = 0.0002,
41  double kw = 0,
42  double kh = 0,
43  double momentum = 0.9,
44  double min = -DBL_MIN,
45  double max = DBL_MAX)
46  : u(u), kw(kw), kh(kh), min(min), max(max), momentum(momentum)
47  {}
48 
49  template<typename MatType>
50  void Initialize(const MatType& dataset, const size_t rank)
51  {
52  const size_t n = dataset.n_rows;
53  const size_t m = dataset.n_cols;
54 
55  mW.zeros(n, rank);
56  mH.zeros(rank, m);
57  }
58 
68  template<typename MatType>
69  inline void WUpdate(const MatType& V,
70  arma::mat& W,
71  const arma::mat& H)
72  {
73  size_t n = V.n_rows;
74  size_t m = V.n_cols;
75 
76  size_t r = W.n_cols;
77 
78  mW = momentum * mW;
79 
80  arma::mat deltaW(n, r);
81  deltaW.zeros();
82 
83  for(size_t i = 0;i < n;i++)
84  {
85  for(size_t j = 0;j < m;j++)
86  {
87  double val;
88  if((val = V(i, j)) != 0)
89  deltaW.row(i) += (val - arma::dot(W.row(i), H.col(j))) *
90  arma::trans(H.col(j));
91  }
92  if(kw != 0) deltaW.row(i) -= kw * W.row(i);
93  }
94 
95  mW += u * deltaW;
96  W += mW;
97  }
98 
108  template<typename MatType>
109  inline void HUpdate(const MatType& V,
110  const arma::mat& W,
111  arma::mat& H)
112  {
113  size_t n = V.n_rows;
114  size_t m = V.n_cols;
115 
116  size_t r = W.n_cols;
117 
118  mH = momentum * mH;
119 
120  arma::mat deltaH(r, m);
121  deltaH.zeros();
122 
123  for(size_t j = 0;j < m;j++)
124  {
125  for(size_t i = 0;i < n;i++)
126  {
127  double val;
128  if((val = V(i, j)) != 0)
129  deltaH.col(j) += (val - arma::dot(W.row(i), H.col(j))) *
130  arma::trans(W.row(i));
131  }
132  if(kh != 0) deltaH.col(j) -= kh * H.col(j);
133  }
134 
135  mH += u*deltaH;
136  H += mH;
137  }
138 
139  private:
140  double u;
141  double kw;
142  double kh;
143  double min;
144  double max;
145  double momentum;
146 
147  arma::mat mW;
148  arma::mat mH;
149 };
150 
153 
157 template<>
158 inline void SVDBatchLearning::WUpdate<arma::sp_mat>(const arma::sp_mat& V,
159  arma::mat& W,
160  const arma::mat& H)
161 {
162  size_t n = V.n_rows;
163 
164  size_t r = W.n_cols;
165 
166  mW = momentum * mW;
167 
168  arma::mat deltaW(n, r);
169  deltaW.zeros();
170 
171  for(arma::sp_mat::const_iterator it = V.begin();it != V.end();it++)
172  {
173  size_t row = it.row();
174  size_t col = it.col();
175  deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) *
176  arma::trans(H.col(col));
177  }
178 
179  if(kw != 0) for(size_t i = 0; i < n; i++)
180  {
181  deltaW.row(i) -= kw * W.row(i);
182  }
183 
184  mW += u * deltaW;
185  W += mW;
186 }
187 
188 template<>
189 inline void SVDBatchLearning::HUpdate<arma::sp_mat>(const arma::sp_mat& V,
190  const arma::mat& W,
191  arma::mat& H)
192 {
193  size_t m = V.n_cols;
194 
195  size_t r = W.n_cols;
196 
197  mH = momentum * mH;
198 
199  arma::mat deltaH(r, m);
200  deltaH.zeros();
201 
202  for(arma::sp_mat::const_iterator it = V.begin();it != V.end();it++)
203  {
204  size_t row = it.row();
205  size_t col = it.col();
206  deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) *
207  arma::trans(W.row(row));
208  }
209 
210  if(kh != 0) for(size_t j = 0; j < m; j++)
211  {
212  deltaH.col(j) -= kh * H.col(j);
213  }
214 
215  mH += u*deltaH;
216  H += mH;
217 }
218 
219 } // namespace amf
220 } // namespace mlpack
221 
222 #endif
223 
224 
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: load.hpp:23
void Initialize(const MatType &dataset, const size_t rank)
SVDBatchLearning(double u=0.0002, double kw=0, double kh=0, double momentum=0.9, double min=-DBL_MIN, double max=DBL_MAX)
SVD Batch learning constructor.
This class implements SVD batch learning with momentum.
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.