ViennaCL - The Vienna Computing Library  1.5.2
matrix_solve.hpp
Go to the documentation of this file.
1 #ifndef VIENNACL_LINALG_OPENCL_KERNELS_MATRIX_SOLVE_HPP
2 #define VIENNACL_LINALG_OPENCL_KERNELS_MATRIX_SOLVE_HPP
3 
7 #include "viennacl/ocl/utils.hpp"
8 
10 
13 namespace viennacl
14 {
15  namespace linalg
16  {
17  namespace opencl
18  {
19  namespace kernels
20  {
21 
22  template <typename StringType>
23  void generate_matrix_solve_blas3(StringType & source, std::string const & numeric_string,
24  bool row_major_A, bool row_major_B,
25  bool transpose_A, bool transpose_B,
26  bool upper_solve, bool unit_diagonal)
27  {
28  //start OpenCL code:
29  source.append("__kernel void ");
30  if (transpose_A)
31  source.append("trans_");
32  if (unit_diagonal)
33  source.append("unit_");
34  if (upper_solve)
35  source.append("upper_");
36  else
37  source.append("lower_");
38  if (transpose_B)
39  source.append("trans_");
40  source.append("solve");
41 
42  source.append("( \n");
43  source.append(" __global const "); source.append(numeric_string); source.append(" * A, \n");
44  source.append(" unsigned int A_start1, unsigned int A_start2, \n");
45  source.append(" unsigned int A_inc1, unsigned int A_inc2, \n");
46  source.append(" unsigned int A_size1, unsigned int A_size2, \n");
47  source.append(" unsigned int A_internal_size1, unsigned int A_internal_size2, \n");
48  source.append(" __global "); source.append(numeric_string); source.append(" * B, \n");
49  source.append(" unsigned int B_start1, unsigned int B_start2, \n");
50  source.append(" unsigned int B_inc1, unsigned int B_inc2, \n");
51  source.append(" unsigned int B_size1, unsigned int B_size2, \n");
52  source.append(" unsigned int B_internal_size1, unsigned int B_internal_size2) { \n");
53  source.append(" "); source.append(numeric_string); source.append(" temp; \n");
54  if (upper_solve)
55  {
56  //Note: A is square, thus A_rows == A_cols and no dispatch for transposedness needed
57  source.append(" for (unsigned int row_cnt = 0; row_cnt < A_size1; ++row_cnt) \n");
58  source.append(" { \n");
59  source.append(" unsigned int row = A_size1 - 1 - row_cnt; \n");
60  }
61  else //lower triangular solve
62  {
63  source.append(" for (unsigned int row = 0; row < A_size1; ++row) \n");
64  source.append(" { \n");
65  }
66 
67  if (!unit_diagonal)
68  {
69  source.append(" barrier(CLK_GLOBAL_MEM_FENCE); \n");
70  source.append(" if (get_local_id(0) == 0) \n");
71  //Note: A is square, thus A_internal_rows == A_internal_cols and no dispatch for transposedness needed
72  if (row_major_B && transpose_B)
73  source.append(" B[(get_group_id(0) * B_inc1 + B_start1) * B_internal_size2 + (row * B_inc2 + B_start2)] /= ");
74  else if (row_major_B && !transpose_B)
75  source.append(" B[(row * B_inc1 + B_start1) * B_internal_size2 + (get_group_id(0) * B_inc2 + B_start2)] /= ");
76  else if (!row_major_B && transpose_B)
77  source.append(" B[(get_group_id(0) * B_inc1 + B_start1) + (row * B_inc2 + B_start2) * B_internal_size1] /= ");
78  else if (!row_major_B && !transpose_B)
79  source.append(" B[(row * B_inc1 + B_start1) + (get_group_id(0) * B_inc2 + B_start2) * B_internal_size1] /= ");
80 
81  if (row_major_A)
82  source.append("A[(row * A_inc1 + A_start1) * A_internal_size2 + (row * A_inc2 + A_start2)]; \n");
83  else
84  source.append("A[(row * A_inc1 + A_start1) + (row * A_inc2 + A_start2)*A_internal_size1]; \n");
85  }
86 
87  source.append(" barrier(CLK_GLOBAL_MEM_FENCE); \n");
88 
89  if (row_major_B && transpose_B)
90  source.append(" temp = B[(get_group_id(0) * B_inc1 + B_start1) * B_internal_size2 + (row * B_inc2 + B_start2)]; \n");
91  else if (row_major_B && !transpose_B)
92  source.append(" temp = B[(row * B_inc1 + B_start1) * B_internal_size2 + (get_group_id(0) * B_inc2 + B_start2)]; \n");
93  else if (!row_major_B && transpose_B)
94  source.append(" temp = B[(get_group_id(0) * B_inc1 + B_start1) + (row * B_inc2 + B_start2) * B_internal_size1]; \n");
95  else if (!row_major_B && !transpose_B)
96  source.append(" temp = B[(row * B_inc1 + B_start1) + (get_group_id(0) * B_inc2 + B_start2) * B_internal_size1]; \n");
97 
98  source.append(" //eliminate column of op(A) with index 'row' in parallel: \n");
99  if (upper_solve)
100  source.append(" for (unsigned int elim = get_local_id(0); elim < row; elim += get_local_size(0)) \n");
101  else
102  source.append(" for (unsigned int elim = row + get_local_id(0) + 1; elim < A_size1; elim += get_local_size(0)) \n");
103 
104  if (row_major_B && transpose_B)
105  source.append(" B[(get_group_id(0) * B_inc1 + B_start1) * B_internal_size2 + (elim * B_inc2 + B_start2)] -= temp * ");
106  else if (row_major_B && !transpose_B)
107  source.append(" B[(elim * B_inc1 + B_start1) * B_internal_size2 + (get_group_id(0) * B_inc2 + B_start2)] -= temp * ");
108  else if (!row_major_B && transpose_B)
109  source.append(" B[(get_group_id(0) * B_inc1 + B_start1) + (elim * B_inc2 + B_start2) * B_internal_size1] -= temp * ");
110  else if (!row_major_B && !transpose_B)
111  source.append(" B[(elim * B_inc1 + B_start1) + (get_group_id(0) * B_inc2 + B_start2) * B_internal_size1] -= temp * ");
112 
113  if (row_major_A && transpose_A)
114  source.append("A[(row * A_inc1 + A_start1) * A_internal_size2 + (elim * A_inc2 + A_start2)]; \n");
115  else if (row_major_A && !transpose_A)
116  source.append("A[(elim * A_inc1 + A_start1) * A_internal_size2 + (row * A_inc2 + A_start2)]; \n");
117  else if (!row_major_A && transpose_A)
118  source.append("A[(row * A_inc1 + A_start1) + (elim * A_inc2 + A_start2) * A_internal_size1]; \n");
119  else if (!row_major_A && !transpose_A)
120  source.append("A[(elim * A_inc1 + A_start1) + (row * A_inc2 + A_start2) * A_internal_size1]; \n");
121 
122  source.append(" } \n");
123  source.append("} \n");
124  }
125 
126 
127  // main kernel class
133  template <class NumericT, typename F1, typename F2>
135  {
136  static std::string program_name()
137  {
139  }
140 
141  static void init(viennacl::ocl::context & ctx)
142  {
144  std::string numeric_string = viennacl::ocl::type_to_string<NumericT>::apply();
145  bool matrix_row_major = viennacl::is_row_major<F1>::value;
146  bool rhs_row_major = viennacl::is_row_major<F2>::value;
147 
148 
149  static std::map<cl_context, bool> init_done;
150  if (!init_done[ctx.handle().get()])
151  {
152  std::string source;
153  source.reserve(8192);
154 
155  viennacl::ocl::append_double_precision_pragma<NumericT>(ctx, source);
156 
157  // only generate for floating points (forces error for integers)
158  if (numeric_string == "float" || numeric_string == "double")
159  {
160  generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
161  false, false, false, false);
162  generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
163  false, false, false, true);
164  generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
165  false, false, true, false);
166  generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
167  false, false, true, true);
168 
169  generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
170  false, true, false, false);
171  generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
172  false, true, false, true);
173  generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
174  false, true, true, false);
175  generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
176  false, true, true, true);
177 
178  generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
179  true, false, false, false);
180  generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
181  true, false, false, true);
182  generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
183  true, false, true, false);
184  generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
185  true, false, true, true);
186 
187  generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
188  true, true, false, false);
189  generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
190  true, true, false, true);
191  generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
192  true, true, true, false);
193  generate_matrix_solve_blas3(source, numeric_string, matrix_row_major, rhs_row_major,
194  true, true, true, true);
195  }
196 
197  std::string prog_name = program_name();
198  #ifdef VIENNACL_BUILD_INFO
199  std::cout << "Creating program " << prog_name << std::endl;
200  #endif
201  ctx.add_program(source, prog_name);
202  init_done[ctx.handle().get()] = true;
203  } //if
204  } //init
205  };
206 
207  } // namespace kernels
208  } // namespace opencl
209  } // namespace linalg
210 } // namespace viennacl
211 #endif
212 
Implements a OpenCL platform within ViennaCL.
Helper class for checking whether a matrix has a row-major layout.
Definition: forwards.h:399
Various little tools used here and there in ViennaCL.
Manages an OpenCL context and provides the respective convenience functions for creating buffers...
Definition: context.hpp:51
Provides OpenCL-related utilities.
const viennacl::ocl::handle< cl_context > & handle() const
Returns the context handle.
Definition: context.hpp:476
void generate_matrix_solve_blas3(StringType &source, std::string const &numeric_string, bool row_major_A, bool row_major_B, bool transpose_A, bool transpose_B, bool upper_solve, bool unit_diagonal)
Definition: matrix_solve.hpp:23
const OCL_TYPE & get() const
Definition: handle.hpp:189
Main kernel class for the generation of matrix solve kernels.
Definition: matrix_solve.hpp:134
static void apply(viennacl::ocl::context const &)
Definition: utils.hpp:40
Representation of an OpenCL kernel in ViennaCL.
std::string type_to_string(viennacl::row_major)
Definition: matrix.hpp:868
static void init(viennacl::ocl::context &ctx)
Definition: matrix_solve.hpp:141
static std::string program_name()
Definition: matrix_solve.hpp:136
Helper class for converting a type to its string representation.
Definition: utils.hpp:57
Runtime generation of OpenCL kernels for matrix operations.