summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-10-03 16:02:16 -0400
committerGitHub <noreply@github.com>2024-10-03 16:02:16 -0400
commit9f246a43667b4893040669873400e2e3813328ff (patch)
treef1fafe8c266b1db6f5f2cb76ab4fb7332cc2be54
parentaa64c853142076b17bd020f1386ea5fc6fcd5e3e (diff)
Support custom derivatives of member functions of differentiable types (#5124)
* Initial work to support custom derivatives for member methods of differentiable types * Support custom derivatives of member functions of differentiable types - Also adds support for declaring custom derivatives via extensions. * Fix * move defs * Update slang-check-decl.cpp * Create diff-member-func-custom-derivative.slang.expected.txt * Update slang-check-decl.cpp * Fix for static custom derivatives * Fix diagnostics for [PreferRecompute] * Add backward custom derivative tests
-rw-r--r--source/slang/slang-check-decl.cpp269
-rw-r--r--source/slang/slang-check-expr.cpp4
-rw-r--r--source/slang/slang-diagnostic-defs.h5
-rw-r--r--tests/autodiff/diff-member-func-custom-derivative.slang59
-rw-r--r--tests/autodiff/diff-member-func-custom-derivative.slang.expected.txt3
-rw-r--r--tests/autodiff/member-func-extension-custom-derivative.slang55
-rw-r--r--tests/autodiff/member-func-extension-custom-derivative.slang.expected.txt2
-rw-r--r--tests/autodiff/static-func-custom-derivative.slang59
-rw-r--r--tests/autodiff/static-func-custom-derivative.slang.expected.txt3
9 files changed, 419 insertions, 40 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 4d742d9f4..8c3429c9a 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -7733,13 +7733,13 @@ namespace Slang
// Two such constraints are equivalent if their `sub`
// and `sup` types are pairwise equivalent.
//
- auto leftSub = leftConstraint->sub;
- auto rightSub = getSub(m_astBuilder, rightConstraint);
+ auto leftSub = leftConstraint->sub.type;
+ auto rightSub = substInnerRightToLeft.substitute(m_astBuilder, rightConstraint.getDecl()->sub.type);
if(!leftSub->equals(rightSub))
return false;
- auto leftSup = leftConstraint->sup;
- auto rightSup = getSup(m_astBuilder, rightConstraint);
+ auto leftSup = leftConstraint->sup.type;
+ auto rightSup = substInnerRightToLeft.substitute(m_astBuilder, rightConstraint.getDecl()->sup.type);
if(!leftSup->equals(rightSup))
return false;
}
@@ -10339,10 +10339,67 @@ namespace Slang
return result;
}
+ bool areTypesCompatibile(SemanticsVisitor* visitor, Type* fst, Type* snd)
+ {
+ if (fst->equals(snd))
+ return true;
+
+ if (auto declRefType = as<DeclRefType>(fst))
+ {
+ auto decl = declRefType->getDeclRef().getDecl();
+ if (auto extGenericDecl = visitor->GetOuterGeneric(decl))
+ {
+ SemanticsVisitor::ConstraintSystem constraints;
+ constraints.loc = decl->loc;
+ constraints.genericDecl = extGenericDecl;
+
+ if (!visitor->TryUnifyTypes(constraints, SemanticsVisitor::ValUnificationContext(), fst, snd))
+ return false;
+
+ ConversionCost baseCost;
+ if (!visitor->trySolveConstraintSystem(&constraints, makeDeclRef(extGenericDecl), ArrayView<Val*>(), baseCost))
+ return false;
+
+ // If we reach here, it means we have a valid unification.
+ return true;
+ }
+ }
+ return false;
+ }
+
+ Type* getTypeForThisExpr(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl)
+ {
+ ThisExpr* expr = visitor->getASTBuilder()->create<ThisExpr>();
+ expr->scope = funcDecl->ownedScope;
+ expr->loc = funcDecl->loc;
+
+ DiagnosticSink dummySink;
+ auto tempVisitor = SemanticsVisitor(visitor->withSink(&dummySink));
+
+ auto checkedExpr = tempVisitor.CheckTerm(expr);
+
+ return !(as<ErrorType>(checkedExpr->type.type)) ? (checkedExpr->type.type) : nullptr;
+ }
+
+ Type* getTypeForThisExpr(SemanticsVisitor* visitor, DeclRef<FunctionDeclBase> funcDeclRef)
+ {
+ auto type = getTypeForThisExpr(visitor, funcDeclRef.getDecl());
+ if (type)
+ return substituteType(
+ SubstitutionSet(funcDeclRef.declRefBase),
+ visitor->getASTBuilder(),
+ type);
+ return nullptr;
+ }
+
+
struct ArgsWithDirectionInfo
{
List<Expr*> args;
List<ParameterDirection> directions;
+
+ Expr* thisArg;
+ ParameterDirection thisArgDirection;
};
template<typename TDerivativeAttr>
@@ -10351,7 +10408,9 @@ namespace Slang
Decl* funcDecl,
TDerivativeAttr* attr,
const List<Expr*>& imaginaryArguments,
- const List<ParameterDirection>& expectedParamDirections)
+ const List<ParameterDirection>& expectedParamDirections,
+ Expr* expectedThisArg,
+ ParameterDirection expectedThisArgDirection)
{
if (isInterfaceRequirement(funcDecl))
{
@@ -10402,7 +10461,18 @@ namespace Slang
return type->toString();
};
- auto invokeExpr = subVisitor.constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments);
+ List<Expr*> argList = imaginaryArguments;
+ List<ParameterDirection> paramDirections = expectedParamDirections;
+ bool expectStaticFunc = false;
+
+ if (expectedThisArg)
+ {
+ argList.insert(0, expectedThisArg);
+ paramDirections.insert(0, expectedThisArgDirection);
+ expectStaticFunc = true;
+ }
+
+ auto invokeExpr = subVisitor.constructUncheckedInvokeExpr(checkedFuncExpr, argList);
auto resolved = subVisitor.ResolveInvoke(invokeExpr);
if (auto resolvedInvoke = as<InvokeExpr>(resolved))
@@ -10430,61 +10500,104 @@ namespace Slang
visitor->getSink()->diagnose(attr, Diagnostics::cannotUseInterfaceRequirementAsDerivative);
return;
}
- if (funcType->getParamCount() != imaginaryArguments.getCount())
+ if (funcType->getParamCount() != argList.getCount())
{
goto error;
}
- for (Index ii = 0; ii < imaginaryArguments.getCount(); ++ii)
+ for (Index ii = 0; ii < argList.getCount(); ++ii)
{
// Check if the resolved invoke argument type is an error type.
// If so, then we have a type mismatch.
//
if (resolvedInvoke->arguments[ii]->type.type->equals(ctx.getASTBuilder()->getErrorType()) ||
- funcType->getParamDirection(ii) != expectedParamDirections[ii])
+ funcType->getParamDirection(ii) != paramDirections[ii])
{
visitor->getSink()->diagnose(
attr,
Diagnostics::customDerivativeSignatureMismatchAtPosition,
ii,
- qualTypeToString(imaginaryArguments[ii]->type),
+ qualTypeToString(argList[ii]->type),
funcType->getParamType(ii)->toString());
}
}
// The `imaginaryArguments` list does not include the `this` parameter.
// So we need to check that `this` type matches.
bool funcIsStatic = isEffectivelyStatic(funcDecl);
+ if (funcIsStatic)
+ expectStaticFunc = true;
+
bool derivativeFuncIsStatic = isEffectivelyStatic(calleeDeclRef->declRef.getDecl());
- if (funcIsStatic != derivativeFuncIsStatic)
+
+ if (expectStaticFunc && !derivativeFuncIsStatic)
{
visitor->getSink()->diagnose(
attr,
- Diagnostics::customDerivativeSignatureThisParamMismatch);
+ Diagnostics::customDerivativeExpectedStatic);
return;
}
- if (!funcIsStatic)
+
+ if (!derivativeFuncIsStatic)
{
auto defaultFuncDeclRef = createDefaultSubstitutionsIfNeeded(
visitor->getASTBuilder(),
visitor,
makeDeclRef(funcDecl));
- auto funcThisType = visitor->calcThisType(defaultFuncDeclRef);
- auto derivativeFuncThisType = visitor->calcThisType(calleeDeclRef->declRef);
- if (!funcThisType->equals(derivativeFuncThisType))
+
+ DeclRef<FunctionDeclBase> funcDeclRef = defaultFuncDeclRef.as<FunctionDeclBase>();
+ auto funcThisType = getTypeForThisExpr(visitor, funcDeclRef);
+ DeclRef<FunctionDeclBase> calleeFuncDeclRef = calleeDeclRef->declRef.template as<FunctionDeclBase>();
+ auto derivativeFuncThisType = getTypeForThisExpr(visitor, calleeFuncDeclRef);
+
+ // If the function is a member function, we need to check that the
+ // `this` type matches the expected type. This will ensure that after lowering to IR,
+ // the two functions are compatible.
+ //
+ if (!areTypesCompatibile(visitor, funcThisType, derivativeFuncThisType))
{
visitor->getSink()->diagnose(
attr,
Diagnostics::customDerivativeSignatureThisParamMismatch);
return;
}
- if (visitor->isTypeDifferentiable(funcThisType))
+ }
+
+ // If the two decls are under different generic contexts, we'll need to check that
+ // they agree and specialize the attribute's decl-ref accordingly.
+ //
+
+ auto originalNextGeneric = visitor->findNextOuterGeneric(visitor->getOuterGenericOrSelf(funcDecl));
+ auto derivativeNextGeneric = visitor->findNextOuterGeneric(visitor->getOuterGenericOrSelf(calleeDeclRef->declRef.getDecl()));
+
+ if ((!originalNextGeneric) != (!derivativeNextGeneric))
+ {
+ // Diagnostic for when one is generic and the other is not.
+ visitor->getSink()->diagnose(attr, Diagnostics::cannotResolveGenericArgumentForDerivativeFunction);
+ return;
+ }
+
+ if (originalNextGeneric != derivativeNextGeneric)
+ {
+ // If the two generic containers are not the same, but are compatible, we can
+ // unify them.
+ //
+
+ DeclRef<Decl> specializedDecl;
+ if (!visitor->doGenericSignaturesMatch(originalNextGeneric, derivativeNextGeneric, &specializedDecl))
{
- visitor->getSink()->diagnose(
- attr,
- Diagnostics::customDerivativeNotAllowedForMemberFunctionsOfDifferentiableType);
+ visitor->getSink()->diagnose(attr, Diagnostics::customDerivativeSignatureMismatch);
return;
}
- }
+ calleeDeclRef->declRef = substituteDeclRef(
+ SubstitutionSet(specializedDecl),
+ visitor->getASTBuilder(),
+ calleeDeclRef->declRef);
+ calleeDeclRef->type = substituteType(
+ SubstitutionSet(specializedDecl),
+ visitor->getASTBuilder(),
+ calleeDeclRef->type);
+ }
+
attr->funcExpr = calleeDeclRef;
if (attr->args.getCount())
attr->args[0] = attr->funcExpr;
@@ -10497,12 +10610,12 @@ namespace Slang
//
StringBuilder builder;
builder << "(";
- for (Index ii = 0; ii < imaginaryArguments.getCount(); ++ii)
+ for (Index ii = 0; ii < argList.getCount(); ++ii)
{
if (ii != 0)
builder << ", ";
- if (imaginaryArguments[ii]->type)
- builder << qualTypeToString(imaginaryArguments[ii]->type);
+ if (argList[ii]->type)
+ builder << qualTypeToString(argList[ii]->type);
else
builder << "<error>";
}
@@ -10544,11 +10657,36 @@ namespace Slang
imaginaryArguments.add(arg);
directions.add(getParameterDirection(param));
}
- return { imaginaryArguments, directions };
+ return { imaginaryArguments, directions, nullptr, ParameterDirection::kParameterDirection_In };
}
ArgsWithDirectionInfo getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
{
+ Expr* thisArgExpr = nullptr;
+ if (auto thisType = getTypeForThisExpr(visitor, originalFuncDecl))
+ {
+ thisArgExpr = visitor->getASTBuilder()->create<VarExpr>();
+ thisArgExpr->type = thisType;
+ thisArgExpr->loc = loc;
+
+ if (visitor->isTypeDifferentiable(thisType) &&
+ !originalFuncDecl->findModifier<NoDiffThisAttribute>() &&
+ !isEffectivelyStatic(originalFuncDecl))
+ {
+ auto pairType = visitor->getDifferentialPairType(thisType);
+ thisArgExpr->type.type = pairType;
+ }
+ else
+ {
+ thisArgExpr = nullptr;
+ }
+ }
+
+ ParameterDirection thisTypeDirection =
+ (thisArgExpr && !thisArgExpr->type.isLeftValue) ?
+ ParameterDirection::kParameterDirection_In :
+ ParameterDirection::kParameterDirection_InOut;
+
List<Expr*> imaginaryArguments;
for (auto param : originalFuncDecl->getParameters())
{
@@ -10574,11 +10712,40 @@ namespace Slang
expectedParamDirections.add(getParameterDirection(param));
}
- return { imaginaryArguments, expectedParamDirections };
+ return { imaginaryArguments, expectedParamDirections, thisArgExpr, thisTypeDirection };
}
ArgsWithDirectionInfo getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
{
+ Expr* thisArgExpr = nullptr;
+ if (auto thisType = getTypeForThisExpr(visitor, originalFuncDecl))
+ {
+ thisArgExpr = visitor->getASTBuilder()->create<VarExpr>();
+ thisArgExpr->type = thisType;
+ thisArgExpr->loc = loc;
+
+ if (visitor->isTypeDifferentiable(thisType) &&
+ !originalFuncDecl->findModifier<NoDiffThisAttribute>() &&
+ !isEffectivelyStatic(originalFuncDecl))
+ {
+ auto pairType = visitor->getDifferentialPairType(thisType);
+ thisArgExpr->type.type = pairType;
+
+ // TODO: for ptr pair types, no need to set isLeftValue to true.
+ if (as<DifferentialPairType>(thisArgExpr->type.type))
+ thisArgExpr->type.isLeftValue = true;
+ }
+ else
+ {
+ thisArgExpr = nullptr;
+ }
+ }
+
+ ParameterDirection thisTypeDirection =
+ (thisArgExpr && !thisArgExpr->type.isLeftValue) ?
+ ParameterDirection::kParameterDirection_In :
+ ParameterDirection::kParameterDirection_InOut;
+
List<Expr*> imaginaryArguments;
List<ParameterDirection> expectedParamDirections;
@@ -10660,7 +10827,7 @@ namespace Slang
expectedParamDirections.add(ParameterDirection::kParameterDirection_In);
}
- return {imaginaryArguments, expectedParamDirections};
+ return {imaginaryArguments, expectedParamDirections, thisArgExpr, thisTypeDirection};
}
// This helper function is needed to workaround a gcc bug.
@@ -10685,7 +10852,11 @@ namespace Slang
higherOrderFuncExpr->baseFunction = derivativeOfAttr->funcExpr;
if (derivativeOfAttr->args.getCount() > 0)
higherOrderFuncExpr->loc = derivativeOfAttr->args[0]->loc;
- Expr* checkedHigherOrderFuncExpr = visitor->dispatchExpr(higherOrderFuncExpr, visitor->allowStaticReferenceToNonStaticMember());
+
+ Expr* checkedHigherOrderFuncExpr = visitor->dispatchExpr(
+ higherOrderFuncExpr,
+ visitor->allowStaticReferenceToNonStaticMember());
+
if (!checkedHigherOrderFuncExpr)
{
visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
@@ -10701,7 +10872,15 @@ namespace Slang
{
auto resolvedFuncExpr = as<HigherOrderInvokeExpr>(resolvedInvoke->functionExpr);
if (resolvedFuncExpr)
+ {
calleeDeclRefExpr = as<DeclRefExpr>(resolvedFuncExpr->baseFunction);
+ if (!calleeDeclRef && as<OverloadedExpr>(resolvedFuncExpr->baseFunction))
+ {
+ visitor->getSink()->diagnose(
+ derivativeOfAttr,
+ Diagnostics::overloadedFuncUsedWithDerivativeOfAttributes);
+ }
+ }
}
if (!calleeDeclRefExpr)
@@ -10729,13 +10908,6 @@ namespace Slang
// We may relax this restriction in the future by solving the "inverse" generic arguments
// from the `calleeDeclRef`, and use them to create a declRef to funcDecl from the original
// func.
- auto originalNextGeneric = visitor->findNextOuterGeneric(visitor->getOuterGenericOrSelf(calleeFunc));
- auto derivativeNextGeneric = visitor->findNextOuterGeneric(visitor->getOuterGenericOrSelf(funcDecl));
- if (originalNextGeneric != derivativeNextGeneric)
- {
- visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveGenericArgumentForDerivativeFunction);
- return;
- }
if (isInterfaceRequirement(calleeFunc))
{
@@ -10787,7 +10959,14 @@ namespace Slang
return;
ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToForwardDerivative(visitor, funcDecl, attr->loc);
- checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions);
+ checkDerivativeAttributeImpl(
+ visitor,
+ funcDecl,
+ attr,
+ imaginaryArguments.args,
+ imaginaryArguments.directions,
+ imaginaryArguments.thisArg,
+ imaginaryArguments.thisArgDirection);
}
static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, BackwardDerivativeAttribute* attr)
@@ -10798,7 +10977,14 @@ namespace Slang
return;
ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToBackwardDerivative(visitor, funcDecl, attr->loc);
- checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions);
+ checkDerivativeAttributeImpl(
+ visitor,
+ funcDecl,
+ attr,
+ imaginaryArguments.args,
+ imaginaryArguments.directions,
+ imaginaryArguments.thisArg,
+ imaginaryArguments.thisArgDirection);
}
static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, PrimalSubstituteAttribute* attr)
@@ -10809,7 +10995,14 @@ namespace Slang
return;
ArgsWithDirectionInfo imaginaryArguments = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, attr->loc);
- checkDerivativeAttributeImpl(visitor, funcDecl, attr, imaginaryArguments.args, imaginaryArguments.directions);
+ checkDerivativeAttributeImpl(
+ visitor,
+ funcDecl,
+ attr,
+ imaginaryArguments.args,
+ imaginaryArguments.directions,
+ imaginaryArguments.thisArg,
+ imaginaryArguments.thisArgDirection);
}
static void checkCudaKernelAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, CudaKernelAttribute*)
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index e9c257750..842ffb527 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -4754,7 +4754,9 @@ namespace Slang
scope = scope->parent;
}
- getSink()->diagnose(expr, Diagnostics::thisExpressionOutsideOfTypeDecl);
+ if (auto sink = getSink())
+ sink->diagnose(expr, Diagnostics::thisExpressionOutsideOfTypeDecl);
+
return CreateErrorExpr(expr);
}
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 9a14b71e4..23eba5f03 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -463,8 +463,11 @@ DIAGNOSTIC(31151, Error, cannotResolveGenericArgumentForDerivativeFunction,
"[BackwardDerivativeOf], and [PrimalSubstituteOf] attributes are not supported when the generic arguments to the derivatives cannot be automatically deduced.")
DIAGNOSTIC(31152, Error, cannotAssociateInterfaceRequirementWithDerivative, "cannot associate an interface requirement with a derivative.")
DIAGNOSTIC(31153, Error, cannotUseInterfaceRequirementAsDerivative, "cannot use an interface requirement as a derivative.")
-DIAGNOSTIC(31154, Error, customDerivativeSignatureThisParamMismatch, "custom derivative does not match expected signature on `this`. Either both the original and the derivative function are static, or they must have the same `this` type.")
+DIAGNOSTIC(31154, Error, customDerivativeSignatureThisParamMismatch, "custom derivative does not match expected signature on `this`. Both original and derivative function must have the same `this` type.")
DIAGNOSTIC(31155, Error, customDerivativeNotAllowedForMemberFunctionsOfDifferentiableType, "custom derivative is not allowed for non-static member functions of a differentiable type.")
+DIAGNOSTIC(31156, Error, customDerivativeExpectedStatic, "expected a static definition for the custom derivative.")
+DIAGNOSTIC(31157, Error, overloadedFuncUsedWithDerivativeOfAttributes, "cannot resolve overloaded functions for derivative-of attributes.")
+
DIAGNOSTIC(31200, Warning, deprecatedUsage, "$0 has been deprecated: $1")
DIAGNOSTIC(31201, Error, modifierNotAllowed, "modifier '$0' is not allowed here.")
DIAGNOSTIC(31202, Error, duplicateModifier, "modifier '$0' is redundant or conflicting with existing modifier '$1'")
diff --git a/tests/autodiff/diff-member-func-custom-derivative.slang b/tests/autodiff/diff-member-func-custom-derivative.slang
new file mode 100644
index 000000000..4e4f540f9
--- /dev/null
+++ b/tests/autodiff/diff-member-func-custom-derivative.slang
@@ -0,0 +1,59 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+struct A : IDifferentiable
+{
+ float x;
+
+ [ForwardDerivative(diff_f)]
+ float f(float v)
+ {
+ return v * v;
+ }
+
+ static DifferentialPair<float> diff_f(DifferentialPair<A> dpa, DifferentialPair<float> v)
+ {
+ return diffPair(v.p * v.p, v.p * v.d * 2.0);
+ }
+
+ [BackwardDerivative(diff_g)]
+ float g(float v)
+ {
+ return v * v;
+ }
+
+ static void diff_g(inout DifferentialPair<A> dpa, inout DifferentialPair<float> v, float dOut)
+ {
+ v = diffPair(v.p, dOut * 2.0);
+ }
+}
+
+[ForwardDifferentiable]
+float test(A obj, float v)
+{
+ return obj.f(v);
+}
+
+[BackwardDifferentiable]
+float test2(A obj, float v)
+{
+ return obj.g(v);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ A a = {0.0};
+ var p = diffPair(3.0, 1.0);
+ let rs = fwd_diff(test)(diffPair(a, {1.0}), p);
+
+ var q = diffPair(3.0);
+ var qa = diffPair(a);
+ bwd_diff(test2)(qa, q, 1.0);
+
+ outputBuffer[0] = rs.d;
+ outputBuffer[1] = q.d;
+}
diff --git a/tests/autodiff/diff-member-func-custom-derivative.slang.expected.txt b/tests/autodiff/diff-member-func-custom-derivative.slang.expected.txt
new file mode 100644
index 000000000..1bb28547d
--- /dev/null
+++ b/tests/autodiff/diff-member-func-custom-derivative.slang.expected.txt
@@ -0,0 +1,3 @@
+type: float
+6.000000
+2.000000 \ No newline at end of file
diff --git a/tests/autodiff/member-func-extension-custom-derivative.slang b/tests/autodiff/member-func-extension-custom-derivative.slang
new file mode 100644
index 000000000..8752dfff5
--- /dev/null
+++ b/tests/autodiff/member-func-extension-custom-derivative.slang
@@ -0,0 +1,55 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+struct A
+{
+ float x;
+
+ float f(float v)
+ {
+ return v * v;
+ }
+}
+
+extension A
+{
+ [ForwardDerivativeOf(f)]
+ DifferentialPair<float> diff_f(DifferentialPair<float> v)
+ {
+ return diffPair(v.p * v.p, v.p * v.d * 2.0);
+ }
+}
+
+struct Foo<T : IDifferentiable>
+{
+ T value;
+ T doThing() { return value; }
+}
+
+extension<T : IDifferentiable> Foo<T>
+{
+ [ForwardDerivativeOf(doThing)]
+ DifferentialPair<T> diff_doThing()
+ {
+ return diffPair(value, T.dzero());
+ }
+}
+
+
+[ForwardDifferentiable]
+float test(Foo<float> obj, float v)
+{
+ return obj.doThing() * v;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ Foo<float> a = {0.0};
+ var p = diffPair(3.0, 1.0);
+ let rs = __fwd_diff(test)(a, p);
+ outputBuffer[0] = rs.d;
+}
diff --git a/tests/autodiff/member-func-extension-custom-derivative.slang.expected.txt b/tests/autodiff/member-func-extension-custom-derivative.slang.expected.txt
new file mode 100644
index 000000000..4b1f4c0d9
--- /dev/null
+++ b/tests/autodiff/member-func-extension-custom-derivative.slang.expected.txt
@@ -0,0 +1,2 @@
+type: float
+0.000000
diff --git a/tests/autodiff/static-func-custom-derivative.slang b/tests/autodiff/static-func-custom-derivative.slang
new file mode 100644
index 000000000..b75012735
--- /dev/null
+++ b/tests/autodiff/static-func-custom-derivative.slang
@@ -0,0 +1,59 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+struct A : IDifferentiable
+{
+ float x;
+
+ [ForwardDerivative(diff_f)]
+ static float f(float v)
+ {
+ return v * v;
+ }
+
+ static DifferentialPair<float> diff_f(DifferentialPair<float> v)
+ {
+ return diffPair(v.p * v.p, v.p * v.d * 2.0);
+ }
+
+ [BackwardDerivative(diff_g)]
+ static float g(float v)
+ {
+ return v * v;
+ }
+
+ static void diff_g(inout DifferentialPair<float> v, float.Differential dOut)
+ {
+ v = diffPair(v.p, dOut * 2.0);
+ }
+}
+
+[ForwardDifferentiable]
+float test(A obj, float v)
+{
+ return obj.f(v);
+}
+
+[BackwardDifferentiable]
+float test2(A obj, float v)
+{
+ return obj.g(v);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ A a = {0.0};
+ var p = diffPair(3.0, 1.0);
+ let rs = fwd_diff(test)(diffPair(a, {1.0}), p);
+
+ var q = diffPair(3.0);
+ var qa = diffPair(a);
+ bwd_diff(test2)(qa, q, 1.0);
+
+ outputBuffer[0] = rs.d;
+ outputBuffer[1] = q.d;
+}
diff --git a/tests/autodiff/static-func-custom-derivative.slang.expected.txt b/tests/autodiff/static-func-custom-derivative.slang.expected.txt
new file mode 100644
index 000000000..1bb28547d
--- /dev/null
+++ b/tests/autodiff/static-func-custom-derivative.slang.expected.txt
@@ -0,0 +1,3 @@
+type: float
+6.000000
+2.000000 \ No newline at end of file