summaryrefslogtreecommitdiff
path: root/src/gpu-compute/compute_unit.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu-compute/compute_unit.cc')
-rw-r--r--src/gpu-compute/compute_unit.cc99
1 files changed, 28 insertions, 71 deletions
diff --git a/src/gpu-compute/compute_unit.cc b/src/gpu-compute/compute_unit.cc
index 5ec061172..83e2414db 100644
--- a/src/gpu-compute/compute_unit.cc
+++ b/src/gpu-compute/compute_unit.cc
@@ -193,50 +193,6 @@ ComputeUnit::FillKernelState(Wavefront *w, NDRange *ndr)
}
void
-ComputeUnit::InitializeWFContext(WFContext *wfCtx, NDRange *ndr, int cnt,
- int trueWgSize[], int trueWgSizeTotal,
- LdsChunk *ldsChunk, uint64_t origSpillMemStart)
-{
- wfCtx->cnt = cnt;
-
- VectorMask init_mask;
- init_mask.reset();
-
- for (int k = 0; k < wfSize(); ++k) {
- if (k + cnt * wfSize() < trueWgSizeTotal)
- init_mask[k] = 1;
- }
-
- wfCtx->init_mask = init_mask.to_ullong();
- wfCtx->exec_mask = init_mask.to_ullong();
-
- wfCtx->bar_cnt.resize(wfSize(), 0);
-
- wfCtx->max_bar_cnt = 0;
- wfCtx->old_barrier_cnt = 0;
- wfCtx->barrier_cnt = 0;
-
- wfCtx->privBase = ndr->q.privMemStart;
- ndr->q.privMemStart += ndr->q.privMemPerItem * wfSize();
-
- wfCtx->spillBase = ndr->q.spillMemStart;
- ndr->q.spillMemStart += ndr->q.spillMemPerItem * wfSize();
-
- wfCtx->pc = 0;
- wfCtx->rpc = UINT32_MAX;
-
- // set the wavefront context to have a pointer to this section of the LDS
- wfCtx->ldsChunk = ldsChunk;
-
- // WG state
- wfCtx->wg_id = ndr->globalWgId;
- wfCtx->barrier_id = barrier_id;
-
- // Kernel wide state
- wfCtx->ndr = ndr;
-}
-
-void
ComputeUnit::updateEvents() {
if (!timestampVec.empty()) {
@@ -264,19 +220,25 @@ ComputeUnit::updateEvents() {
void
-ComputeUnit::StartWF(Wavefront *w, WFContext *wfCtx, int trueWgSize[],
- int trueWgSizeTotal)
+ComputeUnit::StartWF(Wavefront *w, int trueWgSize[], int trueWgSizeTotal,
+ int cnt, LdsChunk *ldsChunk, NDRange *ndr)
{
static int _n_wave = 0;
- int cnt = wfCtx->cnt;
- NDRange *ndr = wfCtx->ndr;
// Fill in Kernel state
FillKernelState(w, ndr);
+ VectorMask init_mask;
+ init_mask.reset();
+
+ for (int k = 0; k < wfSize(); ++k) {
+ if (k + cnt * wfSize() < trueWgSizeTotal)
+ init_mask[k] = 1;
+ }
+
w->kern_id = ndr->dispatchId;
w->dynwaveid = cnt;
- w->init_mask = wfCtx->init_mask;
+ w->init_mask = init_mask.to_ullong();
for (int k = 0; k < wfSize(); ++k) {
w->workitemid[0][k] = (k+cnt*wfSize()) % trueWgSize[0];
@@ -290,32 +252,34 @@ ComputeUnit::StartWF(Wavefront *w, WFContext *wfCtx, int trueWgSize[],
w->workitemid[0][k];
}
- w->old_barrier_cnt = wfCtx->old_barrier_cnt;
- w->barrier_cnt = wfCtx->barrier_cnt;
w->barrier_slots = divCeil(trueWgSizeTotal, wfSize());
- for (int i = 0; i < wfSize(); ++i) {
- w->bar_cnt[i] = wfCtx->bar_cnt[i];
- }
+ w->bar_cnt.resize(wfSize(), 0);
+
+ w->max_bar_cnt = 0;
+ w->old_barrier_cnt = 0;
+ w->barrier_cnt = 0;
+
+ w->privBase = ndr->q.privMemStart;
+ ndr->q.privMemStart += ndr->q.privMemPerItem * wfSize();
- w->max_bar_cnt = wfCtx->max_bar_cnt;
- w->privBase = wfCtx->privBase;
- w->spillBase = wfCtx->spillBase;
+ w->spillBase = ndr->q.spillMemStart;
+ ndr->q.spillMemStart += ndr->q.spillMemPerItem * wfSize();
- w->pushToReconvergenceStack(wfCtx->pc, wfCtx->rpc, wfCtx->exec_mask);
+ w->pushToReconvergenceStack(0, UINT32_MAX, init_mask.to_ulong());
// WG state
- w->wg_id = wfCtx->wg_id;
- w->dispatchid = wfCtx->ndr->dispatchId;
+ w->wg_id = ndr->globalWgId;
+ w->dispatchid = ndr->dispatchId;
w->workgroupid[0] = w->wg_id % ndr->numWg[0];
w->workgroupid[1] = (w->wg_id / ndr->numWg[0]) % ndr->numWg[1];
w->workgroupid[2] = w->wg_id / (ndr->numWg[0] * ndr->numWg[1]);
- w->barrier_id = wfCtx->barrier_id;
+ w->barrier_id = barrier_id;
w->stalledAtBarrier = false;
- // move this from the context into the actual wavefront
- w->ldsChunk = wfCtx->ldsChunk;
+ // set the wavefront context to have a pointer to this section of the LDS
+ w->ldsChunk = ldsChunk;
int32_t refCount M5_VAR_USED =
lds.increaseRefCounter(w->dispatchid, w->wg_id);
@@ -340,7 +304,6 @@ ComputeUnit::StartWF(Wavefront *w, WFContext *wfCtx, int trueWgSize[],
"WF[%d][%d]\n", _n_wave, barrier_id, cu_id, w->simdId, w->wfSlotId);
w->start(++_n_wave, ndr->q.code_ptr);
- wfCtx->bar_cnt.clear();
}
void
@@ -376,7 +339,6 @@ ComputeUnit::StartWorkgroup(NDRange *ndr)
trueWgSizeTotal *= trueWgSize[d];
}
- uint64_t origSpillMemStart = ndr->q.spillMemStart;
// calculate the number of 32-bit vector registers required by wavefront
int vregDemand = ndr->q.sRegCount + (2 * ndr->q.dRegCount);
int cnt = 0;
@@ -403,12 +365,7 @@ ComputeUnit::StartWorkgroup(NDRange *ndr)
w->reservedVectorRegs = normSize;
vectorRegsReserved[m % numSIMDs] += w->reservedVectorRegs;
- WFContext wfCtx;
-
- InitializeWFContext(&wfCtx, ndr, cnt, trueWgSize, trueWgSizeTotal,
- ldsChunk, origSpillMemStart);
-
- StartWF(w, &wfCtx, trueWgSize, trueWgSizeTotal);
+ StartWF(w, trueWgSize, trueWgSizeTotal, cnt, ldsChunk, ndr);
++cnt;
}
}