mlpack  1.0.12
svd_incomplete_incremental_learning.hpp
Go to the documentation of this file.
1 #ifndef SVD_INCREMENTAL_LEARNING_HPP_INCLUDED
2 #define SVD_INCREMENTAL_LEARNING_HPP_INCLUDED
3 
4 namespace mlpack
5 {
6 namespace amf
7 {
9 {
10  public:
12  double kw = 0,
13  double kh = 0)
14  : u(u), kw(kw), kh(kh)
15  {}
16 
17  template<typename MatType>
18  void Initialize(const MatType& dataset, const size_t rank)
19  {
20  (void)rank;
21 
22  n = dataset.n_rows;
23  m = dataset.n_cols;
24 
25  currentUserIndex = 0;
26  }
27 
44  template<typename MatType>
45  inline void WUpdate(const MatType& V,
46  arma::mat& W,
47  const arma::mat& H)
48  {
49  arma::mat deltaW(n, W.n_cols);
50  deltaW.zeros();
51  for(size_t i = 0;i < n;i++)
52  {
53  double val;
54  if((val = V(i, currentUserIndex)) != 0)
55  deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
56  arma::trans(H.col(currentUserIndex));
57  if(kw != 0) deltaW.row(i) -= kw * W.row(i);
58  }
59 
60  W += u*deltaW;
61  }
62 
72  template<typename MatType>
73  inline void HUpdate(const MatType& V,
74  const arma::mat& W,
75  arma::mat& H)
76  {
77  arma::mat deltaH(H.n_rows, 1);
78  deltaH.zeros();
79 
80  for(size_t i = 0;i < n;i++)
81  {
82  double val;
83  if((val = V(i, currentUserIndex)) != 0)
84  deltaH += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
85  arma::trans(W.row(i));
86  }
87  if(kh != 0) deltaH -= kh * H.col(currentUserIndex);
88 
89  H.col(currentUserIndex++) += u * deltaH;
91  }
92 
93  private:
94  double u;
95  double kw;
96  double kh;
97 
98  size_t n;
99  size_t m;
100 
102 };
103 
104 template<>
105 inline void SVDIncompleteIncrementalLearning::
106  WUpdate<arma::sp_mat>(const arma::sp_mat& V,
107  arma::mat& W,
108  const arma::mat& H)
109 {
110  arma::mat deltaW(n, W.n_cols);
111  deltaW.zeros();
112  for(arma::sp_mat::const_iterator it = V.begin_col(currentUserIndex);
113  it != V.end_col(currentUserIndex);it++)
114  {
115  double val = *it;
116  size_t i = it.row();
117  deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
118  arma::trans(H.col(currentUserIndex));
119  if(kw != 0) deltaW.row(i) -= kw * W.row(i);
120  }
121 
122  W += u*deltaW;
123 }
124 
125 template<>
126 inline void SVDIncompleteIncrementalLearning::
127  HUpdate<arma::sp_mat>(const arma::sp_mat& V,
128  const arma::mat& W,
129  arma::mat& H)
130 {
131  arma::mat deltaH(H.n_rows, 1);
132  deltaH.zeros();
133 
134  for(arma::sp_mat::const_iterator it = V.begin_col(currentUserIndex);
135  it != V.end_col(currentUserIndex);it++)
136  {
137  double val = *it;
138  size_t i = it.row();
139  if((val = V(i, currentUserIndex)) != 0)
140  deltaH += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
141  arma::trans(W.row(i));
142  }
143  if(kh != 0) deltaH -= kh * H.col(currentUserIndex);
144 
145  H.col(currentUserIndex++) += u * deltaH;
146  currentUserIndex = currentUserIndex % m;
147 }
148 
149 }; // namepsace amf
150 }; // namespace mlpack
151 
152 
153 #endif // SVD_INCREMENTAL_LEARNING_HPP_INCLUDED
154 
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: load.hpp:23
void Initialize(const MatType &dataset, const size_t rank)
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
SVDIncompleteIncrementalLearning(double u=0.001, double kw=0, double kh=0)