diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-27 23:00:42 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-27 23:00:42 -0700 |
| commit | 0a6926003fd2300858e3089fe82f421543852395 (patch) | |
| tree | 19865fa9eb69373f0c0c16b7fac4993f67aa2b20 /source/slang/slang-ir-pytorch-cpp-binding.cpp | |
| parent | d120fec7e81bbd5e8cf2c551b573feaf6678b43d (diff) | |
Translate all composed types into tuple types in pyBind. (#2744)
* Translate all composed types into tuple types in pyBind.
* Delete temp file.
* Fix get tuple element code emit logic.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-pytorch-cpp-binding.cpp')
| -rw-r--r-- | source/slang/slang-ir-pytorch-cpp-binding.cpp | 202 |
1 files changed, 164 insertions, 38 deletions
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp index e33adec1d..eb81bfd8c 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.cpp +++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp @@ -5,113 +5,204 @@ namespace Slang { -static bool getHostReturnTypeImpl(List<IRType*>& elementTypes, IRBuilder& builder, IRType* type) +// Convert a type to a target tuple type. +static IRType* translateToTupleType(IRBuilder& builder, IRType* type) { - bool isValid = true; if (as<IRVoidType>(type)) - return true; + return type; if (as<IRBasicType>(type)) - elementTypes.add(type); + return type; else if (as<IRTorchTensorType>(type)) - elementTypes.add(type); + return type; else if (auto vectorType = as<IRVectorType>(type)) { auto count = as<IRIntLit>(vectorType->getElementCount()); if (!count) { - return false; + return nullptr; } + List<IRType*> elementTypes; for (IRIntegerValue i = 0; i < count->getValue(); i++) { elementTypes.addRange(vectorType->getElementType()); } + return builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); } else if (auto arrayType = as<IRArrayType>(type)) { auto arraySize = as<IRIntLit>(arrayType->getElementCount()); if (!arraySize) { - return false; + return nullptr; } List<IRType*> subElementTypes; - isValid &= getHostReturnTypeImpl(subElementTypes, builder, arrayType->getElementType()); + auto subElementType = translateToTupleType(builder, arrayType->getElementType()); for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) { - elementTypes.addRange(subElementTypes); + subElementTypes.addRange(subElementType); } + return builder.getTargetTupleType((UInt)subElementTypes.getCount(), subElementTypes.getBuffer()); } else if (auto structType = as<IRStructType>(type)) { + List<IRType*> elementTypes; for (auto field : structType->getFields()) { - isValid &= getHostReturnTypeImpl(elementTypes, builder, field->getFieldType()); + auto fieldType = translateToTupleType(builder, field->getFieldType()); + if (!fieldType) + { + return nullptr; + } + elementTypes.addRange(fieldType); } + return builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); } else { - return false; + return nullptr; } - return isValid; -} - -static IRType* getHostReturnType(IRBuilder& builder, IRType* type) -{ - List<IRType*> types; - bool isValid = getHostReturnTypeImpl(types, builder, type); - if (isValid) - return builder.getTargetTupleType((UInt)types.getCount(), types.getBuffer()); - return nullptr; } -static void flattenToTupleImpl(List<IRInst*>& result, IRBuilder& builder, IRInst* val) +// Convert a value to a target tuple type. +static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val) { auto type = val->getDataType(); if (as<IRVoidType>(type)) - return; + return val; if (as<IRBasicType>(type)) - result.add(val); + return val; else if (as<IRTorchTensorType>(type)) - result.add(val); + return val; else if (auto vectorType = as<IRVectorType>(type)) { auto count = as<IRIntLit>(vectorType->getElementCount()); if (!count) { - return; + return nullptr; } + List<IRInst*> resultElements; + List<IRType*> elementTypes; for (IRIntegerValue i = 0; i < count->getValue(); i++) { - result.add(builder.emitElementExtract(vectorType->getElementType(), builder.getIntValue(builder.getIntType(), i))); + auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i)); + auto tupleElement = makeTargetTuple(builder, elementVal); + if (!tupleElement) + return nullptr; + resultElements.add(tupleElement); + elementTypes.add(tupleElement->getFullType()); } + auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); + return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer()); } else if (auto arrayType = as<IRArrayType>(type)) { auto arraySize = as<IRIntLit>(arrayType->getElementCount()); if (!arraySize) { - return; + return nullptr; } + List<IRInst*> resultElements; + List<IRType*> elementTypes; for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) { auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i)); - flattenToTupleImpl(result, builder, elementVal); + auto tupleElement = makeTargetTuple(builder, elementVal); + if (!tupleElement) + return nullptr; + resultElements.add(tupleElement); + elementTypes.add(tupleElement->getFullType()); } + auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); + return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer()); } else if (auto structType = as<IRStructType>(type)) { + List<IRInst*> resultElements; + List<IRType*> elementTypes; for (auto field : structType->getFields()) { auto elementVal = builder.emitFieldExtract(field->getFieldType(), val, field->getKey()); - flattenToTupleImpl(result, builder, elementVal); + auto tupleElement = makeTargetTuple(builder, elementVal); + if (!tupleElement) + return nullptr; + resultElements.add(tupleElement); + elementTypes.add(tupleElement->getFullType()); } + auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); + return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer()); + } + else + { + return nullptr; } } -static IRInst* flattenToHostReturnTuple(IRBuilder& builder, IRType* type, IRInst* val) +// Convert a target tuple type to a value. +static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst* val) { - List<IRInst*> vals; - flattenToTupleImpl(vals, builder, val); - return builder.emitMakeTargetTuple(type, (UInt)vals.getCount(), vals.getBuffer()); + if (as<IRVoidType>(type)) + return val; + if (as<IRBasicType>(type)) + return val; + else if (as<IRTorchTensorType>(type)) + return val; + else if (auto vectorType = as<IRVectorType>(type)) + { + auto count = as<IRIntLit>(vectorType->getElementCount()); + if (!count) + { + return nullptr; + } + List<IRInst*> resultElements; + auto elementType = vectorType->getElementType(); + for (IRIntegerValue i = 0; i < count->getValue(); i++) + { + auto tupleElement = builder.emitTargetTupleGetElement(elementType, val, builder.getIntValue(builder.getIntType(), i)); + auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement); + if (!convertedElement) + return nullptr; + resultElements.add(convertedElement); + } + return builder.emitMakeVector(type, (UInt)resultElements.getCount(), resultElements.getBuffer()); + } + else if (auto arrayType = as<IRArrayType>(type)) + { + auto arraySize = as<IRIntLit>(arrayType->getElementCount()); + if (!arraySize) + { + return nullptr; + } + List<IRInst*> resultElements; + auto elementType = arrayType->getElementType(); + for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) + { + auto tupleElement = builder.emitTargetTupleGetElement(elementType, val, builder.getIntValue(builder.getIntType(), i)); + auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement); + if (!convertedElement) + return nullptr; + resultElements.add(convertedElement); + } + return builder.emitMakeArray(type, (UInt)resultElements.getCount(), resultElements.getBuffer()); + } + else if (auto structType = as<IRStructType>(type)) + { + List<IRInst*> resultElements; + IRIntegerValue i = 0; + for (auto field : structType->getFields()) + { + auto tupleElement = builder.emitTargetTupleGetElement(field->getFieldType(), val, builder.getIntValue(builder.getIntType(), i)); + auto convertedElement = makeValueFromTargetTuple(builder, field->getFieldType(), tupleElement); + if (!convertedElement) + return nullptr; + resultElements.add(convertedElement); + i++; + } + return builder.emitMakeStruct(type, (UInt)resultElements.getCount(), resultElements.getBuffer()); + } + else + { + return nullptr; + } } static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) @@ -119,7 +210,7 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) IRBuilder builder(func); builder.setInsertBefore(func); - auto hostReturnType = getHostReturnType(builder, func->getResultType()); + auto hostReturnType = translateToTupleType(builder, func->getResultType()); if (!hostReturnType) { sink->diagnose(func->sourceLoc, Diagnostics::invalidTorchKernelReturnType, func->getResultType()); @@ -129,15 +220,50 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) auto funcType = as<IRFuncType>(func->getDataType()); for (UInt i = 0; i < funcType->getParamCount(); i++) { - hostParamTypes.add(funcType->getParamType(i)); + hostParamTypes.add(translateToTupleType(builder, funcType->getParamType(i))); } auto bindingFuncType = builder.getFuncType(hostParamTypes, hostReturnType); func->setFullType(bindingFuncType); builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); - auto allocator = builder.emitVar(builder.getType(kIROp_TorchKernelMemoryAllocatorType)); List<IRInst*> instsToRemove; + List<IRInst*> oldParams; + for (auto param : func->getFirstBlock()->getParams()) + { + oldParams.add(param); + } + + List<IRInst*> newParams; + for (auto param : oldParams) + { + auto paramType = param->getFullType(); + auto newParamType = translateToTupleType(builder, paramType); + if (!newParamType) + { + sink->diagnose(param->sourceLoc, Diagnostics::invalidTorchKernelParamType, paramType); + return; + } + auto newParam = builder.emitParam(newParamType); + newParams.add(newParam); + } + + // Convert all new parameters from tuples to their original types. + for (Index i = 0; i < newParams.getCount(); i++) + { + auto oldParam = oldParams[i]; + auto newParam = newParams[i]; + auto convertedParam = makeValueFromTargetTuple(builder, oldParam->getFullType(), newParam); + if (!convertedParam) + { + return; + } + oldParam->replaceUsesWith(convertedParam); + oldParam->removeAndDeallocate(); + } + + auto allocator = builder.emitVar(builder.getType(kIROp_TorchKernelMemoryAllocatorType)); + for (auto block : func->getBlocks()) { for (auto inst : block->getChildren()) @@ -177,7 +303,7 @@ static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) else if (auto ret = as<IRReturn>(inst)) { builder.setInsertBefore(ret); - auto retVal = flattenToHostReturnTuple(builder, hostReturnType, ret->getVal()); + auto retVal = makeTargetTuple(builder, ret->getVal()); ret->setOperand(0, retVal); } } |
