(1) template <class RealType, size_t PanelSizeK, size_t PanelSizeiA, |
(2) size_t PanelSizejB, class VecType> |
(3) void ScalarGemmIna(const RealType __restrict__ A, const RealType __restrict__ B, |
(4) RealType __restrict__ C, const size_t size) |
(5) |
(6) const int BlockSize = VecType::VecLength; |
(7) |
(8) static_assert(PanelSizeK >= BlockSize, "PanelSizeK must be greater than block"); |
(9) static_assert(PanelSizeiA >= BlockSize, "PanelSizeiA must be greater than block"); |
(10) static_assert(PanelSizejB >= BlockSize, "PanelSizejB must be greater than block"); |
(11) static_assert((PanelSizeK/BlockSize)BlockSize == PanelSizeK, "PanelSizeK must be a … multiple of block"); |
(12) static_assert((PanelSizeiA/BlockSize)BlockSize == PanelSizeiA, "PanelSizeiA must be a … multiple of block"); |
(13) static_assert((PanelSizejB/BlockSize)BlockSize == PanelSizejB, "PanelSizejB must be a … multiple of block"); |
(14) // Restrict to a multiple of panelsize for simplcity |
(15) assert((size/PanelSizeK)PanelSizeK == size); |
(16) assert((size/PanelSizeiA)PanelSizeiA == size); |
(17) assert((size/PanelSizejB)PanelSizejB == size); |
(18) |
(19) for(size_t ip = 0 ; ip < size ; ip += PanelSizeiA) |
(20) for(size_t jp = 0 ; jp < size ; jp += PanelSizejB) |
(21) |
(22) for(size_t kp = 0 ; kp < size ; kp += PanelSizeK) |
(23) |
(24) alignas(64) RealType panelA[PanelSizeiAPanelSizeK]; |
(25) alignas(64) RealType panelB[PanelSizeKBlockSize]; |
(26) |
(27) for(size_t jb = 0 ; jb < PanelSizejB ; jb += BlockSize) |
(28) |
(29) CopyMat<RealType, BlockSize>(panelB, PanelSizeK, &B[jpsize + kp], size); |
(30) |
(31) for(size_t ib = 0 ; ib < PanelSizeiA ; ib += BlockSize) |
(32) |
(33) if(jb == 0) |
(34) CopyMat<RealType, BlockSize>(&panelA[PanelSizeKib], PanelSizeK, … |
&A[(ib+ip)size + kp], size); |
(35) |
(36) |
(37) for(size_t idxRow = 0 ; idxRow < BlockSize ; ++idxRow) |
(38) for(size_t idxCol = 0 ; idxCol < BlockSize ; ++idxCol) |
(39) VecType sum = 0.; |
(40) for(size_t idxK = 0 ; idxK < PanelSizeK ; idxK += BlockSize) |
(41) sum += VecType(&panelA[(idxRow+ib)PanelSizeK+ idxK]) |
(42) VecType(&panelB[idxColPanelSizeK+ idxK]); |
(43) |
(44) C[(jp+jb+idxCol)size + ip + ib + idxRow] += sum.horizontalSum(); |
(45) |
(46) |
(47) |
(48) |
(49) |
(50) |
(51) |
(52) |