1 #ifndef VIENNACL_GENERATOR_GENERATE_MATRIX_PRODUCT_HPP
2 #define VIENNACL_GENERATOR_GENERATE_MATRIX_PRODUCT_HPP
56 lmem_used += (ml_ + 1) * (cache_width_ + 1) * scalartype_size;
58 lmem_used += (cache_width_ + 1) * (nl_ + 1) * scalartype_size;
62 virtual void print(std::ostream & s)
const{
63 s <<
"{vector_type, local_size1, cache_width, local_size2, ms, ks, ns, use_lhs_shared, use_rhs_shared} = {"
65 << local_size1_ <<
", "
66 << cache_width_ <<
", "
67 << local_size2_ <<
", "
71 << use_lhs_shared_ <<
", " << use_rhs_shared_ <<
"}" ;
76 static const unsigned int alignment = 128;
77 return ml_ > alignment
78 || cache_width_ > alignment
92 ,
unsigned int ms,
unsigned int ks,
unsigned int ns
93 ,
bool use_lhs_shared,
bool use_rhs_shared) :
profile_base(vectorization,local_size1, local_size2,1){
94 local_size1_ = local_size1;
95 local_size2_ = local_size2;
96 cache_width_=cache_width;
102 use_lhs_shared_ = use_lhs_shared;
103 use_rhs_shared_ = use_rhs_shared;
107 return "Vec,LSize1,CacheWidth,LSize2,mS,kS,nS,NumGroups";
111 std::ostringstream oss;
113 <<
"," << local_size1_
114 <<
"," << cache_width_
115 <<
"," << local_size2_
119 <<
"," << use_lhs_shared_
120 <<
"," << use_rhs_shared_;
137 k.
arg(n_arg++, cl_uint(M));
138 k.
arg(n_arg++, cl_uint(N));
141 for(statements_type::const_iterator it = statements.begin() ; it != statements.end() ; ++it){
143 for(scheduler::statement::container_type::iterator iit = exprs.begin() ; iit != exprs.end() ; ++iit){
162 assert(
false &&
bool(
"unexpected expression tree"));
172 static std::string
size1() {
return "M"; }
173 static std::string
size2() {
return "K"; }
174 static std::string
size3() {
return "N"; }
177 arguments_string += detail::generate_value_kernel_argument(
"unsigned int",
"M");
178 arguments_string += detail::generate_value_kernel_argument(
"unsigned int",
"N");
179 arguments_string += detail::generate_value_kernel_argument(
"unsigned int",
"K");
185 ,
unsigned int & large_block_1,
unsigned int & large_block_2
186 ,
unsigned int & small_block_1,
unsigned int & small_block_2
187 , access_flow flow)
const {
201 std::string helper_variable(utils::kernel_generation_stream & stream
202 ,
bool store_in_register
203 , std::string
const & type
204 , std::string
const & name
205 , std::string
const & expr)
const {
206 if(!store_in_register)
208 stream << type <<
" " << name <<
" = " << expr <<
";" << std::endl;
212 void fetch_element_to_local_mem(utils::kernel_generation_stream & stream,
213 std::string
const & lmem_name,
215 std::string
const & global_ptr,
216 detail::mapped_matrix
const & mat,
218 std::string
const & i,
219 std::string
const & j)
const {
222 stream <<
"val = *(" << global_ptr <<
" + " << j <<
" + " << mat.size2() <<
"*" << i <<
");" << std::endl;
225 stream << lmem_name <<
"[" << i <<
"*" << lmem_size2 <<
" + " << j <<
"*" << vector_size_<<
" + " << a <<
"] = val.s" << a <<
";" <<std::endl;
227 stream << lmem_name <<
"[" << i <<
"*" << lmem_size2 <<
" + " << j <<
"*" << vector_size_ <<
"] = val" <<
";" <<std::endl;
230 stream <<
"val = *(" << global_ptr <<
"+ " << j <<
"*" << mat.size1() <<
" + " << i <<
");" << std::endl;
233 stream << lmem_name <<
"[" << i <<
"*" << vector_size_*lmem_size2 <<
" + " << j <<
" + " << a*lmem_size2 <<
"] = val.s" << a <<
";" <<std::endl;
235 stream << lmem_name <<
"[" << i <<
"*" << vector_size_*lmem_size2 <<
" + " << j <<
"] = val" <<
";" <<std::endl;
238 void fetch_to_local_mem(utils::kernel_generation_stream & stream,
239 std::string
const & lmem_name,
241 std::string
const & global_ptr,
244 detail::mapped_matrix
const & mat,
245 access_flow flow)
const {
246 std::string aligned_scalartype = mat.scalartype();
249 stream <<
"barrier(CLK_LOCAL_MEM_FENCE);" << std::endl;
250 stream <<
"{" << std::endl;
251 stream << aligned_scalartype <<
" val;" << std::endl;
253 if(bound2%local_size2_==0 && bound1%local_size1_==0){
254 for(
unsigned int j = 0 ; j < bound2 ; j+=static_cast<unsigned int>(local_size2_)){
255 for(
unsigned int i = 0 ; i < bound1 ; i+=static_cast<unsigned int>(local_size1_)){
258 fetch_element_to_local_mem(stream,lmem_name,lmem_size2,global_ptr,mat,flow,indi,indj);
263 stream <<
"for(unsigned int j = get_local_id(1)" <<
" ; j < " << bound2 <<
"; j+= " << local_size2_ <<
"){" << std::endl;
265 stream <<
"for(unsigned int i = get_local_id(0)" <<
" ; i < " << bound1 <<
"; i+= " << local_size1_ <<
"){" << std::endl;
267 fetch_element_to_local_mem(stream,lmem_name,lmem_size2,global_ptr,mat,flow,
"i",
"j");
269 stream <<
"}" << std::endl;
271 stream <<
"}" << std::endl;
274 stream <<
"}" << std::endl;
275 stream <<
"barrier(CLK_LOCAL_MEM_FENCE);" << std::endl;
279 void core(
vcl_size_t , utils::kernel_generation_stream& stream,
statements_type const & statements, std::vector<detail::mapping_type>
const & mapping)
const {
285 detail::mapped_matrix
const * assigned =
static_cast<detail::mapped_matrix
const *
>(
at(mapping.at(0), std::make_pair(&statements.front().second,
detail::LHS_NODE_TYPE)).get());
286 detail::mapped_matrix_product*
prod = NULL;
287 detail::mapped_matrix
const * lhs = NULL;
288 detail::mapped_matrix
const * rhs = NULL;
290 bool is_lhs_transposed =
false;
291 bool is_rhs_transposed =
false;
293 for(statements_type::const_iterator it = statements.begin() ; it != statements.end() ; ++it){
295 vcl_size_t i = std::distance(statements.begin(), it);
296 for(scheduler::statement::container_type::const_iterator iit = exprs.begin() ; iit != exprs.end() ; ++iit){
300 is_lhs_transposed =
true;
301 lhs = (detail::mapped_matrix
const *)
at(mapping.at(i), std::make_pair(&exprs[iit->lhs.node_index],
detail::LHS_NODE_TYPE)).
get();
304 is_lhs_transposed =
false;
309 is_rhs_transposed =
true;
310 rhs = (detail::mapped_matrix
const *)
at(mapping.at(i), std::make_pair(&exprs[iit->rhs.node_index],
detail::LHS_NODE_TYPE)).
get();
313 is_rhs_transposed =
false;
324 for(detail::mapping_type::const_iterator it = mapping.front().begin() ; it != mapping.front().end() ; ++it){
325 if(detail::mapped_matrix
const * p = dynamic_cast<detail::mapped_matrix const *>(it->second.get())){
326 if(p->is_row_major())
327 p->bind_sizes(
"M",
"N"+StrV);
329 p->bind_sizes(
"M"+StrV,
"N");
333 if(lhs->is_row_major())
334 if(is_lhs_transposed)
335 lhs->bind_sizes(
"M"+StrV,
"K");
337 lhs->bind_sizes(
"M",
"K"+StrV);
339 if(is_lhs_transposed)
340 lhs->bind_sizes(
"M",
"K"+StrV);
342 lhs->bind_sizes(
"M"+StrV,
"K");
345 if(rhs->is_row_major())
346 if(is_rhs_transposed)
347 rhs->bind_sizes(
"K"+StrV,
"N");
349 rhs->bind_sizes(
"K",
"N"+StrV);
351 if(is_rhs_transposed)
352 rhs->bind_sizes(
"K",
"N"+StrV);
354 rhs->bind_sizes(
"K"+StrV,
"N");
359 for(detail::mapping_type::const_iterator it = mapping.front().begin() ; it != mapping.front().end() ; ++it){
360 if(detail::mapped_matrix
const * p = dynamic_cast<detail::mapped_matrix const *>(it->second.get())){
361 p->bind_sizes(
"M",
"N");
365 lhs->bind_sizes(
"M",
"K");
366 rhs->bind_sizes(
"K",
"N");
371 std::string aligned_scalartype = assigned->scalartype();
376 access_flow result_access_flow;
377 if(assigned->is_row_major())
378 result_access_flow = REGULAR;
380 result_access_flow = STRIDED;
382 access_flow lhs_access_flow;
383 if((lhs->is_row_major() && !is_lhs_transposed)
384 ||(!lhs->is_row_major() && is_lhs_transposed))
385 lhs_access_flow = REGULAR;
387 lhs_access_flow = STRIDED;
389 access_flow rhs_access_flow;
390 if((rhs->is_row_major() && !is_rhs_transposed)
391 ||(!rhs->is_row_major() && is_rhs_transposed))
392 rhs_access_flow = REGULAR;
394 rhs_access_flow = STRIDED;
397 std::string lhs_value_scalartype;
399 lhs_value_scalartype = lhs->scalartype();
401 lhs_value_scalartype = aligned_scalartype;
403 std::string rhs_value_scalartype;
405 rhs_value_scalartype = rhs->scalartype();
407 rhs_value_scalartype = aligned_scalartype;
410 unsigned int ml_res =
static_cast<unsigned int>(ml_), nl_res = static_cast<unsigned int>(nl_), ms_res =
static_cast<unsigned int>(ms_), ns_res = static_cast<unsigned int>(ns_);
411 unsigned int ml_lhs =
static_cast<unsigned int>(ml_), cache_width_lhs = static_cast<unsigned int>(cache_width_), ms_lhs =
static_cast<unsigned int>(ms_), ks_lhs = static_cast<unsigned int>(ks_);
412 unsigned int cache_width_rhs =
static_cast<unsigned int>(cache_width_), nl_rhs = static_cast<unsigned int>(nl_), ks_rhs =
static_cast<unsigned int>(ks_), ns_rhs = static_cast<unsigned int>(ns_);
414 transform_block(*assigned,
false,ml_res,nl_res,ms_res,ns_res,result_access_flow);
415 transform_block(*lhs,use_lhs_shared_,ml_lhs,cache_width_lhs,ms_lhs,ks_lhs,lhs_access_flow);
416 transform_block(*rhs,use_rhs_shared_,cache_width_rhs,nl_rhs,ks_rhs,ns_rhs,rhs_access_flow);
424 vcl_size_t local_lhs_size2 = cache_width_ + 1;
430 for(
unsigned int m=0; m< ms_res; ++m)
431 for(
unsigned int n=0; n < ns_res ; ++n)
432 stream << aligned_scalartype <<
" " <<
"res" << m <<
"_" << n <<
" = (" << aligned_scalartype <<
")(0) ;" << std::endl;
436 stream <<
"__local " << lhs->scalartype() <<
" lhs_buf[" << local_lhs_size1*local_lhs_size2 <<
"]" <<
";" << std::endl;
438 stream <<
"__local " << rhs->scalartype() <<
" rhs_buf[" << local_rhs_size1*local_rhs_size2 <<
"]" <<
";" << std::endl;
447 stream <<
"__global " << aligned_scalartype <<
"* global_lhs_ptr = " << lhs->name() <<
" + ";
448 if(lhs_access_flow==REGULAR)
449 stream <<
"(" << i <<
")" <<
"*" << lhs->size2();
452 stream <<
";" << std::endl;
457 if(lhs_access_flow==REGULAR)
458 for(
unsigned int m=0; m<ms_lhs; ++m)
459 stream <<
"__global " << aligned_scalartype <<
"* " <<
"lhs_ptr_" << m <<
" = " << lhs->name() <<
" + "
460 << lhs->size2() <<
"* ("
461 <<
"get_group_id(0)*" << ml_lhs <<
"+" <<
"get_local_id(0)*" << ms_lhs <<
"+" << m
462 <<
" );" << std::endl;
464 for(
unsigned int k=0; k<ks_lhs; ++k)
465 stream <<
"__global " << aligned_scalartype<<
"* " <<
"lhs_ptr_" << k <<
" = " << lhs->name() <<
" + "
466 <<
"(" << lhs->size1() <<
")*" << k
467 <<
"+ " <<
"get_group_id(0)*" << ml_lhs <<
"+" <<
"get_local_id(0)*" << ms_lhs <<
";" << std::endl;
473 stream <<
"__global " << aligned_scalartype <<
"* global_rhs_ptr = " << rhs->name() <<
" + ";
474 if(rhs_access_flow==REGULAR)
477 stream <<
"(" << j <<
")" <<
"*" << rhs->size1();
478 stream <<
";" << std::endl;
483 if(rhs_access_flow==REGULAR)
484 for(
unsigned int k = 0 ; k < ks_rhs ; ++k)
485 stream <<
"__global " << aligned_scalartype <<
"* " <<
"rhs_ptr_" << k <<
" = " << rhs->name() <<
" + "
486 <<
"(" << k <<
")" <<
"*" << rhs->size2()
487 <<
" + " <<
"get_local_id(1)*" << ns_rhs <<
" + get_group_id(1)*" << nl_rhs
490 for(
unsigned int n = 0 ; n < ns_rhs ; ++n)
491 stream <<
"__global " << aligned_scalartype <<
"* " <<
"rhs_ptr_" << n <<
" = " << rhs->name() <<
" + "
492 <<
"(" <<
"get_local_id(1)*" << ns_rhs <<
" + get_group_id(1)*" << nl_rhs <<
" + " << n <<
")" <<
"*" << rhs->size1()
498 std::string block_num = helper_variable(stream,
false,
"unsigned int",
"block_num",
"K/" +
utils::to_string(cache_width_));
499 stream <<
"for(unsigned int bl=0 ; bl<" << block_num <<
" ; ++bl){" << std::endl;
504 fetch_to_local_mem(stream,
"lhs_buf",local_lhs_size2,
"global_lhs_ptr",ml_lhs,cache_width_lhs,*lhs,lhs_access_flow);
505 for(
unsigned int m=0; m<ms_lhs; ++m)
506 stream <<
"__local " << lhs_value_scalartype <<
"* lhs_ptr_" << m <<
" = lhs_buf + "
507 <<
"(" <<
"get_local_id(0)*" << ms_lhs <<
"+" << m <<
")" <<
"*" << local_lhs_size2
513 fetch_to_local_mem(stream,
"rhs_buf", local_rhs_size2,
"global_rhs_ptr",cache_width_rhs,nl_rhs,*rhs,rhs_access_flow);
514 for(
unsigned int k=0; k<ks_rhs; ++k)
515 stream <<
"__local " << rhs_value_scalartype <<
"* rhs_ptr_" << k <<
" = rhs_buf + "
516 << k*local_rhs_size2 <<
" + " <<
"get_local_id(1)*" << ns_rhs
521 stream <<
" for(unsigned int bs=0 ; bs < " << cache_width_/ks_ <<
" ; ++bs){" << std::endl;
525 for(
unsigned int k = 0 ; k < ks_rhs ; ++k){
526 for(
unsigned int n=0 ; n < ns_rhs ; ++n){
527 stream << rhs_value_scalartype <<
" val_rhs_" << k <<
"_" << n <<
" = " ;
529 stream <<
"* rhs_ptr_" << k <<
"++";
531 if(rhs_access_flow==REGULAR)
532 stream <<
"* rhs_ptr_" << k <<
"++";
534 stream <<
"* rhs_ptr_" << n <<
"++";
542 for(
unsigned int k = 0 ; k < ks_lhs ; ++k){
543 for(
unsigned int m=0 ; m < ms_lhs ; ++m){
544 stream << lhs_value_scalartype <<
" " <<
"val_lhs_" << m <<
"_" << k <<
" = ";
546 stream <<
"* lhs_ptr_" << m <<
"++" ;
547 else if(lhs_access_flow==REGULAR)
548 stream <<
"* lhs_ptr_" << m <<
"++";
550 stream <<
"* lhs_ptr_" << k <<
"++";
557 for(
unsigned int n=0 ; n < ns_res ; ++n){
558 for(
unsigned int k = 0 ; k < ks_ ; ++k){
559 for(
unsigned int m=0 ; m < ms_res ; ++m){
570 if(result_access_flow==REGULAR){
571 if(!use_lhs_shared_){
572 if(lhs_access_flow==REGULAR){
584 ind_lhs_1 = ind_lhs_1*vector_size_+a;
587 if(lhs_access_flow==REGULAR){
588 ind_lhs_1 = ind_lhs_1*vector_size_+a;
595 if(result_access_flow==REGULAR){
597 ind_rhs_2 = ind_rhs_2*vector_size_+a;
600 if(rhs_access_flow==STRIDED){
601 ind_rhs_2 = ind_rhs_2*vector_size_+a;
610 if(!use_rhs_shared_){
611 if(rhs_access_flow==REGULAR){
622 std::ostringstream res_oss;
623 std::ostringstream lhs_oss;
624 std::ostringstream rhs_oss;
626 res_oss <<
"res" << m <<
"_" << n ;
627 if(vector_size_>1) res_oss <<
".s" << a;
629 lhs_oss <<
"val_lhs_" << ind_lhs_1 <<
"_" << ind_lhs_2;
630 if(!use_lhs_shared_ && vector_size_>1) lhs_oss <<
".s" << ind_s_lhs;
633 rhs_oss <<
"val_rhs_" << ind_rhs_1 <<
"_" << ind_rhs_2;
634 if(!use_rhs_shared_ && vector_size_>1) rhs_oss <<
".s" << ind_s_rhs;
637 stream << res_oss.str() <<
"+=" << lhs_oss.str() <<
"*" << rhs_oss.str() <<
";" << std::endl;
645 for(
unsigned int k=0 ; k<ks_ ; ++k)
646 stream <<
"rhs_ptr_" << k <<
" += " << ks_rhs*local_rhs_size2 - ns_rhs <<
";" << std::endl;
649 if(rhs_access_flow==REGULAR)
650 for(
unsigned int k=0 ; k<ks_ ; ++k)
651 stream <<
"rhs_ptr_" << k <<
" += " << ks_rhs <<
"*" << rhs->size2() <<
" - " << ns_rhs <<
";" << std::endl;
654 if(!use_lhs_shared_){
655 if(lhs_access_flow==STRIDED)
656 for(
unsigned int k=0 ; k<ks_lhs ; ++k)
657 stream <<
"lhs_ptr_" << k <<
" += " << ks_lhs <<
"*" << lhs->size1() <<
" - " << ms_lhs <<
";" << std::endl;
663 stream <<
"}" << std::endl;
666 if(lhs_access_flow==REGULAR)
667 stream <<
"global_lhs_ptr += " << cache_width_lhs <<
";" << std::endl;
669 stream <<
"global_lhs_ptr += " << cache_width_lhs <<
"*" << lhs->size1() <<
";" << std::endl;
673 if(rhs_access_flow==REGULAR)
674 stream <<
"global_rhs_ptr += " << cache_width_rhs <<
"*" << rhs->size2() <<
";" << std::endl;
676 stream <<
"global_rhs_ptr += " << cache_width_rhs <<
";" << std::endl;
680 stream <<
"}" << std::endl;
682 for(
unsigned int m=0 ; m < ms_res ; ++m){
683 for(
unsigned int n=0 ; n < ns_res ; ++n){
688 detail::traverse(statements.front().first, statements.front().second, detail::expression_generation_traversal(std::make_pair(i, j), -1, str, mapping[0]),
false);
689 stream << str <<
";" << std::endl;
708 bool use_lhs_shared_;
709 bool use_rhs_shared_;
static std::string csv_format()
Definition: matrix_product.hpp:106
void arg(unsigned int pos, cl_char val)
Sets a char argument at the provided position.
Definition: kernel.hpp:124
std::size_t vcl_size_t
Definition: forwards.h:58
Definition: forwards.h:83
void configure_range_enqueue_arguments(vcl_size_t kernel_id, statements_type const &statements, viennacl::ocl::kernel &k, unsigned int &n_arg) const
Configures the range and enqueues the arguments associated with the profile.
Definition: matrix_product.hpp:124
Kernel generation class for matrix-matrix products.
Definition: matrix_product.hpp:44
static std::string size2()
Definition: matrix_product.hpp:173
vcl_size_t node_index
Definition: forwards.h:276
Internal utils for a dynamic OpenCL kernel generation.
Represents an OpenCL kernel within ViennaCL.
Definition: kernel.hpp:59
Base class for an operation profile.
Definition: profile_base.hpp:47
lhs_rhs_element lhs
Definition: forwards.h:422
Definition: forwards.h:176
static std::string size1()
Definition: matrix_product.hpp:172
Definition: forwards.h:114
void prod(const T1 &A, bool transposed_A, const T2 &B, bool transposed_B, T3 &C, ScalarType alpha, ScalarType beta)
Definition: matrix_operations.hpp:2305
void kernel_arguments(statements_type const &, std::string &arguments_string) const
Definition: matrix_product.hpp:176
A class representing a compute device (e.g. a GPU)
Definition: device.hpp:49
Mapping of a matrix to a generator class.
Definition: mapped_objects.hpp:236
This file provides the forward declarations for the main types used within ViennaCL.
Functor for obtaining the internal number of columns of a ViennaCL matrix.
Definition: utils.hpp:188
std::list< std::pair< scheduler::statement, scheduler::statement_node > > statements_type
Definition: profile_base.hpp:49
ValueT const & at(std::map< KeyT, ValueT > const &map, KeyT const &key)
Emulation of C++11's .at() member for std::map<>
Definition: forwards.h:97
matrix_product(unsigned int vectorization, vcl_size_t local_size1, vcl_size_t cache_width, vcl_size_t local_size2, unsigned int ms, unsigned int ks, unsigned int ns, bool use_lhs_shared, bool use_rhs_shared)
The user constructor.
Definition: matrix_product.hpp:90
Base classes for the profiles.
Map ViennaCL objects to generator wrappers.
static std::string size3()
Definition: matrix_product.hpp:174
Functor for obtaining the internal number of rows of a ViennaCL matrix.
Definition: utils.hpp:181
Definition: forwards.h:113
void configure_local_sizes(viennacl::ocl::kernel &k, vcl_size_t) const
Definition: profile_base.hpp:59
Provides the datastructures for dealing with a single statement such as 'x = y + z;'.
std::vector< value_type > container_type
Definition: forwards.h:452
unsigned int vector_size_
Definition: profile_base.hpp:178
Definition: forwards.h:115
Definition: forwards.h:96
std::string csv_representation() const
csv representation of an operation
Definition: matrix_product.hpp:110
std::string to_string(T const t)
Definition: utils.hpp:204
statement_node_type_family type_family
Definition: forwards.h:269
size_type global_work_size(int index=0) const
Returns the global work size at the respective dimension.
Definition: kernel.hpp:759
op_element op
Definition: forwards.h:423
Main datastructure for an node in the statement tree.
Definition: forwards.h:420
operation_node_type type
Definition: forwards.h:416
Definition: forwards.h:167