mlpack  1.0.12
svd_complete_incremental_learning.hpp
Go to the documentation of this file.
1 #ifndef SVD_COMPLETE_INCREMENTAL_LEARNING_HPP_INCLUDED
2 #define SVD_COMPLETE_INCREMENTAL_LEARNING_HPP_INCLUDED
3 
4 #include <mlpack/core.hpp>
5 
6 namespace mlpack
7 {
8 namespace amf
9 {
10 
11 template <class MatType>
13 {
14  public:
16  double kw = 0,
17  double kh = 0)
18  : u(u), kw(kw), kh(kh)
19  {}
20 
21  void Initialize(const MatType& dataset, const size_t rank)
22  {
23  (void)rank;
24  n = dataset.n_rows;
25  m = dataset.n_cols;
26 
27  currentUserIndex = 0;
28  currentItemIndex = 0;
29  }
30 
47  inline void WUpdate(const MatType& V,
48  arma::mat& W,
49  const arma::mat& H)
50  {
51  arma::mat deltaW(1, W.n_cols);
52  deltaW.zeros();
53  while(true)
54  {
55  double val;
56  if((val = V(currentItemIndex, currentUserIndex)) != 0)
57  {
58  deltaW += (val - arma::dot(W.row(currentItemIndex), H.col(currentUserIndex)))
59  * arma::trans(H.col(currentUserIndex));
60  if(kw != 0) deltaW -= kw * W.row(currentItemIndex);
61  break;
62  }
64  if(currentUserIndex == n)
65  {
66  currentUserIndex = 0;
68  }
69  }
70 
71  W.row(currentItemIndex) += u*deltaW;
72  }
73 
83  inline void HUpdate(const MatType& V,
84  const arma::mat& W,
85  arma::mat& H)
86  {
87  arma::mat deltaH(H.n_rows, 1);
88  deltaH.zeros();
89 
90  while(true)
91  {
92  double val;
93  if((val = V(currentItemIndex, currentUserIndex)) != 0)
94  deltaH += (val - arma::dot(W.row(currentItemIndex), H.col(currentUserIndex)))
95  * arma::trans(W.row(currentItemIndex));
96  if(kh != 0) deltaH -= kh * H.col(currentUserIndex);
97 
99  if(currentUserIndex == n)
100  {
101  currentUserIndex = 0;
103  }
104  }
105 
106  H.col(currentUserIndex++) += u * deltaH;
107  }
108 
109  private:
110  double u;
111  double kw;
112  double kh;
113 
114  size_t n;
115  size_t m;
116 
119 };
120 
121 template<>
123 {
124  public:
126  double kw = 0,
127  double kh = 0)
128  : u(u), kw(kw), kh(kh), it(NULL)
129  {}
130 
132  {
133  delete it;
134  }
135 
136  void Initialize(const arma::sp_mat& dataset, const size_t rank)
137  {
138  (void)rank;
139  n = dataset.n_rows;
140  m = dataset.n_cols;
141 
142  it = new arma::sp_mat::const_iterator(dataset.begin());
143  isStart = true;
144  }
145 
155  inline void WUpdate(const arma::sp_mat& V,
156  arma::mat& W,
157  const arma::mat& H)
158  {
159  if(!isStart) (*it)++;
160  else isStart = false;
161 
162  if(*it == V.end())
163  {
164  delete it;
165  it = new arma::sp_mat::const_iterator(V.begin());
166  }
167 
168  size_t currentUserIndex = it->col();
169  size_t currentItemIndex = it->row();
170 
171  arma::mat deltaW(1, W.n_cols);
172  deltaW.zeros();
173 
174  deltaW += (**it - arma::dot(W.row(currentItemIndex), H.col(currentUserIndex)))
175  * arma::trans(H.col(currentUserIndex));
176  if(kw != 0) deltaW -= kw * W.row(currentItemIndex);
177 
178  W.row(currentItemIndex) += u*deltaW;
179  }
180 
190  inline void HUpdate(const arma::sp_mat& V,
191  const arma::mat& W,
192  arma::mat& H)
193  {
194  (void)V;
195 
196  arma::mat deltaH(H.n_rows, 1);
197  deltaH.zeros();
198 
199  size_t currentUserIndex = it->col();
200  size_t currentItemIndex = it->row();
201 
202  deltaH += (**it - arma::dot(W.row(currentItemIndex), H.col(currentUserIndex)))
203  * arma::trans(W.row(currentItemIndex));
204  if(kh != 0) deltaH -= kh * H.col(currentUserIndex);
205 
206  H.col(currentUserIndex++) += u * deltaH;
207  }
208 
209  private:
210  double u;
211  double kw;
212  double kh;
213 
214  size_t n;
215  size_t m;
216 
217  arma::sp_mat dummy;
218  arma::sp_mat::const_iterator* it;
219 
220  bool isStart;
221 };
222 
223 }
224 }
225 
226 
227 #endif // SVD_COMPLETE_INCREMENTAL_LEARNING_HPP_INCLUDED
228 
void HUpdate(const arma::sp_mat &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: load.hpp:23
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
void WUpdate(const arma::sp_mat &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
SVDCompleteIncrementalLearning(double u=0.0001, double kw=0, double kh=0)
void Initialize(const MatType &dataset, const size_t rank)
void Initialize(const arma::sp_mat &dataset, const size_t rank)