diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-10-05 12:52:49 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-10-05 09:52:49 -0700 |
| commit | 441e13e13f30b96eb04c05725ad7fe1983c92f53 (patch) | |
| tree | aee5c31b62876ef8ad60a37b2a4767b6f1a299c6 /source/slang/slang-ir-pytorch-cpp-binding.cpp | |
| parent | 65751ce222adb302e62b5b7b6312de65638abed5 (diff) | |
Various AD Fixes (#3263)
* Various fixes
* Remove unused parameter
* Update slang-ir-loop-unroll.cpp
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir-pytorch-cpp-binding.cpp')
| -rw-r--r-- | source/slang/slang-ir-pytorch-cpp-binding.cpp | 110 |
1 files changed, 95 insertions, 15 deletions
diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp index 41665ddf7..3a7e8b9fb 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.cpp +++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp @@ -177,9 +177,10 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst } List<IRInst*> resultElements; auto elementType = arrayType->getElementType(); + auto tupleElementType = translateToTupleType(builder, elementType); for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) { - auto tupleElement = builder.emitTargetTupleGetElement(elementType, val, builder.getIntValue(builder.getIntType(), i)); + auto tupleElement = builder.emitTargetTupleGetElement(tupleElementType, val, builder.getIntValue(builder.getIntType(), i)); auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement); if (!convertedElement) return nullptr; @@ -346,7 +347,7 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, Diag List<IRType*> fieldTypes; for (auto field : as<IRStructType>(type)->getFields()) { - fieldTypes.add(translateToHostType(builder, field->getFieldType(), func)); + fieldTypes.add(translateToHostType(builder, field->getFieldType(), func, sink)); } auto hostStructType = builder->createStructType(); @@ -358,6 +359,13 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, Diag return hostStructType; } + case kIROp_ArrayType: + { + auto elementType = translateToHostType(builder, as<IRArrayType>(type)->getElementType(), func, sink); + if (!elementType) + return nullptr; + return builder->getArrayType(elementType, as<IRArrayType>(type)->getElementCount()); + } default: break; } @@ -422,13 +430,36 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp return builder->emitMakeStruct(cudaType, (UInt)resultFields.getCount(), resultFields.getBuffer()); } + case kIROp_ArrayType: + { + auto cudaArrayType = cast<IRArrayType>(cudaType); + auto hostArrayType = cast<IRArrayType>(hostType); + + List<IRInst*> resultElements; + for (UInt i = 0; i < (UInt)cast<IRIntLit>(cudaArrayType->getElementCount())->getValue(); i++) + { + auto cudaElementType = cudaArrayType->getElementType(); + auto hostElementType = hostArrayType->getElementType(); + auto castedElement = castHostToCUDAType( + builder, + hostElementType, + cudaElementType, + builder->emitElementExtract(inst, builder->getIntValue(builder->getIntType(), i))); + + SLANG_RELEASE_ASSERT(castedElement); + resultElements.add(castedElement); + } + + return builder->emitMakeArray(cudaType, (UInt)resultElements.getCount(), resultElements.getBuffer()); + } default: break; } - // If translateToHostType worked correctly, we shouldn't get here. - SLANG_UNREACHABLE("unhandled type"); + // If translateToHostType worked correctly, there should be no unhandled cases here. + // However, we won't diagnose here since its already diagnosed in translateToHostType() + return nullptr; } void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* hostFunc) @@ -553,6 +584,12 @@ IRInst* generateHostParamForCUDAParam(IRBuilder* builder, IRParam* param, Diagno auto type = translateToHostType(builder, param->getDataType(), getParentFunc(param), sink); if (outType) *outType = type; + + if (!type || sink->getErrorCount() > 0) + { + return nullptr; + } + auto hostParam = builder->emitParam(type); // Add a namehint to the param by appending the suffix "_host". if (auto nameHint = param->findDecoration<IRNameHintDecoration>()) @@ -600,6 +637,38 @@ void markTypeForPyExport(IRType* type, DiagnosticSink* sink) } return; } + else if (auto arrayType = as<IRArrayType>(type)) + { + IRBuilder builder(arrayType->getModule()); + if (!arrayType->findDecoration<IRPyExportDecoration>()) + builder.addPyExportDecoration(arrayType, UnownedStringSlice("Array")); + + markTypeForPyExport(arrayType->getElementType(), sink); + return; + } +} + +String tryGetExportTypeName(IRBuilder* builder, IRType* type) +{ + if (auto structType = as<IRStructType>(type)) + { + if (auto pyExportDecoration = type->findDecoration<IRPyExportDecoration>()) + return String(pyExportDecoration->getExportName()); + else + return String(""); + } + else if (auto arrayType = as<IRArrayType>(type)) + { + StringBuilder nameBuilder; + nameBuilder << "Array_"; + nameBuilder << tryGetExportTypeName(builder, arrayType->getElementType()); + nameBuilder << "_"; + nameBuilder << cast<IRIntLit>(arrayType->getElementCount())->getValue(); + + return nameBuilder.produceString(); + } + else + return String(); } void generateReflectionForType(IRType* type, DiagnosticSink* sink) @@ -609,7 +678,6 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink) // The list will contain the names of all the fields of the type. // - // TODO: Fix this to avoid emitting the same type reflection multiple times. if (!type->findDecoration<IRPyExportDecoration>()) return; @@ -635,20 +703,32 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink) else fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("")))); - if (!field->getFieldType()->findDecoration<IRPyExportDecoration>()) - { - fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("")))); - continue; - } - auto fieldType = field->getFieldType(); + auto exportName = tryGetExportTypeName(&builder, fieldType); - fieldTypeNames.add( - builder.emitGetNativeString( - builder.getStringValue(fieldType->findDecoration<IRPyExportDecoration>()->getExportName()))); + if (exportName.getLength() > 0) + fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(exportName.getUnownedSlice()))); + else + fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("")))); } break; } + case kIROp_ArrayType: + { + auto elementType = as<IRArrayType>(type)->getElementType(); + fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("type")))); + fieldTypeNames.add( + builder.emitGetNativeString( + builder.getStringValue(tryGetExportTypeName(&builder, elementType).getUnownedSlice()))); + + auto elementCount = as<IRIntLit>(as<IRArrayType>(type)->getElementCount()); + fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("size")))); + + StringBuilder elementCountStr; + elementCountStr << elementCount->getValue(); + fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(elementCountStr.getUnownedSlice()))); + break; + } default: break; } @@ -676,7 +756,7 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink) // Set function name. StringBuilder reflFuncExportName; - reflFuncExportName << "__typeinfo__" << type->findDecoration<IRPyExportDecoration>()->getExportName(); + reflFuncExportName << "__typeinfo__" << tryGetExportTypeName(&builder, type).getUnownedSlice(); builder.addTorchEntryPointDecoration(reflFunc, reflFuncExportName.getUnownedSlice()); builder.addExternCppDecoration(reflFunc, reflFuncExportName.getUnownedSlice()); |
