Это довольно простой вопрос C ++ для вычисления умножения матриц на GPU. Следующий код технически MSL, но синтаксис почти идентичен.
Apple предоставляет пример умножения матриц для вычислений A^T * B
, Я ищу помощь, чтобы изменить его, чтобы просто вычислить A * B
,
Каждый вызов этого шейдера работает в секторе 8 х 8 C
, а также gid
позиция этого сектора в сетке. Вот источник:
// Note:
//
// (1) m is the number of rows in matrices A and C.
//
// (2) n is the number of columns in matrix A; number of rows in matrix B.
//
// (3) k is the number of columns in matrices B and C.
//
// (4) Matrix multiple computes C = A^T * B where A is m x n matrix (so
// that, A^T is n x m), B is n x k .
//
// (5) pbytes is stride in bytes from row to another of matrix A.
// pbytes should be multiple of 32, i.e. A is padded to be
// M x k matrix where M > m and P is multiple of 8.
//
// (6) Similarly qbytes is stride in bytes from one row to another
// of B, i.e. B is n x K matrix where K > k matrix where K is
// multiple of 8.
//
// (7) The output matrix C is the M x K matrix.
typedef struct
{
ushort m, k, n, pbytes, qbytes;
} MetalMatrixDim;kernel void MatrixMultiply(const device float* A [[ buffer(0) ]],
const device float* B [[ buffer(1) ]],
device float* C [[ buffer(2) ]],
constant MetalMatrixDim& dims [[ buffer(3) ]],
ushort2 gid [[ thread_position_in_grid ]])
{
ushort m = dims.m;
ushort k = dims.k;
ushort n = dims.n;
ushort pbytes = dims.pbytes;
ushort qbytes = dims.qbytes;
// Multiply gid by 8 to get the absolute position in C
ushort2 gidIn = ushort2(gid.x << 3, gid.y << 3);
if (gidIn.x >= m || gidIn.y >= k) return;
const device float4* a = (const device float4*)(A + gidIn.x);
const device float4* b = (const device float4*)(B + gidIn.y);
C = (device float*)((device char*)C + gidIn.x*qbytes);
device float4* c = (device float4*)(C + gidIn.y);
const device float4* Bend = (const device float4*)((const device char*)B + qbytes*n);
float4 s0 = 0.0f, s1 = 0.0f, s2 = 0.0f, s3 = 0.0f;
float4 s4 = 0.0f, s5 = 0.0f, s6 = 0.0f, s7 = 0.0f;
float4 s8 = 0.0f, s9 = 0.0f, s10 = 0.0f, s11 = 0.0f;
float4 s12 = 0.0f, s13 = 0.0f, s14 = 0.0f, s15 = 0.0f;
do
{
float4 aCurr0 = a[0];
float4 aCurr1 = a[1];
float4 bCurr0 = b[0];
float4 bCurr1 = b[1];
s0 += (aCurr0.x * bCurr0);
s2 += (aCurr0.y * bCurr0);
s4 += (aCurr0.z * bCurr0);
s6 += (aCurr0.w * bCurr0);
s1 += (aCurr0.x * bCurr1);
s3 += (aCurr0.y * bCurr1);
s5 += (aCurr0.z * bCurr1);
s7 += (aCurr0.w * bCurr1);
s8 += (aCurr1.x * bCurr0);
s10 += (aCurr1.y * bCurr0);
s12 += (aCurr1.z * bCurr0);
s14 += (aCurr1.w * bCurr0);
s9 += (aCurr1.x * bCurr1);
s11 += (aCurr1.y * bCurr1);
s13 += (aCurr1.z * bCurr1);
s15 += (aCurr1.w * bCurr1);
a = (device float4*)((device char*)a + pbytes);
b = (device float4*)((device char*)b + qbytes);
} while(b < Bend);
c[0] = s0; c[1] = s1; c = (device float4*)((device char*)c + qbytes);
c[0] = s2; c[1] = s3; c = (device float4*)((device char*)c + qbytes);
c[0] = s4; c[1] = s5; c = (device float4*)((device char*)c + qbytes);
c[0] = s6; c[1] = s7; c = (device float4*)((device char*)c + qbytes);
c[0] = s8; c[1] = s9; c = (device float4*)((device char*)c + qbytes);
c[0] = s10; c[1] = s11; c = (device float4*)((device char*)c + qbytes);
c[0] = s12; c[1] = s13; c = (device float4*)((device char*)c + qbytes);
c[0] = s14; c[1] = s15;
}
Я потратил немало времени на это, но лучшее, что я придумал, — это наивное решение, которое не учитывает задержку памяти. Вместо этого я надеюсь изменить код Apple, чтобы исключить A
, в то же время позволяя графическому процессору оптимизировать чтение / запись памяти.
Может ли кто-нибудь помочь мне здесь?
Редактировать: Вот моя (очень) наивная реализация. Он работает примерно в 100 раз медленнее, чем ядро Apple:
int pbytes = (int)dims.pbytes;
int qbytes = (int)dims.qbytes;
for (int row = 0; row < 8; row++) {
int aStart = (gidIn.y + row) * pbytes / 4;
for (int col = 0; col < 8; col++) {
int cIdx = gidIn.y + (row * qbytes / 4) + gidIn.x + col;
int bStart = gidIn.x + col;
float sum = 0.0f;
for (int i = 0; i < (pbytes / 4); i++) {
float prod = A[aStart + i] * B[bStart + (i * qbytes / 4)];
sum += prod;
}
C[cIdx] = sum;
}
}
Проблема с этой реализацией заключается в том, что она вообще не оптимизируется для доступа к памяти. В идеале вы должны читать и записывать как можно больше данных за один раз, что позволит компилятору векторизовать операцию.
Фреймворк MetalPerformanceShaders имеет встроенное ядро умножения матриц, которое вы можете просто кодировать в буфер команд металла. Я рекомендую делать это вместо того, чтобы тратить здесь много времени.
Других решений пока нет …