summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-pytorch-cpp-binding.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-04-30 16:05:33 -0400
committerGitHub <noreply@github.com>2024-04-30 16:05:33 -0400
commit52b91231cdadc048f93b224f5035759cf1a96eaa (patch)
tree23d3263bc662eb96d6284266282695a9b0f1e2db /source/slang/slang-ir-pytorch-cpp-binding.cpp
parent70111daf43c87e182695666c34345e061e114a68 (diff)
Added diagnostics & built-in type lowering for `[CUDAKernel]` functions (#4042)
* Added diagnostics & built-in type lowering for `[CUDAKernel]` functions This PR adds - Diagnostics for non-void return from a cuda kernel entry point - Diagnostics for using differentiable types in a differentiable cuda kernel entry point - Logic for converting built-in types (float3, float3x3, etc..) to portable struct types and unpacks the parameter back into a built-in type on the CUDA side. This is because built-in types have different implementations in CUDA & CPP targets, which causes signature mis-match when linking. * Fix error codes * Add ability to lower structs and arrays that contain built-in types. + Added tests + Fix issue where the host-side was not marshalling data to lowered types. * Update slang-ir-pytorch-cpp-binding.cpp --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir-pytorch-cpp-binding.cpp')
-rw-r--r--source/slang/slang-ir-pytorch-cpp-binding.cpp278
1 files changed, 234 insertions, 44 deletions
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp
index 6a85f0324..fd885dae7 100644
--- a/source/slang/slang-ir-pytorch-cpp-binding.cpp
+++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp
@@ -3,6 +3,7 @@
#include "slang-ir-insts.h"
#include "slang-diagnostics.h"
#include "slang-ir-autodiff.h"
+#include "slang-ir-lower-cuda-builtin-types.h"
namespace Slang
{
@@ -13,10 +14,31 @@ static IRType* translateToTupleType(
{
if (as<IRVoidType>(type))
return type;
- if (as<IRBasicType>(type))
+ else if (as<IRBasicType>(type))
return type;
else if (as<IRTorchTensorType>(type))
return type;
+ else if (auto matrixType = as<IRMatrixType>(type))
+ {
+ auto rowCount = as<IRIntLit>(matrixType->getRowCount());
+ auto colCount = as<IRIntLit>(matrixType->getColumnCount());
+ if (!rowCount || !colCount)
+ {
+ return nullptr;
+ }
+ List<IRType*> elementTypes;
+ for (IRIntegerValue i = 0; i < rowCount->getValue(); i++)
+ {
+ elementTypes.addRange(matrixType->getElementType());
+ }
+ auto elementTupleType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
+ List<IRType*> rowTypes;
+ for (IRIntegerValue i = 0; i < colCount->getValue(); i++)
+ {
+ rowTypes.add(elementTupleType);
+ }
+ return builder.getTargetTupleType((UInt)rowTypes.getCount(), rowTypes.getBuffer());
+ }
else if (auto vectorType = as<IRVectorType>(type))
{
auto count = as<IRIntLit>(vectorType->getElementCount());
@@ -60,6 +82,10 @@ static IRType* translateToTupleType(
}
return builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
}
+ else if (auto targetTupleType = as<IRTargetTupleType>(type))
+ {
+ return type;
+ }
else
{
return nullptr;
@@ -76,6 +102,38 @@ static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val)
return val;
else if (as<IRTorchTensorType>(type))
return val;
+ else if (auto matrixType = as<IRMatrixType>(type))
+ {
+ auto rowCount = as<IRIntLit>(matrixType->getRowCount());
+ auto colCount = as<IRIntLit>(matrixType->getColumnCount());
+ if (!rowCount || !colCount)
+ {
+ return nullptr;
+ }
+ List<IRInst*> rowElements;
+ List<IRType*> rowTypes;
+ for (IRIntegerValue i = 0; i < rowCount->getValue(); i++)
+ {
+ List<IRInst*> colElements;
+ List<IRType*> colTypes;
+ for (IRIntegerValue j = 0; j < colCount->getValue(); j++)
+ {
+ auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i));
+ auto tupleElement = makeTargetTuple(builder, elementVal);
+ if (!tupleElement)
+ return nullptr;
+ colElements.add(tupleElement);
+ colTypes.add(tupleElement->getFullType());
+ }
+ auto rowType = builder.getTargetTupleType((UInt)colTypes.getCount(), colTypes.getBuffer());
+ rowTypes.add(rowType);
+ rowElements.add(builder.emitMakeTargetTuple(rowType, (UInt)colElements.getCount(), colElements.getBuffer()));
+ }
+ return builder.emitMakeTargetTuple(
+ builder.getTargetTupleType((UInt)rowTypes.getCount(), rowTypes.getBuffer()),
+ (UInt)rowElements.getCount(),
+ rowElements.getBuffer());
+ }
else if (auto vectorType = as<IRVectorType>(type))
{
auto count = as<IRIntLit>(vectorType->getElementCount());
@@ -134,6 +192,10 @@ static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val)
auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer());
return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer());
}
+ else if (auto targetTupleType = as<IRTargetTupleType>(type))
+ {
+ return val;
+ }
else
{
return nullptr;
@@ -149,6 +211,25 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst
return val;
else if (as<IRTorchTensorType>(type))
return val;
+ else if (auto matrixType = as<IRMatrixType>(type))
+ {
+ auto rowCount = as<IRIntLit>(matrixType->getRowCount());
+ auto colCount = as<IRIntLit>(matrixType->getColumnCount());
+ SLANG_ASSERT(rowCount && colCount);
+
+ List<IRInst*> resultElements;
+ auto rowType = builder.getTargetTupleType((UInt)colCount->getValue(), List<IRType*>().makeRepeated(matrixType->getElementType(), (Index)colCount->getValue()).getBuffer());
+ for (IRIntegerValue i = 0; i < rowCount->getValue(); i++)
+ {
+ auto rowElement = builder.emitTargetTupleGetElement(rowType, val, builder.getIntValue(builder.getIntType(), i));
+ for (IRIntegerValue j = 0; j < colCount->getValue(); j++)
+ {
+ auto element = builder.emitTargetTupleGetElement(matrixType->getElementType(), rowElement, builder.getIntValue(builder.getIntType(), j));
+ resultElements.add(element);
+ }
+ }
+ return builder.emitMakeMatrix(type, (UInt)resultElements.getCount(), resultElements.getBuffer());
+ }
else if (auto vectorType = as<IRVectorType>(type))
{
auto count = as<IRIntLit>(vectorType->getElementCount());
@@ -203,6 +284,10 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst
}
return builder.emitMakeStruct(type, (UInt)resultElements.getCount(), resultElements.getBuffer());
}
+ else if (auto targetTupleType = as<IRTargetTupleType>(type))
+ {
+ return val;
+ }
else
{
return nullptr;
@@ -318,29 +403,13 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink)
IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, DiagnosticSink* sink = nullptr)
{
- if (as<IRBasicType>(type) || as<IRVectorType>(type))
+ if (as<IRBasicType>(type) || as<IRVectorType>(type) || as<IRMatrixType>(type))
return type;
switch (type->getOp())
{
case kIROp_TensorViewType:
return builder->getTorchTensorType(as<IRTensorViewType>(type)->getElementType());
-#if 0
- case kIROp_VectorType:
- {
- // Create a new struct type representing the vector.
- auto hostStructType = builder->createStructType();
- const char* names[4] = { "x", "y", "z", "w" };
- for (IRIntegerValue i = 0; i < getIntVal(as<IRVectorType>(type)->getElementCount()); i++)
- {
- auto key = builder->createStructKey();
- if (i < 4)
- builder->addNameHintDecoration(key, UnownedStringSlice(names[i]));
- builder->createStructField(hostStructType, key, as<IRVectorType>(type)->getElementType());
- }
- return hostStructType;
- }
-#endif
case kIROp_StructType:
{
// Create a new struct type with translated fields.
@@ -386,18 +455,6 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp
{
case kIROp_TensorViewType:
return builder->emitMakeTensorView(cudaType, inst);
-#if 0
- case kIROp_VectorType:
- {
- List<IRInst*> args;
- auto hostStructType = cast<IRStructType>(hostType);
- for (auto field : hostStructType->getFields())
- {
- args.add(builder->emitFieldExtract(field->getFieldType(), inst, field->getKey()));
- }
- return builder->emitMakeVector(cudaType, args);
- }
-#endif
case kIROp_StructType:
{
auto cudaStructType = cast<IRStructType>(cudaType);
@@ -858,12 +915,138 @@ IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink)
return hostFunc;
}
-void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink)
+void lowerBuiltinTypesForKernelEntryPoints(IRModule* module, DiagnosticSink*)
{
- List<IRFunc*> workList;
List<IRFunc*> cudaKernels;
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (auto func = as<IRFunc>(globalInst))
+ {
+ if (func->findDecoration<IRCudaKernelDecoration>())
+ {
+ cudaKernels.add(func);
+ }
+ }
+ }
+
+ BuiltinTypeLoweringEnv typeLoweringEnv;
+ IRBuilder builder(module);
+ for (auto func : cudaKernels)
+ {
+ // Go through parameters and replace any built-in types with their equivalent.
+ List<IRParam*> params;
+ for (auto param : func->getFirstBlock()->getParams())
+ {
+ params.add(param);
+ }
+
+ bool changed = false;
+ List<LoweredBuiltinTypeInfo> loweredParamTypes;
+ for (auto param : params)
+ {
+ LoweredBuiltinTypeInfo info = lowerType(&typeLoweringEnv, &builder, param->getDataType());
+ loweredParamTypes.add(info);
+
+ if (info.convertLoweredToOriginal != nullptr)
+ {
+ // Replace parameter with the lowered type.
+ auto originalType = param->getDataType();
+ param->setFullType(info.loweredType);
+
+ // Call the conversion function to convert the lowered parameter to the original parameter.
+ List<IRInst*> args;
+ args.add(param);
+
+ setInsertAfterOrdinaryInst(&builder, param);
+ auto convertedParam = builder.emitCallInst(originalType, info.convertLoweredToOriginal, args);
+
+ // Replace all uses of the lowered parameter with the converted parameter, except for the call instruction.
+ for (auto use = param->firstUse; use;)
+ {
+ auto nextUse = use->nextUse;
+
+ if (use->getUser() == convertedParam)
+ {
+ use = nextUse;
+ continue;
+ }
+
+ use->set(convertedParam);
+ use = nextUse;
+ }
+
+ changed = true;
+ }
+ }
+
+ if (!changed)
+ continue;
+
+ fixUpFuncType(func);
+
+ // Go through any calls to this function and insert a call to converOriginalToLowered before the call.
+ for (auto use = func->firstUse; use;)
+ {
+ auto nextUse = use->nextUse;
+
+ if (as<IRCall>(use->getUser()) || as<IRDispatchKernel>(use->getUser()))
+ {
+ auto user = use->getUser();
+ IROperandList<IRInst> argsList;
+ if (auto callInst = as<IRCall>(user))
+ argsList = callInst->getArgsList();
+ else if (auto dispatchInst = as<IRDispatchKernel>(user))
+ argsList = dispatchInst->getArgsList();
+
+ // Insert a call to convertOriginalToLowered before the call.
+ List<IRInst*> convertedArgs;
+ IRBuilder callBuilder(func->getModule());
+ callBuilder.setInsertBefore(user);
+ for (auto arg : argsList)
+ {
+ if (loweredParamTypes[convertedArgs.getCount()].convertOriginalToLowered != nullptr)
+ {
+ auto convertedArg = callBuilder.emitCallInst(
+ loweredParamTypes[convertedArgs.getCount()].loweredType,
+ loweredParamTypes[convertedArgs.getCount()].convertOriginalToLowered,
+ List<IRInst*>(arg));
+ convertedArgs.add(convertedArg);
+ }
+ else
+ {
+ convertedArgs.add(arg);
+ }
+ }
+
+ // Rebuild the call/dispatch inst.
+ IRInst* newCall = nullptr;
+
+ if (auto callInst = as<IRCall>(user))
+ newCall = callBuilder.emitCallInst(user->getFullType(), func, convertedArgs);
+ else if (auto dispatchInst = as<IRDispatchKernel>(user))
+ newCall = callBuilder.emitDispatchKernelInst(
+ user->getFullType(),
+ func,
+ dispatchInst->getThreadGroupSize(),
+ dispatchInst->getDispatchSize(),
+ convertedArgs.getCount(),
+ convertedArgs.getBuffer());
+
+ // Replace the call instruction.
+ user->replaceUsesWith(newCall);
+
+ // Remove the call instruction.
+ user->removeAndDeallocate();
+ }
+
+ use = nextUse;
+ }
+ }
+}
+
+void generateHostFunctionsForAutoBindCuda(IRModule* module, DiagnosticSink* sink)
+{
List<IRFunc*> autoBindRequests;
- List<IRType*> typesToExport;
for (auto globalInst : module->getGlobalInsts())
{
if (auto func = as<IRFunc>(globalInst))
@@ -872,6 +1055,24 @@ void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink)
{
autoBindRequests.add(func);
}
+ }
+ }
+
+ for (auto func : autoBindRequests)
+ {
+ generateCUDAWrapperForFunc(func, sink);
+ }
+}
+
+void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink)
+{
+ List<IRFunc*> workList;
+ List<IRFunc*> cudaKernels;
+ List<IRType*> typesToExport;
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (auto func = as<IRFunc>(globalInst))
+ {
if (func->findDecoration<IRTorchEntryPointDecoration>())
{
workList.add(func);
@@ -895,16 +1096,6 @@ void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink)
}
}
- // Generate CUDA wrappers for all functions that have the auto-bind decoration.
- for (auto func : autoBindRequests)
- {
- if (auto hostFunc = generateCUDAWrapperForFunc(func, sink))
- {
- // Add generated wrapper to worklist for python bindings.
- workList.add(hostFunc);
- }
- }
-
for (auto func : workList)
generateCppBindingForFunc(func, sink);
@@ -967,7 +1158,6 @@ void handleAutoBindNames(IRModule* module)
nameBuilder << "__kernel__" << externCppHint->getName();
externCppHint->removeAndDeallocate();
builder.addExternCppDecoration(globalInst, nameBuilder.getUnownedSlice());
- builder.addExternCDecoration(globalInst);
}
}
}