summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-04-30 16:05:33 -0400
committerGitHub <noreply@github.com>2024-04-30 16:05:33 -0400
commit52b91231cdadc048f93b224f5035759cf1a96eaa (patch)
tree23d3263bc662eb96d6284266282695a9b0f1e2db /source/slang
parent70111daf43c87e182695666c34345e061e114a68 (diff)
Added diagnostics & built-in type lowering for `[CUDAKernel]` functions (#4042)
* Added diagnostics & built-in type lowering for `[CUDAKernel]` functions This PR adds - Diagnostics for non-void return from a cuda kernel entry point - Diagnostics for using differentiable types in a differentiable cuda kernel entry point - Logic for converting built-in types (float3, float3x3, etc..) to portable struct types and unpacks the parameter back into a built-in type on the CUDA side. This is because built-in types have different implementations in CUDA & CPP targets, which causes signature mis-match when linking. * Fix error codes * Add ability to lower structs and arrays that contain built-in types. + Added tests + Fix issue where the host-side was not marshalling data to lowered types. * Update slang-ir-pytorch-cpp-binding.cpp --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-check-decl.cpp36
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-emit-c-like.cpp5
-rw-r--r--source/slang/slang-emit.cpp3
-rw-r--r--source/slang/slang-ir-insts.h4
-rw-r--r--source/slang/slang-ir-lower-cuda-builtin-types.cpp461
-rw-r--r--source/slang/slang-ir-lower-cuda-builtin-types.h51
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.cpp278
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.h2
9 files changed, 795 insertions, 47 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 0f9da12c4..921bd38e9 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -7700,6 +7700,16 @@ namespace Slang
}
}
}
+
+ // If this method is intended to be a CUDA kernel, verify that the return type is void.
+ if (decl->findModifier<CudaKernelAttribute>())
+ {
+ if (decl->returnType.type && !decl->returnType.type->equals(m_astBuilder->getVoidType()))
+ {
+ getSink()->diagnose(decl, Diagnostics::cudaKernelMustReturnVoid);
+ }
+ }
+
checkVisibility(decl);
}
@@ -9547,6 +9557,30 @@ namespace Slang
checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions);
}
+ static void checkCudaKernelAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, CudaKernelAttribute*)
+ {
+ // If the method is also marked differentiable, check that the data types are either non-differentiable
+ // or marked with no_diff.
+ //
+ // Note: This is a temporary restriction until we have a more complete story for differentiability.
+ //
+ if (funcDecl->findModifier<DifferentiableAttribute>())
+ {
+ for (auto paramDecl : funcDecl->getParameters())
+ {
+ auto paramType = paramDecl->type;
+
+ if (visitor->isTypeDifferentiable(paramType))
+ {
+ if (!paramDecl->hasModifier<NoDiffModifier>())
+ {
+ visitor->getSink()->diagnose(paramDecl, Diagnostics::differentiableKernelEntryPointCannotHaveDifferentiableParams);
+ }
+ }
+ }
+ }
+ }
+
template<typename TDerivativeAttr, typename TDerivativeOfAttr>
bool tryCheckDerivativeOfAttributeImpl(
SemanticsVisitor* visitor,
@@ -9747,6 +9781,8 @@ namespace Slang
checkDerivativeAttribute(this, decl, bwdDerivativeAttr);
else if (auto primalAttr = as<PrimalSubstituteAttribute>(attr))
checkDerivativeAttribute(this, decl, primalAttr);
+ else if (auto cudaKernelAttr = as<CudaKernelAttribute>(attr))
+ checkCudaKernelAttribute(this, decl, cudaKernelAttr);
}
}
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index b294383bb..eb131df21 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -454,6 +454,8 @@ DIAGNOSTIC(31208, Error, requireInputDecoratedVarForParameter, "$0 expects for a
DIAGNOSTIC(31210, Error, derivativeGroupQuadMustBeMultiple2ForXYThreads, "compute derivative group quad requires thread dispatch count of X and Y to each be at a multiple of 2")
DIAGNOSTIC(31211, Error, derivativeGroupLinearMustBeMultiple4ForTotalThreadCount, "compute derivative group linear requires total thread dispatch count to be at a multiple of 4")
DIAGNOSTIC(31212, Error, onlyOneOfDerivativeGroupLinearOrQuadCanBeSet, "cannot set compute derivative group linear and compute derivative group quad at the same time")
+DIAGNOSTIC(31213, Error, cudaKernelMustReturnVoid, "return type of a CUDA kernel function cannot be non-void.")
+DIAGNOSTIC(31214, Error, differentiableKernelEntryPointCannotHaveDifferentiableParams, "differentiable kernel entry point cannot have differentiable parameters. Consider using DiffTensorView to pass differentiable data, or marking this parameter with 'no_diff'")
// Enums
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index fa380e061..626c372e9 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -3518,14 +3518,13 @@ bool CLikeSourceEmitter::isTargetIntrinsic(IRInst* inst)
return findTargetIntrinsicDefinition(inst, intrinsicDef);
}
-bool shouldWrappInExternCBlock(IRFunc* func)
+bool shouldWrapInExternCBlock(IRFunc* func)
{
for (auto decor : func->getDecorations())
{
switch (decor->getOp())
{
case kIROp_ExternCDecoration:
- case kIROp_CudaKernelDecoration:
return true;
}
}
@@ -3540,7 +3539,7 @@ void CLikeSourceEmitter::emitFunc(IRFunc* func)
if (isTargetIntrinsic(func))
return;
- bool shouldCloseExternCBlock = shouldWrappInExternCBlock(func);
+ bool shouldCloseExternCBlock = shouldWrapInExternCBlock(func);
if (shouldCloseExternCBlock)
{
// If this is a C++ `extern "C"` function, then we need to emit
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 1fa04b4be..afdd37fce 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -471,10 +471,13 @@ Result linkAndOptimizeIR(
switch (target)
{
case CodeGenTarget::PyTorchCppBinding:
+ generateHostFunctionsForAutoBindCuda(irModule, sink);
+ lowerBuiltinTypesForKernelEntryPoints(irModule, sink);
generatePyTorchCppBinding(irModule, sink);
handleAutoBindNames(irModule);
break;
case CodeGenTarget::CUDASource:
+ lowerBuiltinTypesForKernelEntryPoints(irModule, sink);
removeTorchKernels(irModule);
handleAutoBindNames(irModule);
break;
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index eae025c96..0f41f9bd9 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -1189,6 +1189,10 @@ struct IRDispatchKernel : IRInst
IRInst* getDispatchSize() { return getOperand(2); }
UInt getArgCount() { return getOperandCount() - 3; }
IRInst* getArg(UInt i) { return getOperand(3 + i); }
+ IROperandList<IRInst> getArgsList()
+ {
+ return IROperandList<IRInst>(getOperands() + 3, getOperands() + getOperandCount());
+ }
IR_LEAF_ISA(DispatchKernel)
};
diff --git a/source/slang/slang-ir-lower-cuda-builtin-types.cpp b/source/slang/slang-ir-lower-cuda-builtin-types.cpp
new file mode 100644
index 000000000..675e43043
--- /dev/null
+++ b/source/slang/slang-ir-lower-cuda-builtin-types.cpp
@@ -0,0 +1,461 @@
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+#include "slang-ir-clone.h"
+#include "slang-ir-layout.h"
+#include "slang-ir-lower-cuda-builtin-types.h"
+namespace Slang
+{
+
+ IRFunc* createMatrixUnpackFunc(
+ IRMatrixType* matrixType,
+ IRStructType* structType,
+ IRStructKey* dataKey,
+ IRArrayType* arrayType)
+ {
+ IRBuilder builder(structType);
+ builder.setInsertAfter(structType);
+ auto func = builder.createFunc();
+ auto funcType = builder.getFuncType(1, (IRType**)&structType, matrixType);
+ func->setFullType(funcType);
+ builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage"));
+ builder.setInsertInto(func);
+ builder.emitBlock();
+ auto rowCount = (Index)getIntVal(matrixType->getRowCount());
+ auto colCount = (Index)getIntVal(matrixType->getColumnCount());
+ auto packedParam = builder.emitParam(structType);
+ auto matrixArray = builder.emitFieldExtract(arrayType, packedParam, dataKey);
+ List<IRInst*> args;
+ args.setCount(rowCount * colCount);
+ if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
+ {
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ args[(Index)(r * colCount + c)] = builder.emitElementExtract(matrixArray, (Index)(r*colCount + c));
+ }
+ else
+ {
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ args[(Index)(c * rowCount + r)] = builder.emitElementExtract(matrixArray, (Index)(r*colCount + c));
+ }
+ IRInst* result = builder.emitMakeMatrix(matrixType, (UInt)args.getCount(), args.getBuffer());
+ builder.emitReturn(result);
+ return func;
+ }
+
+ IRFunc* createMatrixPackFunc(
+ IRMatrixType* matrixType,
+ IRStructType* structType,
+ IRArrayType* arrayType)
+ {
+ IRBuilder builder(structType);
+ builder.setInsertAfter(structType);
+ auto func = builder.createFunc();
+ auto funcType = builder.getFuncType(1, (IRType**)&matrixType, structType);
+ func->setFullType(funcType);
+ builder.addNameHintDecoration(func, UnownedStringSlice("packMatrix"));
+ builder.setInsertInto(func);
+ builder.emitBlock();
+ auto rowCount = getIntVal(matrixType->getRowCount());
+ auto colCount = getIntVal(matrixType->getColumnCount());
+ auto originalParam = builder.emitParam(matrixType);
+ List<IRInst*> elements;
+ elements.setCount((Index)(rowCount * colCount));
+ for (IRIntegerValue r = 0; r < rowCount; r++)
+ {
+ auto vector = builder.emitElementExtract(originalParam, r);
+ for (IRIntegerValue c = 0; c < colCount; c++)
+ {
+ auto element = builder.emitElementExtract(vector, c);
+ elements[(Index)(r * colCount + c)] = element;
+ }
+ }
+
+ auto matrixArray = builder.emitMakeArray(arrayType, (UInt)elements.getCount(), elements.getBuffer());
+ auto result = builder.emitMakeStruct(structType, 1, &matrixArray);
+ builder.emitReturn(result);
+ return func;
+ }
+
+ IRFunc* createVectorUnpackFunc(
+ IRVectorType* vectorType,
+ IRStructType* structType,
+ IRStructKey* dataKey,
+ IRArrayType* arrayType)
+ {
+ IRBuilder builder(structType);
+ builder.setInsertAfter(structType);
+ auto func = builder.createFunc();
+ auto funcType = builder.getFuncType(1, (IRType**)&structType, vectorType);
+ func->setFullType(funcType);
+ builder.addNameHintDecoration(func, UnownedStringSlice("unpackVector"));
+ builder.setInsertInto(func);
+ builder.emitBlock();
+ auto packedParam = builder.emitParam(structType);
+ auto packedArray = builder.emitFieldExtract(arrayType, packedParam, dataKey);
+ auto count = getIntVal(vectorType->getElementCount());
+ List<IRInst*> args;
+ args.setCount((Index)count);
+ for (IRIntegerValue ii = 0; ii < count; ++ii)
+ {
+ args[(Index)ii] = builder.emitElementExtract(packedArray, ii);
+ }
+ auto result = builder.emitMakeVector(vectorType, (UInt)args.getCount(), args.getBuffer());
+ builder.emitReturn(result);
+ return func;
+ }
+
+ IRFunc* createVectorPackFunc(
+ IRVectorType* vectorType,
+ IRStructType* structType,
+ IRArrayType* arrayType)
+ {
+ IRBuilder builder(structType);
+ builder.setInsertAfter(structType);
+ auto func = builder.createFunc();
+ auto funcType = builder.getFuncType(1, (IRType**)&vectorType, structType);
+ func->setFullType(funcType);
+ builder.addNameHintDecoration(func, UnownedStringSlice("packVector"));
+ builder.setInsertInto(func);
+ builder.emitBlock();
+ auto originalParam = builder.emitParam(vectorType);
+ auto count = getIntVal(vectorType->getElementCount());
+ List<IRInst*> args;
+ args.setCount((Index)count);
+ for (IRIntegerValue ii = 0; ii < count; ++ii)
+ {
+ args[(Index)ii] = builder.emitElementExtract(originalParam, ii);
+ }
+ auto packedArray = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer());
+ auto result = builder.emitMakeStruct(structType, 1, &packedArray);
+ builder.emitReturn(result);
+ return func;
+ }
+
+ LoweredBuiltinTypeInfo lowerMatrixType(
+ IRBuilder* builder,
+ IRMatrixType* matrixType,
+ String nameSuffix)
+ {
+ LoweredBuiltinTypeInfo info;
+
+ auto loweredType = builder->createStructType();
+ StringBuilder nameSB;
+ bool isColMajor = getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR;
+ nameSB << "_MatrixStorage_";
+ getTypeNameHint(nameSB, matrixType->getElementType());
+ nameSB << getIntVal(matrixType->getRowCount()) << "x" << getIntVal(matrixType->getColumnCount());
+ if (isColMajor)
+ nameSB << "_ColMajor";
+ nameSB << nameSuffix;
+ builder->addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
+ auto structKey = builder->createStructKey();
+ builder->addNameHintDecoration(structKey, UnownedStringSlice("data"));
+ //auto vectorType = builder->getVectorType(matrixType->getElementType(),
+ // isColMajor?matrixType->getRowCount():matrixType->getColumnCount());
+
+ auto arrayType =
+ builder->getArrayType(
+ matrixType->getElementType(),
+ builder->getIntValue(
+ builder->getUIntType(),
+ getIntVal(matrixType->getColumnCount()) * getIntVal(matrixType->getRowCount())));
+
+ builder->createStructField(loweredType, structKey, arrayType);
+
+ info.originalType = matrixType;
+ info.loweredType = loweredType;
+ info.loweredInnerArrayType = arrayType;
+ info.loweredInnerStructKey = structKey;
+ info.convertLoweredToOriginal = createMatrixUnpackFunc(matrixType, loweredType, structKey, arrayType);
+ info.convertOriginalToLowered = createMatrixPackFunc(matrixType, loweredType, arrayType);
+ return info;
+ }
+
+ LoweredBuiltinTypeInfo lowerVectorType(
+ IRBuilder* builder,
+ IRVectorType* vectorType,
+ String nameSuffix)
+ {
+ LoweredBuiltinTypeInfo info;
+
+ auto loweredType = builder->createStructType();
+
+ StringBuilder nameSB;
+ nameSB << "_VectorStorage_";
+ getTypeNameHint(nameSB, vectorType->getElementType());
+ nameSB << getIntVal(vectorType->getElementCount()) << "_";
+ nameSB << nameSuffix;
+ builder->addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
+
+
+ info.originalType = vectorType;
+ info.loweredType = loweredType;
+
+ auto structKey = builder->createStructKey();
+ builder->addNameHintDecoration(structKey, UnownedStringSlice("data"));
+
+ auto arrayType = builder->getArrayType(
+ vectorType->getElementType(),
+ vectorType->getElementCount());
+
+ builder->createStructField(loweredType, structKey, arrayType);
+
+ info.convertLoweredToOriginal = createVectorUnpackFunc(vectorType, loweredType, structKey, arrayType);
+ info.convertOriginalToLowered = createVectorPackFunc(vectorType, loweredType, arrayType);
+
+ return info;
+ }
+
+ LoweredBuiltinTypeInfo lowerStructType(
+ BuiltinTypeLoweringEnv* env,
+ IRBuilder* builder,
+ IRStructType* structType,
+ String nameSuffix)
+ {
+ // Recursively lower the fields of the struct type
+ List<IRType*> fieldTypes;
+ List<IRStructField*> fields;
+ for (auto field : structType->getFields())
+ {
+ fieldTypes.add(field->getFieldType());
+ fields.add(field);
+ }
+
+ auto loweredType = builder->createStructType();
+ StringBuilder nameSB;
+ nameSB << "_StructStorage_";
+
+ // Find a name hint for the struct type
+ for (auto decoration : structType->getDecorations())
+ if (auto nameHint = as<IRNameHintDecoration>(decoration))
+ nameSB << nameHint->getName();
+
+ nameSB << nameSuffix;
+ builder->addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice());
+
+ bool changesRequired = false;
+
+ // Lower field types.
+ List<LoweredBuiltinTypeInfo> loweredFieldInfos;
+ for (auto field : fields)
+ {
+ // Lower the field type
+ auto loweredFieldInfo = lowerType(env, builder, field->getFieldType(), nameSuffix);
+ loweredFieldInfos.add(loweredFieldInfo);
+
+ // Add the lowered field type to the lowered struct type
+ builder->createStructField(loweredType, field->getKey(), loweredFieldInfo.loweredType);
+
+ if (loweredFieldInfo.convertLoweredToOriginal != nullptr)
+ changesRequired = true;
+ }
+
+ if (!changesRequired)
+ {
+ // If no changes are required, then we can just return the original struct type
+ LoweredBuiltinTypeInfo info;
+ info.originalType = structType;
+ info.loweredType = structType;
+ info.convertLoweredToOriginal = nullptr;
+ info.convertOriginalToLowered = nullptr;
+ return info;
+ }
+
+ LoweredBuiltinTypeInfo info;
+ info.originalType = structType;
+ info.loweredType = loweredType;
+
+ // Create the conversion function from the lowered struct type to the original struct type
+ {
+ builder->setInsertAfter(loweredType);
+ auto func = builder->createFunc();
+ auto funcType = builder->getFuncType(1, (IRType**)&loweredType, structType);
+ func->setFullType(funcType);
+ builder->addNameHintDecoration(func, UnownedStringSlice("convertLoweredToOriginal"));
+ builder->setInsertInto(func);
+ builder->emitBlock();
+ auto loweredParam = builder->emitParam(loweredType);
+ List<IRInst*> args;
+ args.setCount((Index)fields.getCount());
+ for (Index i = 0; i < fields.getCount(); i++)
+ {
+ auto loweredField = builder->emitFieldExtract(loweredFieldInfos[i].loweredType, loweredParam, fields[i]->getKey());
+ List<IRInst*> callArgs;
+ callArgs.add(loweredField);
+
+ if (loweredFieldInfos[i].convertLoweredToOriginal == nullptr)
+ {
+ args[i] = loweredField;
+ continue;
+ }
+
+ args[i] = builder->emitCallInst(
+ loweredFieldInfos[i].originalType,
+ loweredFieldInfos[i].convertLoweredToOriginal,
+ callArgs);
+ }
+
+ auto result = builder->emitMakeStruct(structType, (UInt)args.getCount(), args.getBuffer());
+ builder->emitReturn(result);
+ info.convertLoweredToOriginal = func;
+ }
+
+ // Create the conversion function from the original struct type to the lowered struct type
+ {
+ builder->setInsertAfter(structType);
+ auto func = builder->createFunc();
+ auto funcType = builder->getFuncType(1, (IRType**)&structType, loweredType);
+ func->setFullType(funcType);
+ builder->addNameHintDecoration(func, UnownedStringSlice("convertOriginalToLowered"));
+ builder->setInsertInto(func);
+ builder->emitBlock();
+ auto originalParam = builder->emitParam(structType);
+ List<IRInst*> args;
+ args.setCount((Index)fields.getCount());
+ for (Index i = 0; i < fields.getCount(); i++)
+ {
+ auto originalField = builder->emitFieldExtract(loweredFieldInfos[i].originalType, originalParam, fields[i]->getKey());
+ List<IRInst*> callArgs;
+ callArgs.add(originalField);
+
+ if (loweredFieldInfos[i].convertOriginalToLowered == nullptr)
+ {
+ args[i] = originalField;
+ continue;
+ }
+
+ args[i] = builder->emitCallInst(
+ loweredFieldInfos[i].loweredType,
+ loweredFieldInfos[i].convertOriginalToLowered,
+ callArgs);
+ }
+
+ auto result = builder->emitMakeStruct(loweredType, (UInt)args.getCount(), args.getBuffer());
+ builder->emitReturn(result);
+ info.convertOriginalToLowered = func;
+ }
+
+ return info;
+ }
+
+ LoweredBuiltinTypeInfo lowerArrayType(
+ BuiltinTypeLoweringEnv* env,
+ IRBuilder* builder,
+ IRArrayType* arrayType,
+ String nameSuffix)
+ {
+ auto loweredElementTypeInfo = lowerType(env, builder, arrayType->getElementType(), nameSuffix);
+ auto loweredType = builder->getArrayType(loweredElementTypeInfo.loweredType, arrayType->getElementCount());
+
+ LoweredBuiltinTypeInfo info;
+ info.originalType = arrayType;
+ info.loweredType = loweredType;
+
+ // If the element type was lowered, then we need to create conversion functions
+ if (loweredElementTypeInfo.convertLoweredToOriginal != nullptr)
+ {
+ builder->setInsertAfter(loweredType);
+ auto func = builder->createFunc();
+ auto funcType = builder->getFuncType(1, (IRType**)&loweredType, arrayType);
+ func->setFullType(funcType);
+ builder->addNameHintDecoration(func, UnownedStringSlice("convertLoweredToOriginal"));
+ builder->setInsertInto(func);
+ builder->emitBlock();
+ auto loweredParam = builder->emitParam(loweredType);
+
+ auto count = getIntVal(arrayType->getElementCount());
+ List<IRInst*> args;
+ args.setCount((Index)count);
+ for (IRIntegerValue ii = 0; ii < count; ++ii)
+ {
+ auto loweredElement = builder->emitElementExtract(loweredParam, ii);
+ List<IRInst*> callArgs;
+ callArgs.add(loweredElement);
+ args[(Index)ii] = builder->emitCallInst(
+ arrayType->getElementType(),
+ loweredElementTypeInfo.convertLoweredToOriginal,
+ callArgs);
+ }
+
+ auto result = builder->emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer());
+ builder->emitReturn(result);
+ info.convertLoweredToOriginal = func;
+ }
+
+ if (loweredElementTypeInfo.convertOriginalToLowered != nullptr)
+ {
+ builder->setInsertAfter(arrayType);
+ auto func = builder->createFunc();
+ auto funcType = builder->getFuncType(1, (IRType**)&arrayType, loweredType);
+ func->setFullType(funcType);
+ builder->addNameHintDecoration(func, UnownedStringSlice("convertOriginalToLowered"));
+ builder->setInsertInto(func);
+ builder->emitBlock();
+ auto originalParam = builder->emitParam(arrayType);
+ auto count = getIntVal(arrayType->getElementCount());
+ List<IRInst*> args;
+ args.setCount((Index)count);
+ for (IRIntegerValue ii = 0; ii < count; ++ii)
+ {
+ auto originalElement = builder->emitElementExtract(originalParam, ii);
+ List<IRInst*> callArgs;
+ callArgs.add(originalElement);
+ args[(Index)ii] = builder->emitCallInst(
+ loweredElementTypeInfo.loweredType,
+ loweredElementTypeInfo.convertOriginalToLowered,
+ callArgs);
+ }
+
+ auto result = builder->emitMakeArray(loweredType, (UInt)args.getCount(), args.getBuffer());
+ builder->emitReturn(result);
+ info.convertOriginalToLowered = func;
+ }
+
+ return info;
+ }
+
+ LoweredBuiltinTypeInfo lowerType(
+ BuiltinTypeLoweringEnv* env,
+ IRBuilder* builder,
+ IRType* type,
+ String nameSuffix)
+ {
+ if (env->loweredTypes.containsKey(type))
+ return env->loweredTypes[type];
+
+ if (auto matrixType = as<IRMatrixType>(type))
+ {
+ auto loweredInfo = lowerMatrixType(builder, matrixType, nameSuffix);
+ env->loweredTypes[type] = loweredInfo;
+ return loweredInfo;
+ }
+ else if (auto vectorType = as<IRVectorType>(type))
+ {
+ auto loweredInfo = lowerVectorType(builder, vectorType, nameSuffix);
+ env->loweredTypes[type] = loweredInfo;
+ return loweredInfo;
+ }
+ else if (auto structType = as<IRStructType>(type))
+ {
+ auto loweredInfo = lowerStructType(env, builder, structType, nameSuffix);
+ env->loweredTypes[type] = loweredInfo;
+ return loweredInfo;
+ }
+ else if (auto arrayType = as<IRArrayType>(type))
+ {
+ auto loweredInfo = lowerArrayType(env, builder, arrayType, nameSuffix);
+ env->loweredTypes[type] = loweredInfo;
+ return loweredInfo;
+ }
+
+ LoweredBuiltinTypeInfo info;
+ info.originalType = type;
+ info.loweredType = type;
+ info.convertLoweredToOriginal = nullptr;
+ info.convertOriginalToLowered = nullptr;
+
+ return info;
+ }
+}; \ No newline at end of file
diff --git a/source/slang/slang-ir-lower-cuda-builtin-types.h b/source/slang/slang-ir-lower-cuda-builtin-types.h
new file mode 100644
index 000000000..900e90d10
--- /dev/null
+++ b/source/slang/slang-ir-lower-cuda-builtin-types.h
@@ -0,0 +1,51 @@
+#ifndef SLANG_IR_LOWER_BUILTIN_TYPES_H
+#define SLANG_IR_LOWER_BUILTIN_TYPES_H
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+#include "slang-ir-clone.h"
+#include "slang-ir-layout.h"
+
+namespace Slang
+{
+ struct LoweredBuiltinTypeInfo
+ {
+ IRType* originalType;
+ IRType* loweredType;
+ IRType* loweredInnerArrayType = nullptr; // For matrix/array types that are lowered into a struct type, this is the inner array type of the data field.
+ IRStructKey* loweredInnerStructKey = nullptr; // For matrix/array types that are lowered into a struct type, this is the struct key of the data field.
+ IRFunc* convertOriginalToLowered = nullptr;
+ IRFunc* convertLoweredToOriginal = nullptr;
+ };
+
+ struct BuiltinTypeLoweringEnv
+ {
+ Dictionary<IRType*, LoweredBuiltinTypeInfo> loweredTypes;
+ };
+
+ LoweredBuiltinTypeInfo lowerMatrixType(
+ IRBuilder* builder,
+ IRMatrixType* matrixType,
+ String nameSuffix = "");
+
+ LoweredBuiltinTypeInfo lowerVectorType(
+ IRBuilder* builder,
+ IRVectorType* vectorType,
+ String nameSuffix = "");
+
+ LoweredBuiltinTypeInfo lowerStructType(
+ BuiltinTypeLoweringEnv* env,
+ IRBuilder* builder,
+ IRStructType* structType,
+ String nameSuffix = "");
+
+ LoweredBuiltinTypeInfo lowerType(
+ BuiltinTypeLoweringEnv* env,
+ IRBuilder* builder,
+ IRType* type,
+ String nameSuffix = "");
+
+} // namespace Slang
+
+#endif // SLANG_IR_LOWER_BUILTIN_TYPES_H \ No newline at end of file
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp
index 6a85f0324..fd885dae7 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.cpp
+++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp
@@ -3,6 +3,7 @@
#include "slang-ir-insts.h"
#include "slang-diagnostics.h"
#include "slang-ir-autodiff.h"
+#include "slang-ir-lower-cuda-builtin-types.h"
namespace Slang
{
@@ -13,10 +14,31 @@ static IRType* translateToTupleType(
{
if (as<IRVoidType>(type))
return type;
- if (as<IRBasicType>(type))
+ else if (as<IRBasicType>(type))
return type;
else if (as<IRTorchTensorType>(type))
return type;
+ else if (auto matrixType = as<IRMatrixType>(type))
+ {
+ auto rowCount = as<IRIntLit>(matrixType->getRowCount());
+ auto colCount = as<IRIntLit>(matrixType->getColumnCount());
+ if (!rowCount || !colCount)
+ {
+ return nullptr;
+ }
+ List<IRType*> elementTypes;
+ for (IRIntegerValue i = 0; i < rowCount->getValue(); i++)
+ {
+ elementTypes.addRange(matrixType->getElementType());
+ }
+ auto elementTupleType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
+ List<IRType*> rowTypes;
+ for (IRIntegerValue i = 0; i < colCount->getValue(); i++)
+ {
+ rowTypes.add(elementTupleType);
+ }
+ return builder.getTargetTupleType((UInt)rowTypes.getCount(), rowTypes.getBuffer());
+ }
else if (auto vectorType = as<IRVectorType>(type))
{
auto count = as<IRIntLit>(vectorType->getElementCount());
@@ -60,6 +82,10 @@ static IRType* translateToTupleType(
}
return builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
}
+ else if (auto targetTupleType = as<IRTargetTupleType>(type))
+ {
+ return type;
+ }
else
{
return nullptr;
@@ -76,6 +102,38 @@ static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val)
return val;
else if (as<IRTorchTensorType>(type))
return val;
+ else if (auto matrixType = as<IRMatrixType>(type))
+ {
+ auto rowCount = as<IRIntLit>(matrixType->getRowCount());
+ auto colCount = as<IRIntLit>(matrixType->getColumnCount());
+ if (!rowCount || !colCount)
+ {
+ return nullptr;
+ }
+ List<IRInst*> rowElements;
+ List<IRType*> rowTypes;
+ for (IRIntegerValue i = 0; i < rowCount->getValue(); i++)
+ {
+ List<IRInst*> colElements;
+ List<IRType*> colTypes;
+ for (IRIntegerValue j = 0; j < colCount->getValue(); j++)
+ {
+ auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i));
+ auto tupleElement = makeTargetTuple(builder, elementVal);
+ if (!tupleElement)
+ return nullptr;
+ colElements.add(tupleElement);
+ colTypes.add(tupleElement->getFullType());
+ }
+ auto rowType = builder.getTargetTupleType((UInt)colTypes.getCount(), colTypes.getBuffer());
+ rowTypes.add(rowType);
+ rowElements.add(builder.emitMakeTargetTuple(rowType, (UInt)colElements.getCount(), colElements.getBuffer()));
+ }
+ return builder.emitMakeTargetTuple(
+ builder.getTargetTupleType((UInt)rowTypes.getCount(), rowTypes.getBuffer()),
+ (UInt)rowElements.getCount(),
+ rowElements.getBuffer());
+ }
else if (auto vectorType = as<IRVectorType>(type))
{
auto count = as<IRIntLit>(vectorType->getElementCount());
@@ -134,6 +192,10 @@ static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val)
auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer());
}
+ else if (auto targetTupleType = as<IRTargetTupleType>(type))
+ {
+ return val;
+ }
else
{
return nullptr;
@@ -149,6 +211,25 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst
return val;
else if (as<IRTorchTensorType>(type))
return val;
+ else if (auto matrixType = as<IRMatrixType>(type))
+ {
+ auto rowCount = as<IRIntLit>(matrixType->getRowCount());
+ auto colCount = as<IRIntLit>(matrixType->getColumnCount());
+ SLANG_ASSERT(rowCount && colCount);
+
+ List<IRInst*> resultElements;
+ auto rowType = builder.getTargetTupleType((UInt)colCount->getValue(), List<IRType*>().makeRepeated(matrixType->getElementType(), (Index)colCount->getValue()).getBuffer());
+ for (IRIntegerValue i = 0; i < rowCount->getValue(); i++)
+ {
+ auto rowElement = builder.emitTargetTupleGetElement(rowType, val, builder.getIntValue(builder.getIntType(), i));
+ for (IRIntegerValue j = 0; j < colCount->getValue(); j++)
+ {
+ auto element = builder.emitTargetTupleGetElement(matrixType->getElementType(), rowElement, builder.getIntValue(builder.getIntType(), j));
+ resultElements.add(element);
+ }
+ }
+ return builder.emitMakeMatrix(type, (UInt)resultElements.getCount(), resultElements.getBuffer());
+ }
else if (auto vectorType = as<IRVectorType>(type))
{
auto count = as<IRIntLit>(vectorType->getElementCount());
@@ -203,6 +284,10 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst
}
return builder.emitMakeStruct(type, (UInt)resultElements.getCount(), resultElements.getBuffer());
}
+ else if (auto targetTupleType = as<IRTargetTupleType>(type))
+ {
+ return val;
+ }
else
{
return nullptr;
@@ -318,29 +403,13 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, DiagnosticSink* sink = nullptr)
{
- if (as<IRBasicType>(type) || as<IRVectorType>(type))
+ if (as<IRBasicType>(type) || as<IRVectorType>(type) || as<IRMatrixType>(type))
return type;
switch (type->getOp())
{
case kIROp_TensorViewType:
return builder->getTorchTensorType(as<IRTensorViewType>(type)->getElementType());
-#if 0
- case kIROp_VectorType:
- {
- // Create a new struct type representing the vector.
- auto hostStructType = builder->createStructType();
- const char* names[4] = { "x", "y", "z", "w" };
- for (IRIntegerValue i = 0; i < getIntVal(as<IRVectorType>(type)->getElementCount()); i++)
- {
- auto key = builder->createStructKey();
- if (i < 4)
- builder->addNameHintDecoration(key, UnownedStringSlice(names[i]));
- builder->createStructField(hostStructType, key, as<IRVectorType>(type)->getElementType());
- }
- return hostStructType;
- }
-#endif
case kIROp_StructType:
{
// Create a new struct type with translated fields.
@@ -386,18 +455,6 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp
{
case kIROp_TensorViewType:
return builder->emitMakeTensorView(cudaType, inst);
-#if 0
- case kIROp_VectorType:
- {
- List<IRInst*> args;
- auto hostStructType = cast<IRStructType>(hostType);
- for (auto field : hostStructType->getFields())
- {
- args.add(builder->emitFieldExtract(field->getFieldType(), inst, field->getKey()));
- }
- return builder->emitMakeVector(cudaType, args);
- }
-#endif
case kIROp_StructType:
{
auto cudaStructType = cast<IRStructType>(cudaType);
@@ -858,12 +915,138 @@ IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink)
return hostFunc;
}
-void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink)
+void lowerBuiltinTypesForKernelEntryPoints(IRModule* module, DiagnosticSink*)
{
- List<IRFunc*> workList;
List<IRFunc*> cudaKernels;
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (auto func = as<IRFunc>(globalInst))
+ {
+ if (func->findDecoration<IRCudaKernelDecoration>())
+ {
+ cudaKernels.add(func);
+ }
+ }
+ }
+
+ BuiltinTypeLoweringEnv typeLoweringEnv;
+ IRBuilder builder(module);
+ for (auto func : cudaKernels)
+ {
+ // Go through parameters and replace any built-in types with their equivalent.
+ List<IRParam*> params;
+ for (auto param : func->getFirstBlock()->getParams())
+ {
+ params.add(param);
+ }
+
+ bool changed = false;
+ List<LoweredBuiltinTypeInfo> loweredParamTypes;
+ for (auto param : params)
+ {
+ LoweredBuiltinTypeInfo info = lowerType(&typeLoweringEnv, &builder, param->getDataType());
+ loweredParamTypes.add(info);
+
+ if (info.convertLoweredToOriginal != nullptr)
+ {
+ // Replace parameter with the lowered type.
+ auto originalType = param->getDataType();
+ param->setFullType(info.loweredType);
+
+ // Call the conversion function to convert the lowered parameter to the original parameter.
+ List<IRInst*> args;
+ args.add(param);
+
+ setInsertAfterOrdinaryInst(&builder, param);
+ auto convertedParam = builder.emitCallInst(originalType, info.convertLoweredToOriginal, args);
+
+ // Replace all uses of the lowered parameter with the converted parameter, except for the call instruction.
+ for (auto use = param->firstUse; use;)
+ {
+ auto nextUse = use->nextUse;
+
+ if (use->getUser() == convertedParam)
+ {
+ use = nextUse;
+ continue;
+ }
+
+ use->set(convertedParam);
+ use = nextUse;
+ }
+
+ changed = true;
+ }
+ }
+
+ if (!changed)
+ continue;
+
+ fixUpFuncType(func);
+
+ // Go through any calls to this function and insert a call to converOriginalToLowered before the call.
+ for (auto use = func->firstUse; use;)
+ {
+ auto nextUse = use->nextUse;
+
+ if (as<IRCall>(use->getUser()) || as<IRDispatchKernel>(use->getUser()))
+ {
+ auto user = use->getUser();
+ IROperandList<IRInst> argsList;
+ if (auto callInst = as<IRCall>(user))
+ argsList = callInst->getArgsList();
+ else if (auto dispatchInst = as<IRDispatchKernel>(user))
+ argsList = dispatchInst->getArgsList();
+
+ // Insert a call to convertOriginalToLowered before the call.
+ List<IRInst*> convertedArgs;
+ IRBuilder callBuilder(func->getModule());
+ callBuilder.setInsertBefore(user);
+ for (auto arg : argsList)
+ {
+ if (loweredParamTypes[convertedArgs.getCount()].convertOriginalToLowered != nullptr)
+ {
+ auto convertedArg = callBuilder.emitCallInst(
+ loweredParamTypes[convertedArgs.getCount()].loweredType,
+ loweredParamTypes[convertedArgs.getCount()].convertOriginalToLowered,
+ List<IRInst*>(arg));
+ convertedArgs.add(convertedArg);
+ }
+ else
+ {
+ convertedArgs.add(arg);
+ }
+ }
+
+ // Rebuild the call/dispatch inst.
+ IRInst* newCall = nullptr;
+
+ if (auto callInst = as<IRCall>(user))
+ newCall = callBuilder.emitCallInst(user->getFullType(), func, convertedArgs);
+ else if (auto dispatchInst = as<IRDispatchKernel>(user))
+ newCall = callBuilder.emitDispatchKernelInst(
+ user->getFullType(),
+ func,
+ dispatchInst->getThreadGroupSize(),
+ dispatchInst->getDispatchSize(),
+ convertedArgs.getCount(),
+ convertedArgs.getBuffer());
+
+ // Replace the call instruction.
+ user->replaceUsesWith(newCall);
+
+ // Remove the call instruction.
+ user->removeAndDeallocate();
+ }
+
+ use = nextUse;
+ }
+ }
+}
+
+void generateHostFunctionsForAutoBindCuda(IRModule* module, DiagnosticSink* sink)
+{
List<IRFunc*> autoBindRequests;
- List<IRType*> typesToExport;
for (auto globalInst : module->getGlobalInsts())
{
if (auto func = as<IRFunc>(globalInst))
@@ -872,6 +1055,24 @@ void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink)
{
autoBindRequests.add(func);
}
+ }
+ }
+
+ for (auto func : autoBindRequests)
+ {
+ generateCUDAWrapperForFunc(func, sink);
+ }
+}
+
+void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink)
+{
+ List<IRFunc*> workList;
+ List<IRFunc*> cudaKernels;
+ List<IRType*> typesToExport;
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (auto func = as<IRFunc>(globalInst))
+ {
if (func->findDecoration<IRTorchEntryPointDecoration>())
{
workList.add(func);
@@ -895,16 +1096,6 @@ void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink)
}
}
- // Generate CUDA wrappers for all functions that have the auto-bind decoration.
- for (auto func : autoBindRequests)
- {
- if (auto hostFunc = generateCUDAWrapperForFunc(func, sink))
- {
- // Add generated wrapper to worklist for python bindings.
- workList.add(hostFunc);
- }
- }
-
for (auto func : workList)
generateCppBindingForFunc(func, sink);
@@ -967,7 +1158,6 @@ void handleAutoBindNames(IRModule* module)
nameBuilder << "__kernel__" << externCppHint->getName();
externCppHint->removeAndDeallocate();
builder.addExternCppDecoration(globalInst, nameBuilder.getUnownedSlice());
- builder.addExternCDecoration(globalInst);
}
}
}
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.h b/source/slang/slang-ir-pytorch-cpp-binding.h
index dd7dcc9a4..a761dbc03 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.h
+++ b/source/slang/slang-ir-pytorch-cpp-binding.h
@@ -6,9 +6,11 @@ struct IRModule;
class DiagnosticSink;
void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink);
+void generateHostFunctionsForAutoBindCuda(IRModule* module, DiagnosticSink* sink);
void removeTorchKernels(IRModule* module);
void handleAutoBindNames(IRModule* module);
void generateDerivativeWrappers(IRModule* module, DiagnosticSink* sink);
+void lowerBuiltinTypesForKernelEntryPoints(IRModule* module, DiagnosticSink* sink);
}