[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]
[Getfem-commits] (no subject)
From: |
Yves Renard |
Subject: |
[Getfem-commits] (no subject) |
Date: |
Wed, 31 Oct 2018 14:51:48 -0400 (EDT) |
branch: master
commit 98b56a36a7ffa5419a8d373390409c037dd57f3b
Author: Yves Renard <address@hidden>
Date: Wed Oct 31 19:51:35 2018 +0100
adding the scaled case
---
src/gmm/gmm_blas_interface.h | 57 ++++++++++++++++++++++++++++++++++++++++----
1 file changed, 53 insertions(+), 4 deletions(-)
diff --git a/src/gmm/gmm_blas_interface.h b/src/gmm/gmm_blas_interface.h
index f051c74..8eb2c0f 100644
--- a/src/gmm/gmm_blas_interface.h
+++ b/src/gmm/gmm_blas_interface.h
@@ -367,6 +367,46 @@ namespace gmm {
}
}
+ template<size_type N, class V1, class V2, class T>
+ inline void add_fixed(const V1 &x, V2 &y, const T &a)
+ {
+ for(size_type i = 0; i != N; ++i) y[i] += a*x[i];
+ }
+
+ template<class V1, class V2, class T>
+ inline void add_for_short_vectors(const V1 &x, V2 &y, const T &a, size_type
n)
+ {
+ switch(n)
+ {
+ case 1: add_fixed<1>(x, y, a); break;
+ case 2: add_fixed<2>(x, y, a); break;
+ case 3: add_fixed<3>(x, y, a); break;
+ case 4: add_fixed<4>(x, y, a); break;
+ case 5: add_fixed<5>(x, y, a); break;
+ case 6: add_fixed<6>(x, y, a); break;
+ case 7: add_fixed<7>(x, y, a); break;
+ case 8: add_fixed<8>(x, y, a); break;
+ case 9: add_fixed<9>(x, y, a); break;
+ case 10: add_fixed<10>(x, y, a); break;
+ case 11: add_fixed<11>(x, y, a); break;
+ case 12: add_fixed<12>(x, y, a); break;
+ case 13: add_fixed<13>(x, y, a); break;
+ case 14: add_fixed<14>(x, y, a); break;
+ case 15: add_fixed<15>(x, y, a); break;
+ case 16: add_fixed<16>(x, y, a); break;
+ case 17: add_fixed<17>(x, y, a); break;
+ case 18: add_fixed<18>(x, y, a); break;
+ case 19: add_fixed<19>(x, y, a); break;
+ case 20: add_fixed<20>(x, y, a); break;
+ case 21: add_fixed<21>(x, y, a); break;
+ case 22: add_fixed<22>(x, y, a); break;
+ case 23: add_fixed<23>(x, y, a); break;
+ case 24: add_fixed<24>(x, y, a); break;
+ default: GMM_ASSERT2(false, "add_for_short_vectors used with unsupported
size"); break;
+ }
+ }
+
+
# define axpy_interface(param1, trans1, blas_name, base_type) \
inline void add(param1(base_type), std::vector<base_type > &y) { \
GMMLAPACK_TRACE("axpy_interface"); \
@@ -376,6 +416,15 @@ namespace gmm {
else blas_name(&n, &a, &x[0], &inc, &y[0], &inc); \
}
+# define axpy2_interface(param1, trans1, blas_name, base_type) \
+ inline void add(param1(base_type), std::vector<base_type > &y) { \
+ GMMLAPACK_TRACE("axpy_interface"); \
+ long inc(1), n(long(vect_size(y))); trans1(base_type); \
+ if(n == 0) return; \
+ else if(n < 25) add_for_short_vectors(x, y, a, n); \
+ else blas_name(&n, &a, &x[0], &inc, &y[0], &inc); \
+ }
+
# define axpy_p1(base_type) const std::vector<base_type > &x
# define axpy_trans1(base_type) base_type a(1)
# define axpy_p1_s(base_type) \
@@ -390,10 +439,10 @@ namespace gmm {
axpy_interface(axpy_p1, axpy_trans1, caxpy_, BLAS_C)
axpy_interface(axpy_p1, axpy_trans1, zaxpy_, BLAS_Z)
- axpy_interface(axpy_p1_s, axpy_trans1_s, saxpy_, BLAS_S)
- axpy_interface(axpy_p1_s, axpy_trans1_s, daxpy_, BLAS_D)
- axpy_interface(axpy_p1_s, axpy_trans1_s, caxpy_, BLAS_C)
- axpy_interface(axpy_p1_s, axpy_trans1_s, zaxpy_, BLAS_Z)
+ axpy2_interface(axpy_p1_s, axpy_trans1_s, saxpy_, BLAS_S)
+ axpy2_interface(axpy_p1_s, axpy_trans1_s, daxpy_, BLAS_D)
+ axpy2_interface(axpy_p1_s, axpy_trans1_s, caxpy_, BLAS_C)
+ axpy2_interface(axpy_p1_s, axpy_trans1_s, zaxpy_, BLAS_Z)
/* ********************************************************************* */