mlpack  1.0.12
validation_RMSE_termination.hpp
Go to the documentation of this file.
1 
14 #ifndef VALIDATION_RMSE_TERMINATION_HPP_INCLUDED
15 #define VALIDATION_RMSE_TERMINATION_HPP_INCLUDED
16 
17 #include <mlpack/core.hpp>
18 
19 namespace mlpack
20 {
21 namespace amf
22 {
23 template <class MatType>
25 {
26  public:
28  size_t num_test_points,
29  double tolerance = 1e-5,
30  size_t maxIterations = 10000,
31  size_t reverseStepTolerance = 3)
34  num_test_points(num_test_points),
36  {
37  size_t n = V.n_rows;
38  size_t m = V.n_cols;
39 
40  test_points.zeros(num_test_points, 3);
41 
42  for(size_t i = 0; i < num_test_points; i++)
43  {
44  double t_val;
45  size_t t_row;
46  size_t t_col;
47  do
48  {
49  t_row = rand() % n;
50  t_col = rand() % m;
51  } while((t_val = V(t_row, t_col)) == 0);
52 
53  test_points(i, 0) = t_row;
54  test_points(i, 1) = t_col;
55  test_points(i, 2) = t_val;
56  V(t_row, t_col) = 0;
57  }
58  }
59 
60  void Initialize(const MatType& /* V */)
61  {
62  iteration = 1;
63 
64  rmse = DBL_MAX;
65  rmseOld = DBL_MAX;
66 
67  c_index = 0;
68  c_indexOld = 0;
69 
70  reverseStepCount = 0;
71  isCopy = false;
72  }
73 
74  bool IsConverged(arma::mat& W, arma::mat& H)
75  {
76  // Calculate norm of WH after each iteration.
77  arma::mat WH;
78 
79  WH = W * H;
80 
81  if (iteration != 0)
82  {
83  rmseOld = rmse;
84  rmse = 0;
85  for(size_t i = 0; i < num_test_points; i++)
86  {
87  size_t t_row = test_points(i, 0);
88  size_t t_col = test_points(i, 1);
89  double t_val = test_points(i, 2);
90  double temp = (t_val - WH(t_row, t_col));
91  temp *= temp;
92  rmse += temp;
93  }
95  rmse = sqrt(rmse);
96  }
97 
98  iteration++;
99 
100  if((rmseOld - rmse) / rmseOld < tolerance && iteration > 4)
101  {
102  if(reverseStepCount == 0 && isCopy == false)
103  {
104  isCopy = true;
105  this->W = W;
106  this->H = H;
108  c_index = rmse;
109  }
111  }
112  else
113  {
114  reverseStepCount = 0;
115  if(rmse <= c_indexOld && isCopy == true)
116  {
117  isCopy = false;
118  }
119  }
120 
122  {
123  if(isCopy)
124  {
125  W = this->W;
126  H = this->H;
127  rmse = c_index;
128  }
129  return true;
130  }
131  else return false;
132  }
133 
134  const double& Index() { return rmse; }
135 
136  const size_t& Iteration() { return iteration; }
137 
138  const size_t& MaxIterations() { return maxIterations; }
139 
140  private:
141  double tolerance;
144  size_t iteration;
145 
146  arma::Mat<double> test_points;
147 
148  double rmseOld;
149  double rmse;
150 
153 
154  bool isCopy;
155  arma::mat W;
156  arma::mat H;
157  double c_indexOld;
158  double c_index;
159 };
160 
161 } // namespace amf
162 } // namespace mlpack
163 
164 
165 #endif // VALIDATION_RMSE_TERMINATION_HPP_INCLUDED
166 
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: load.hpp:23
ValidationRMSETermination(MatType &V, size_t num_test_points, double tolerance=1e-5, size_t maxIterations=10000, size_t reverseStepTolerance=3)