summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ast-decl.cpp13
-rw-r--r--source/slang/slang-ast-decl.h12
-rw-r--r--source/slang/slang-ast-type.cpp9
-rw-r--r--source/slang/slang-ast-type.h2
-rw-r--r--source/slang/slang-check-decl.cpp61
-rw-r--r--source/slang/slang-check-overload.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp40
-rw-r--r--source/slang/slang-lower-to-ir.cpp37
-rw-r--r--tests/autodiff/bool-return-val.slang28
-rw-r--r--tests/autodiff/bool-return-val.slang.expected.txt5
-rw-r--r--tests/autodiff/dynamic-dispatch-generic-2.slang49
-rw-r--r--tests/autodiff/dynamic-dispatch-generic-2.slang.expected.txt6
-rw-r--r--tests/autodiff/dynamic-dispatch-generic.slang46
-rw-r--r--tests/autodiff/dynamic-dispatch-generic.slang.expected.txt6
-rw-r--r--tests/autodiff/generic-autodiff-1.slang39
-rw-r--r--tests/autodiff/generic-autodiff-1.slang.expected.txt6
16 files changed, 291 insertions, 70 deletions
diff --git a/source/slang/slang-ast-decl.cpp b/source/slang/slang-ast-decl.cpp
index b2802e304..9931bbcaf 100644
--- a/source/slang/slang-ast-decl.cpp
+++ b/source/slang/slang-ast-decl.cpp
@@ -18,6 +18,19 @@ const TypeExp& TypeConstraintDecl::_getSupOverride() const
//return TypeExp::empty;
}
+InterfaceDecl* findParentInterfaceDecl(Decl* decl)
+{
+ auto ancestor = decl->parentDecl;
+ for (; ancestor; ancestor = ancestor->parentDecl)
+ {
+ if (auto interfaceDecl = as<InterfaceDecl>(ancestor))
+ return interfaceDecl;
+
+ if (as<ExtensionDecl>(ancestor))
+ return nullptr;
+ }
+ return nullptr;
+}
bool isInterfaceRequirement(Decl* decl)
{
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index e7dc73a85..ccbac0286 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -518,7 +518,8 @@ class AttributeDecl : public ContainerDecl
SyntaxClass<NodeBase> syntaxClass;
};
-// A synthesized decl used as a placeholder for a differentiable function requirement.
+// A synthesized decl used as a placeholder for a differentiable function requirement. This decl will
+// be a child of interface decl.
// This allows us to form an interface requirement key for the derivative of an interface function.
// The synthesized `DerivativeRequirementDecl` will be a child of the original function requirement
// decl after an interface type is checked.
@@ -527,6 +528,14 @@ class DerivativeRequirementDecl : public FunctionDeclBase
SLANG_AST_CLASS(DerivativeRequirementDecl)
};
+// A reference to a synthesized decl representing a differentiable function requirement, this decl will
+// be a child in the orignal function.
+class DerivativeRequirementReferenceDecl : public FunctionDeclBase
+{
+ SLANG_AST_CLASS(DerivativeRequirementReferenceDecl)
+ DerivativeRequirementDecl* referencedDecl;
+};
+
class ForwardDerivativeRequirementDecl : public DerivativeRequirementDecl
{
SLANG_AST_CLASS(ForwardDerivativeRequirementDecl)
@@ -538,5 +547,6 @@ class BackwardDerivativeRequirementDecl : public DerivativeRequirementDecl
};
bool isInterfaceRequirement(Decl* decl);
+InterfaceDecl* findParentInterfaceDecl(Decl* decl);
} // namespace Slang
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index 76623d01c..3fcc762ec 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -1178,5 +1178,14 @@ Val* ModifiedType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionS
return substType;
}
+Type* removeParamDirType(Type* type)
+{
+ for (auto paramDirType = as<ParamDirectionType>(type); paramDirType;)
+ {
+ type = paramDirType->getValueType();
+ paramDirType = as<ParamDirectionType>(type);
+ }
+ return type;
+}
} // namespace Slang
diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h
index d85391d58..8953f0b10 100644
--- a/source/slang/slang-ast-type.h
+++ b/source/slang/slang-ast-type.h
@@ -872,4 +872,6 @@ class ModifiedType : public Type
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
+Type* removeParamDirType(Type* type);
+
} // namespace Slang
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 36a1061c9..4d2839b8d 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -1926,23 +1926,33 @@ namespace Slang
requiredMemberDeclRef.getDecl(),
RequirementWitness(satisfyingMemberDeclRef));
- if (hasForwardDerivative)
+ if (hasForwardDerivative || hasBackwardDerivative)
{
- auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType<ForwardDerivativeRequirementDecl>();
- SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty());
- ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>();
- val->func = satisfyingMemberDeclRef;
- witnessTable->add(reqDecl.getFirst(), RequirementWitness(val));
- }
+ int fwdReqFound = 0;
+ int bwdReqFound = 0;
+ for (auto reqRefDecl : requiredMemberDeclRef.getDecl()->getMembersOfType<DerivativeRequirementReferenceDecl>())
+ {
+ if (auto fwdReq = as<ForwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl))
+ {
+ ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>();
+ val->func = satisfyingMemberDeclRef;
+ witnessTable->add(fwdReq, RequirementWitness(val));
+ fwdReqFound++;
+ }
+ else if (auto bwdReq = as<BackwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl))
+ {
+ BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>();
+ val->func = satisfyingMemberDeclRef;
+ witnessTable->add(bwdReq, RequirementWitness(val));
+ bwdReqFound++;
+ }
+ }
- if (hasBackwardDerivative)
- {
- auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType<BackwardDerivativeRequirementDecl>();
- SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty());
- BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>();
- val->func = satisfyingMemberDeclRef;
- witnessTable->add(reqDecl.getFirst(), RequirementWitness(val));
+ SLANG_RELEASE_ASSERT(
+ fwdReqFound == (hasForwardDerivative ? 1 : 0) &&
+ bwdReqFound == (hasBackwardDerivative ? 1 : 0));
}
+
return true;
}
@@ -3706,7 +3716,8 @@ namespace Slang
{
if(isAssociatedTypeDecl(requiredMemberDeclRef))
continue;
-
+ if (requiredMemberDeclRef.as<DerivativeRequirementDecl>())
+ continue;
auto requirementSatisfied = findWitnessForInterfaceRequirement(
context,
subType,
@@ -5617,7 +5628,7 @@ namespace Slang
}
decl->errorType = errorType;
- if (isInterfaceRequirement(decl))
+ if (auto interfaceDecl = findParentInterfaceDecl(decl))
{
if (decl->hasModifier<ForwardDifferentiableAttribute>())
{
@@ -5626,8 +5637,13 @@ namespace Slang
auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl));
auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef));
setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType));
- decl->members.add(reqDecl);
- reqDecl->parentDecl = decl;
+ interfaceDecl->members.add(reqDecl);
+ reqDecl->parentDecl = interfaceDecl;
+
+ auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
+ reqRef->referencedDecl = reqDecl;
+ reqRef->parentDecl = decl;
+ decl->members.add(reqRef);
}
if (decl->hasModifier<BackwardDifferentiableAttribute>())
{
@@ -5636,8 +5652,13 @@ namespace Slang
auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl));
auto diffFuncType = getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef));
setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType));
- decl->members.add(reqDecl);
- reqDecl->parentDecl = decl;
+ interfaceDecl->members.add(reqDecl);
+ reqDecl->parentDecl = interfaceDecl;
+
+ auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
+ reqRef->referencedDecl = reqDecl;
+ reqRef->parentDecl = decl;
+ decl->members.add(reqRef);
}
}
}
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index 83774303b..3867dda03 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -1580,7 +1580,7 @@ namespace Slang
List<Type*> paramTypes;
for (UIndex ii = 0; ii < diffFuncType->getParamCount(); ii++)
- paramTypes.add(diffFuncType->getParamType(ii));
+ paramTypes.add(removeParamDirType(diffFuncType->getParamType(ii)));
// Try to infer generic arguments, based on the updated context.
DeclRef<Decl> innerRef = inferGenericArguments(
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 04a898ea9..c93522565 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -111,6 +111,8 @@ struct JVPTranscriber
IRInst* lookupPrimalInst(IRInst* origInst)
{
+ if (!origInst)
+ return nullptr;
if (shouldUseOriginalAsPrimal(origInst))
return origInst;
return cloneEnv.mapOldValToNew[origInst];
@@ -118,11 +120,15 @@ struct JVPTranscriber
IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst)
{
+ if (!origInst)
+ return nullptr;
return (hasPrimalInst(origInst)) ? lookupPrimalInst(origInst) : defaultInst;
}
bool hasPrimalInst(IRInst* origInst)
{
+ if (!origInst)
+ return true;
if (shouldUseOriginalAsPrimal(origInst))
return true;
return cloneEnv.mapOldValToNew.ContainsKey(origInst);
@@ -175,7 +181,7 @@ struct JVPTranscriber
if (auto returnPairType = tryGetDiffPairType(builder, origResultType))
diffReturnType = returnPairType;
else
- diffReturnType = builder->getVoidType();
+ diffReturnType = origResultType;
return builder->getFuncType(newParameterTypes, diffReturnType);
}
@@ -735,13 +741,12 @@ struct JVPTranscriber
SLANG_ASSERT(primalArg);
auto primalType = primalArg->getDataType();
- auto diffArg = findOrTranscribeDiffInst(builder, origArg);
-
- if (!diffArg)
- diffArg = getDifferentialZeroOfType(builder, primalType);
-
if (auto pairType = tryGetDiffPairType(builder, primalType))
{
+ auto diffArg = findOrTranscribeDiffInst(builder, origArg);
+ if (!diffArg)
+ diffArg = getDifferentialZeroOfType(builder, primalType);
+
// If a pair type can be formed, this must be non-null.
SLANG_RELEASE_ASSERT(diffArg);
auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg);
@@ -984,6 +989,18 @@ struct JVPTranscriber
builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
return InstPair(primalSpecialize, diffSpecialize);
}
+ else if (auto diffDecor = genericInnerVal->findDecoration<IRForwardDifferentiableDecoration>())
+ {
+ List<IRInst*> args;
+ for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
+ {
+ args.add(primalSpecialize->getArg(i));
+ }
+ diffBase = findOrTranscribeDiffInst(builder, origSpecialize->getBase());
+ auto diffSpecialize = builder->emitSpecializeInst(
+ builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
+ return InstPair(primalSpecialize, diffSpecialize);
+ }
else
{
return InstPair(primalSpecialize, nullptr);
@@ -1365,15 +1382,14 @@ struct JVPTranscriber
{
differentiableTypeConformanceContext.setFunc(innerFunc);
}
+ else if (auto funcType = as<IRFuncType>(innerVal))
+ {
+ }
else
{
return InstPair(origGeneric, nullptr);
}
- // For now, we assume there's only one generic layer. So this inst must be top level
- bool isTopLevel = (as<IRModuleInst>(origGeneric->getParent()) != nullptr);
- SLANG_RELEASE_ASSERT(isTopLevel);
-
IRGeneric* primalGeneric = origGeneric;
IRBuilder builder(inBuilder->getSharedBuilder());
@@ -1395,10 +1411,6 @@ struct JVPTranscriber
diffGeneric->setFullType(diffType);
- // TODO(sai): Replace naming scheme
- // if (auto jvpName = this->getJVPFuncName(builder, primalFn))
- // builder->addNameHintDecoration(diffFunc, jvpName);
-
// Transcribe children from origFunc into diffFunc.
builder.setInsertInto(diffGeneric);
for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock())
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index a0becdafa..09dacc20d 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -6863,14 +6863,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
operandCount += associatedTypeDecl->getMembersOfType<TypeConstraintDecl>().getCount();
}
- else if (auto callableDecl = as<CallableDecl>(requirementDecl))
- {
- // Differentiable functions has additional requirements for the derivatives.
- if (callableDecl->getMembersOfType<ForwardDerivativeRequirementDecl>().getCount())
- operandCount++;
- if (callableDecl->getMembersOfType<BackwardDerivativeRequirementDecl>().getCount())
- operandCount++;
- }
}
// Allocate an IRInterfaceType with the `operandCount` operands.
@@ -6957,33 +6949,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
if (auto callableDecl = as<CallableDecl>(requirementDecl))
{
// Differentiable functions has additional requirements for the derivatives.
- for (auto diffDecl : callableDecl->getMembersOfType<DerivativeRequirementDecl>())
+ for (auto diffDecl : callableDecl->getMembersOfType<DerivativeRequirementReferenceDecl>())
{
- auto diffKey = getInterfaceRequirementKey(diffDecl);
- IRInst* diffVal = ensureDecl(subContext, diffDecl).val;
- auto diffEntry = subBuilder->createInterfaceRequirementEntry(diffKey, diffVal);
- if (diffVal)
- {
- switch (diffVal->getOp())
- {
- case kIROp_Func:
- case kIROp_Generic:
- {
- // Remove lowered `IRFunc`s since we only care about
- // function types.
- auto reqType = diffVal->getFullType();
- diffEntry->setRequirementVal(reqType);
- break;
- }
- default:
- break;
- }
- }
- irInterface->setOperand(entryIndex, diffEntry);
- entryIndex++;
-
- setValue(context, diffDecl, LoweredValInfo::simple(diffEntry));
- insertRequirementKeyAssociation(irInterface, diffDecl, requirementKey, diffKey);
+ auto diffKey = getInterfaceRequirementKey(diffDecl->referencedDecl);
+ insertRequirementKeyAssociation(irInterface, diffDecl->referencedDecl, requirementKey, diffKey);
}
}
// Add lowered requirement entry to current decl mapping to prevent
diff --git a/tests/autodiff/bool-return-val.slang b/tests/autodiff/bool-return-val.slang
new file mode 100644
index 000000000..a43495dd9
--- /dev/null
+++ b/tests/autodiff/bool-return-val.slang
@@ -0,0 +1,28 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+struct NonDiff
+{
+ float a;
+}
+
+[ForwardDifferentiable]
+bool myFunc(NonDiff fIn, inout float x)
+{
+ x = pow(x, fIn.a);
+ return x > 100.f;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ float a = 10.0;
+ NonDiff fIn = { a };
+ DifferentialPair<float> dpx = DifferentialPair<float>(4.f, 1.f);
+ bool res = __fwd_diff(myFunc)(fIn, dpx);
+
+ outputBuffer[0] = res?1:0;
+} \ No newline at end of file
diff --git a/tests/autodiff/bool-return-val.slang.expected.txt b/tests/autodiff/bool-return-val.slang.expected.txt
new file mode 100644
index 000000000..5fce3dc6d
--- /dev/null
+++ b/tests/autodiff/bool-return-val.slang.expected.txt
@@ -0,0 +1,5 @@
+type: float
+1.000000
+0.000000
+0.000000
+0.000000 \ No newline at end of file
diff --git a/tests/autodiff/dynamic-dispatch-generic-2.slang b/tests/autodiff/dynamic-dispatch-generic-2.slang
new file mode 100644
index 000000000..bbf7c7da1
--- /dev/null
+++ b/tests/autodiff/dynamic-dispatch-generic-2.slang
@@ -0,0 +1,49 @@
+// Test calling differentiable function through dynamic dispatch.
+
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[anyValueSize(16)]
+interface IInterface
+{
+ [ForwardDifferentiable]
+ float calc(float x);
+}
+
+struct A : IInterface
+{
+ float z;
+ [ForwardDifferentiable]
+ float calc(float x) { return x * x * x; }
+};
+
+struct B : IInterface
+{
+ float z;
+
+ [ForwardDifferentiable]
+ float calc(float x) { return x * x + z; }
+};
+
+[ForwardDifferentiable]
+float sqr<T:IInterface>(T obj, float x)
+{
+ return obj.calc(x);
+}
+
+//TEST_INPUT: type_conformance A:IInterface = 0
+//TEST_INPUT: type_conformance B:IInterface = 1
+
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ var obj = createDynamicObject<IInterface>(dispatchThreadID.x, 0); // A
+ outputBuffer[0] = __fwd_diff(sqr)(obj, DifferentialPair<float>(2.0, 1.0)).d; // A.calc, expect 12
+
+ obj = createDynamicObject<IInterface>(dispatchThreadID.x + 1, 0); // B
+ outputBuffer[1] = __fwd_diff(sqr)(obj, DifferentialPair<float>(1.5, 1.0)).d; // B.calc, expect 3
+}
diff --git a/tests/autodiff/dynamic-dispatch-generic-2.slang.expected.txt b/tests/autodiff/dynamic-dispatch-generic-2.slang.expected.txt
new file mode 100644
index 000000000..8a664e432
--- /dev/null
+++ b/tests/autodiff/dynamic-dispatch-generic-2.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+12.000000
+3.000000
+0.000000
+0.000000
+0.000000
diff --git a/tests/autodiff/dynamic-dispatch-generic.slang b/tests/autodiff/dynamic-dispatch-generic.slang
new file mode 100644
index 000000000..37c37a745
--- /dev/null
+++ b/tests/autodiff/dynamic-dispatch-generic.slang
@@ -0,0 +1,46 @@
+// Test calling differentiable function through dynamic dispatch.
+
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[anyValueSize(16)]
+interface IInterface
+{
+ [ForwardDifferentiable]
+ static float calc(float x);
+}
+
+struct A : IInterface
+{
+ [ForwardDifferentiable]
+ static float calc(float x) { return x * x * x; }
+};
+
+struct B : IInterface
+{
+ [ForwardDifferentiable]
+ static float calc(float x) { return x * x; }
+};
+
+[ForwardDifferentiable]
+float sqr<T:IInterface>(T obj, float x)
+{
+ return obj.calc(x);
+}
+
+//TEST_INPUT: type_conformance A:IInterface = 0
+//TEST_INPUT: type_conformance B:IInterface = 1
+
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ var obj = createDynamicObject<IInterface>(dispatchThreadID.x, 0); // A
+ outputBuffer[0] = __fwd_diff(sqr)(obj, DifferentialPair<float>(2.0, 1.0)).d; // A.calc, expect 12
+
+ obj = createDynamicObject<IInterface>(dispatchThreadID.x + 1, 0); // B
+ outputBuffer[1] = __fwd_diff(sqr)(obj, DifferentialPair<float>(1.5, 1.0)).d; // B.calc, expect 3
+}
diff --git a/tests/autodiff/dynamic-dispatch-generic.slang.expected.txt b/tests/autodiff/dynamic-dispatch-generic.slang.expected.txt
new file mode 100644
index 000000000..8a664e432
--- /dev/null
+++ b/tests/autodiff/dynamic-dispatch-generic.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+12.000000
+3.000000
+0.000000
+0.000000
+0.000000
diff --git a/tests/autodiff/generic-autodiff-1.slang b/tests/autodiff/generic-autodiff-1.slang
new file mode 100644
index 000000000..db7ae7e87
--- /dev/null
+++ b/tests/autodiff/generic-autodiff-1.slang
@@ -0,0 +1,39 @@
+// Test calling differentiable function through dynamic dispatch.
+
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[anyValueSize(16)]
+interface IInterface
+{
+ [mutating]
+ float sample();
+}
+
+struct A : IInterface
+{
+ float z;
+ [mutating]
+ float sample() { z = z + 1.0; return 1.0; }
+};
+
+
+[ForwardDifferentiable]
+float sqr<T:IInterface>(inout T obj, float x)
+{
+ return obj.sample() + x*x;
+}
+
+//TEST_INPUT: type_conformance A:IInterface = 0
+
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ A obj;
+ obj.z = 0.0;
+ outputBuffer[0] = __fwd_diff(sqr)(obj, DifferentialPair<float>(2.0, 1.0)).d;
+}
diff --git a/tests/autodiff/generic-autodiff-1.slang.expected.txt b/tests/autodiff/generic-autodiff-1.slang.expected.txt
new file mode 100644
index 000000000..64510fe94
--- /dev/null
+++ b/tests/autodiff/generic-autodiff-1.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+4.000000
+0.000000
+0.000000
+0.000000
+0.000000