Умножение матрицы с ее транспонированием с использованием cuBlas

Я пытаюсь умножить матрицу с помощью ее транспонирования, но мне не удалось сделать правильный вызов sgemm. Sgemm принимает много параметров. Некоторые из них, такие как lda, ldb, сбивают меня с толку. Если я вызываю функцию ниже с квадратной матрицей, она работает, иначе она не работает.

/*param inMatrix: contains the matrix data in major order like [1 2 3 1 2 3]
param rowNum: Number of rows in a matrix eg if matrix is
|1  1|
|2  2|
|3  3| than rowNum should be 3*/
void matrixtTransposeMult(std::vector<float>& inMatrix, int rowNum)
{
cublasHandle_t handle;
cublasCreate(&handle);

int colNum = (int)inMatrix.size() / rowNum;
thrust::device_vector<float> d_InMatrix(inMatrix);
thrust::device_vector<float> d_outputMatrix(rowNum*rowNum);
float alpha = 1.0f;
float beta = 0.0f;

cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, rowNum, rowNum, colNum, &alpha,
thrust::raw_pointer_cast(d_InMatrix.data()), colNum, thrust::raw_pointer_cast(d_InMatrix.data()), colNum, &beta,
thrust::raw_pointer_cast(d_outputMatrix.data()), rowNum);

thrust::host_vector<float> result = d_outputMatrix;
for (auto elem : result)
std::cout << elem << ",";
std::cout << std::endl;

cublasDestroy(handle);
}

Что мне не хватает? Как сделать правильный вызов sgemm для matrix * matrixTranspose?

1

Решение

Ниже настройки работали для меня, если я что-то упустил, пожалуйста, предупредите меня. Надеюсь это кому-нибудь пригодится

void matrixtTransposeMult(std::vector<float>& inMatrix, int rowNum)
{
cublasHandle_t handle;
cublasCreate(&handle);

int colNum = (int)inMatrix.size() / rowNum;
thrust::device_vector<float> d_InMatrix(inMatrix);
thrust::device_vector<float> d_outputMatrix(rowNum*rowNum);
float alpha = 1.0f;
float beta = 0.0f;

cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, rowNum, rowNum, colNum, &alpha,
thrust::raw_pointer_cast(d_InMatrix.data()), rowNum, thrust::raw_pointer_cast(d_InMatrix.data()), rowNum, &beta,
thrust::raw_pointer_cast(d_outputMatrix.data()), rowNum);

thrust::host_vector<float> result = d_outputMatrix;
for (auto elem : result)
std::cout << elem << ",";
std::cout << std::endl;

cublasDestroy(handle);
}
1

Другие решения


По вопросам рекламы ammmcru@yandex.ru
Adblock
detector