diff options
-rw-r--r-- | src/gpu-compute/compute_unit.cc | 50 | ||||
-rw-r--r-- | src/gpu-compute/compute_unit.hh | 6 | ||||
-rw-r--r-- | src/gpu-compute/wavefront.cc | 11 | ||||
-rw-r--r-- | src/gpu-compute/wavefront.hh | 6 |
4 files changed, 39 insertions, 34 deletions
diff --git a/src/gpu-compute/compute_unit.cc b/src/gpu-compute/compute_unit.cc index b937584eb..97e018713 100644 --- a/src/gpu-compute/compute_unit.cc +++ b/src/gpu-compute/compute_unit.cc @@ -174,7 +174,7 @@ ComputeUnit::~ComputeUnit() } void -ComputeUnit::FillKernelState(Wavefront *w, NDRange *ndr) +ComputeUnit::fillKernelState(Wavefront *w, NDRange *ndr) { w->resizeRegFiles(ndr->q.cRegCount, ndr->q.sRegCount, ndr->q.dRegCount); @@ -190,6 +190,7 @@ ComputeUnit::FillKernelState(Wavefront *w, NDRange *ndr) w->spillSizePerItem = ndr->q.spillMemPerItem; w->roBase = ndr->q.roMemStart; w->roSize = ndr->q.roMemTotal; + w->computeActualWgSz(ndr); } void @@ -220,19 +221,16 @@ ComputeUnit::updateEvents() { void -ComputeUnit::StartWF(Wavefront *w, int trueWgSize[], int trueWgSizeTotal, - int waveId, LdsChunk *ldsChunk, NDRange *ndr) +ComputeUnit::startWavefront(Wavefront *w, int waveId, LdsChunk *ldsChunk, + NDRange *ndr) { static int _n_wave = 0; - // Fill in Kernel state - FillKernelState(w, ndr); - VectorMask init_mask; init_mask.reset(); for (int k = 0; k < wfSize(); ++k) { - if (k + waveId * wfSize() < trueWgSizeTotal) + if (k + waveId * wfSize() < w->actualWgSzTotal) init_mask[k] = 1; } @@ -241,18 +239,18 @@ ComputeUnit::StartWF(Wavefront *w, int trueWgSize[], int trueWgSizeTotal, w->initMask = init_mask.to_ullong(); for (int k = 0; k < wfSize(); ++k) { - w->workItemId[0][k] = (k + waveId * wfSize()) % trueWgSize[0]; - w->workItemId[1][k] = - ((k + waveId * wfSize()) / trueWgSize[0]) % trueWgSize[1]; - w->workItemId[2][k] = - (k + waveId * wfSize()) / (trueWgSize[0] * trueWgSize[1]); - - w->workItemFlatId[k] = w->workItemId[2][k] * trueWgSize[0] * - trueWgSize[1] + w->workItemId[1][k] * trueWgSize[0] + + w->workItemId[0][k] = (k + waveId * wfSize()) % w->actualWgSz[0]; + w->workItemId[1][k] = ((k + waveId * wfSize()) / w->actualWgSz[0]) % + w->actualWgSz[1]; + w->workItemId[2][k] = (k + waveId * wfSize()) / + (w->actualWgSz[0] * w->actualWgSz[1]); + + w->workItemFlatId[k] = w->workItemId[2][k] * w->actualWgSz[0] * + w->actualWgSz[1] + w->workItemId[1][k] * w->actualWgSz[0] + w->workItemId[0][k]; } - w->barrierSlots = divCeil(trueWgSizeTotal, wfSize()); + w->barrierSlots = divCeil(w->actualWgSzTotal, wfSize()); w->barCnt.resize(wfSize(), 0); @@ -294,8 +292,8 @@ ComputeUnit::StartWF(Wavefront *w, int trueWgSize[], int trueWgSizeTotal, // is this the last wavefront in the workgroup // if set the spillWidth to be the remaining work-items // so that the vector access is correct - if ((waveId + 1) * wfSize() >= trueWgSizeTotal) { - w->spillWidth = trueWgSizeTotal - (waveId * wfSize()); + if ((waveId + 1) * wfSize() >= w->actualWgSzTotal) { + w->spillWidth = w->actualWgSzTotal - (waveId * wfSize()); } else { w->spillWidth = wfSize(); } @@ -328,17 +326,6 @@ ComputeUnit::StartWorkgroup(NDRange *ndr) injectGlobalMemFence(gpuDynInst, true); } - // Get true size of workgroup (after clamping to grid size) - int trueWgSize[3]; - int trueWgSizeTotal = 1; - - for (int d = 0; d < 3; ++d) { - trueWgSize[d] = std::min(ndr->q.wgSize[d], ndr->q.gdSize[d] - - ndr->wgId[d] * ndr->q.wgSize[d]); - - trueWgSizeTotal *= trueWgSize[d]; - } - // calculate the number of 32-bit vector registers required by wavefront int vregDemand = ndr->q.sRegCount + (2 * ndr->q.dRegCount); int wave_id = 0; @@ -350,9 +337,10 @@ ComputeUnit::StartWorkgroup(NDRange *ndr) // It must be stopped and not waiting // for a release to complete S_RETURNING if (w->status == Wavefront::S_STOPPED) { + fillKernelState(w, ndr); // if we have scheduled all work items then stop // scheduling wavefronts - if (wave_id * wfSize() >= trueWgSizeTotal) + if (wave_id * wfSize() >= w->actualWgSzTotal) break; // reserve vector registers for the scheduled wavefront @@ -365,7 +353,7 @@ ComputeUnit::StartWorkgroup(NDRange *ndr) w->reservedVectorRegs = normSize; vectorRegsReserved[m % numSIMDs] += w->reservedVectorRegs; - StartWF(w, trueWgSize, trueWgSizeTotal, wave_id, ldsChunk, ndr); + startWavefront(w, wave_id, ldsChunk, ndr); ++wave_id; } } diff --git a/src/gpu-compute/compute_unit.hh b/src/gpu-compute/compute_unit.hh index 34b710cd6..a3547402a 100644 --- a/src/gpu-compute/compute_unit.hh +++ b/src/gpu-compute/compute_unit.hh @@ -254,10 +254,10 @@ class ComputeUnit : public MemObject void exec(); void initiateFetch(Wavefront *wavefront); void fetch(PacketPtr pkt, Wavefront *wavefront); - void FillKernelState(Wavefront *w, NDRange *ndr); + void fillKernelState(Wavefront *w, NDRange *ndr); - void StartWF(Wavefront *w, int trueWgSize[], int trueWgSizeTotal, - int cnt, LdsChunk *ldsChunk, NDRange *ndr); + void startWavefront(Wavefront *w, int waveId, LdsChunk *ldsChunk, + NDRange *ndr); void StartWorkgroup(NDRange *ndr); int ReadyWorkgroup(NDRange *ndr); diff --git a/src/gpu-compute/wavefront.cc b/src/gpu-compute/wavefront.cc index 42739a7b0..c677cbe41 100644 --- a/src/gpu-compute/wavefront.cc +++ b/src/gpu-compute/wavefront.cc @@ -1066,3 +1066,14 @@ Wavefront::setContext(const void *in) ldsChunk->write<char>(i, val); } } + +void +Wavefront::computeActualWgSz(NDRange *ndr) +{ + actualWgSzTotal = 1; + for (int d = 0; d < 3; ++d) { + actualWgSz[d] = std::min(workGroupSz[d], + gridSz[d] - ndr->wgId[d] * workGroupSz[d]); + actualWgSzTotal *= actualWgSz[d]; + } +} diff --git a/src/gpu-compute/wavefront.hh b/src/gpu-compute/wavefront.hh index ef8c80989..0df8a6c82 100644 --- a/src/gpu-compute/wavefront.hh +++ b/src/gpu-compute/wavefront.hh @@ -47,6 +47,7 @@ #include "gpu-compute/condition_register_state.hh" #include "gpu-compute/lds_state.hh" #include "gpu-compute/misc.hh" +#include "gpu-compute/ndrange.hh" #include "params/Wavefront.hh" #include "sim/sim_object.hh" @@ -189,11 +190,16 @@ class Wavefront : public SimObject std::vector<Addr> lastAddr; std::vector<uint32_t> workItemId[3]; std::vector<uint32_t> workItemFlatId; + /* kernel launch parameters */ uint32_t workGroupId[3]; uint32_t workGroupSz[3]; uint32_t gridSz[3]; uint32_t wgId; uint32_t wgSz; + /* the actual WG size can differ than the maximum size */ + uint32_t actualWgSz[3]; + uint32_t actualWgSzTotal; + void computeActualWgSz(NDRange *ndr); // wavefront id within a workgroup uint32_t wfId; uint32_t maxDynWaveId; |