diff options
Diffstat (limited to 'source/slang/slang-ir-pytorch-cpp-binding.cpp')
| -rw-r--r-- | source/slang/slang-ir-pytorch-cpp-binding.cpp | 278 |
1 files changed, 234 insertions, 44 deletions
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp index 6a85f0324..fd885dae7 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.cpp +++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp @@ -3,6 +3,7 @@ #include "slang-ir-insts.h" #include "slang-diagnostics.h" #include "slang-ir-autodiff.h" +#include "slang-ir-lower-cuda-builtin-types.h" namespace Slang { @@ -13,10 +14,31 @@ static IRType* translateToTupleType( { if (as<IRVoidType>(type)) return type; - if (as<IRBasicType>(type)) + else if (as<IRBasicType>(type)) return type; else if (as<IRTorchTensorType>(type)) return type; + else if (auto matrixType = as<IRMatrixType>(type)) + { + auto rowCount = as<IRIntLit>(matrixType->getRowCount()); + auto colCount = as<IRIntLit>(matrixType->getColumnCount()); + if (!rowCount || !colCount) + { + return nullptr; + } + List<IRType*> elementTypes; + for (IRIntegerValue i = 0; i < rowCount->getValue(); i++) + { + elementTypes.addRange(matrixType->getElementType()); + } + auto elementTupleType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); + List<IRType*> rowTypes; + for (IRIntegerValue i = 0; i < colCount->getValue(); i++) + { + rowTypes.add(elementTupleType); + } + return builder.getTargetTupleType((UInt)rowTypes.getCount(), rowTypes.getBuffer()); + } else if (auto vectorType = as<IRVectorType>(type)) { auto count = as<IRIntLit>(vectorType->getElementCount()); @@ -60,6 +82,10 @@ static IRType* translateToTupleType( } return builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); } + else if (auto targetTupleType = as<IRTargetTupleType>(type)) + { + return type; + } else { return nullptr; @@ -76,6 +102,38 @@ static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val) return val; else if (as<IRTorchTensorType>(type)) return val; + else if (auto matrixType = as<IRMatrixType>(type)) + { + auto rowCount = as<IRIntLit>(matrixType->getRowCount()); + auto colCount = as<IRIntLit>(matrixType->getColumnCount()); + if (!rowCount || !colCount) + { + return nullptr; + } + List<IRInst*> rowElements; + List<IRType*> rowTypes; + for (IRIntegerValue i = 0; i < rowCount->getValue(); i++) + { + List<IRInst*> colElements; + List<IRType*> colTypes; + for (IRIntegerValue j = 0; j < colCount->getValue(); j++) + { + auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i)); + auto tupleElement = makeTargetTuple(builder, elementVal); + if (!tupleElement) + return nullptr; + colElements.add(tupleElement); + colTypes.add(tupleElement->getFullType()); + } + auto rowType = builder.getTargetTupleType((UInt)colTypes.getCount(), colTypes.getBuffer()); + rowTypes.add(rowType); + rowElements.add(builder.emitMakeTargetTuple(rowType, (UInt)colElements.getCount(), colElements.getBuffer())); + } + return builder.emitMakeTargetTuple( + builder.getTargetTupleType((UInt)rowTypes.getCount(), rowTypes.getBuffer()), + (UInt)rowElements.getCount(), + rowElements.getBuffer()); + } else if (auto vectorType = as<IRVectorType>(type)) { auto count = as<IRIntLit>(vectorType->getElementCount()); @@ -134,6 +192,10 @@ static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val) auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer()); } + else if (auto targetTupleType = as<IRTargetTupleType>(type)) + { + return val; + } else { return nullptr; @@ -149,6 +211,25 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst return val; else if (as<IRTorchTensorType>(type)) return val; + else if (auto matrixType = as<IRMatrixType>(type)) + { + auto rowCount = as<IRIntLit>(matrixType->getRowCount()); + auto colCount = as<IRIntLit>(matrixType->getColumnCount()); + SLANG_ASSERT(rowCount && colCount); + + List<IRInst*> resultElements; + auto rowType = builder.getTargetTupleType((UInt)colCount->getValue(), List<IRType*>().makeRepeated(matrixType->getElementType(), (Index)colCount->getValue()).getBuffer()); + for (IRIntegerValue i = 0; i < rowCount->getValue(); i++) + { + auto rowElement = builder.emitTargetTupleGetElement(rowType, val, builder.getIntValue(builder.getIntType(), i)); + for (IRIntegerValue j = 0; j < colCount->getValue(); j++) + { + auto element = builder.emitTargetTupleGetElement(matrixType->getElementType(), rowElement, builder.getIntValue(builder.getIntType(), j)); + resultElements.add(element); + } + } + return builder.emitMakeMatrix(type, (UInt)resultElements.getCount(), resultElements.getBuffer()); + } else if (auto vectorType = as<IRVectorType>(type)) { auto count = as<IRIntLit>(vectorType->getElementCount()); @@ -203,6 +284,10 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst } return builder.emitMakeStruct(type, (UInt)resultElements.getCount(), resultElements.getBuffer()); } + else if (auto targetTupleType = as<IRTargetTupleType>(type)) + { + return val; + } else { return nullptr; @@ -318,29 +403,13 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, DiagnosticSink* sink = nullptr) { - if (as<IRBasicType>(type) || as<IRVectorType>(type)) + if (as<IRBasicType>(type) || as<IRVectorType>(type) || as<IRMatrixType>(type)) return type; switch (type->getOp()) { case kIROp_TensorViewType: return builder->getTorchTensorType(as<IRTensorViewType>(type)->getElementType()); -#if 0 - case kIROp_VectorType: - { - // Create a new struct type representing the vector. - auto hostStructType = builder->createStructType(); - const char* names[4] = { "x", "y", "z", "w" }; - for (IRIntegerValue i = 0; i < getIntVal(as<IRVectorType>(type)->getElementCount()); i++) - { - auto key = builder->createStructKey(); - if (i < 4) - builder->addNameHintDecoration(key, UnownedStringSlice(names[i])); - builder->createStructField(hostStructType, key, as<IRVectorType>(type)->getElementType()); - } - return hostStructType; - } -#endif case kIROp_StructType: { // Create a new struct type with translated fields. @@ -386,18 +455,6 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp { case kIROp_TensorViewType: return builder->emitMakeTensorView(cudaType, inst); -#if 0 - case kIROp_VectorType: - { - List<IRInst*> args; - auto hostStructType = cast<IRStructType>(hostType); - for (auto field : hostStructType->getFields()) - { - args.add(builder->emitFieldExtract(field->getFieldType(), inst, field->getKey())); - } - return builder->emitMakeVector(cudaType, args); - } -#endif case kIROp_StructType: { auto cudaStructType = cast<IRStructType>(cudaType); @@ -858,12 +915,138 @@ IRFunc* generateCUDAWrapperForFunc(IRFunc* func, DiagnosticSink* sink) return hostFunc; } -void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) +void lowerBuiltinTypesForKernelEntryPoints(IRModule* module, DiagnosticSink*) { - List<IRFunc*> workList; List<IRFunc*> cudaKernels; + for (auto globalInst : module->getGlobalInsts()) + { + if (auto func = as<IRFunc>(globalInst)) + { + if (func->findDecoration<IRCudaKernelDecoration>()) + { + cudaKernels.add(func); + } + } + } + + BuiltinTypeLoweringEnv typeLoweringEnv; + IRBuilder builder(module); + for (auto func : cudaKernels) + { + // Go through parameters and replace any built-in types with their equivalent. + List<IRParam*> params; + for (auto param : func->getFirstBlock()->getParams()) + { + params.add(param); + } + + bool changed = false; + List<LoweredBuiltinTypeInfo> loweredParamTypes; + for (auto param : params) + { + LoweredBuiltinTypeInfo info = lowerType(&typeLoweringEnv, &builder, param->getDataType()); + loweredParamTypes.add(info); + + if (info.convertLoweredToOriginal != nullptr) + { + // Replace parameter with the lowered type. + auto originalType = param->getDataType(); + param->setFullType(info.loweredType); + + // Call the conversion function to convert the lowered parameter to the original parameter. + List<IRInst*> args; + args.add(param); + + setInsertAfterOrdinaryInst(&builder, param); + auto convertedParam = builder.emitCallInst(originalType, info.convertLoweredToOriginal, args); + + // Replace all uses of the lowered parameter with the converted parameter, except for the call instruction. + for (auto use = param->firstUse; use;) + { + auto nextUse = use->nextUse; + + if (use->getUser() == convertedParam) + { + use = nextUse; + continue; + } + + use->set(convertedParam); + use = nextUse; + } + + changed = true; + } + } + + if (!changed) + continue; + + fixUpFuncType(func); + + // Go through any calls to this function and insert a call to converOriginalToLowered before the call. + for (auto use = func->firstUse; use;) + { + auto nextUse = use->nextUse; + + if (as<IRCall>(use->getUser()) || as<IRDispatchKernel>(use->getUser())) + { + auto user = use->getUser(); + IROperandList<IRInst> argsList; + if (auto callInst = as<IRCall>(user)) + argsList = callInst->getArgsList(); + else if (auto dispatchInst = as<IRDispatchKernel>(user)) + argsList = dispatchInst->getArgsList(); + + // Insert a call to convertOriginalToLowered before the call. + List<IRInst*> convertedArgs; + IRBuilder callBuilder(func->getModule()); + callBuilder.setInsertBefore(user); + for (auto arg : argsList) + { + if (loweredParamTypes[convertedArgs.getCount()].convertOriginalToLowered != nullptr) + { + auto convertedArg = callBuilder.emitCallInst( + loweredParamTypes[convertedArgs.getCount()].loweredType, + loweredParamTypes[convertedArgs.getCount()].convertOriginalToLowered, + List<IRInst*>(arg)); + convertedArgs.add(convertedArg); + } + else + { + convertedArgs.add(arg); + } + } + + // Rebuild the call/dispatch inst. + IRInst* newCall = nullptr; + + if (auto callInst = as<IRCall>(user)) + newCall = callBuilder.emitCallInst(user->getFullType(), func, convertedArgs); + else if (auto dispatchInst = as<IRDispatchKernel>(user)) + newCall = callBuilder.emitDispatchKernelInst( + user->getFullType(), + func, + dispatchInst->getThreadGroupSize(), + dispatchInst->getDispatchSize(), + convertedArgs.getCount(), + convertedArgs.getBuffer()); + + // Replace the call instruction. + user->replaceUsesWith(newCall); + + // Remove the call instruction. + user->removeAndDeallocate(); + } + + use = nextUse; + } + } +} + +void generateHostFunctionsForAutoBindCuda(IRModule* module, DiagnosticSink* sink) +{ List<IRFunc*> autoBindRequests; - List<IRType*> typesToExport; for (auto globalInst : module->getGlobalInsts()) { if (auto func = as<IRFunc>(globalInst)) @@ -872,6 +1055,24 @@ void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) { autoBindRequests.add(func); } + } + } + + for (auto func : autoBindRequests) + { + generateCUDAWrapperForFunc(func, sink); + } +} + +void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) +{ + List<IRFunc*> workList; + List<IRFunc*> cudaKernels; + List<IRType*> typesToExport; + for (auto globalInst : module->getGlobalInsts()) + { + if (auto func = as<IRFunc>(globalInst)) + { if (func->findDecoration<IRTorchEntryPointDecoration>()) { workList.add(func); @@ -895,16 +1096,6 @@ void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) } } - // Generate CUDA wrappers for all functions that have the auto-bind decoration. - for (auto func : autoBindRequests) - { - if (auto hostFunc = generateCUDAWrapperForFunc(func, sink)) - { - // Add generated wrapper to worklist for python bindings. - workList.add(hostFunc); - } - } - for (auto func : workList) generateCppBindingForFunc(func, sink); @@ -967,7 +1158,6 @@ void handleAutoBindNames(IRModule* module) nameBuilder << "__kernel__" << externCppHint->getName(); externCppHint->removeAndDeallocate(); builder.addExternCppDecoration(globalInst, nameBuilder.getUnownedSlice()); - builder.addExternCDecoration(globalInst); } } } |
