diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-23 16:02:56 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-23 16:02:56 -0800 |
| commit | 4ad0470025da4e808c46023f9a2525febcf973a2 (patch) | |
| tree | 8fcb1c84121ddf40c50ca58b5de867da0da435ee | |
| parent | 97cb4851eed7a43f10196971b08d3d311386ce9f (diff) | |
Fix issues around dynamic generic function and autodiff. (#2528)
* Fix issues around dynamic generic function and autodiff.
* Fix return type issue.
* Fix type unification for generic `inout` parameter.
* Fix.
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/slang-ast-decl.cpp | 13 | ||||
| -rw-r--r-- | source/slang/slang-ast-decl.h | 12 | ||||
| -rw-r--r-- | source/slang/slang-ast-type.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-ast-type.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 61 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 40 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 37 | ||||
| -rw-r--r-- | tests/autodiff/bool-return-val.slang | 28 | ||||
| -rw-r--r-- | tests/autodiff/bool-return-val.slang.expected.txt | 5 | ||||
| -rw-r--r-- | tests/autodiff/dynamic-dispatch-generic-2.slang | 49 | ||||
| -rw-r--r-- | tests/autodiff/dynamic-dispatch-generic-2.slang.expected.txt | 6 | ||||
| -rw-r--r-- | tests/autodiff/dynamic-dispatch-generic.slang | 46 | ||||
| -rw-r--r-- | tests/autodiff/dynamic-dispatch-generic.slang.expected.txt | 6 | ||||
| -rw-r--r-- | tests/autodiff/generic-autodiff-1.slang | 39 | ||||
| -rw-r--r-- | tests/autodiff/generic-autodiff-1.slang.expected.txt | 6 |
16 files changed, 291 insertions, 70 deletions
diff --git a/source/slang/slang-ast-decl.cpp b/source/slang/slang-ast-decl.cpp index b2802e304..9931bbcaf 100644 --- a/source/slang/slang-ast-decl.cpp +++ b/source/slang/slang-ast-decl.cpp @@ -18,6 +18,19 @@ const TypeExp& TypeConstraintDecl::_getSupOverride() const //return TypeExp::empty; } +InterfaceDecl* findParentInterfaceDecl(Decl* decl) +{ + auto ancestor = decl->parentDecl; + for (; ancestor; ancestor = ancestor->parentDecl) + { + if (auto interfaceDecl = as<InterfaceDecl>(ancestor)) + return interfaceDecl; + + if (as<ExtensionDecl>(ancestor)) + return nullptr; + } + return nullptr; +} bool isInterfaceRequirement(Decl* decl) { diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index e7dc73a85..ccbac0286 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -518,7 +518,8 @@ class AttributeDecl : public ContainerDecl SyntaxClass<NodeBase> syntaxClass; }; -// A synthesized decl used as a placeholder for a differentiable function requirement. +// A synthesized decl used as a placeholder for a differentiable function requirement. This decl will +// be a child of interface decl. // This allows us to form an interface requirement key for the derivative of an interface function. // The synthesized `DerivativeRequirementDecl` will be a child of the original function requirement // decl after an interface type is checked. @@ -527,6 +528,14 @@ class DerivativeRequirementDecl : public FunctionDeclBase SLANG_AST_CLASS(DerivativeRequirementDecl) }; +// A reference to a synthesized decl representing a differentiable function requirement, this decl will +// be a child in the orignal function. +class DerivativeRequirementReferenceDecl : public FunctionDeclBase +{ + SLANG_AST_CLASS(DerivativeRequirementReferenceDecl) + DerivativeRequirementDecl* referencedDecl; +}; + class ForwardDerivativeRequirementDecl : public DerivativeRequirementDecl { SLANG_AST_CLASS(ForwardDerivativeRequirementDecl) @@ -538,5 +547,6 @@ class BackwardDerivativeRequirementDecl : public DerivativeRequirementDecl }; bool isInterfaceRequirement(Decl* decl); +InterfaceDecl* findParentInterfaceDecl(Decl* decl); } // namespace Slang diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 76623d01c..3fcc762ec 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -1178,5 +1178,14 @@ Val* ModifiedType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionS return substType; } +Type* removeParamDirType(Type* type) +{ + for (auto paramDirType = as<ParamDirectionType>(type); paramDirType;) + { + type = paramDirType->getValueType(); + paramDirType = as<ParamDirectionType>(type); + } + return type; +} } // namespace Slang diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index d85391d58..8953f0b10 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -872,4 +872,6 @@ class ModifiedType : public Type Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; +Type* removeParamDirType(Type* type); + } // namespace Slang diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 36a1061c9..4d2839b8d 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1926,23 +1926,33 @@ namespace Slang requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingMemberDeclRef)); - if (hasForwardDerivative) + if (hasForwardDerivative || hasBackwardDerivative) { - auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType<ForwardDerivativeRequirementDecl>(); - SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty()); - ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>(); - val->func = satisfyingMemberDeclRef; - witnessTable->add(reqDecl.getFirst(), RequirementWitness(val)); - } + int fwdReqFound = 0; + int bwdReqFound = 0; + for (auto reqRefDecl : requiredMemberDeclRef.getDecl()->getMembersOfType<DerivativeRequirementReferenceDecl>()) + { + if (auto fwdReq = as<ForwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl)) + { + ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(fwdReq, RequirementWitness(val)); + fwdReqFound++; + } + else if (auto bwdReq = as<BackwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl)) + { + BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(bwdReq, RequirementWitness(val)); + bwdReqFound++; + } + } - if (hasBackwardDerivative) - { - auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType<BackwardDerivativeRequirementDecl>(); - SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty()); - BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>(); - val->func = satisfyingMemberDeclRef; - witnessTable->add(reqDecl.getFirst(), RequirementWitness(val)); + SLANG_RELEASE_ASSERT( + fwdReqFound == (hasForwardDerivative ? 1 : 0) && + bwdReqFound == (hasBackwardDerivative ? 1 : 0)); } + return true; } @@ -3706,7 +3716,8 @@ namespace Slang { if(isAssociatedTypeDecl(requiredMemberDeclRef)) continue; - + if (requiredMemberDeclRef.as<DerivativeRequirementDecl>()) + continue; auto requirementSatisfied = findWitnessForInterfaceRequirement( context, subType, @@ -5617,7 +5628,7 @@ namespace Slang } decl->errorType = errorType; - if (isInterfaceRequirement(decl)) + if (auto interfaceDecl = findParentInterfaceDecl(decl)) { if (decl->hasModifier<ForwardDifferentiableAttribute>()) { @@ -5626,8 +5637,13 @@ namespace Slang auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef)); setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType)); - decl->members.add(reqDecl); - reqDecl->parentDecl = decl; + interfaceDecl->members.add(reqDecl); + reqDecl->parentDecl = interfaceDecl; + + auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); + reqRef->referencedDecl = reqDecl; + reqRef->parentDecl = decl; + decl->members.add(reqRef); } if (decl->hasModifier<BackwardDifferentiableAttribute>()) { @@ -5636,8 +5652,13 @@ namespace Slang auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); auto diffFuncType = getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef)); setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType)); - decl->members.add(reqDecl); - reqDecl->parentDecl = decl; + interfaceDecl->members.add(reqDecl); + reqDecl->parentDecl = interfaceDecl; + + auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); + reqRef->referencedDecl = reqDecl; + reqRef->parentDecl = decl; + decl->members.add(reqRef); } } } diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 83774303b..3867dda03 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1580,7 +1580,7 @@ namespace Slang List<Type*> paramTypes; for (UIndex ii = 0; ii < diffFuncType->getParamCount(); ii++) - paramTypes.add(diffFuncType->getParamType(ii)); + paramTypes.add(removeParamDirType(diffFuncType->getParamType(ii))); // Try to infer generic arguments, based on the updated context. DeclRef<Decl> innerRef = inferGenericArguments( diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 04a898ea9..c93522565 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -111,6 +111,8 @@ struct JVPTranscriber IRInst* lookupPrimalInst(IRInst* origInst) { + if (!origInst) + return nullptr; if (shouldUseOriginalAsPrimal(origInst)) return origInst; return cloneEnv.mapOldValToNew[origInst]; @@ -118,11 +120,15 @@ struct JVPTranscriber IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst) { + if (!origInst) + return nullptr; return (hasPrimalInst(origInst)) ? lookupPrimalInst(origInst) : defaultInst; } bool hasPrimalInst(IRInst* origInst) { + if (!origInst) + return true; if (shouldUseOriginalAsPrimal(origInst)) return true; return cloneEnv.mapOldValToNew.ContainsKey(origInst); @@ -175,7 +181,7 @@ struct JVPTranscriber if (auto returnPairType = tryGetDiffPairType(builder, origResultType)) diffReturnType = returnPairType; else - diffReturnType = builder->getVoidType(); + diffReturnType = origResultType; return builder->getFuncType(newParameterTypes, diffReturnType); } @@ -735,13 +741,12 @@ struct JVPTranscriber SLANG_ASSERT(primalArg); auto primalType = primalArg->getDataType(); - auto diffArg = findOrTranscribeDiffInst(builder, origArg); - - if (!diffArg) - diffArg = getDifferentialZeroOfType(builder, primalType); - if (auto pairType = tryGetDiffPairType(builder, primalType)) { + auto diffArg = findOrTranscribeDiffInst(builder, origArg); + if (!diffArg) + diffArg = getDifferentialZeroOfType(builder, primalType); + // If a pair type can be formed, this must be non-null. SLANG_RELEASE_ASSERT(diffArg); auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg); @@ -984,6 +989,18 @@ struct JVPTranscriber builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); return InstPair(primalSpecialize, diffSpecialize); } + else if (auto diffDecor = genericInnerVal->findDecoration<IRForwardDifferentiableDecoration>()) + { + List<IRInst*> args; + for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) + { + args.add(primalSpecialize->getArg(i)); + } + diffBase = findOrTranscribeDiffInst(builder, origSpecialize->getBase()); + auto diffSpecialize = builder->emitSpecializeInst( + builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); + return InstPair(primalSpecialize, diffSpecialize); + } else { return InstPair(primalSpecialize, nullptr); @@ -1365,15 +1382,14 @@ struct JVPTranscriber { differentiableTypeConformanceContext.setFunc(innerFunc); } + else if (auto funcType = as<IRFuncType>(innerVal)) + { + } else { return InstPair(origGeneric, nullptr); } - // For now, we assume there's only one generic layer. So this inst must be top level - bool isTopLevel = (as<IRModuleInst>(origGeneric->getParent()) != nullptr); - SLANG_RELEASE_ASSERT(isTopLevel); - IRGeneric* primalGeneric = origGeneric; IRBuilder builder(inBuilder->getSharedBuilder()); @@ -1395,10 +1411,6 @@ struct JVPTranscriber diffGeneric->setFullType(diffType); - // TODO(sai): Replace naming scheme - // if (auto jvpName = this->getJVPFuncName(builder, primalFn)) - // builder->addNameHintDecoration(diffFunc, jvpName); - // Transcribe children from origFunc into diffFunc. builder.setInsertInto(diffGeneric); for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock()) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index a0becdafa..09dacc20d 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -6863,14 +6863,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { operandCount += associatedTypeDecl->getMembersOfType<TypeConstraintDecl>().getCount(); } - else if (auto callableDecl = as<CallableDecl>(requirementDecl)) - { - // Differentiable functions has additional requirements for the derivatives. - if (callableDecl->getMembersOfType<ForwardDerivativeRequirementDecl>().getCount()) - operandCount++; - if (callableDecl->getMembersOfType<BackwardDerivativeRequirementDecl>().getCount()) - operandCount++; - } } // Allocate an IRInterfaceType with the `operandCount` operands. @@ -6957,33 +6949,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> if (auto callableDecl = as<CallableDecl>(requirementDecl)) { // Differentiable functions has additional requirements for the derivatives. - for (auto diffDecl : callableDecl->getMembersOfType<DerivativeRequirementDecl>()) + for (auto diffDecl : callableDecl->getMembersOfType<DerivativeRequirementReferenceDecl>()) { - auto diffKey = getInterfaceRequirementKey(diffDecl); - IRInst* diffVal = ensureDecl(subContext, diffDecl).val; - auto diffEntry = subBuilder->createInterfaceRequirementEntry(diffKey, diffVal); - if (diffVal) - { - switch (diffVal->getOp()) - { - case kIROp_Func: - case kIROp_Generic: - { - // Remove lowered `IRFunc`s since we only care about - // function types. - auto reqType = diffVal->getFullType(); - diffEntry->setRequirementVal(reqType); - break; - } - default: - break; - } - } - irInterface->setOperand(entryIndex, diffEntry); - entryIndex++; - - setValue(context, diffDecl, LoweredValInfo::simple(diffEntry)); - insertRequirementKeyAssociation(irInterface, diffDecl, requirementKey, diffKey); + auto diffKey = getInterfaceRequirementKey(diffDecl->referencedDecl); + insertRequirementKeyAssociation(irInterface, diffDecl->referencedDecl, requirementKey, diffKey); } } // Add lowered requirement entry to current decl mapping to prevent diff --git a/tests/autodiff/bool-return-val.slang b/tests/autodiff/bool-return-val.slang new file mode 100644 index 000000000..a43495dd9 --- /dev/null +++ b/tests/autodiff/bool-return-val.slang @@ -0,0 +1,28 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +struct NonDiff +{ + float a; +} + +[ForwardDifferentiable] +bool myFunc(NonDiff fIn, inout float x) +{ + x = pow(x, fIn.a); + return x > 100.f; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + float a = 10.0; + NonDiff fIn = { a }; + DifferentialPair<float> dpx = DifferentialPair<float>(4.f, 1.f); + bool res = __fwd_diff(myFunc)(fIn, dpx); + + outputBuffer[0] = res?1:0; +}
\ No newline at end of file diff --git a/tests/autodiff/bool-return-val.slang.expected.txt b/tests/autodiff/bool-return-val.slang.expected.txt new file mode 100644 index 000000000..5fce3dc6d --- /dev/null +++ b/tests/autodiff/bool-return-val.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +1.000000 +0.000000 +0.000000 +0.000000
\ No newline at end of file diff --git a/tests/autodiff/dynamic-dispatch-generic-2.slang b/tests/autodiff/dynamic-dispatch-generic-2.slang new file mode 100644 index 000000000..bbf7c7da1 --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-generic-2.slang @@ -0,0 +1,49 @@ +// Test calling differentiable function through dynamic dispatch. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[anyValueSize(16)] +interface IInterface +{ + [ForwardDifferentiable] + float calc(float x); +} + +struct A : IInterface +{ + float z; + [ForwardDifferentiable] + float calc(float x) { return x * x * x; } +}; + +struct B : IInterface +{ + float z; + + [ForwardDifferentiable] + float calc(float x) { return x * x + z; } +}; + +[ForwardDifferentiable] +float sqr<T:IInterface>(T obj, float x) +{ + return obj.calc(x); +} + +//TEST_INPUT: type_conformance A:IInterface = 0 +//TEST_INPUT: type_conformance B:IInterface = 1 + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var obj = createDynamicObject<IInterface>(dispatchThreadID.x, 0); // A + outputBuffer[0] = __fwd_diff(sqr)(obj, DifferentialPair<float>(2.0, 1.0)).d; // A.calc, expect 12 + + obj = createDynamicObject<IInterface>(dispatchThreadID.x + 1, 0); // B + outputBuffer[1] = __fwd_diff(sqr)(obj, DifferentialPair<float>(1.5, 1.0)).d; // B.calc, expect 3 +} diff --git a/tests/autodiff/dynamic-dispatch-generic-2.slang.expected.txt b/tests/autodiff/dynamic-dispatch-generic-2.slang.expected.txt new file mode 100644 index 000000000..8a664e432 --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-generic-2.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +12.000000 +3.000000 +0.000000 +0.000000 +0.000000 diff --git a/tests/autodiff/dynamic-dispatch-generic.slang b/tests/autodiff/dynamic-dispatch-generic.slang new file mode 100644 index 000000000..37c37a745 --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-generic.slang @@ -0,0 +1,46 @@ +// Test calling differentiable function through dynamic dispatch. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[anyValueSize(16)] +interface IInterface +{ + [ForwardDifferentiable] + static float calc(float x); +} + +struct A : IInterface +{ + [ForwardDifferentiable] + static float calc(float x) { return x * x * x; } +}; + +struct B : IInterface +{ + [ForwardDifferentiable] + static float calc(float x) { return x * x; } +}; + +[ForwardDifferentiable] +float sqr<T:IInterface>(T obj, float x) +{ + return obj.calc(x); +} + +//TEST_INPUT: type_conformance A:IInterface = 0 +//TEST_INPUT: type_conformance B:IInterface = 1 + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var obj = createDynamicObject<IInterface>(dispatchThreadID.x, 0); // A + outputBuffer[0] = __fwd_diff(sqr)(obj, DifferentialPair<float>(2.0, 1.0)).d; // A.calc, expect 12 + + obj = createDynamicObject<IInterface>(dispatchThreadID.x + 1, 0); // B + outputBuffer[1] = __fwd_diff(sqr)(obj, DifferentialPair<float>(1.5, 1.0)).d; // B.calc, expect 3 +} diff --git a/tests/autodiff/dynamic-dispatch-generic.slang.expected.txt b/tests/autodiff/dynamic-dispatch-generic.slang.expected.txt new file mode 100644 index 000000000..8a664e432 --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-generic.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +12.000000 +3.000000 +0.000000 +0.000000 +0.000000 diff --git a/tests/autodiff/generic-autodiff-1.slang b/tests/autodiff/generic-autodiff-1.slang new file mode 100644 index 000000000..db7ae7e87 --- /dev/null +++ b/tests/autodiff/generic-autodiff-1.slang @@ -0,0 +1,39 @@ +// Test calling differentiable function through dynamic dispatch. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[anyValueSize(16)] +interface IInterface +{ + [mutating] + float sample(); +} + +struct A : IInterface +{ + float z; + [mutating] + float sample() { z = z + 1.0; return 1.0; } +}; + + +[ForwardDifferentiable] +float sqr<T:IInterface>(inout T obj, float x) +{ + return obj.sample() + x*x; +} + +//TEST_INPUT: type_conformance A:IInterface = 0 + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + A obj; + obj.z = 0.0; + outputBuffer[0] = __fwd_diff(sqr)(obj, DifferentialPair<float>(2.0, 1.0)).d; +} diff --git a/tests/autodiff/generic-autodiff-1.slang.expected.txt b/tests/autodiff/generic-autodiff-1.slang.expected.txt new file mode 100644 index 000000000..64510fe94 --- /dev/null +++ b/tests/autodiff/generic-autodiff-1.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +4.000000 +0.000000 +0.000000 +0.000000 +0.000000 |
