ViennaCL - The Vienna Computing Library  1.5.2
execute_elementwise.hpp
Go to the documentation of this file.
1 #ifndef VIENNACL_SCHEDULER_EXECUTE_ELEMENTWISE_HPP
2 #define VIENNACL_SCHEDULER_EXECUTE_ELEMENTWISE_HPP
3 
4 /* =========================================================================
5  Copyright (c) 2010-2014, Institute for Microelectronics,
6  Institute for Analysis and Scientific Computing,
7  TU Wien.
8  Portions of this software are copyright by UChicago Argonne, LLC.
9 
10  -----------------
11  ViennaCL - The Vienna Computing Library
12  -----------------
13 
14  Project Head: Karl Rupp rupp@iue.tuwien.ac.at
15 
16  (A list of authors and contributors can be found in the PDF manual)
17 
18  License: MIT (X11), see file LICENSE in the base directory
19 ============================================================================= */
20 
21 
26 #include "viennacl/forwards.h"
31 
32 namespace viennacl
33 {
34  namespace scheduler
35  {
36  namespace detail
37  {
38  // result = element_op(x,y) for vectors or matrices x, y
39  inline void element_op(lhs_rhs_element result,
40  lhs_rhs_element const & x,
41  operation_node_type op_type)
42  {
43  assert( result.numeric_type == x.numeric_type && bool("Numeric type not the same!"));
44  assert( result.type_family == x.type_family && bool("Subtype not the same!"));
45 
46  if (x.subtype == DENSE_VECTOR_TYPE)
47  {
48  assert( result.subtype == x.subtype && bool("result not of vector type for unary elementwise operation"));
49  if (x.numeric_type == FLOAT_TYPE)
50  {
51  switch (op_type)
52  {
53 #define VIENNACL_SCHEDULER_GENERATE_UNARY_ELEMENT_OP(OPNAME, SCALARTYPE, OPTAG) \
54  case OPNAME: viennacl::linalg::element_op(*result.vector_##SCALARTYPE, \
55  viennacl::vector_expression<const vector_base<SCALARTYPE>, const vector_base<SCALARTYPE>, \
56  op_element_unary<OPTAG> >(*x.vector_##SCALARTYPE, *x.vector_##SCALARTYPE)); break;
57 
75  default:
76  throw statement_not_supported_exception("Invalid op_type in unary elementwise operations");
77  }
78  }
79  else if (x.numeric_type == DOUBLE_TYPE)
80  {
81  switch (op_type)
82  {
100 
101 #undef VIENNACL_SCHEDULER_GENERATE_UNARY_ELEMENT_OP
102  default:
103  throw statement_not_supported_exception("Invalid op_type in unary elementwise operations");
104  }
105  }
106  else
107  throw statement_not_supported_exception("Invalid numeric type in unary elementwise operator");
108  }
109  else if (x.subtype == DENSE_ROW_MATRIX_TYPE)
110  {
111  if (x.numeric_type == FLOAT_TYPE)
112  {
113  switch (op_type)
114  {
115 #define VIENNACL_SCHEDULER_GENERATE_UNARY_ELEMENT_OP(OPNAME, SCALARTYPE, OPTAG) \
116  case OPNAME: viennacl::linalg::element_op(*result.matrix_row_##SCALARTYPE, \
117  viennacl::matrix_expression<const matrix_base<SCALARTYPE, viennacl::row_major>, const matrix_base<SCALARTYPE, viennacl::row_major>, \
118  op_element_unary<OPTAG> >(*x.matrix_row_##SCALARTYPE, *x.matrix_row_##SCALARTYPE)); break;
119 
137  default:
138  throw statement_not_supported_exception("Invalid op_type in unary elementwise operations");
139  }
140 
141  }
142  else if (x.numeric_type == DOUBLE_TYPE)
143  {
144  switch (op_type)
145  {
163  default:
164  throw statement_not_supported_exception("Invalid op_type in unary elementwise operations");
165  }
166  }
167  else
168  throw statement_not_supported_exception("Invalid numeric type in unary elementwise operator");
169 
170 #undef VIENNACL_SCHEDULER_GENERATE_UNARY_ELEMENT_OP
171 
172  }
173  else if (x.subtype == DENSE_COL_MATRIX_TYPE)
174  {
175  if (x.numeric_type == FLOAT_TYPE)
176  {
177  switch (op_type)
178  {
179 #define VIENNACL_SCHEDULER_GENERATE_UNARY_ELEMENT_OP(OPNAME, SCALARTYPE, OPTAG) \
180  case OPNAME: viennacl::linalg::element_op(*result.matrix_col_##SCALARTYPE, \
181  viennacl::matrix_expression<const matrix_base<SCALARTYPE, viennacl::column_major>, const matrix_base<SCALARTYPE, viennacl::column_major>, \
182  op_element_unary<OPTAG> >(*x.matrix_col_##SCALARTYPE, *x.matrix_col_##SCALARTYPE)); break;
183 
201  default:
202  throw statement_not_supported_exception("Invalid op_type in unary elementwise operations");
203  }
204 
205  }
206  else if (x.numeric_type == DOUBLE_TYPE)
207  {
208  switch (op_type)
209  {
227  default:
228  throw statement_not_supported_exception("Invalid op_type in unary elementwise operations");
229  }
230  }
231  else
232  throw statement_not_supported_exception("Invalid numeric type in unary elementwise operator");
233 
234 #undef VIENNACL_SCHEDULER_GENERATE_UNARY_ELEMENT_OP
235  }
236  }
237 
238  // result = element_op(x,y) for vectors or matrices x, y
239  inline void element_op(lhs_rhs_element result,
240  lhs_rhs_element const & x,
241  lhs_rhs_element const & y,
242  operation_node_type op_type)
243  {
244  assert( x.numeric_type == y.numeric_type && bool("Numeric type not the same!"));
245  assert( result.numeric_type == y.numeric_type && bool("Numeric type not the same!"));
246 
247  assert( x.type_family == y.type_family && bool("Subtype not the same!"));
248  assert( result.type_family == y.type_family && bool("Subtype not the same!"));
249 
250  switch (op_type)
251  {
252 
254  if (x.subtype == DENSE_VECTOR_TYPE)
255  {
256  switch (x.numeric_type)
257  {
258  case FLOAT_TYPE:
261  const vector_base<float>,
263  break;
264  case DOUBLE_TYPE:
267  const vector_base<double>,
269  break;
270  default:
271  throw statement_not_supported_exception("Invalid numeric type for binary elementwise division");
272  }
273  }
274  else if (x.subtype == DENSE_ROW_MATRIX_TYPE)
275  {
276  switch (x.numeric_type)
277  {
278  case FLOAT_TYPE:
283  break;
284  case DOUBLE_TYPE:
289  break;
290  default:
291  throw statement_not_supported_exception("Invalid numeric type for binary elementwise division");
292  }
293  }
294  else if (x.subtype == DENSE_COL_MATRIX_TYPE)
295  {
296  switch (x.numeric_type)
297  {
298  case FLOAT_TYPE:
303  break;
304  case DOUBLE_TYPE:
309  break;
310  default:
311  throw statement_not_supported_exception("Invalid numeric type for binary elementwise division");
312  }
313  }
314  else
315  throw statement_not_supported_exception("Invalid operand type for binary elementwise division");
316  break;
317 
318 
320  if (x.subtype == DENSE_VECTOR_TYPE)
321  {
322  switch (x.numeric_type)
323  {
324  case FLOAT_TYPE:
327  const vector_base<float>,
329  break;
330  case DOUBLE_TYPE:
333  const vector_base<double>,
335  break;
336  default:
337  throw statement_not_supported_exception("Invalid numeric type for binary elementwise division");
338  }
339  }
340  else if (x.subtype == DENSE_ROW_MATRIX_TYPE)
341  {
342  switch (x.numeric_type)
343  {
344  case FLOAT_TYPE:
349  break;
350  case DOUBLE_TYPE:
355  break;
356  default:
357  throw statement_not_supported_exception("Invalid numeric type for binary elementwise division");
358  }
359  }
360  else if (x.subtype == DENSE_COL_MATRIX_TYPE)
361  {
362  switch (x.numeric_type)
363  {
364  case FLOAT_TYPE:
369  break;
370  case DOUBLE_TYPE:
375  break;
376  default:
377  throw statement_not_supported_exception("Invalid numeric type for binary elementwise division");
378  }
379  }
380  else
381  throw statement_not_supported_exception("Invalid operand type for binary elementwise division");
382  break;
383  default:
384  throw statement_not_supported_exception("Invalid operation type for binary elementwise operations");
385  }
386  }
387  }
388 
390  inline void execute_element_composite(statement const & s, statement_node const & root_node)
391  {
392  statement_node const & leaf = s.array()[root_node.rhs.node_index];
393 
394  statement_node new_root_lhs;
395  statement_node new_root_rhs;
396 
397  // check for temporary on lhs:
399  {
400  detail::new_element(new_root_lhs.lhs, root_node.lhs);
401 
403  new_root_lhs.op.type = OPERATION_BINARY_ASSIGN_TYPE;
404 
406  new_root_lhs.rhs.subtype = INVALID_SUBTYPE;
407  new_root_lhs.rhs.numeric_type = INVALID_NUMERIC_TYPE;
408  new_root_lhs.rhs.node_index = leaf.lhs.node_index;
409 
410  // work on subexpression:
411  // TODO: Catch exception, free temporary, then rethrow
412  detail::execute_composite(s, new_root_lhs);
413  }
414 
416  {
417  // check for temporary on rhs:
419  {
420  detail::new_element(new_root_rhs.lhs, root_node.lhs);
421 
423  new_root_rhs.op.type = OPERATION_BINARY_ASSIGN_TYPE;
424 
426  new_root_rhs.rhs.subtype = INVALID_SUBTYPE;
427  new_root_rhs.rhs.numeric_type = INVALID_NUMERIC_TYPE;
428  new_root_rhs.rhs.node_index = leaf.rhs.node_index;
429 
430  // work on subexpression:
431  // TODO: Catch exception, free temporary, then rethrow
432  detail::execute_composite(s, new_root_rhs);
433  }
434 
435  lhs_rhs_element x = (leaf.lhs.type_family == COMPOSITE_OPERATION_FAMILY) ? new_root_lhs.lhs : leaf.lhs;
436  lhs_rhs_element y = (leaf.rhs.type_family == COMPOSITE_OPERATION_FAMILY) ? new_root_rhs.lhs : leaf.rhs;
437 
438  // compute element-wise operation:
439  detail::element_op(root_node.lhs, x, y, leaf.op.type);
440 
442  detail::delete_element(new_root_rhs.lhs);
443  }
444  else if (leaf.op.type_family == OPERATION_UNARY_TYPE_FAMILY)
445  {
446  lhs_rhs_element x = (leaf.lhs.type_family == COMPOSITE_OPERATION_FAMILY) ? new_root_lhs.lhs : leaf.lhs;
447 
448  // compute element-wise operation:
449  detail::element_op(root_node.lhs, x, leaf.op.type);
450  }
451  else
452  throw statement_not_supported_exception("Unsupported elementwise operation.");
453 
454  // clean up:
456  detail::delete_element(new_root_lhs.lhs);
457 
458  }
459 
460 
461  } // namespace scheduler
462 
463 } // namespace viennacl
464 
465 #endif
466 
statement_node_subtype subtype
Definition: forwards.h:270
viennacl::matrix_base< float > * matrix_row_float
Definition: forwards.h:339
A tag class representing the cosh() function.
Definition: forwards.h:107
A tag class representing the tan() function.
Definition: forwards.h:133
Definition: forwards.h:182
Implementations of dense matrix related operations including matrix-vector products.
vcl_size_t node_index
Definition: forwards.h:276
void new_element(lhs_rhs_element &new_elem, lhs_rhs_element const &old_element)
Definition: execute_util.hpp:102
Implementations of vector operations.
lhs_rhs_element lhs
Definition: forwards.h:422
Definition: forwards.h:217
A dense matrix class.
Definition: forwards.h:290
A tag class representing the modulus function for integers.
Definition: forwards.h:93
Expression template class for representing a tree of expressions which ultimately result in a matrix...
Definition: forwards.h:283
A tag class representing the ceil() function.
Definition: forwards.h:103
This file provides the forward declarations for the main types used within ViennaCL.
A class representing the 'data' for the LHS or RHS operand of the respective node.
Definition: forwards.h:267
operation_node_type_family type_family
Definition: forwards.h:415
void execute_element_composite(statement const &s, statement_node const &root_node)
Deals with x = RHS where RHS is a vector expression.
Definition: execute_elementwise.hpp:390
An expression template class that represents a binary operation that yields a vector.
Definition: forwards.h:181
A tag class representing the log() function.
Definition: forwards.h:123
lhs_rhs_element rhs
Definition: forwards.h:424
A tag class representing the tanh() function.
Definition: forwards.h:135
A tag class representing the fabs() function.
Definition: forwards.h:111
viennacl::matrix_base< float, viennacl::column_major > * matrix_col_float
Definition: forwards.h:351
void delete_element(lhs_rhs_element &elem)
Definition: execute_util.hpp:179
A tag class representing the atan() function.
Definition: forwards.h:99
A tag class representing the sinh() function.
Definition: forwards.h:129
viennacl::matrix_base< double, viennacl::column_major > * matrix_col_double
Definition: forwards.h:352
viennacl::vector_base< float > * vector_float
Definition: forwards.h:315
statement_node_numeric_type numeric_type
Definition: forwards.h:271
A tag class representing the exp() function.
Definition: forwards.h:109
viennacl::vector_base< double > * vector_double
Definition: forwards.h:316
#define VIENNACL_SCHEDULER_GENERATE_UNARY_ELEMENT_OP(OPNAME, SCALARTYPE, OPTAG)
A tag class representing the sqrt() function.
Definition: forwards.h:131
void element_op(matrix_base< T, F > &A, matrix_expression< const matrix_base< T, F >, const matrix_base< T, F >, OP > const &proxy)
Implementation of the element-wise operation A = B .* C and A = B ./ C for matrices (using MATLAB syn...
Definition: matrix_operations.hpp:598
Provides the datastructures for dealing with a single statement such as 'x = y + z;'.
operation_node_type
Enumeration for identifying the possible operations.
Definition: forwards.h:61
A tag class representing the sin() function.
Definition: forwards.h:127
container_type const & array() const
Definition: forwards.h:473
void execute_composite(statement const &s, statement_node const &root_node)
Deals with x = RHS where RHS is an expression and x is either a scalar, a vector, or a matrix...
Definition: execute.hpp:41
A tag class representing the floor() function.
Definition: forwards.h:115
A tag class representing the asin() function.
Definition: forwards.h:97
viennacl::matrix_base< double > * matrix_row_double
Definition: forwards.h:340
A tag class representing element-wise binary operations (like multiplication) on vectors or matrices...
Definition: forwards.h:86
statement_node_type_family type_family
Definition: forwards.h:269
The main class for representing a statement such as x = inner_prod(y,z); at runtime.
Definition: forwards.h:447
A tag class representing the acos() function.
Definition: forwards.h:95
Definition: forwards.h:187
A tag class representing the log10() function.
Definition: forwards.h:125
op_element op
Definition: forwards.h:423
void element_op(lhs_rhs_element result, lhs_rhs_element const &x, operation_node_type op_type)
Definition: execute_elementwise.hpp:39
Definition: forwards.h:216
Provides various utilities for implementing the execution of statements.
A tag class representing the cos() function.
Definition: forwards.h:105
Main datastructure for an node in the statement tree.
Definition: forwards.h:420
Exception for the case the scheduler is unable to deal with the operation.
Definition: forwards.h:36
operation_node_type type
Definition: forwards.h:416