44 #ifndef KOKKOS_MV_GEMM_HPP 45 #define KOKKOS_MV_GEMM_HPP 52 #include<Teuchos_BLAS.hpp> 54 #ifdef KOKKOS_HAVE_CUDA 68 class BLAS<int, ::Kokkos::complex<float> > {
70 typedef float mag_type;
71 typedef ::Kokkos::complex<float> val_type;
72 typedef std::complex<float> impl_type;
75 BLAS (
const BLAS<int, val_type>&) {}
89 GEMV (ETransp trans,
const int m,
const int n,
const val_type alpha,
90 const val_type* A,
const int lda,
const val_type* x,
const int incx,
91 const val_type beta, val_type* y,
const int incy)
const 93 BLAS<int, impl_type> blas;
94 blas.GEMV (trans, m, n, static_cast<impl_type> (alpha),
95 reinterpret_cast<const impl_type*> (A), lda,
96 reinterpret_cast<const impl_type*> (x), incx,
97 static_cast<impl_type> (beta),
98 reinterpret_cast<impl_type*> (y), incy);
105 GEMM (ETransp transa, ETransp transb,
const int m,
const int n,
const int k,
106 const val_type alpha,
const val_type* A,
const int lda,
107 const val_type* B,
const int ldb,
const val_type beta, val_type* C,
110 BLAS<int, impl_type> blas;
111 blas.GEMM (transa, transb, m, n, k,
112 static_cast<impl_type> (alpha),
113 reinterpret_cast<const impl_type*> (A), lda,
114 reinterpret_cast<const impl_type*> (B), ldb,
115 static_cast<impl_type> (beta),
116 reinterpret_cast<impl_type*> (C), ldc);
126 class BLAS<int, ::Kokkos::complex<double> > {
128 typedef double mag_type;
129 typedef ::Kokkos::complex<double> val_type;
130 typedef std::complex<double> impl_type;
133 BLAS (
const BLAS<int, val_type>&) {}
147 GEMV (ETransp trans,
const int m,
const int n,
const val_type alpha,
148 const val_type* A,
const int lda,
const val_type* x,
const int incx,
149 const val_type beta, val_type* y,
const int incy)
const 151 BLAS<int, impl_type> blas;
152 blas.GEMV (trans, m, n, static_cast<impl_type> (alpha),
153 reinterpret_cast<const impl_type*> (A), lda,
154 reinterpret_cast<const impl_type*> (x), incx,
155 static_cast<impl_type> (beta),
156 reinterpret_cast<impl_type*> (y), incy);
163 GEMM (ETransp transa, ETransp transb,
const int m,
const int n,
const int k,
164 const val_type alpha,
const val_type* A,
const int lda,
165 const val_type* B,
const int ldb,
const val_type beta, val_type* C,
168 BLAS<int, impl_type> blas;
169 blas.GEMM (transa, transb, m, n, k,
170 static_cast<impl_type> (alpha),
171 reinterpret_cast<const impl_type*> (A), lda,
172 reinterpret_cast<const impl_type*> (B), ldb,
173 static_cast<impl_type> (beta),
174 reinterpret_cast<impl_type*> (C), ldc);
189 template<
class ViewType>
190 size_t getStride2DView (ViewType A) {
193 return A.dimension_1 () > 1 ? stride[1] : A.dimension_0 ();
203 template <
typename Scalar,
typename DeviceType>
207 GEMM (
const Teuchos::ETransp transA,
208 const Teuchos::ETransp transB,
210 View<const Scalar**, LayoutLeft, DeviceType> A,
211 View<const Scalar**, LayoutLeft, DeviceType> B,
213 View<Scalar**, LayoutLeft, DeviceType> C)
215 Teuchos::BLAS<int,Scalar> blas;
216 const int m =
static_cast<int> (C.dimension_0 ()),
217 n = static_cast<int> (C.dimension_1 ()),
218 k = (transA == Teuchos::NO_TRANS ? A.dimension_1 () : A.dimension_0 ()),
219 lda = static_cast<int> (Impl::getStride2DView (A)),
220 ldb = static_cast<int> (Impl::getStride2DView (B)),
221 ldc = static_cast<int> (Impl::getStride2DView (C));
224 if (n == 1 && transB == Teuchos::NO_TRANS) {
225 blas.GEMV (transA, A.dimension_0 (), A.dimension_1 (), alpha,
226 A.ptr_on_device(), lda,
227 B.ptr_on_device(),
static_cast<int> (1),
228 beta, C.ptr_on_device(),
static_cast<int> (1));
231 blas.GEMM (transA, transB, m, n, k, alpha,
232 A.ptr_on_device(), lda,
233 B.ptr_on_device(), ldb,
234 beta, C.ptr_on_device(), ldc);
299 #ifdef KOKKOS_HAVE_CUDA 300 template <
typename Scalar>
303 static void GEMM(Teuchos::ETransp transA, Teuchos::ETransp transB, Scalar alpha,
304 View<const Scalar**,LayoutLeft,Cuda> A, View<const Scalar**,LayoutLeft,Cuda> B,
305 Scalar beta, View<Scalar**,LayoutLeft,Cuda> C) {
306 TEUCHOS_TEST_FOR_EXCEPTION(
true, std::logic_error,
"DeviceGEMM: Kokkos::Cuda has no support for GEMM operations over Scalar=" << Teuchos::typeName(alpha) <<
".");
314 static void GEMM(Teuchos::ETransp transA, Teuchos::ETransp transB,
float alpha,
315 View<const float**,LayoutLeft,Cuda> A, View<const float**,LayoutLeft,Cuda> B,
316 float beta, View<float**,LayoutLeft,Cuda> C) {
317 const int m =
static_cast<int>(C.dimension_0()),
318 n = static_cast<int>(C.dimension_1()),
319 k = (transA == Teuchos::NO_TRANS ? A.dimension_1() : A.dimension_0()),
320 lda = static_cast<int>(Impl::getStride2DView(A)),
321 ldb = static_cast<int>(Impl::getStride2DView(B)),
322 ldc = static_cast<int>(Impl::getStride2DView(C));
323 const char char_transA = (transA == Teuchos::NO_TRANS ?
'N' :
'T'),
324 char_transB = (transB == Teuchos::NO_TRANS ?
'N' :
'T');
325 cublasSgemm(char_transA, char_transB, m, n, k, alpha, A.ptr_on_device(), lda, B.ptr_on_device(), ldb, beta, C.ptr_on_device(), ldc);
326 #ifdef HAVE_KOKKOS_DEBUG 327 cublasStatus info = cublasGetError();
328 TEUCHOS_TEST_FOR_EXCEPTION( info != CUBLAS_STATUS_SUCCESS, std::runtime_error,
"cublasSgemm failed with status " << info <<
"." );
336 static void GEMM(Teuchos::ETransp transA, Teuchos::ETransp transB,
double alpha,
337 View<const double**,LayoutLeft,Cuda> A, View<const double**,LayoutLeft,Cuda> B,
338 double beta, View<double**,LayoutLeft,Cuda> C) {
339 const int m =
static_cast<int>(C.dimension_0()),
340 n = static_cast<int>(C.dimension_1()),
341 k = (transA == Teuchos::NO_TRANS ? A.dimension_1() : A.dimension_0()),
342 lda = static_cast<int>(Impl::getStride2DView(A)),
343 ldb = static_cast<int>(Impl::getStride2DView(B)),
344 ldc = static_cast<int>(Impl::getStride2DView(C));
345 const char char_transA = (transA == Teuchos::NO_TRANS ?
'N' :
'T'),
346 char_transB = (transB == Teuchos::NO_TRANS ?
'N' :
'T');
347 cublasDgemm(char_transA, char_transB, m, n, k, alpha, A.ptr_on_device(), lda, B.ptr_on_device(), ldb, beta, C.ptr_on_device(), ldc);
348 #ifdef HAVE_KOKKOS_DEBUG 349 cublasStatus info = cublasGetError();
350 TEUCHOS_TEST_FOR_EXCEPTION( info != CUBLAS_STATUS_SUCCESS, std::runtime_error,
"cublasDgemm failed with status " << info <<
"." );
358 #endif // KOKKOS_MV_GEMM_HPP void GEMV(const CoefficientType &alpha, const LittleBlockType &A, const LittleVectorType1 &x, const LittleVectorType2 &y)
y := y + alpha * A * x (dense matrix-vector multiply)
void GEMM(const char transA[], const char transB[], const CoefficientType &alpha, const ViewType1 &A, const ViewType2 &B, const CoefficientType &beta, const ViewType3 &C)
Small dense matrix-matrix multiply: C := alpha*A*B + beta*C
Class that provides GEMM for a particular Kokkos Device.