diff options
| -rw-r--r-- | build/visual-studio/slang/slang.vcxproj | 2 | ||||
| -rw-r--r-- | build/visual-studio/slang/slang.vcxproj.filters | 6 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 36 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-cuda-builtin-types.cpp | 461 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-cuda-builtin-types.h | 51 | ||||
| -rw-r--r-- | source/slang/slang-ir-pytorch-cpp-binding.cpp | 278 | ||||
| -rw-r--r-- | source/slang/slang-ir-pytorch-cpp-binding.h | 2 | ||||
| -rw-r--r-- | tests/autodiff/autobind-plain-matrix-input.slang | 21 | ||||
| -rw-r--r-- | tests/autodiff/autobind-plain-vector-input.slang | 21 | ||||
| -rw-r--r-- | tests/autodiff/autobind-struct-with-array-of-builtins.slang | 22 | ||||
| -rw-r--r-- | tests/autodiff/autobind-struct-with-builtin-types.slang | 32 | ||||
| -rw-r--r-- | tests/diagnostics/cuda-kernel-differentiable-params.slang | 18 | ||||
| -rw-r--r-- | tests/diagnostics/cuda-kernel-non-void-return.slang | 17 |
17 files changed, 934 insertions, 47 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj index 6ab460faf..f5e1c4c5d 100644 --- a/build/visual-studio/slang/slang.vcxproj +++ b/build/visual-studio/slang/slang.vcxproj @@ -426,6 +426,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClInclude Include="..\..\..\source\slang\slang-ir-lower-buffer-element-type.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-com-methods.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-combined-texture-sampler.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-lower-cuda-builtin-types.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-error-handling.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-existential.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-generic-call.h" />
@@ -658,6 +659,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClCompile Include="..\..\..\source\slang\slang-ir-lower-buffer-element-type.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-com-methods.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-combined-texture-sampler.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-lower-cuda-builtin-types.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-error-handling.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-existential.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-generic-call.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters index 0112bb818..8e33fcfbb 100644 --- a/build/visual-studio/slang/slang.vcxproj.filters +++ b/build/visual-studio/slang/slang.vcxproj.filters @@ -366,6 +366,9 @@ <ClInclude Include="..\..\..\source\slang\slang-ir-lower-combined-texture-sampler.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-lower-cuda-builtin-types.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-error-handling.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -1058,6 +1061,9 @@ <ClCompile Include="..\..\..\source\slang\slang-ir-lower-combined-texture-sampler.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-lower-cuda-builtin-types.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-error-handling.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 0f9da12c4..921bd38e9 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -7700,6 +7700,16 @@ namespace Slang } } } + + // If this method is intended to be a CUDA kernel, verify that the return type is void. + if (decl->findModifier<CudaKernelAttribute>()) + { + if (decl->returnType.type && !decl->returnType.type->equals(m_astBuilder->getVoidType())) + { + getSink()->diagnose(decl, Diagnostics::cudaKernelMustReturnVoid); + } + } + checkVisibility(decl); } @@ -9547,6 +9557,30 @@ namespace Slang checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions); } + static void checkCudaKernelAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, CudaKernelAttribute*) + { + // If the method is also marked differentiable, check that the data types are either non-differentiable + // or marked with no_diff. + // + // Note: This is a temporary restriction until we have a more complete story for differentiability. + // + if (funcDecl->findModifier<DifferentiableAttribute>()) + { + for (auto paramDecl : funcDecl->getParameters()) + { + auto paramType = paramDecl->type; + + if (visitor->isTypeDifferentiable(paramType)) + { + if (!paramDecl->hasModifier<NoDiffModifier>()) + { + visitor->getSink()->diagnose(paramDecl, Diagnostics::differentiableKernelEntryPointCannotHaveDifferentiableParams); + } + } + } + } + } + template<typename TDerivativeAttr, typename TDerivativeOfAttr> bool tryCheckDerivativeOfAttributeImpl( SemanticsVisitor* visitor, @@ -9747,6 +9781,8 @@ namespace Slang checkDerivativeAttribute(this, decl, bwdDerivativeAttr); else if (auto primalAttr = as<PrimalSubstituteAttribute>(attr)) checkDerivativeAttribute(this, decl, primalAttr); + else if (auto cudaKernelAttr = as<CudaKernelAttribute>(attr)) + checkCudaKernelAttribute(this, decl, cudaKernelAttr); } } diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index b294383bb..eb131df21 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -454,6 +454,8 @@ DIAGNOSTIC(31208, Error, requireInputDecoratedVarForParameter, "$0 expects for a DIAGNOSTIC(31210, Error, derivativeGroupQuadMustBeMultiple2ForXYThreads, "compute derivative group quad requires thread dispatch count of X and Y to each be at a multiple of 2") DIAGNOSTIC(31211, Error, derivativeGroupLinearMustBeMultiple4ForTotalThreadCount, "compute derivative group linear requires total thread dispatch count to be at a multiple of 4") DIAGNOSTIC(31212, Error, onlyOneOfDerivativeGroupLinearOrQuadCanBeSet, "cannot set compute derivative group linear and compute derivative group quad at the same time") +DIAGNOSTIC(31213, Error, cudaKernelMustReturnVoid, "return type of a CUDA kernel function cannot be non-void.") +DIAGNOSTIC(31214, Error, differentiableKernelEntryPointCannotHaveDifferentiableParams, "differentiable kernel entry point cannot have differentiable parameters. Consider using DiffTensorView to pass differentiable data, or marking this parameter with 'no_diff'") // Enums diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index fa380e061..626c372e9 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -3518,14 +3518,13 @@ bool CLikeSourceEmitter::isTargetIntrinsic(IRInst* inst) return findTargetIntrinsicDefinition(inst, intrinsicDef); } -bool shouldWrappInExternCBlock(IRFunc* func) +bool shouldWrapInExternCBlock(IRFunc* func) { for (auto decor : func->getDecorations()) { switch (decor->getOp()) { case kIROp_ExternCDecoration: - case kIROp_CudaKernelDecoration: return true; } } @@ -3540,7 +3539,7 @@ void CLikeSourceEmitter::emitFunc(IRFunc* func) if (isTargetIntrinsic(func)) return; - bool shouldCloseExternCBlock = shouldWrappInExternCBlock(func); + bool shouldCloseExternCBlock = shouldWrapInExternCBlock(func); if (shouldCloseExternCBlock) { // If this is a C++ `extern "C"` function, then we need to emit diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 1fa04b4be..afdd37fce 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -471,10 +471,13 @@ Result linkAndOptimizeIR( switch (target) { case CodeGenTarget::PyTorchCppBinding: + generateHostFunctionsForAutoBindCuda(irModule, sink); + lowerBuiltinTypesForKernelEntryPoints(irModule, sink); generatePyTorchCppBinding(irModule, sink); handleAutoBindNames(irModule); break; case CodeGenTarget::CUDASource: + lowerBuiltinTypesForKernelEntryPoints(irModule, sink); removeTorchKernels(irModule); handleAutoBindNames(irModule); break; diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index eae025c96..0f41f9bd9 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1189,6 +1189,10 @@ struct IRDispatchKernel : IRInst IRInst* getDispatchSize() { return getOperand(2); } UInt getArgCount() { return getOperandCount() - 3; } IRInst* getArg(UInt i) { return getOperand(3 + i); } + IROperandList<IRInst> getArgsList() + { + return IROperandList<IRInst>(getOperands() + 3, getOperands() + getOperandCount()); + } IR_LEAF_ISA(DispatchKernel) }; diff --git a/source/slang/slang-ir-lower-cuda-builtin-types.cpp b/source/slang/slang-ir-lower-cuda-builtin-types.cpp new file mode 100644 index 000000000..675e43043 --- /dev/null +++ b/source/slang/slang-ir-lower-cuda-builtin-types.cpp @@ -0,0 +1,461 @@ +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-ir-clone.h" +#include "slang-ir-layout.h" +#include "slang-ir-lower-cuda-builtin-types.h" +namespace Slang +{ + + IRFunc* createMatrixUnpackFunc( + IRMatrixType* matrixType, + IRStructType* structType, + IRStructKey* dataKey, + IRArrayType* arrayType) + { + IRBuilder builder(structType); + builder.setInsertAfter(structType); + auto func = builder.createFunc(); + auto funcType = builder.getFuncType(1, (IRType**)&structType, matrixType); + func->setFullType(funcType); + builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage")); + builder.setInsertInto(func); + builder.emitBlock(); + auto rowCount = (Index)getIntVal(matrixType->getRowCount()); + auto colCount = (Index)getIntVal(matrixType->getColumnCount()); + auto packedParam = builder.emitParam(structType); + auto matrixArray = builder.emitFieldExtract(arrayType, packedParam, dataKey); + List<IRInst*> args; + args.setCount(rowCount * colCount); + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + for (IRIntegerValue c = 0; c < colCount; c++) + for (IRIntegerValue r = 0; r < rowCount; r++) + args[(Index)(r * colCount + c)] = builder.emitElementExtract(matrixArray, (Index)(r*colCount + c)); + } + else + { + for (IRIntegerValue c = 0; c < colCount; c++) + for (IRIntegerValue r = 0; r < rowCount; r++) + args[(Index)(c * rowCount + r)] = builder.emitElementExtract(matrixArray, (Index)(r*colCount + c)); + } + IRInst* result = builder.emitMakeMatrix(matrixType, (UInt)args.getCount(), args.getBuffer()); + builder.emitReturn(result); + return func; + } + + IRFunc* createMatrixPackFunc( + IRMatrixType* matrixType, + IRStructType* structType, + IRArrayType* arrayType) + { + IRBuilder builder(structType); + builder.setInsertAfter(structType); + auto func = builder.createFunc(); + auto funcType = builder.getFuncType(1, (IRType**)&matrixType, structType); + func->setFullType(funcType); + builder.addNameHintDecoration(func, UnownedStringSlice("packMatrix")); + builder.setInsertInto(func); + builder.emitBlock(); + auto rowCount = getIntVal(matrixType->getRowCount()); + auto colCount = getIntVal(matrixType->getColumnCount()); + auto originalParam = builder.emitParam(matrixType); + List<IRInst*> elements; + elements.setCount((Index)(rowCount * colCount)); + for (IRIntegerValue r = 0; r < rowCount; r++) + { + auto vector = builder.emitElementExtract(originalParam, r); + for (IRIntegerValue c = 0; c < colCount; c++) + { + auto element = builder.emitElementExtract(vector, c); + elements[(Index)(r * colCount + c)] = element; + } + } + + auto matrixArray = builder.emitMakeArray(arrayType, (UInt)elements.getCount(), elements.getBuffer()); + auto result = builder.emitMakeStruct(structType, 1, &matrixArray); + builder.emitReturn(result); + return func; + } + + IRFunc* createVectorUnpackFunc( + IRVectorType* vectorType, + IRStructType* structType, + IRStructKey* dataKey, + IRArrayType* arrayType) + { + IRBuilder builder(structType); + builder.setInsertAfter(structType); + auto func = builder.createFunc(); + auto funcType = builder.getFuncType(1, (IRType**)&structType, vectorType); + func->setFullType(funcType); + builder.addNameHintDecoration(func, UnownedStringSlice("unpackVector")); + builder.setInsertInto(func); + builder.emitBlock(); + auto packedParam = builder.emitParam(structType); + auto packedArray = builder.emitFieldExtract(arrayType, packedParam, dataKey); + auto count = getIntVal(vectorType->getElementCount()); + List<IRInst*> args; + args.setCount((Index)count); + for (IRIntegerValue ii = 0; ii < count; ++ii) + { + args[(Index)ii] = builder.emitElementExtract(packedArray, ii); + } + auto result = builder.emitMakeVector(vectorType, (UInt)args.getCount(), args.getBuffer()); + builder.emitReturn(result); + return func; + } + + IRFunc* createVectorPackFunc( + IRVectorType* vectorType, + IRStructType* structType, + IRArrayType* arrayType) + { + IRBuilder builder(structType); + builder.setInsertAfter(structType); + auto func = builder.createFunc(); + auto funcType = builder.getFuncType(1, (IRType**)&vectorType, structType); + func->setFullType(funcType); + builder.addNameHintDecoration(func, UnownedStringSlice("packVector")); + builder.setInsertInto(func); + builder.emitBlock(); + auto originalParam = builder.emitParam(vectorType); + auto count = getIntVal(vectorType->getElementCount()); + List<IRInst*> args; + args.setCount((Index)count); + for (IRIntegerValue ii = 0; ii < count; ++ii) + { + args[(Index)ii] = builder.emitElementExtract(originalParam, ii); + } + auto packedArray = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer()); + auto result = builder.emitMakeStruct(structType, 1, &packedArray); + builder.emitReturn(result); + return func; + } + + LoweredBuiltinTypeInfo lowerMatrixType( + IRBuilder* builder, + IRMatrixType* matrixType, + String nameSuffix) + { + LoweredBuiltinTypeInfo info; + + auto loweredType = builder->createStructType(); + StringBuilder nameSB; + bool isColMajor = getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR; + nameSB << "_MatrixStorage_"; + getTypeNameHint(nameSB, matrixType->getElementType()); + nameSB << getIntVal(matrixType->getRowCount()) << "x" << getIntVal(matrixType->getColumnCount()); + if (isColMajor) + nameSB << "_ColMajor"; + nameSB << nameSuffix; + builder->addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); + auto structKey = builder->createStructKey(); + builder->addNameHintDecoration(structKey, UnownedStringSlice("data")); + //auto vectorType = builder->getVectorType(matrixType->getElementType(), + // isColMajor?matrixType->getRowCount():matrixType->getColumnCount()); + + auto arrayType = + builder->getArrayType( + matrixType->getElementType(), + builder->getIntValue( + builder->getUIntType(), + getIntVal(matrixType->getColumnCount()) * getIntVal(matrixType->getRowCount()))); + + builder->createStructField(loweredType, structKey, arrayType); + + info.originalType = matrixType; + info.loweredType = loweredType; + info.loweredInnerArrayType = arrayType; + info.loweredInnerStructKey = structKey; + info.convertLoweredToOriginal = createMatrixUnpackFunc(matrixType, loweredType, structKey, arrayType); + info.convertOriginalToLowered = createMatrixPackFunc(matrixType, loweredType, arrayType); + return info; + } + + LoweredBuiltinTypeInfo lowerVectorType( + IRBuilder* builder, + IRVectorType* vectorType, + String nameSuffix) + { + LoweredBuiltinTypeInfo info; + + auto loweredType = builder->createStructType(); + + StringBuilder nameSB; + nameSB << "_VectorStorage_"; + getTypeNameHint(nameSB, vectorType->getElementType()); + nameSB << getIntVal(vectorType->getElementCount()) << "_"; + nameSB << nameSuffix; + builder->addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); + + + info.originalType = vectorType; + info.loweredType = loweredType; + + auto structKey = builder->createStructKey(); + builder->addNameHintDecoration(structKey, UnownedStringSlice("data")); + + auto arrayType = builder->getArrayType( + vectorType->getElementType(), + vectorType->getElementCount()); + + builder->createStructField(loweredType, structKey, arrayType); + + info.convertLoweredToOriginal = createVectorUnpackFunc(vectorType, loweredType, structKey, arrayType); + info.convertOriginalToLowered = createVectorPackFunc(vectorType, loweredType, arrayType); + + return info; + } + + LoweredBuiltinTypeInfo lowerStructType( + BuiltinTypeLoweringEnv* env, + IRBuilder* builder, + IRStructType* structType, + String nameSuffix) + { + // Recursively lower the fields of the struct type + List<IRType*> fieldTypes; + List<IRStructField*> fields; + for (auto field : structType->getFields()) + { + fieldTypes.add(field->getFieldType()); + fields.add(field); + } + + auto loweredType = builder->createStructType(); + StringBuilder nameSB; + nameSB << "_StructStorage_"; + + // Find a name hint for the struct type + for (auto decoration : structType->getDecorations()) + if (auto nameHint = as<IRNameHintDecoration>(decoration)) + nameSB << nameHint->getName(); + + nameSB << nameSuffix; + builder->addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); + + bool changesRequired = false; + + // Lower field types. + List<LoweredBuiltinTypeInfo> loweredFieldInfos; + for (auto field : fields) + { + // Lower the field type + auto loweredFieldInfo = lowerType(env, builder, field->getFieldType(), nameSuffix); + loweredFieldInfos.add(loweredFieldInfo); + + // Add the lowered field type to the lowered struct type + builder->createStructField(loweredType, field->getKey(), loweredFieldInfo.loweredType); + + if (loweredFieldInfo.convertLoweredToOriginal != nullptr) + changesRequired = true; + } + + if (!changesRequired) + { + // If no changes are required, then we can just return the original struct type + LoweredBuiltinTypeInfo info; + info.originalType = structType; + info.loweredType = structType; + info.convertLoweredToOriginal = nullptr; + info.convertOriginalToLowered = nullptr; + return info; + } + + LoweredBuiltinTypeInfo info; + info.originalType = structType; + info.loweredType = loweredType; + + // Create the conversion function from the lowered struct type to the original struct type + { + builder->setInsertAfter(loweredType); + auto func = builder->createFunc(); + auto funcType = builder->getFuncType(1, (IRType**)&loweredType, structType); + func->setFullType(funcType); + builder->addNameHintDecoration(func, UnownedStringSlice("convertLoweredToOriginal")); + builder->setInsertInto(func); + builder->emitBlock(); + auto loweredParam = builder->emitParam(loweredType); + List<IRInst*> args; + args.setCount((Index)fields.getCount()); + for (Index i = 0; i < fields.getCount(); i++) + { + auto loweredField = builder->emitFieldExtract(loweredFieldInfos[i].loweredType, loweredParam, fields[i]->getKey()); + List<IRInst*> callArgs; + callArgs.add(loweredField); + + if (loweredFieldInfos[i].convertLoweredToOriginal == nullptr) + { + args[i] = loweredField; + continue; + } + + args[i] = builder->emitCallInst( + loweredFieldInfos[i].originalType, + loweredFieldInfos[i].convertLoweredToOriginal, + callArgs); + } + + auto result = builder->emitMakeStruct(structType, (UInt)args.getCount(), args.getBuffer()); + builder->emitReturn(result); + info.convertLoweredToOriginal = func; + } + + // Create the conversion function from the original struct type to the lowered struct type + { + builder->setInsertAfter(structType); + auto func = builder->createFunc(); + auto funcType = builder->getFuncType(1, (IRType**)&structType, loweredType); + func->setFullType(funcType); + builder->addNameHintDecoration(func, UnownedStringSlice("convertOriginalToLowered")); + builder->setInsertInto(func); + builder->emitBlock(); + auto originalParam = builder->emitParam(structType); + List<IRInst*> args; + args.setCount((Index)fields.getCount()); + for (Index i = 0; i < fields.getCount(); i++) + { + auto originalField = builder->emitFieldExtract(loweredFieldInfos[i].originalType, originalParam, fields[i]->getKey()); + List<IRInst*> callArgs; + callArgs.add(originalField); + + if (loweredFieldInfos[i].convertOriginalToLowered == nullptr) + { + args[i] = originalField; + continue; + } + + args[i] = builder->emitCallInst( + loweredFieldInfos[i].loweredType, + loweredFieldInfos[i].convertOriginalToLowered, + callArgs); + } + + auto result = builder->emitMakeStruct(loweredType, (UInt)args.getCount(), args.getBuffer()); + builder->emitReturn(result); + info.convertOriginalToLowered = func; + } + + return info; + } + + LoweredBuiltinTypeInfo lowerArrayType( + BuiltinTypeLoweringEnv* env, + IRBuilder* builder, + IRArrayType* arrayType, + String nameSuffix) + { + auto loweredElementTypeInfo = lowerType(env, builder, arrayType->getElementType(), nameSuffix); + auto loweredType = builder->getArrayType(loweredElementTypeInfo.loweredType, arrayType->getElementCount()); + + LoweredBuiltinTypeInfo info; + info.originalType = arrayType; + info.loweredType = loweredType; + + // If the element type was lowered, then we need to create conversion functions + if (loweredElementTypeInfo.convertLoweredToOriginal != nullptr) + { + builder->setInsertAfter(loweredType); + auto func = builder->createFunc(); + auto funcType = builder->getFuncType(1, (IRType**)&loweredType, arrayType); + func->setFullType(funcType); + builder->addNameHintDecoration(func, UnownedStringSlice("convertLoweredToOriginal")); + builder->setInsertInto(func); + builder->emitBlock(); + auto loweredParam = builder->emitParam(loweredType); + + auto count = getIntVal(arrayType->getElementCount()); + List<IRInst*> args; + args.setCount((Index)count); + for (IRIntegerValue ii = 0; ii < count; ++ii) + { + auto loweredElement = builder->emitElementExtract(loweredParam, ii); + List<IRInst*> callArgs; + callArgs.add(loweredElement); + args[(Index)ii] = builder->emitCallInst( + arrayType->getElementType(), + loweredElementTypeInfo.convertLoweredToOriginal, + callArgs); + } + + auto result = builder->emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer()); + builder->emitReturn(result); + info.convertLoweredToOriginal = func; + } + + if (loweredElementTypeInfo.convertOriginalToLowered != nullptr) + { + builder->setInsertAfter(arrayType); + auto func = builder->createFunc(); + auto funcType = builder->getFuncType(1, (IRType**)&arrayType, loweredType); + func->setFullType(funcType); + builder->addNameHintDecoration(func, UnownedStringSlice("convertOriginalToLowered")); + builder->setInsertInto(func); + builder->emitBlock(); + auto originalParam = builder->emitParam(arrayType); + auto count = getIntVal(arrayType->getElementCount()); + List<IRInst*> args; + args.setCount((Index)count); + for (IRIntegerValue ii = 0; ii < count; ++ii) + { + auto originalElement = builder->emitElementExtract(originalParam, ii); + List<IRInst*> callArgs; + callArgs.add(originalElement); + args[(Index)ii] = builder->emitCallInst( + loweredElementTypeInfo.loweredType, + loweredElementTypeInfo.convertOriginalToLowered, + callArgs); + } + + auto result = builder->emitMakeArray(loweredType, (UInt)args.getCount(), args.getBuffer()); + builder->emitReturn(result); + info.convertOriginalToLowered = func; + } + + return info; + } + + LoweredBuiltinTypeInfo lowerType( + BuiltinTypeLoweringEnv* env, + IRBuilder* builder, + IRType* type, + String nameSuffix) + { + if (env->loweredTypes.containsKey(type)) + return env->loweredTypes[type]; + + if (auto matrixType = as<IRMatrixType>(type)) + { + auto loweredInfo = lowerMatrixType(builder, matrixType, nameSuffix); + env->loweredTypes[type] = loweredInfo; + return loweredInfo; + } + else if (auto vectorType = as<IRVectorType>(type)) + { + auto loweredInfo = lowerVectorType(builder, vectorType, nameSuffix); + env->loweredTypes[type] = loweredInfo; + return loweredInfo; + } + else if (auto structType = as<IRStructType>(type)) + { + auto loweredInfo = lowerStructType(env, builder, structType, nameSuffix); + env->loweredTypes[type] = loweredInfo; + return loweredInfo; + } + else if (auto arrayType = as<IRArrayType>(type)) + { + auto loweredInfo = lowerArrayType(env, builder, arrayType, nameSuffix); + env->loweredTypes[type] = loweredInfo; + return loweredInfo; + } + + LoweredBuiltinTypeInfo info; + info.originalType = type; + info.loweredType = type; + info.convertLoweredToOriginal = nullptr; + info.convertOriginalToLowered = nullptr; + + return info; + } +};
\ No newline at end of file diff --git a/source/slang/slang-ir-lower-cuda-builtin-types.h b/source/slang/slang-ir-lower-cuda-builtin-types.h new file mode 100644 index 000000000..900e90d10 --- /dev/null +++ b/source/slang/slang-ir-lower-cuda-builtin-types.h @@ -0,0 +1,51 @@ +#ifndef SLANG_IR_LOWER_BUILTIN_TYPES_H +#define SLANG_IR_LOWER_BUILTIN_TYPES_H + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-ir-clone.h" +#include "slang-ir-layout.h" + +namespace Slang +{ + struct LoweredBuiltinTypeInfo + { + IRType* originalType; + IRType* loweredType; + IRType* loweredInnerArrayType = nullptr; // For matrix/array types that are lowered into a struct type, this is the inner array type of the data field. + IRStructKey* loweredInnerStructKey = nullptr; // For matrix/array types that are lowered into a struct type, this is the struct key of the data field. + IRFunc* convertOriginalToLowered = nullptr; + IRFunc* convertLoweredToOriginal = nullptr; + }; + + struct BuiltinTypeLoweringEnv + { + Dictionary<IRType*, LoweredBuiltinTypeInfo> loweredTypes; + }; + + LoweredBuiltinTypeInfo lowerMatrixType( + IRBuilder* builder, + IRMatrixType* matrixType, + String nameSuffix = ""); + + LoweredBuiltinTypeInfo lowerVectorType( + IRBuilder* builder, + IRVectorType* vectorType, + String nameSuffix = ""); + + LoweredBuiltinTypeInfo lowerStructType( + BuiltinTypeLoweringEnv* env, + IRBuilder* builder, + IRStructType* structType, + String nameSuffix = ""); + + LoweredBuiltinTypeInfo lowerType( + BuiltinTypeLoweringEnv* env, + IRBuilder* builder, + IRType* type, + String nameSuffix = ""); + +} // namespace Slang + +#endif // SLANG_IR_LOWER_BUILTIN_TYPES_H
\ No newline at end of file diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp index 6a85f0324..fd885dae7 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.cpp +++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp @@ -3,6 +3,7 @@ #include "slang-ir-insts.h" #include "slang-diagnostics.h" #include "slang-ir-autodiff.h" +#include "slang-ir-lower-cuda-builtin-types.h" namespace Slang { @@ -13,10 +14,31 @@ static IRType* translateToTupleType( { if (as<IRVoidType>(type)) return type; - if (as<IRBasicType>(type)) + else if (as<IRBasicType>(type)) return type; else if (as<IRTorchTensorType>(type)) return type; + else if (auto matrixType = as<IRMatrixType>(type)) + { + auto rowCount = as<IRIntLit>(matrixType->getRowCount()); + auto colCount = as<IRIntLit>(matrixType->getColumnCount()); + if (!rowCount || !colCount) + { + return nullptr; + } + List<IRType*> elementTypes; + for (IRIntegerValue i = 0; i < rowCount->getValue(); i++) + { + elementTypes.addRange(matrixType->getElementType()); + } + auto elementTupleType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); + List<IRType*> rowTypes; + for (IRIntegerValue i = 0; i < colCount->getValue(); i++) + { + rowTypes.add(elementTupleType); + } + return builder.getTargetTupleType((UInt)rowTypes.getCount(), rowTypes.getBuffer()); + } else if (auto vectorType = as<IRVectorType>(type)) { auto count = as<IRIntLit>(vectorType->getElementCount()); @@ -60,6 +82,10 @@ static IRType* translateToTupleType( } return builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); } + else if (auto targetTupleType = as<IRTargetTupleType>(type)) + { + return type; + } else { return nullptr; @@ -76,6 +102,38 @@ static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val) return val; else if (as<IRTorchTensorType>(type)) return val; + else if (auto matrixType = as<IRMatrixType>(type)) + { + auto rowCount = as<IRIntLit>(matrixType->getRowCount()); + auto colCount = as<IRIntLit>(matrixType->getColumnCount()); + if (!rowCount || !colCount) + { + return nullptr; + } + List<IRInst*> rowElements; + List<IRType*> rowTypes; + for (IRIntegerValue i = 0; i < rowCount->getValue(); i++) + { + List<IRInst*> colElements; + List<IRType*> colTypes; + for (IRIntegerValue j = 0; j < colCount->getValue(); j++) + { + auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i)); + auto tupleElement = makeTargetTuple(builder, elementVal); + if (!tupleElement) + return nullptr; + colElements.add(tupleElement); + colTypes.add(tupleElement->getFullType()); + } + auto rowType = builder.getTargetTupleType((UInt)colTypes.getCount(), colTypes.getBuffer()); + rowTypes.add(rowType); + rowElements.add(builder.emitMakeTargetTuple(rowType, (UInt)colElements.getCount(), colElements.getBuffer())); + } + return builder.emitMakeTargetTuple( + builder.getTargetTupleType((UInt)rowTypes.getCount(), rowTypes.getBuffer()), + (UInt)rowElements.getCount(), + rowElements.getBuffer()); + } else if (auto vectorType = as<IRVectorType>(type)) { auto count = as<IRIntLit>(vectorType->getElementCount()); @@ -134,6 +192,10 @@ static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val) auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer()); } + else if (auto targetTupleType = as<IRTargetTupleType>(type)) + { + return val; + } else { return nullptr; @@ -149,6 +211,25 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst return val; else if (as<IRTorchTensorType>(type)) return val; + else if (auto matrixType = as<IRMatrixType>(type)) + { + auto rowCount = as<IRIntLit>(matrixType->getRowCount()); + auto colCount = as<IRIntLit>(matrixType->getColumnCount()); + SLANG_ASSERT(rowCount && colCount); + + List<IRInst*> resultElements; + auto rowType = builder.getTargetTupleType((UInt)colCount->getValue(), List<IRType*>().makeRepeated(matrixType->getElementType(), (Index)colCount->getValue()).getBuffer()); + for (IRIntegerValue i = 0; i < rowCount->getValue(); i++) + { + auto rowElement = builder.emitTargetTupleGetElement(rowType, val, builder.getIntValue(builder.getIntType(), i)); + for (IRIntegerValue j = 0; j < colCount->getValue(); j++) + { + auto element = builder.emitTargetTupleGetElement(matrixType->getElementType(), rowElement, builder.getIntValue(builder.getIntType(), j)); + resultElements.add(element); + } + } + return builder.emitMakeMatrix(type, (UInt)resultElements.getCount(), resultElements.getBuffer()); + } else if (auto vectorType = as<IRVectorType>(type)) { auto count = as<IRIntLit>(vectorType->getElementCount()); @@ -203,6 +284,10 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst } return builder.emitMakeStruct(type, (UInt)resultElements.getCount(), resultElements.getBuffer()); } + else if (auto targetTupleType = as<IRTargetTupleType>(type)) + { + return val; + } else { return nullptr; @@ -318,29 +403,13 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, DiagnosticSink* sink = nullptr) { - if (as<IRBasicType>(type) || as<IRVectorType>(type)) + if (as<IRBasicType>(type) || as<IRVectorType>(type) || as<IRMatrixType>(type)) return type; switch (type->getOp()) { case kIROp_TensorViewType: return builder->getTorchTensorType(as<IRTensorViewType>(type)->getElementType()); -#if 0 - case kIROp_VectorType: - { - // Create a new struct type representing the vector. - auto hostStructType = builder->createStructType(); - const char* names[4] = { "x", "y", "z", "w" }; - for (IRIntegerValue i = 0; i < getIntVal(as<IRVectorType>(type)->getElementCount()); i++) - { - auto key = builder->createStructKey(); - if (i < 4) - builder->addNameHintDecoration(key, UnownedStringSlice(names[i])); - builder->createStructField(hostStructType, key, as<IRVectorType>(type)->getElementType()); - } - return hostStructType; - } -#endif case kIROp_StructType: { // Create a new struct type with translated fields. @@ -386,18 +455,6 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp { case kIROp_TensorViewType: return builder->emitMakeTensorView(cudaType, inst); -#if 0 - case kIROp_VectorType: - { - List<IRInst*> args; - auto hostStructType = cast<IRStructType>(hostType); - for (auto field : hostStructType->getFields()) - { - args.add(builder->emitFieldExtract(field->getFieldType(), inst, field->getKey())); - } - return builder->emitMakeVector(cudaType, args); - } -#endif case kIROp_StructType: { auto cudaStructType = cast<IRStructType>(cudaType); @@ -858,12 +915,138 @@ IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink) return hostFunc; } -void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) +void lowerBuiltinTypesForKernelEntryPoints(IRModule* module, DiagnosticSink*) { - List<IRFunc*> workList; List<IRFunc*> cudaKernels; + for (auto globalInst : module->getGlobalInsts()) + { + if (auto func = as<IRFunc>(globalInst)) + { + if (func->findDecoration<IRCudaKernelDecoration>()) + { + cudaKernels.add(func); + } + } + } + + BuiltinTypeLoweringEnv typeLoweringEnv; + IRBuilder builder(module); + for (auto func : cudaKernels) + { + // Go through parameters and replace any built-in types with their equivalent. + List<IRParam*> params; + for (auto param : func->getFirstBlock()->getParams()) + { + params.add(param); + } + + bool changed = false; + List<LoweredBuiltinTypeInfo> loweredParamTypes; + for (auto param : params) + { + LoweredBuiltinTypeInfo info = lowerType(&typeLoweringEnv, &builder, param->getDataType()); + loweredParamTypes.add(info); + + if (info.convertLoweredToOriginal != nullptr) + { + // Replace parameter with the lowered type. + auto originalType = param->getDataType(); + param->setFullType(info.loweredType); + + // Call the conversion function to convert the lowered parameter to the original parameter. + List<IRInst*> args; + args.add(param); + + setInsertAfterOrdinaryInst(&builder, param); + auto convertedParam = builder.emitCallInst(originalType, info.convertLoweredToOriginal, args); + + // Replace all uses of the lowered parameter with the converted parameter, except for the call instruction. + for (auto use = param->firstUse; use;) + { + auto nextUse = use->nextUse; + + if (use->getUser() == convertedParam) + { + use = nextUse; + continue; + } + + use->set(convertedParam); + use = nextUse; + } + + changed = true; + } + } + + if (!changed) + continue; + + fixUpFuncType(func); + + // Go through any calls to this function and insert a call to converOriginalToLowered before the call. + for (auto use = func->firstUse; use;) + { + auto nextUse = use->nextUse; + + if (as<IRCall>(use->getUser()) || as<IRDispatchKernel>(use->getUser())) + { + auto user = use->getUser(); + IROperandList<IRInst> argsList; + if (auto callInst = as<IRCall>(user)) + argsList = callInst->getArgsList(); + else if (auto dispatchInst = as<IRDispatchKernel>(user)) + argsList = dispatchInst->getArgsList(); + + // Insert a call to convertOriginalToLowered before the call. + List<IRInst*> convertedArgs; + IRBuilder callBuilder(func->getModule()); + callBuilder.setInsertBefore(user); + for (auto arg : argsList) + { + if (loweredParamTypes[convertedArgs.getCount()].convertOriginalToLowered != nullptr) + { + auto convertedArg = callBuilder.emitCallInst( + loweredParamTypes[convertedArgs.getCount()].loweredType, + loweredParamTypes[convertedArgs.getCount()].convertOriginalToLowered, + List<IRInst*>(arg)); + convertedArgs.add(convertedArg); + } + else + { + convertedArgs.add(arg); + } + } + + // Rebuild the call/dispatch inst. + IRInst* newCall = nullptr; + + if (auto callInst = as<IRCall>(user)) + newCall = callBuilder.emitCallInst(user->getFullType(), func, convertedArgs); + else if (auto dispatchInst = as<IRDispatchKernel>(user)) + newCall = callBuilder.emitDispatchKernelInst( + user->getFullType(), + func, + dispatchInst->getThreadGroupSize(), + dispatchInst->getDispatchSize(), + convertedArgs.getCount(), + convertedArgs.getBuffer()); + + // Replace the call instruction. + user->replaceUsesWith(newCall); + + // Remove the call instruction. + user->removeAndDeallocate(); + } + + use = nextUse; + } + } +} + +void generateHostFunctionsForAutoBindCuda(IRModule* module, DiagnosticSink* sink) +{ List<IRFunc*> autoBindRequests; - List<IRType*> typesToExport; for (auto globalInst : module->getGlobalInsts()) { if (auto func = as<IRFunc>(globalInst)) @@ -872,6 +1055,24 @@ void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) { autoBindRequests.add(func); } + } + } + + for (auto func : autoBindRequests) + { + generateCUDAWrapperForFunc(func, sink); + } +} + +void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) +{ + List<IRFunc*> workList; + List<IRFunc*> cudaKernels; + List<IRType*> typesToExport; + for (auto globalInst : module->getGlobalInsts()) + { + if (auto func = as<IRFunc>(globalInst)) + { if (func->findDecoration<IRTorchEntryPointDecoration>()) { workList.add(func); @@ -895,16 +1096,6 @@ void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) } } - // Generate CUDA wrappers for all functions that have the auto-bind decoration. - for (auto func : autoBindRequests) - { - if (auto hostFunc = generateCUDAWrapperForFunc(func, sink)) - { - // Add generated wrapper to worklist for python bindings. - workList.add(hostFunc); - } - } - for (auto func : workList) generateCppBindingForFunc(func, sink); @@ -967,7 +1158,6 @@ void handleAutoBindNames(IRModule* module) nameBuilder << "__kernel__" << externCppHint->getName(); externCppHint->removeAndDeallocate(); builder.addExternCppDecoration(globalInst, nameBuilder.getUnownedSlice()); - builder.addExternCDecoration(globalInst); } } } diff --git a/source/slang/slang-ir-pytorch-cpp-binding.h b/source/slang/slang-ir-pytorch-cpp-binding.h index dd7dcc9a4..a761dbc03 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.h +++ b/source/slang/slang-ir-pytorch-cpp-binding.h @@ -6,9 +6,11 @@ struct IRModule; class DiagnosticSink; void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink); +void generateHostFunctionsForAutoBindCuda(IRModule* module, DiagnosticSink* sink); void removeTorchKernels(IRModule* module); void handleAutoBindNames(IRModule* module); void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink); +void lowerBuiltinTypesForKernelEntryPoints(IRModule* module, DiagnosticSink* sink); } diff --git a/tests/autodiff/autobind-plain-matrix-input.slang b/tests/autodiff/autobind-plain-matrix-input.slang new file mode 100644 index 000000000..d09e77fc1 --- /dev/null +++ b/tests/autodiff/autobind-plain-matrix-input.slang @@ -0,0 +1,21 @@ +//TEST:SIMPLE(filecheck=CUDA): -target cuda -line-directive-mode none +//TEST:SIMPLE(filecheck=TORCH): -target torch -line-directive-mode none + +[AutoPyBindCUDA] +[CUDAKernel] +void plain_copy(float3x3 input, TensorView<float> output) +{ + // CUDA: __global__ void __kernel__plain_copy(_MatrixStorage_float3x3_ColMajor_0 input_0, TensorView output_0) + // TORCH: void __kernel__plain_copy(_MatrixStorage_float3x3_ColMajor_0 _0, TensorView _1); + + // Get the 'global' index of this thread. + uint3 dispatchIdx = cudaThreadIdx() + cudaBlockIdx() * cudaBlockDim(); + + // If the thread index is beyond the input size, exit early. + if (dispatchIdx.x >= 1) + return; + + output[0] = input[0][0]; + output[1] = input[1][1]; + output[2] = input[2][2]; +} diff --git a/tests/autodiff/autobind-plain-vector-input.slang b/tests/autodiff/autobind-plain-vector-input.slang new file mode 100644 index 000000000..216585093 --- /dev/null +++ b/tests/autodiff/autobind-plain-vector-input.slang @@ -0,0 +1,21 @@ +//TEST:SIMPLE(filecheck=CUDA): -target cuda -line-directive-mode none +//TEST:SIMPLE(filecheck=TORCH): -target torch -line-directive-mode none + +[AutoPyBindCUDA] +[CUDAKernel] +void plain_copy(float3 input, TensorView<float> output) +{ + // CUDA: __global__ void __kernel__plain_copy(_VectorStorage_float3_0 input_0, TensorView output_0) + // TORCH: void __kernel__plain_copy(_VectorStorage_float3_0 _0, TensorView _1); + + // Get the 'global' index of this thread. + uint3 dispatchIdx = cudaThreadIdx() + cudaBlockIdx() * cudaBlockDim(); + + // If the thread index is beyond the input size, exit early. + if (dispatchIdx.x >= 1) + return; + + output[0] = input.x; + output[1] = input.y; + output[2] = input.z; +} diff --git a/tests/autodiff/autobind-struct-with-array-of-builtins.slang b/tests/autodiff/autobind-struct-with-array-of-builtins.slang new file mode 100644 index 000000000..69904fadd --- /dev/null +++ b/tests/autodiff/autobind-struct-with-array-of-builtins.slang @@ -0,0 +1,22 @@ +//TEST:SIMPLE(filecheck=CUDA): -target cuda -line-directive-mode none +//TEST:SIMPLE(filecheck=TORCH): -target torch -line-directive-mode none + + +[AutoPyBindCUDA] +[CUDAKernel] +void plain_copy(float3[4] input, TensorView<float> output) +{ + // CUDA: __global__ void __kernel__plain_copy(FixedArray<_VectorStorage_float3_0, 4> input_0, TensorView output_0) + // TORCH: void __kernel__plain_copy(FixedArray<_VectorStorage_float3_0, 4> _0, TensorView _1); + + // Get the 'global' index of this thread. + uint3 dispatchIdx = cudaThreadIdx() + cudaBlockIdx() * cudaBlockDim(); + + // If the thread index is beyond the input size, exit early. + if (dispatchIdx.x >= 1) + return; + + output[0] = input[0].x; + output[1] = input[2].y; + output[2] = input[3].z; +} diff --git a/tests/autodiff/autobind-struct-with-builtin-types.slang b/tests/autodiff/autobind-struct-with-builtin-types.slang new file mode 100644 index 000000000..70832cc40 --- /dev/null +++ b/tests/autodiff/autobind-struct-with-builtin-types.slang @@ -0,0 +1,32 @@ +//TEST:SIMPLE(filecheck=CUDA): -target cuda -line-directive-mode none +//TEST:SIMPLE(filecheck=TORCH): -target torch -line-directive-mode none + +struct MyStruct +{ + float3x3 data; + float3 vec; +}; + +struct MyStruct2 +{ + float data; +}; + +[AutoPyBindCUDA] +[CUDAKernel] +void plain_copy(MyStruct input, MyStruct2 input2, TensorView<float> output) +{ + // CUDA: __global__ void __kernel__plain_copy(U_StructStorage_MyStruct_0 input_0, MyStruct2_0 input2_0, TensorView output_0) + // TORCH: void __kernel__plain_copy(U_StructStorage_MyStruct_0 _0, MyStruct2_0 _1, TensorView _2); + + // Get the 'global' index of this thread. + uint3 dispatchIdx = cudaThreadIdx() + cudaBlockIdx() * cudaBlockDim(); + + // If the thread index is beyond the input size, exit early. + if (dispatchIdx.x >= 1) + return; + + output[0] = input.data[0][0]; + output[1] = input.vec[1]; + output[2] = input.data[2][2]; +} diff --git a/tests/diagnostics/cuda-kernel-differentiable-params.slang b/tests/diagnostics/cuda-kernel-differentiable-params.slang new file mode 100644 index 000000000..0e7604b3d --- /dev/null +++ b/tests/diagnostics/cuda-kernel-differentiable-params.slang @@ -0,0 +1,18 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): + +// Simple check to see if the compiler throws an error if a CUDA kernel is declared with non-void return type. + +[CudaKernel] +[Differentiable] +void myBadKernel(float x, TensorView<float> t1, TensorView<float> t2) +{ + // CHECK: tests/diagnostics/cuda-kernel-differentiable-params.slang([[@LINE-2]]): error 31214: differentiable kernel entry point cannot have differentiable parameters. Consider using DiffTensorView to pass differentiable data, or marking this parameter with 'no_diff' + // CHECK-NEXT: void myBadKernel(float x, TensorView<float> t1, TensorView<float> t2) + // CHECK-NEXT: ^ +} + +[CudaKernel] +void myGoodKernel(float x, TensorView<float> t1, TensorView<float> t2) +{ + +}
\ No newline at end of file diff --git a/tests/diagnostics/cuda-kernel-non-void-return.slang b/tests/diagnostics/cuda-kernel-non-void-return.slang new file mode 100644 index 000000000..75c8bc6d4 --- /dev/null +++ b/tests/diagnostics/cuda-kernel-non-void-return.slang @@ -0,0 +1,17 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): + +// Simple check to see if the compiler throws an error if a CUDA kernel is declared with non-void return type. + +[CudaKernel] +float myBadKernel(TensorView<float> t1, TensorView<float> t2) +{ + // CHECK: tests/diagnostics/cuda-kernel-non-void-return.slang([[@LINE-2]]): error 31213: return type of a CUDA kernel function cannot be non-void. + // CHECK-NEXT: float myBadKernel(TensorView<float> t1, TensorView<float> t2) + // CHECK-NEXT: ^~~~~~~~~~~ +} + +[CudaKernel] +void myGoodKernel(TensorView<float> t1, TensorView<float> t2) +{ + +}
\ No newline at end of file |
