summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-04-13 09:49:22 -0700
committerGitHub <noreply@github.com>2023-04-13 09:49:22 -0700
commit813892cd023e216f6f6560eb47566522d3a82609 (patch)
tree07cbf8851e0c178cbc895be73e17e6340cc22685
parent352a460fc866998da5f45a8c117d891c51ab5a47 (diff)
Set sharedMem argument to 0 when launching cuda kernel. (#2799)
Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--prelude/slang-torch-prelude.h8
-rw-r--r--source/slang/slang-emit-torch.cpp4
-rw-r--r--tools/gfx/cuda/cuda-command-queue.cpp8
3 files changed, 2 insertions, 18 deletions
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h
index 03a00719d..2f5273e1f 100644
--- a/prelude/slang-torch-prelude.h
+++ b/prelude/slang-torch-prelude.h
@@ -114,12 +114,4 @@ TensorView make_tensor_view(torch::Tensor val, const char* name, torch::ScalarTy
return res;
}
-size_t slangGetCudaKernelSharedMemSize(const void* func)
-{
- cudaFuncAttributes attr = {};
- cudaFuncGetAttributes(&attr, func);
- AT_CUDA_CHECK(cudaGetLastError());
- return attr.sharedSizeBytes;
-}
-
#define SLANG_PRELUDE_EXPORT
diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp
index 819a6a136..bdb650607 100644
--- a/source/slang/slang-emit-torch.cpp
+++ b/source/slang/slang-emit-torch.cpp
@@ -94,9 +94,7 @@ void TorchCppSourceEmitter::emitInstStmtImpl(IRInst* inst)
m_writer->emit(", ");
// shared mem
- m_writer->emit("slangGetCudaKernelSharedMemSize((const void*)(");
- emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
- m_writer->emit(")), ");
+ m_writer->emit("0, ");
// stream
m_writer->emit("((cudaStream_t)");
diff --git a/tools/gfx/cuda/cuda-command-queue.cpp b/tools/gfx/cuda/cuda-command-queue.cpp
index 0c17a418e..4b0ab7d94 100644
--- a/tools/gfx/cuda/cuda-command-queue.cpp
+++ b/tools/gfx/cuda/cuda-command-queue.cpp
@@ -93,12 +93,6 @@ void CommandQueueImpl::dispatchCompute(int x, int y, int z)
UInt threadGroupSize[3];
programLayout->getKernelThreadGroupSize(kernelId, threadGroupSize);
- int sharedSizeInBytes;
- cuFuncGetAttribute(
- &sharedSizeInBytes,
- CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES,
- currentPipeline->shaderProgram->cudaKernel);
-
// Copy global parameter data to the `SLANG_globalParams` symbol.
{
CUdeviceptr globalParamsSymbol = 0;
@@ -144,7 +138,7 @@ void CommandQueueImpl::dispatchCompute(int x, int y, int z)
int(threadGroupSize[0]),
int(threadGroupSize[1]),
int(threadGroupSize[2]),
- sharedSizeInBytes,
+ 0,
stream,
nullptr,
extraOptions);