diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-10-05 12:52:49 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-10-05 09:52:49 -0700 |
| commit | 441e13e13f30b96eb04c05725ad7fe1983c92f53 (patch) | |
| tree | aee5c31b62876ef8ad60a37b2a4767b6f1a299c6 | |
| parent | 65751ce222adb302e62b5b7b6312de65638abed5 (diff) | |
Various AD Fixes (#3263)
* Various fixes
* Remove unused parameter
* Update slang-ir-loop-unroll.cpp
---------
Co-authored-by: Yong He <yonghe@outlook.com>
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 14 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-loop-unroll.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-ir-pytorch-cpp-binding.cpp | 110 | ||||
| -rw-r--r-- | source/slang/slang-ir-synthesize-active-mask.cpp | 59 | ||||
| -rw-r--r-- | tests/autodiff/generic-differential-synthesis.slang | 35 | ||||
| -rw-r--r-- | tests/autodiff/generic-differential-synthesis.slang.expected.txt | 5 |
9 files changed, 185 insertions, 59 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index abdd89b01..05a6ed249 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -422,7 +422,9 @@ namespace Slang return result; } - Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult(LookupResultItem const& item, Expr* originalExpr) + Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult( + LookupResultItem const& item, + Expr* originalExpr) { // If the only result from lookup is an entry in an interface decl, it could be that // the user is leaving out an explicit definition for the requirement and depending on @@ -521,13 +523,16 @@ namespace Slang conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType(); conformanceDecl->parentDecl = structDecl; structDecl->members.add(conformanceDecl); + structDecl->parentDecl = parent; synthesizedDecl = structDecl; auto typeDef = m_astBuilder->create<TypeAliasDecl>(); typeDef->nameAndLoc.name = getName("Differential"); - auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl)); - typeDef->type.type = DeclRefType::create(m_astBuilder, declRef); typeDef->parentDecl = structDecl; + + auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl)); + + typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef); structDecl->members.add(typeDef); } break; @@ -545,8 +550,9 @@ namespace Slang auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>(); addModifier(synthesizedDecl, toBeSynthesized); + auto synthDeclMemberRef = m_astBuilder->getMemberDeclRef(subType->getDeclRef(), synthesizedDecl); return ConstructDeclRefExpr( - makeDeclRef(synthesizedDecl), + synthDeclMemberRef, nullptr, originalExpr ? originalExpr->loc : SourceLoc(), originalExpr); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index bbf6885a8..86136a010 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -459,6 +459,9 @@ Result linkAndOptimizeIR( break; } + if (sink->getErrorCount() != 0) + return SLANG_FAIL; + // If we have a target that is GPU like we use the string hashing mechanism // but for that to work we need to inline such that calls (or returns) of strings // boil down into getStringHash(stringLiteral) diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index bd32a1896..026b8b110 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -552,8 +552,6 @@ INST(TargetSwitch, targetSwitch, 1, 0) // A generic asm inst has an return semantics that terminates the control flow. INST(GenericAsm, GenericAsm, 1, 0) -INST(RequirePrelude, RequirePrelude, 1, 0) - INST(discard, discard, 0, 0) /* IRUnreachable */ @@ -563,6 +561,8 @@ INST_RANGE(Unreachable, MissingReturn, Unreachable) INST_RANGE(TerminatorInst, Return, Unreachable) +INST(RequirePrelude, RequirePrelude, 1, 0) + // TODO: We should consider splitting the basic arithmetic/comparison // ops into cases for signed integers, unsigned integers, and floating-point // values, to better match downstream targets that want to treat them diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 070f989b5..c04450b82 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3040,7 +3040,7 @@ struct IRSPIRVAsm : IRInst } }; -struct IRGenericAsm : IRInst +struct IRGenericAsm : IRTerminatorInst { IR_LEAF_ISA(GenericAsm) UnownedStringSlice getAsm() { return as<IRStringLit>(getOperand(0))->getStringSlice(); } diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp index b5af2d974..6970942c9 100644 --- a/source/slang/slang-ir-loop-unroll.cpp +++ b/source/slang/slang-ir-loop-unroll.cpp @@ -472,15 +472,9 @@ bool unrollLoopsInModule(IRModule* module, DiagnosticSink* sink) for (auto inst : module->getGlobalInsts()) { if (auto genFunc = as<IRGeneric>(inst)) - { - if (auto func = as<IRGlobalValueWithCode>(findGenericReturnVal(genFunc))) - { - bool result = unrollLoopsInFunc(module, func, sink); - if (!result) - return false; - } - } - else if (auto func = as<IRGlobalValueWithCode>(inst)) + continue; + + if (auto func = as<IRGlobalValueWithCode>(inst)) { bool result = unrollLoopsInFunc(module, func, sink); if (!result) diff --git a/source/slang/slang-ir-pytorch-cpp-binding.cpp b/source/slang/slang-ir-pytorch-cpp-binding.cpp index 41665ddf7..3a7e8b9fb 100644 --- a/source/slang/slang-ir-pytorch-cpp-binding.cpp +++ b/source/slang/slang-ir-pytorch-cpp-binding.cpp @@ -177,9 +177,10 @@ static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst } List<IRInst*> resultElements; auto elementType = arrayType->getElementType(); + auto tupleElementType = translateToTupleType(builder, elementType); for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) { - auto tupleElement = builder.emitTargetTupleGetElement(elementType, val, builder.getIntValue(builder.getIntType(), i)); + auto tupleElement = builder.emitTargetTupleGetElement(tupleElementType, val, builder.getIntValue(builder.getIntType(), i)); auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement); if (!convertedElement) return nullptr; @@ -346,7 +347,7 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, Diag List<IRType*> fieldTypes; for (auto field : as<IRStructType>(type)->getFields()) { - fieldTypes.add(translateToHostType(builder, field->getFieldType(), func)); + fieldTypes.add(translateToHostType(builder, field->getFieldType(), func, sink)); } auto hostStructType = builder->createStructType(); @@ -358,6 +359,13 @@ IRType* translateToHostType(IRBuilder* builder, IRType* type, IRInst* func, Diag return hostStructType; } + case kIROp_ArrayType: + { + auto elementType = translateToHostType(builder, as<IRArrayType>(type)->getElementType(), func, sink); + if (!elementType) + return nullptr; + return builder->getArrayType(elementType, as<IRArrayType>(type)->getElementCount()); + } default: break; } @@ -422,13 +430,36 @@ IRInst* castHostToCUDAType(IRBuilder* builder, IRType* hostType, IRType* cudaTyp return builder->emitMakeStruct(cudaType, (UInt)resultFields.getCount(), resultFields.getBuffer()); } + case kIROp_ArrayType: + { + auto cudaArrayType = cast<IRArrayType>(cudaType); + auto hostArrayType = cast<IRArrayType>(hostType); + + List<IRInst*> resultElements; + for (UInt i = 0; i < (UInt)cast<IRIntLit>(cudaArrayType->getElementCount())->getValue(); i++) + { + auto cudaElementType = cudaArrayType->getElementType(); + auto hostElementType = hostArrayType->getElementType(); + auto castedElement = castHostToCUDAType( + builder, + hostElementType, + cudaElementType, + builder->emitElementExtract(inst, builder->getIntValue(builder->getIntType(), i))); + + SLANG_RELEASE_ASSERT(castedElement); + resultElements.add(castedElement); + } + + return builder->emitMakeArray(cudaType, (UInt)resultElements.getCount(), resultElements.getBuffer()); + } default: break; } - // If translateToHostType worked correctly, we shouldn't get here. - SLANG_UNREACHABLE("unhandled type"); + // If translateToHostType worked correctly, there should be no unhandled cases here. + // However, we won't diagnose here since its already diagnosed in translateToHostType() + return nullptr; } void generateReflectionFunc(IRBuilder* builder, IRFunc* kernelFunc, IRFunc* hostFunc) @@ -553,6 +584,12 @@ IRInst* generateHostParamForCUDAParam(IRBuilder* builder, IRParam* param, Diagno auto type = translateToHostType(builder, param->getDataType(), getParentFunc(param), sink); if (outType) *outType = type; + + if (!type || sink->getErrorCount() > 0) + { + return nullptr; + } + auto hostParam = builder->emitParam(type); // Add a namehint to the param by appending the suffix "_host". if (auto nameHint = param->findDecoration<IRNameHintDecoration>()) @@ -600,6 +637,38 @@ void markTypeForPyExport(IRType* type, DiagnosticSink* sink) } return; } + else if (auto arrayType = as<IRArrayType>(type)) + { + IRBuilder builder(arrayType->getModule()); + if (!arrayType->findDecoration<IRPyExportDecoration>()) + builder.addPyExportDecoration(arrayType, UnownedStringSlice("Array")); + + markTypeForPyExport(arrayType->getElementType(), sink); + return; + } +} + +String tryGetExportTypeName(IRBuilder* builder, IRType* type) +{ + if (auto structType = as<IRStructType>(type)) + { + if (auto pyExportDecoration = type->findDecoration<IRPyExportDecoration>()) + return String(pyExportDecoration->getExportName()); + else + return String(""); + } + else if (auto arrayType = as<IRArrayType>(type)) + { + StringBuilder nameBuilder; + nameBuilder << "Array_"; + nameBuilder << tryGetExportTypeName(builder, arrayType->getElementType()); + nameBuilder << "_"; + nameBuilder << cast<IRIntLit>(arrayType->getElementCount())->getValue(); + + return nameBuilder.produceString(); + } + else + return String(); } void generateReflectionForType(IRType* type, DiagnosticSink* sink) @@ -609,7 +678,6 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink) // The list will contain the names of all the fields of the type. // - // TODO: Fix this to avoid emitting the same type reflection multiple times. if (!type->findDecoration<IRPyExportDecoration>()) return; @@ -635,20 +703,32 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink) else fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("")))); - if (!field->getFieldType()->findDecoration<IRPyExportDecoration>()) - { - fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("")))); - continue; - } - auto fieldType = field->getFieldType(); + auto exportName = tryGetExportTypeName(&builder, fieldType); - fieldTypeNames.add( - builder.emitGetNativeString( - builder.getStringValue(fieldType->findDecoration<IRPyExportDecoration>()->getExportName()))); + if (exportName.getLength() > 0) + fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(exportName.getUnownedSlice()))); + else + fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("")))); } break; } + case kIROp_ArrayType: + { + auto elementType = as<IRArrayType>(type)->getElementType(); + fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("type")))); + fieldTypeNames.add( + builder.emitGetNativeString( + builder.getStringValue(tryGetExportTypeName(&builder, elementType).getUnownedSlice()))); + + auto elementCount = as<IRIntLit>(as<IRArrayType>(type)->getElementCount()); + fieldNames.add(builder.emitGetNativeString(builder.getStringValue(UnownedStringSlice("size")))); + + StringBuilder elementCountStr; + elementCountStr << elementCount->getValue(); + fieldTypeNames.add(builder.emitGetNativeString(builder.getStringValue(elementCountStr.getUnownedSlice()))); + break; + } default: break; } @@ -676,7 +756,7 @@ void generateReflectionForType(IRType* type, DiagnosticSink* sink) // Set function name. StringBuilder reflFuncExportName; - reflFuncExportName << "__typeinfo__" << type->findDecoration<IRPyExportDecoration>()->getExportName(); + reflFuncExportName << "__typeinfo__" << tryGetExportTypeName(&builder, type).getUnownedSlice(); builder.addTorchEntryPointDecoration(reflFunc, reflFuncExportName.getUnownedSlice()); builder.addExternCppDecoration(reflFunc, reflFuncExportName.getUnownedSlice()); diff --git a/source/slang/slang-ir-synthesize-active-mask.cpp b/source/slang/slang-ir-synthesize-active-mask.cpp index 75246d553..60e13b418 100644 --- a/source/slang/slang-ir-synthesize-active-mask.cpp +++ b/source/slang/slang-ir-synthesize-active-mask.cpp @@ -1855,37 +1855,40 @@ struct SynthesizeActiveMaskForFunctionContext } else if( toBlock->getPredecessors().getCount() > 1 ) { - // If the target block is one with multiple - // predecessors, such that it will have an - // added block parameter (phi node) to select - // the corect mask value, then we need to - // pass along the mask value to use as an - // additional argument on the unconditional branch. - // - // If the old unconditional branch was: - // - // <op>(arg0, arg1, arg2, ...); - // - // Then our new branch will be: - // - // <op>(arg0, arg1, arg2, ..., toActiveMask); - // - List<IRInst*> newOperands; - UInt oldOperandCount = terminator->getOperandCount(); - for( UInt i = 0; i < oldOperandCount; ++i ) + if (doesBlockNeedActiveMask(toBlock)) { - newOperands.add(terminator->getOperand(i)); - } - newOperands.add(toActiveMask); + // If the target block is one with multiple + // predecessors, such that it will have an + // added block parameter (phi node) to select + // the corect mask value, then we need to + // pass along the mask value to use as an + // additional argument on the unconditional branch. + // + // If the old unconditional branch was: + // + // <op>(arg0, arg1, arg2, ...); + // + // Then our new branch will be: + // + // <op>(arg0, arg1, arg2, ..., toActiveMask); + // + List<IRInst*> newOperands; + UInt oldOperandCount = terminator->getOperandCount(); + for( UInt i = 0; i < oldOperandCount; ++i ) + { + newOperands.add(terminator->getOperand(i)); + } + newOperands.add(toActiveMask); - IRInst* newTerminator = builder.emitIntrinsicInst( - terminator->getFullType(), - terminator->getOp(), - newOperands.getCount(), - newOperands.getBuffer()); + IRInst* newTerminator = builder.emitIntrinsicInst( + terminator->getFullType(), + terminator->getOp(), + newOperands.getCount(), + newOperands.getBuffer()); - terminator->replaceUsesWith(newTerminator); - terminator->removeAndDeallocate(); + terminator->replaceUsesWith(newTerminator); + terminator->removeAndDeallocate(); + } } else { diff --git a/tests/autodiff/generic-differential-synthesis.slang b/tests/autodiff/generic-differential-synthesis.slang new file mode 100644 index 000000000..8c858b9b3 --- /dev/null +++ b/tests/autodiff/generic-differential-synthesis.slang @@ -0,0 +1,35 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +__generic<let C : int> +struct Foo : IDifferentiable +{ + float x[C]; +}; + +[Differentiable] +Foo<3> getFoo(float x) +{ + return { { x, x, x } }; +} + +[Differentiable] +float foobar(float x) +{ + int i = 3 * int(floor(x)); + Foo<3> foo = getFoo(x); + return foo.x[i] * foo.x[i]; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + { + float a = 0.5; + var d = fwd_diff(foobar)(diffPair(a, 1.0)).d; + outputBuffer[0] = d; + } +} diff --git a/tests/autodiff/generic-differential-synthesis.slang.expected.txt b/tests/autodiff/generic-differential-synthesis.slang.expected.txt new file mode 100644 index 000000000..97de29f1f --- /dev/null +++ b/tests/autodiff/generic-differential-synthesis.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +1.000000 +0.000000 +0.000000 +0.000000 |
