From 52b91231cdadc048f93b224f5035759cf1a96eaa Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 30 Apr 2024 16:05:33 -0400 Subject: Added diagnostics & built-in type lowering for `[CUDAKernel]` functions (#4042) * Added diagnostics & built-in type lowering for `[CUDAKernel]` functions This PR adds - Diagnostics for non-void return from a cuda kernel entry point - Diagnostics for using differentiable types in a differentiable cuda kernel entry point - Logic for converting built-in types (float3, float3x3, etc..) to portable struct types and unpacks the parameter back into a built-in type on the CUDA side. This is because built-in types have different implementations in CUDA & CPP targets, which causes signature mis-match when linking. * Fix error codes * Add ability to lower structs and arrays that contain built-in types. + Added tests + Fix issue where the host-side was not marshalling data to lowered types. * Update slang-ir-pytorch-cpp-binding.cpp --------- Co-authored-by: Yong He --- source/slang/slang-ir-pytorch-cpp-binding.cpp | 278 ++++++++++++++++++++++---- 1 file changed, 234 insertions(+), 44 deletions(-) (limited to 'source/slang/slang-ir-pytorch-cpp-binding.cpp') diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp index 6a85f0324..fd885dae7 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.cpp +++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp @@ -3,6 +3,7 @@ #include "slang-ir-insts.h" #include "slang-diagnostics.h" #include "slang-ir-autodiff.h" +#include "slang-ir-lower-cuda-builtin-types.h" namespace Slang { @@ -13,10 +14,31 @@ static IRType* translateToTupleType( { if (as(type)) return type; - if (as(type)) + else if (as(type)) return type; else if (as(type)) return type; + else if (auto matrixType = as(type)) + { + auto rowCount = as(matrixType->getRowCount()); + auto colCount = as(matrixType->getColumnCount()); + if (!rowCount || !colCount) + { + return nullptr; + } + List elementTypes; + for (IRIntegerValue i = 0; i < rowCount->getValue(); i++) + { + elementTypes.addRange(matrixType->getElementType()); + } + auto elementTupleType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); + List rowTypes; + for (IRIntegerValue i = 0; i < colCount->getValue(); i++) + { + rowTypes.add(elementTupleType); + } + return builder.getTargetTupleType((UInt)rowTypes.getCount(), rowTypes.getBuffer()); + } else if (auto vectorType = as(type)) { auto count = as(vectorType->getElementCount()); @@ -60,6 +82,10 @@ static IRType* translateToTupleType( } return builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); } + else if (auto targetTupleType = as(type)) + { + return type; + } else { return nullptr; @@ -76,6 +102,38 @@ static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val) return val; else if (as(type)) return val; + else if (auto matrixType = as(type)) + { + auto rowCount = as(matrixType->getRowCount()); + auto colCount = as(matrixType->getColumnCount()); + if (!rowCount || !colCount) + { + return nullptr; + } + List rowElements; + List rowTypes; + for (IRIntegerValue i = 0; i < rowCount->getValue(); i++) + { + List colElements; + List colTypes; + for (IRIntegerValue j = 0; j < colCount->getValue(); j++) + { + auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i)); + auto tupleElement = makeTargetTuple(builder, elementVal); + if (!tupleElement) + return nullptr; + colElements.add(tupleElement); + colTypes.add(tupleElement->getFullType()); + } + auto rowType = builder.getTargetTupleType((UInt)colTypes.getCount(), colTypes.getBuffer()); + rowTypes.add(rowType); + rowElements.add(builder.emitMakeTargetTuple(rowType, (UInt)colElements.getCount(), colElements.getBuffer())); + } + return builder.emitMakeTargetTuple( + builder.getTargetTupleType((UInt)rowTypes.getCount(), rowTypes.getBuffer()), + (UInt)rowElements.getCount(), + rowElements.getBuffer()); + } else if (auto vectorType = as(type)) { auto count = as(vectorType->getElementCount()); @@ -134,6 +192,10 @@ static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val) auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer()); } + else if (auto targetTupleType = as(type)) + { + return val; + } else { return nullptr; @@ -149,6 +211,25 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst return val; else if (as(type)) return val; + else if (auto matrixType = as(type)) + { + auto rowCount = as(matrixType->getRowCount()); + auto colCount = as(matrixType->getColumnCount()); + SLANG_ASSERT(rowCount && colCount); + + List resultElements; + auto rowType = builder.getTargetTupleType((UInt)colCount->getValue(), List().makeRepeated(matrixType->getElementType(), (Index)colCount->getValue()).getBuffer()); + for (IRIntegerValue i = 0; i < rowCount->getValue(); i++) + { + auto rowElement = builder.emitTargetTupleGetElement(rowType, val, builder.getIntValue(builder.getIntType(), i)); + for (IRIntegerValue j = 0; j < colCount->getValue(); j++) + { + auto element = builder.emitTargetTupleGetElement(matrixType->getElementType(), rowElement, builder.getIntValue(builder.getIntType(), j)); + resultElements.add(element); + } + } + return builder.emitMakeMatrix(type, (UInt)resultElements.getCount(), resultElements.getBuffer()); + } else if (auto vectorType = as(type)) { auto count = as(vectorType->getElementCount()); @@ -203,6 +284,10 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst } return builder.emitMakeStruct(type, (UInt)resultElements.getCount(), resultElements.getBuffer()); } + else if (auto targetTupleType = as(type)) + { + return val; + } else { return nullptr; @@ -318,29 +403,13 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, DiagnosticSink* sink = nullptr) { - if (as(type) || as(type)) + if (as(type) || as(type) || as(type)) return type; switch (type->getOp()) { case kIROp_TensorViewType: return builder->getTorchTensorType(as(type)->getElementType()); -#if 0 - case kIROp_VectorType: - { - // Create a new struct type representing the vector. - auto hostStructType = builder->createStructType(); - const char* names[4] = { "x", "y", "z", "w" }; - for (IRIntegerValue i = 0; i < getIntVal(as(type)->getElementCount()); i++) - { - auto key = builder->createStructKey(); - if (i < 4) - builder->addNameHintDecoration(key, UnownedStringSlice(names[i])); - builder->createStructField(hostStructType, key, as(type)->getElementType()); - } - return hostStructType; - } -#endif case kIROp_StructType: { // Create a new struct type with translated fields. @@ -386,18 +455,6 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp { case kIROp_TensorViewType: return builder->emitMakeTensorView(cudaType, inst); -#if 0 - case kIROp_VectorType: - { - List args; - auto hostStructType = cast(hostType); - for (auto field : hostStructType->getFields()) - { - args.add(builder->emitFieldExtract(field->getFieldType(), inst, field->getKey())); - } - return builder->emitMakeVector(cudaType, args); - } -#endif case kIROp_StructType: { auto cudaStructType = cast(cudaType); @@ -858,12 +915,138 @@ IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink) return hostFunc; } -void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) +void lowerBuiltinTypesForKernelEntryPoints(IRModule* module, DiagnosticSink*) { - List workList; List cudaKernels; + for (auto globalInst : module->getGlobalInsts()) + { + if (auto func = as(globalInst)) + { + if (func->findDecoration()) + { + cudaKernels.add(func); + } + } + } + + BuiltinTypeLoweringEnv typeLoweringEnv; + IRBuilder builder(module); + for (auto func : cudaKernels) + { + // Go through parameters and replace any built-in types with their equivalent. + List params; + for (auto param : func->getFirstBlock()->getParams()) + { + params.add(param); + } + + bool changed = false; + List loweredParamTypes; + for (auto param : params) + { + LoweredBuiltinTypeInfo info = lowerType(&typeLoweringEnv, &builder, param->getDataType()); + loweredParamTypes.add(info); + + if (info.convertLoweredToOriginal != nullptr) + { + // Replace parameter with the lowered type. + auto originalType = param->getDataType(); + param->setFullType(info.loweredType); + + // Call the conversion function to convert the lowered parameter to the original parameter. + List args; + args.add(param); + + setInsertAfterOrdinaryInst(&builder, param); + auto convertedParam = builder.emitCallInst(originalType, info.convertLoweredToOriginal, args); + + // Replace all uses of the lowered parameter with the converted parameter, except for the call instruction. + for (auto use = param->firstUse; use;) + { + auto nextUse = use->nextUse; + + if (use->getUser() == convertedParam) + { + use = nextUse; + continue; + } + + use->set(convertedParam); + use = nextUse; + } + + changed = true; + } + } + + if (!changed) + continue; + + fixUpFuncType(func); + + // Go through any calls to this function and insert a call to converOriginalToLowered before the call. + for (auto use = func->firstUse; use;) + { + auto nextUse = use->nextUse; + + if (as(use->getUser()) || as(use->getUser())) + { + auto user = use->getUser(); + IROperandList argsList; + if (auto callInst = as(user)) + argsList = callInst->getArgsList(); + else if (auto dispatchInst = as(user)) + argsList = dispatchInst->getArgsList(); + + // Insert a call to convertOriginalToLowered before the call. + List convertedArgs; + IRBuilder callBuilder(func->getModule()); + callBuilder.setInsertBefore(user); + for (auto arg : argsList) + { + if (loweredParamTypes[convertedArgs.getCount()].convertOriginalToLowered != nullptr) + { + auto convertedArg = callBuilder.emitCallInst( + loweredParamTypes[convertedArgs.getCount()].loweredType, + loweredParamTypes[convertedArgs.getCount()].convertOriginalToLowered, + List(arg)); + convertedArgs.add(convertedArg); + } + else + { + convertedArgs.add(arg); + } + } + + // Rebuild the call/dispatch inst. + IRInst* newCall = nullptr; + + if (auto callInst = as(user)) + newCall = callBuilder.emitCallInst(user->getFullType(), func, convertedArgs); + else if (auto dispatchInst = as(user)) + newCall = callBuilder.emitDispatchKernelInst( + user->getFullType(), + func, + dispatchInst->getThreadGroupSize(), + dispatchInst->getDispatchSize(), + convertedArgs.getCount(), + convertedArgs.getBuffer()); + + // Replace the call instruction. + user->replaceUsesWith(newCall); + + // Remove the call instruction. + user->removeAndDeallocate(); + } + + use = nextUse; + } + } +} + +void generateHostFunctionsForAutoBindCuda(IRModule* module, DiagnosticSink* sink) +{ List autoBindRequests; - List typesToExport; for (auto globalInst : module->getGlobalInsts()) { if (auto func = as(globalInst)) @@ -872,6 +1055,24 @@ void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) { autoBindRequests.add(func); } + } + } + + for (auto func : autoBindRequests) + { + generateCUDAWrapperForFunc(func, sink); + } +} + +void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) +{ + List workList; + List cudaKernels; + List typesToExport; + for (auto globalInst : module->getGlobalInsts()) + { + if (auto func = as(globalInst)) + { if (func->findDecoration()) { workList.add(func); @@ -895,16 +1096,6 @@ void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) } } - // Generate CUDA wrappers for all functions that have the auto-bind decoration. - for (auto func : autoBindRequests) - { - if (auto hostFunc = generateCUDAWrapperForFunc(func, sink)) - { - // Add generated wrapper to worklist for python bindings. - workList.add(hostFunc); - } - } - for (auto func : workList) generateCppBindingForFunc(func, sink); @@ -967,7 +1158,6 @@ void handleAutoBindNames(IRModule* module) nameBuilder << "__kernel__" << externCppHint->getName(); externCppHint->removeAndDeallocate(); builder.addExternCppDecoration(globalInst, nameBuilder.getUnownedSlice()); - builder.addExternCDecoration(globalInst); } } } -- cgit v1.2.3