summaryrefslogtreecommitdiffstats
path: root/tools/gfx/cuda/render-cuda.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2021-06-10 00:30:19 -0700
committerGitHub <noreply@github.com>2021-06-10 00:30:19 -0700
commit0d9bd79e8fd4d57e1a723ca6b6a45efec2b42872 (patch)
treed9e23abd1b51044b12b556cd063916f0b44362c0 /tools/gfx/cuda/render-cuda.cpp
parent86b0d74e58259c1a1c964acf18923303d9e93148 (diff)
Support timestamp queries in `gfx`. (#1880)
* Support timestamp queries in `gfx`. * Fix tab Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tools/gfx/cuda/render-cuda.cpp')
-rw-r--r--tools/gfx/cuda/render-cuda.cpp83
1 files changed, 83 insertions, 0 deletions
diff --git a/tools/gfx/cuda/render-cuda.cpp b/tools/gfx/cuda/render-cuda.cpp
index ed7f44ed2..3e93c090a 100644
--- a/tools/gfx/cuda/render-cuda.cpp
+++ b/tools/gfx/cuda/render-cuda.cpp
@@ -707,6 +707,58 @@ public:
}
};
+class CUDAQueryPool : public IQueryPool, public ComObject
+{
+public:
+ SLANG_COM_OBJECT_IUNKNOWN_ALL;
+ IQueryPool* getInterface(const Guid& guid)
+ {
+ if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IQueryPool)
+ return static_cast<IQueryPool*>(this);
+ return nullptr;
+ }
+public:
+ // The event object for each query. Owned by the pool.
+ List<CUevent> m_events;
+
+ // The event that marks the starting point.
+ CUevent m_startEvent;
+
+ Result init(const IQueryPool::Desc& desc)
+ {
+ cuEventCreate(&m_startEvent, 0);
+ cuEventRecord(m_startEvent, 0);
+ m_events.setCount(desc.count);
+ for (SlangInt i = 0; i < m_events.getCount(); i++)
+ {
+ cuEventCreate(&m_events[i], 0);
+ }
+ return SLANG_OK;
+ }
+
+ ~CUDAQueryPool()
+ {
+ for (auto& e : m_events)
+ {
+ cuEventDestroy(e);
+ }
+ cuEventDestroy(m_startEvent);
+ }
+
+ virtual SLANG_NO_THROW Result SLANG_MCALL getResult(
+ SlangInt queryIndex, SlangInt count, uint64_t* data) override
+ {
+ for (SlangInt i = 0; i < count; i++)
+ {
+ float time = 0.0f;
+ cuEventSynchronize(m_events[i + queryIndex]);
+ cuEventElapsedTime(&time, m_startEvent, m_events[i + queryIndex]);
+ data[i] = (uint64_t)((double)time * 1000.0f);
+ }
+ return SLANG_OK;
+ }
+};
+
class CUDADevice : public RendererBase
{
private:
@@ -906,6 +958,11 @@ public:
m_writer->bindRootShaderObject(m_rootObject);
m_writer->dispatchCompute(x, y, z);
}
+
+ virtual SLANG_NO_THROW void SLANG_MCALL writeTimestamp(IQueryPool* pool, SlangInt index) override
+ {
+ m_writer->writeTimestamp(pool, index);
+ }
};
ComputeCommandEncoderImpl m_computeCommandEncoder;
@@ -959,6 +1016,11 @@ public:
{
m_writer->uploadBufferData(dst, offset, size, data);
}
+
+ virtual SLANG_NO_THROW void SLANG_MCALL writeTimestamp(IQueryPool* pool, SlangInt index) override
+ {
+ m_writer->writeTimestamp(pool, index);
+ }
};
ResourceCommandEncoderImpl m_resourceCommandEncoder;
@@ -1139,6 +1201,12 @@ public:
cudaMemcpy((uint8_t*)dstImpl->m_cudaMemory + offset, data, size, cudaMemcpyDefault);
}
+ void writeTimestamp(IQueryPool* pool, SlangInt index)
+ {
+ auto poolImpl = static_cast<CUDAQueryPool*>(pool);
+ cuEventRecord(poolImpl->m_events[index], stream);
+ }
+
void execute(CommandBufferImpl* commandBuffer)
{
for (auto& cmd : commandBuffer->m_commands)
@@ -1171,6 +1239,10 @@ public:
cmd.operands[2],
commandBuffer->getData<uint8_t>(cmd.operands[3]));
break;
+ case CommandName::WriteTimestamp:
+ writeTimestamp(
+ commandBuffer->getObject<IQueryPool>(cmd.operands[0]),
+ (SlangInt)cmd.operands[1]);
}
}
}
@@ -1218,6 +1290,7 @@ public:
cudaGetDeviceProperties(&deviceProperties, m_deviceIndex);
m_adapterName = deviceProperties.name;
m_info.adapterName = m_adapterName.begin();
+ m_info.timestampFrequency = 1000000;
}
return SLANG_OK;
@@ -1691,6 +1764,16 @@ public:
return SLANG_OK;
}
+ virtual SLANG_NO_THROW Result SLANG_MCALL createQueryPool(
+ const IQueryPool::Desc& desc,
+ IQueryPool** outPool) override
+ {
+ RefPtr<CUDAQueryPool> pool = new CUDAQueryPool();
+ SLANG_RETURN_ON_FAIL(pool->init(desc));
+ returnComPtr(outPool, pool);
+ return SLANG_OK;
+ }
+
virtual Result createShaderObjectLayout(
slang::TypeLayoutReflection* typeLayout,
ShaderObjectLayoutBase** outLayout) override