summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-cuda-immutable-load.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-10-15 20:59:47 -0700
committerGitHub <noreply@github.com>2025-10-16 03:59:47 +0000
commit01510f2c922af8629c7a730ef92a31fa83bd9f49 (patch)
treebbec0cd5424e99670573dc3fa10fdf441320b684 /source/slang/slang-ir-cuda-immutable-load.cpp
parentd1a935c683ac1eb93d95587ee26bdaae7eb17e31 (diff)
Immutable access qualifier for pointers and use `__ldg` on cuda. (#8710)
This PR implements `Access.Immutable` to allow pointers to immutable data. The new type `ImmutablePtr<T>` is defined as an alias of `Ptr<T, Address.Immutable>`. By forming a immutable pointer, the programmer is conveying to the compiler that the data at the pointer address will never change during the execution of the current program. Therefore loads from immutable pointers can be deduplicated by the compiler, and will translate to `__ldg` when generating code for CUDA. The SPIRV backend is not changed in this PR, since the current SPIRV spec makes it very difficult to specify loads from immutable address without generating tons of wrappers and boilerplate type declarations. We would like to see the spec evolved a bit to around its support of `NonWritable` physical storage pointers or immutable loads before we attempt to express such immutability in SPIRV. For now we simply emit ordinary pointers and loads when generating spirv. --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source/slang/slang-ir-cuda-immutable-load.cpp')
-rw-r--r--source/slang/slang-ir-cuda-immutable-load.cpp375
1 files changed, 375 insertions, 0 deletions
diff --git a/source/slang/slang-ir-cuda-immutable-load.cpp b/source/slang/slang-ir-cuda-immutable-load.cpp
new file mode 100644
index 000000000..713d5cb5a
--- /dev/null
+++ b/source/slang/slang-ir-cuda-immutable-load.cpp
@@ -0,0 +1,375 @@
+#include "slang-ir-cuda-immutable-load.h"
+
+#include "slang-ir-inst-pass-base.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-layout.h"
+#include "slang-ir-util.h"
+
+namespace Slang
+{
+
+enum LoadMethodKind
+{
+ Func,
+ Opcode
+};
+
+struct LoadMethod
+{
+ LoadMethodKind kind = LoadMethodKind::Func;
+ union
+ {
+ IRFunc* func;
+ IROp op;
+ };
+ LoadMethod() { func = nullptr; }
+ operator bool() { return kind == LoadMethodKind::Func ? func != nullptr : op != kIROp_Nop; }
+ LoadMethod(IRFunc* f)
+ : kind(LoadMethodKind::Func), func(f)
+ {
+ }
+ LoadMethod(IROp irop)
+ : kind(LoadMethodKind::Opcode), op(irop)
+ {
+ }
+ LoadMethod& operator=(IRFunc* f)
+ {
+ kind = LoadMethodKind::Func;
+ this->func = f;
+ return *this;
+ }
+ LoadMethod& operator=(IROp irop)
+ {
+ kind = LoadMethodKind::Opcode;
+ this->op = irop;
+ return *this;
+ }
+ IRInst* apply(IRBuilder& builder, IRType* resultType, IRInst* operandAddr)
+ {
+ if (kind == LoadMethodKind::Func)
+ {
+ return builder.emitCallInst(resultType, func, 1, &operandAddr);
+ }
+ else
+ {
+ return builder.emitIntrinsicInst(resultType, op, 1, &operandAddr);
+ }
+ }
+};
+
+struct ImmutableBufferLoadLoweringContext : InstPassBase
+{
+ Dictionary<IRType*, LoadMethod> loadFuncs;
+ TargetProgram* targetProgram;
+
+ IRFunc* createLoadFunc(IRBuilder& builder, IRType* valueType, IRParam*& outParam)
+ {
+ auto func = builder.createFunc();
+ builder.addNameHintDecoration(func, toSlice("slang_ldg"));
+ builder.setInsertInto(func);
+ auto block = builder.emitBlock();
+ auto ptrType = builder.getPtrType(valueType);
+ builder.setInsertInto(block);
+ outParam = builder.emitParam(ptrType);
+ builder.addNameHintDecoration(outParam, toSlice("ptr"));
+ func->setFullType(builder.getFuncType(ptrType, valueType));
+ return func;
+ }
+
+ LoadMethod createLoadFuncForType(IRType* type)
+ {
+ IRBuilder builder(type);
+ builder.setInsertAfter(type);
+ switch (type->getOp())
+ {
+ case kIROp_FloatType:
+ case kIROp_HalfType:
+ case kIROp_DoubleType:
+ case kIROp_Int8Type:
+ case kIROp_Int16Type:
+ case kIROp_IntType:
+ case kIROp_Int64Type:
+ case kIROp_IntPtrType:
+ case kIROp_UInt8Type:
+ case kIROp_UInt16Type:
+ case kIROp_UIntType:
+ case kIROp_UInt64Type:
+ case kIROp_UIntPtrType:
+ case kIROp_BoolType:
+ case kIROp_CharType:
+ return kIROp_CUDALDG;
+ case kIROp_VectorType:
+ {
+ // For vector types that has a direct mapping to CUDA __ldg,
+ // use the instruction directly.
+ auto vectorType = as<IRVectorType>(type);
+ auto elementType = vectorType->getElementType();
+ auto elementCount = getIntVal(vectorType->getElementCount());
+ IRSizeAndAlignment elementSize;
+ getNaturalSizeAndAlignment(
+ targetProgram->getOptionSet(),
+ elementType,
+ &elementSize);
+ if (elementCount <= 2)
+ return kIROp_CUDALDG;
+ else if (elementCount == 4)
+ {
+ switch (elementType->getOp())
+ {
+ case kIROp_FloatType:
+ case kIROp_UIntType:
+ case kIROp_IntType:
+ case kIROp_Int8Type:
+ case kIROp_UInt8Type:
+ case kIROp_Int16Type:
+ case kIROp_UInt16Type:
+ return kIROp_CUDALDG;
+ }
+ }
+ // For other vector types, we need to generate a function to load its content.
+ IRParam* ptrParam = nullptr;
+ auto func = createLoadFunc(builder, type, ptrParam);
+ List<IRInst*> args;
+ for (UInt i = 0; i < (UInt)elementCount; i++)
+ {
+ auto elementPtr = builder.emitElementAddress(
+ builder.getPtrType(elementType),
+ ptrParam,
+ builder.getIntValue(builder.getIntType(), i));
+ auto loadedElement =
+ builder.emitIntrinsicInst(elementType, kIROp_CUDALDG, 1, &elementPtr);
+ args.add(loadedElement);
+ }
+ auto result = builder.emitMakeVector(type, args);
+ builder.emitReturn(result);
+ return func;
+ }
+ break;
+ case kIROp_MatrixType:
+ {
+ // For matrix types, we should generate a function to load its content by row or
+ // column, depending on the layout.
+ auto matrixType = as<IRMatrixType>(type);
+ auto elementType = matrixType->getElementType();
+ auto rowCount = getIntVal(matrixType->getRowCount());
+ auto colCount = getIntVal(matrixType->getColumnCount());
+ auto layout = (MatrixLayoutMode)getIntVal(matrixType->getLayout());
+ IRParam* ptrParam = nullptr;
+ auto func = createLoadFunc(builder, type, ptrParam);
+ if (layout == kMatrixLayoutMode_ColumnMajor)
+ {
+ // For column major matrix, we can load it by column (vector) directly.
+ auto vectorType = builder.getVectorType(elementType, rowCount);
+ auto vectorPtrType = builder.getPtrType(vectorType);
+ auto elementBasePtr = builder.emitBitCast(vectorPtrType, ptrParam);
+ List<IRInst*> args;
+ for (UInt i = 0; i < (UInt)colCount; i++)
+ {
+ auto colPtr = builder.emitGetOffsetPtr(
+ elementBasePtr,
+ builder.getIntValue(builder.getIntType(), i));
+ auto loadedCol = emitImmutableLoad(builder, colPtr);
+ args.add(loadedCol);
+ }
+ // Rearrange loaded vectors in row-major order.
+ List<IRInst*> elements;
+ for (UInt i = 0; i < (UInt)rowCount; i++)
+ {
+ for (UInt j = 0; j < (UInt)colCount; j++)
+ {
+ elements.add(builder.emitElementExtract(
+ elementType,
+ args[j],
+ builder.getIntValue(builder.getIntType(), i)));
+ }
+ }
+ auto result = builder.emitMakeMatrix(
+ type,
+ (UInt)elements.getCount(),
+ elements.getArrayView().getBuffer());
+ builder.emitReturn(result);
+ return func;
+ }
+ else
+ {
+ // For row major matrix, we can load it by row (vector) directly.
+ auto vectorType = builder.getVectorType(elementType, colCount);
+ auto vectorPtrType = builder.getPtrType(vectorType);
+ auto elementBasePtr = builder.emitBitCast(vectorPtrType, ptrParam);
+ List<IRInst*> args;
+ for (UInt i = 0; i < (UInt)rowCount; i++)
+ {
+ auto rowPtr = builder.emitGetOffsetPtr(
+ elementBasePtr,
+ builder.getIntValue(builder.getIntType(), i));
+ auto loadedRow = emitImmutableLoad(builder, rowPtr);
+ args.add(loadedRow);
+ }
+ auto result =
+ builder.emitMakeMatrix(type, (UInt)args.getCount(), args.getBuffer());
+ builder.emitReturn(result);
+ return func;
+ }
+ }
+ break;
+ case kIROp_ArrayType:
+ {
+ // For array types, we need to generate a function to load its content by element.
+ auto arrayType = as<IRArrayType>(type);
+ auto elementType = arrayType->getElementType();
+ auto elementCount = getIntVal(arrayType->getElementCount());
+ IRParam* ptrParam = nullptr;
+ auto func = createLoadFunc(builder, type, ptrParam);
+ List<IRInst*> args;
+ for (UInt i = 0; i < (UInt)elementCount; i++)
+ {
+ auto elementPtr = builder.emitElementAddress(
+ builder.getPtrType(elementType),
+ ptrParam,
+ builder.getIntValue(builder.getIntType(), i));
+ auto loadedElement = emitImmutableLoad(builder, elementPtr);
+ if (!loadedElement)
+ {
+ func->removeAndDeallocate();
+ return LoadMethod();
+ }
+ args.add(loadedElement);
+ }
+ auto result = builder.emitMakeArray(type, (UInt)args.getCount(), args.getBuffer());
+ builder.emitReturn(result);
+ return func;
+ }
+ case kIROp_StructType:
+ {
+ // For struct types, we need to generate a function to load its content by field.
+ auto structType = as<IRStructType>(type);
+ IRParam* ptrParam = nullptr;
+ auto func = createLoadFunc(builder, type, ptrParam);
+ List<IRInst*> args;
+ for (auto field : structType->getFields())
+ {
+ auto fieldType = field->getFieldType();
+ auto fieldPtr = builder.emitFieldAddress(
+ builder.getPtrType(fieldType),
+ ptrParam,
+ field->getKey());
+ auto loadedField = emitImmutableLoad(builder, fieldPtr);
+ if (!loadedField)
+ {
+ func->removeAndDeallocate();
+ return LoadMethod();
+ }
+ args.add(loadedField);
+ }
+ auto result = builder.emitMakeStruct(type, args);
+ builder.emitReturn(result);
+ return func;
+ }
+ }
+ return LoadMethod();
+ }
+
+ LoadMethod getOrCreateLoadFuncForType(IRType* type)
+ {
+ if (auto func = loadFuncs.tryGetValue(type))
+ return *func;
+ auto result = createLoadFuncForType(type);
+ loadFuncs[type] = result;
+ return result;
+ }
+
+ IRInst* emitImmutableLoad(IRBuilder& builder, IRInst* ptr)
+ {
+ IRType* valueType = tryGetPointedToType(&builder, ptr->getDataType());
+ if (!valueType)
+ return nullptr;
+ auto loadFunc = getOrCreateLoadFuncForType(valueType);
+ if (!loadFunc)
+ return nullptr;
+ return loadFunc.apply(builder, valueType, ptr);
+ }
+
+ void processInst(IRInst* inst)
+ {
+ // For every load instruction we see in the module, if the it is loading from
+ // an immutable location, try to lower it into a series of __ldg calls.
+ // We need to handle both ordinary loads and structured buffer loads.
+ //
+ switch (inst->getOp())
+ {
+ case kIROp_Load:
+ {
+ auto load = as<IRLoad>(inst);
+ if (isPointerToImmutableLocation(getRootAddr(load->getPtr())))
+ {
+ IRBuilder builder(load);
+ builder.setInsertBefore(load);
+ if (auto newLoad = emitImmutableLoad(builder, load->getPtr()))
+ {
+ load->replaceUsesWith(newLoad);
+ load->removeAndDeallocate();
+ }
+ }
+ }
+ break;
+ case kIROp_StructuredBufferLoad:
+ {
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto ptr = builder.emitRWStructuredBufferGetElementPtr(
+ inst->getOperand(0),
+ inst->getOperand(1));
+ if (auto newLoad = emitImmutableLoad(builder, ptr))
+ {
+ inst->replaceUsesWith(newLoad);
+ inst->removeAndDeallocate();
+ }
+ else
+ {
+ // For some reason this load cannot be lowered, remove the ptr we just created.
+ ptr->removeAndDeallocate();
+ }
+ }
+ break;
+ case kIROp_CUDALDG:
+ {
+ // Does the load needs lowering? If so insert lowered loads.
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto ptr = inst->getOperand(0);
+ auto valueType = tryGetPointedToType(&builder, ptr->getDataType());
+ if (!valueType)
+ break;
+ auto loadFunc = getOrCreateLoadFuncForType(valueType);
+ if (!loadFunc)
+ break;
+ // If the type doesn't need further lowering, we don't need to do anything.
+ if (loadFunc.kind == LoadMethodKind::Opcode && loadFunc.op == kIROp_CUDALDG)
+ break;
+ auto newLoad = loadFunc.apply(builder, valueType, ptr);
+ inst->replaceUsesWith(newLoad);
+ inst->removeAndDeallocate();
+ }
+ break;
+ }
+ }
+
+ void processModule()
+ {
+ processAllInsts([&](IRInst* inst) { processInst(inst); });
+ }
+
+ ImmutableBufferLoadLoweringContext(IRModule* inModule)
+ : InstPassBase(inModule)
+ {
+ }
+};
+
+void lowerImmutableBufferLoadForCUDA(TargetProgram* targetProgram, IRModule* module)
+{
+ ImmutableBufferLoadLoweringContext context(module);
+ context.targetProgram = targetProgram;
+ context.processModule();
+}
+
+} // namespace Slang