diff options
| -rw-r--r-- | source/slang/slang-ast-decl.h | 15 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 69 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 56 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 70 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 14 | ||||
| -rw-r--r-- | source/slang/slang-mangle.cpp | 6 | ||||
| -rw-r--r-- | tests/autodiff/dynamic-dispatch-bwd-diff.slang | 52 | ||||
| -rw-r--r-- | tests/autodiff/dynamic-dispatch-bwd-diff.slang.expected.txt | 6 |
8 files changed, 151 insertions, 137 deletions
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 81a6e3f7d..ccbac0286 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -546,21 +546,6 @@ class BackwardDerivativeRequirementDecl : public DerivativeRequirementDecl SLANG_AST_CLASS(BackwardDerivativeRequirementDecl) }; -class BackwardDerivativePrimalRequirementDecl : public DerivativeRequirementDecl -{ - SLANG_AST_CLASS(BackwardDerivativePrimalRequirementDecl) -}; - -class BackwardDerivativePropagateRequirementDecl : public DerivativeRequirementDecl -{ - SLANG_AST_CLASS(BackwardDerivativePropagateRequirementDecl) -}; - -class BackwardDerivativeIntermediateTypeRequirementDecl : public DerivativeRequirementDecl -{ - SLANG_AST_CLASS(BackwardDerivativeIntermediateTypeRequirementDecl) -}; - bool isInterfaceRequirement(Decl* decl); InterfaceDecl* findParentInterfaceDecl(Decl* decl); diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 142842e12..a1d5acfb0 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2677,24 +2677,6 @@ namespace Slang val->func = satisfyingMemberDeclRef; witnessTable->add(bwdReq, RequirementWitness(val)); } - else if (auto primalReq = as<BackwardDerivativePrimalRequirementDecl>(reqRefDecl->referencedDecl)) - { - DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiatePrimalVal>(); - val->func = satisfyingMemberDeclRef; - witnessTable->add(primalReq, RequirementWitness(val)); - } - else if (auto propReq = as<BackwardDerivativePropagateRequirementDecl>(reqRefDecl->referencedDecl)) - { - DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiatePropagateVal>(); - val->func = satisfyingMemberDeclRef; - witnessTable->add(propReq, RequirementWitness(val)); - } - else if (auto itypeReq = as<BackwardDerivativeIntermediateTypeRequirementDecl>(reqRefDecl->referencedDecl)) - { - DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateIntermediateTypeVal>(); - val->func = satisfyingMemberDeclRef; - witnessTable->add(itypeReq, RequirementWitness(val)); - } } witnessTable->add(requiredMemberDeclRef, RequirementWitness(satisfyingMemberDeclRef)); } @@ -5920,7 +5902,7 @@ namespace Slang if (auto interfaceDecl = findParentInterfaceDecl(decl)) { bool isDiffFunc = false; - if (decl->hasModifier<ForwardDifferentiableAttribute>()) + if (decl->hasModifier<ForwardDifferentiableAttribute>() || decl->hasModifier<BackwardDifferentiableAttribute>()) { auto reqDecl = m_astBuilder->create<ForwardDerivativeRequirementDecl>(); cloneModifiers(reqDecl, decl); @@ -5954,55 +5936,6 @@ namespace Slang reqRef->parentDecl = decl; decl->members.add(reqRef); } - // Requirement for backward derivative intermediate type. - auto intermediateTypeReqDecl = m_astBuilder->create<BackwardDerivativeIntermediateTypeRequirementDecl>(); - auto intermediateType = m_astBuilder->getOrCreateDeclRefType( - intermediateTypeReqDecl, createDefaultSubstitutions(m_astBuilder, this, decl)); - { - cloneModifiers(intermediateTypeReqDecl, decl); - interfaceDecl->members.add(intermediateTypeReqDecl); - intermediateTypeReqDecl->parentDecl = interfaceDecl; - - auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); - reqRef->referencedDecl = intermediateTypeReqDecl; - reqRef->parentDecl = decl; - decl->members.add(reqRef); - } - // Requirement for backward derivative primal func. - { - auto reqDecl = m_astBuilder->create<BackwardDerivativePrimalRequirementDecl>(); - cloneModifiers(reqDecl, decl); - FuncType* primalFuncType = m_astBuilder->create<FuncType>(); - primalFuncType->resultType = originalFuncType->resultType; - primalFuncType->paramTypes.addRange(originalFuncType->paramTypes); - auto outType = m_astBuilder->getOutType(intermediateType); - primalFuncType->paramTypes.add(outType); - setFuncTypeIntoRequirementDecl(reqDecl, primalFuncType); - interfaceDecl->members.add(reqDecl); - reqDecl->parentDecl = interfaceDecl; - - auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); - reqRef->referencedDecl = reqDecl; - reqRef->parentDecl = decl; - decl->members.add(reqRef); - } - // Requirement for backward derivative propagate func. - { - auto reqDecl = m_astBuilder->create<BackwardDerivativePropagateRequirementDecl>(); - cloneModifiers(reqDecl, decl); - interfaceDecl->members.add(reqDecl); - reqDecl->parentDecl = interfaceDecl; - FuncType* propagateFuncType = m_astBuilder->create<FuncType>(); - propagateFuncType->resultType = diffFuncType->resultType; - propagateFuncType->paramTypes.addRange(diffFuncType->paramTypes); - propagateFuncType->paramTypes.add(intermediateType); - setFuncTypeIntoRequirementDecl(reqDecl, propagateFuncType); - auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); - reqRef->referencedDecl = reqDecl; - reqRef->parentDecl = decl; - decl->members.add(reqRef); - } - isDiffFunc = true; } if (isDiffFunc) diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index a4c79d09a..e1832b9eb 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -851,6 +851,8 @@ struct DiffTransposePass if (as<IRDecoration>(child) || as<IRParam>(child)) continue; + if (as<IRType>(child)) + continue; if (isDifferentialInst(child)) transposeInst(&builder, child); @@ -1332,10 +1334,6 @@ struct DiffTransposePass } } - // The call must have been decorated with the continuation context after splitting. - auto primalContextDecor = fwdCall->findDecoration<IRBackwardDerivativePrimalContextDecoration>(); - SLANG_RELEASE_ASSERT(primalContextDecor); - auto baseFn = fwdDiffCallee->getBaseFn(); List<IRInst*> args; @@ -1453,20 +1451,52 @@ struct DiffTransposePass argRequiresLoad.add(false); } - // Ensure availability of the primal context var - auto primalContextVar = hoistPrimalInst(builder, primalContextDecor->getBackwardDerivativePrimalContextVar()); - SLANG_RELEASE_ASSERT(primalContextVar); + // If the callee provides a primal implementation that produces continuation context for propagation phase + // we grab it and pass it as argument to the propagation function. + if (auto primalContextDecor = fwdCall->findDecoration<IRBackwardDerivativePrimalContextDecoration>()) + { + // Ensure availability of the primal context var + auto primalContextVar = hoistPrimalInst(builder, primalContextDecor->getBackwardDerivativePrimalContextVar()); + SLANG_RELEASE_ASSERT(primalContextVar); - args.add(builder->emitLoad(primalContextVar)); - argTypes.add(as<IRPtrTypeBase>( + args.add(builder->emitLoad(primalContextVar)); + argTypes.add(as<IRPtrTypeBase>( primalContextVar->getDataType()) ->getValueType()); - argRequiresLoad.add(false); + argRequiresLoad.add(false); + } auto revFnType = builder->getFuncType(argTypes, builder->getVoidType()); - auto revCallee = builder->emitBackwardDifferentiatePropagateInst( - revFnType, - baseFn); + IRInst* revCallee = nullptr; + if (getResolvedInstForDecorations(baseFn)->getOp() == kIROp_LookupWitness) + { + // This is an interface method call, we can simply transcribe it here. + auto specialize = as<IRSpecialize>(baseFn); + auto innerFn = baseFn; + if (specialize) + innerFn = specialize->getBase(); + auto lookupWitness = as<IRLookupWitnessMethod>(innerFn); + SLANG_RELEASE_ASSERT(lookupWitness); + auto diffDecor = lookupWitness->getRequirementKey()->findDecoration<IRBackwardDerivativeDecoration>(); + SLANG_RELEASE_ASSERT(diffDecor); + auto diffKey = diffDecor->getBackwardDerivativeFunc(); + revCallee = builder->emitLookupInterfaceMethodInst(builder->getTypeKind(), lookupWitness->getWitnessTable(), diffKey); + if (specialize) + { + List<IRInst*> specArgs; + for (UInt i = 0; i < specialize->getArgCount(); i++) + specArgs.add(specialize->getArg(i)); + revCallee = builder->emitSpecializeInst(builder->getTypeKind(), revCallee, specArgs.getCount(), specArgs.getBuffer()); + } + revCallee->setFullType(revFnType); + } + else + { + // All other calls, we insert a `backwardDifferentiate` inst so we will process it in a follow-up iteration. + revCallee = builder->emitBackwardDifferentiatePropagateInst( + revFnType, + baseFn); + } List<IRInst*> callArgs; for (auto arg : args) diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 2ebc330f0..a30826370 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -830,28 +830,51 @@ struct DiffUnzipPass { auto func = findSpecializeReturnVal(specialize); auto outerGen = findOuterGeneric(func); - intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(outerGen); - List<IRInst*> args; - for (UInt i = 0; i < specialize->getArgCount(); i++) - args.add(specialize->getArg(i)); - intermediateType = primalBuilder->emitSpecializeInst( - primalBuilder->getTypeKind(), - intermediateType, - args.getCount(), - args.getBuffer()); + if (func->getOp() == kIROp_LookupWitness) + { + // An interface method won't have intermediate type. + intermediateType = primalBuilder->getVoidType(); + } + else + { + intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(outerGen); + List<IRInst*> args; + for (UInt i = 0; i < specialize->getArgCount(); i++) + args.add(specialize->getArg(i)); + intermediateType = primalBuilder->emitSpecializeInst( + primalBuilder->getTypeKind(), + intermediateType, + args.getCount(), + args.getBuffer()); + } } else { - intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(baseFn); + if (baseFn->getOp() == kIROp_LookupWitness) + intermediateType = primalBuilder->getVoidType(); + else + intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(baseFn); } - auto intermediateVar = primalBuilder->emitVar((IRType*)intermediateType); - primalBuilder->markInstAsPrimal(intermediateVar); + IRVar* intermediateVar = nullptr; + if (!as<IRVoidType>(intermediateType)) + { + intermediateVar = primalBuilder->emitVar((IRType*)intermediateType); + primalBuilder->markInstAsPrimal(intermediateVar); + } - primalBuilder->addBackwardDerivativePrimalContextDecoration(intermediateVar, intermediateVar); - - auto primalFn = primalBuilder->emitBackwardDifferentiatePrimalInst(primalFuncType, baseFn); - + IRInst* primalFn = nullptr; + if (intermediateVar) + { + primalBuilder->addBackwardDerivativePrimalContextDecoration(intermediateVar, intermediateVar); + primalFn = primalBuilder->emitBackwardDifferentiatePrimalInst(primalFuncType, baseFn); + } + else + { + // If we decided not to use diff-primal func that stores an reuse context, + // we can just call the original function instead. + primalFn = baseFn; + } List<IRInst*> primalArgs; for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++) { @@ -865,7 +888,8 @@ struct DiffUnzipPass primalArgs.add(arg); } } - primalArgs.add(intermediateVar); + if (intermediateType->getOp() != kIROp_VoidType) + primalArgs.add(intermediateVar); auto mixedDecoration = mixedCall->findDecoration<IRMixedDifferentialInstDecoration>(); SLANG_ASSERT(mixedDecoration); @@ -881,7 +905,8 @@ struct DiffUnzipPass } auto primalVal = primalBuilder->emitCallInst(primalType, primalFn, primalArgs); - primalBuilder->addBackwardDerivativePrimalContextDecoration(primalVal, intermediateVar); + if (intermediateVar) + primalBuilder->addBackwardDerivativePrimalContextDecoration(primalVal, intermediateVar); primalBuilder->markInstAsPrimal(primalVal); SLANG_RELEASE_ASSERT(mixedCall->getArgCount() <= primalFuncType->getParamCount()); @@ -960,9 +985,12 @@ struct DiffUnzipPass diffArgs); diffBuilder->markInstAsDifferential(callInst, primalType); - disableIRValidationAtInsert(); - diffBuilder->addBackwardDerivativePrimalContextDecoration(callInst, intermediateVar); - enableIRValidationAtInsert(); + if (intermediateVar) + { + disableIRValidationAtInsert(); + diffBuilder->addBackwardDerivativePrimalContextDecoration(callInst, intermediateVar); + enableIRValidationAtInsert(); + } IRInst* diffVal = nullptr; if (as<IRDifferentialPairType>(callInst->getDataType())) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index d09c35eea..261e08168 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -6899,14 +6899,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { op = kIROp_BackwardDerivativeDecoration; } - else if (as<BackwardDerivativePropagateRequirementDecl>(requirementDecl)) - { - op = kIROp_BackwardDerivativePropagateDecoration; - } - else if (as<BackwardDerivativePrimalRequirementDecl>(requirementDecl)) - { - op = kIROp_BackwardDerivativePrimalDecoration; - } else if (as<ForwardDerivativeRequirementDecl>(requirementDecl)) { op = kIROp_ForwardDerivativeDecoration; @@ -8534,12 +8526,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> UNREACHABLE_RETURN(LoweredValInfo()); } - LoweredValInfo visitBackwardDerivativeIntermediateTypeRequirementDecl(BackwardDerivativeIntermediateTypeRequirementDecl* decl) - { - SLANG_UNUSED(decl); - return LoweredValInfo(getBuilder()->getTypeKind()); - } - LoweredValInfo visitFunctionDeclBase(FunctionDeclBase* decl) { // A function declaration may have multiple, target-specific diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index da5099934..a7d047a0c 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -521,12 +521,6 @@ namespace Slang emitRaw(context, "FwdReq_"); else if (as<BackwardDerivativeRequirementDecl>(decl)) emitRaw(context, "BwdReq_"); - else if (as<BackwardDerivativePropagateRequirementDecl>(decl)) - emitRaw(context, "BwdReq_Prop_"); - else if (as<BackwardDerivativePrimalRequirementDecl>(decl)) - emitRaw(context, "BwdReq_Primal_"); - else if (as<BackwardDerivativeIntermediateTypeRequirementDecl>(decl)) - emitRaw(context, "BwdReq_CtxType_"); else { // TODO: handle other cases diff --git a/tests/autodiff/dynamic-dispatch-bwd-diff.slang b/tests/autodiff/dynamic-dispatch-bwd-diff.slang new file mode 100644 index 000000000..5945c22cd --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-bwd-diff.slang @@ -0,0 +1,52 @@ +// Test calling backward differentiable function through dynamic dispatch. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[anyValueSize(16)] +interface IInterface +{ + [BackwardDifferentiable] + float calc(float x); +} + +struct A : IInterface +{ + float a; + [BackwardDifferentiable] + float calc(float x) { return a*x*x; } +}; + +struct B : IInterface +{ + float a; + [BackwardDifferentiable] + float calc(float x) { return a*x*x*x; } +}; + +[BackwardDifferentiable] +float run(IInterface obj, float x) +{ + return obj.calc(x); +} + +//TEST_INPUT: type_conformance A:IInterface = 0 +//TEST_INPUT: type_conformance B:IInterface = 1 + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var obj = createDynamicObject<IInterface>(dispatchThreadID.x, 0.5f); // A + var p = diffPair(3.0); + + __bwd_diff(run)(obj, p, 1.0f); + outputBuffer[0] = p.d; // A.calc, expect 3 + + obj = createDynamicObject<IInterface>(dispatchThreadID.x + 1, 1.5f); // B + p = diffPair(3.0); + __bwd_diff(run)(obj, p, 1.0f); + outputBuffer[1] = p.d; // B.calc, expect 40.5 +} diff --git a/tests/autodiff/dynamic-dispatch-bwd-diff.slang.expected.txt b/tests/autodiff/dynamic-dispatch-bwd-diff.slang.expected.txt new file mode 100644 index 000000000..57bb1ee65 --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-bwd-diff.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +3.000000 +40.500000 +0.000000 +0.000000 +0.000000 |
