summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/core.meta.slang44
-rw-r--r--source/slang/diff.meta.slang97
-rw-r--r--source/slang/slang-ast-synthesis.cpp8
-rw-r--r--source/slang/slang-ast-synthesis.h2
-rw-r--r--source/slang/slang-check-decl.cpp295
-rw-r--r--source/slang/slang-check-impl.h35
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp2
-rw-r--r--source/slang/slang-parser.cpp2
-rw-r--r--tests/autodiff/auto-differential-type.slang5
-rw-r--r--tests/autodiff/custom-intrinsic.slang56
-rw-r--r--tests/autodiff/differential-method-synthesis.slang8
-rw-r--r--tests/autodiff/differential-method-synthesis.slang.expected.txt2
-rw-r--r--tests/autodiff/generic-impl-jvp.slang18
-rw-r--r--tests/autodiff/generic-jvp.slang8
-rw-r--r--tests/autodiff/getter-setter-multi.slang4
-rw-r--r--tests/autodiff/getter-setter.slang4
16 files changed, 456 insertions, 134 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 77b9405ba..ae70a83f4 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -112,6 +112,12 @@ interface __BuiltinSignedArithmeticType : __BuiltinArithmeticType {}
interface __BuiltinIntegerType : __BuiltinArithmeticType
{}
+/// A type that can represent non-integers
+[sealed]
+[builtin]
+interface __BuiltinRealType : __BuiltinSignedArithmeticType {}
+
+
__attributeTarget(AggTypeDecl)
attribute_syntax [__NonCopyableType] : NonCopyableTypeAttribute;
@@ -144,7 +150,7 @@ interface IDifferentiable
// Note: the compiler implementation requires the `Differential` associated type to be defined
// before anything else.
- __builtin_requirement($( (int) BuiltinRequirementKind::DifferentialType) )
+ __builtin_requirement($( (int)BuiltinRequirementKind::DifferentialType) )
associatedtype Differential : IDifferentiable;
__builtin_requirement($( (int)BuiltinRequirementKind::DZeroFunc) )
@@ -154,7 +160,8 @@ interface IDifferentiable
static Differential dadd(Differential, Differential);
__builtin_requirement($( (int)BuiltinRequirementKind::DMulFunc) )
- static Differential dmul(This, Differential);
+ __generic<T : __BuiltinRealType>
+ static Differential dmul(T, Differential);
};
@@ -219,19 +226,16 @@ struct DifferentialPair : IDifferentiable
T.Differential.dadd(a.d, b.d));
}
+ __generic<U : __BuiltinRealType>
[__unsafeForceInlineEarly]
- static Differential dmul(This a, Differential b)
+ static Differential dmul(U a, Differential b)
{
return Differential(
- T.dmul(a.p, b.p),
- T.Differential.dmul(a.d, b.d));
+ T.dmul<U>(a, b.p),
+ T.Differential.dmul<U>(a, b.d));
}
};
-/// A type that can represent non-integers
-[sealed]
-[builtin]
-interface __BuiltinRealType : __BuiltinSignedArithmeticType {}
/// A type that uses a floating-point representation
[sealed]
@@ -339,6 +343,9 @@ __generic<T, let N : int> __intrinsic_op(select) vector<T,N> operator?:(vector<b
__generic<T> __intrinsic_op(select) T select(bool condition, T ifTrue, T ifFalse);
__generic<T, let N : int> __intrinsic_op(select) vector<T,N> select(vector<bool,N> condition, vector<T,N> ifTrue, vector<T,N> ifFalse);
+// Allow real-number types to be cast into each other
+__intrinsic_op($(kIROp_FloatCast))
+ T __realCast<T : __BuiltinRealType, U : __BuiltinRealType>(U val);
${{{{
// We are going to use code generation to produce the
@@ -483,12 +490,13 @@ ${{{{
{
return a + b;
}
-
+
+ __generic<U : __BuiltinRealType>
[__unsafeForceInlineEarly]
[BackwardDifferentiable]
- static Differential dmul(Differential a, Differential b)
+ static Differential dmul(U a, Differential b)
{
- return a * b;
+ return __realCast<Differential, U>(a) * b;
}
${{{{
break;
@@ -1190,11 +1198,12 @@ extension vector<T, N> : IDifferentiable
return a + b;
}
+ __generic<U : __BuiltinRealType>
[__unsafeForceInlineEarly]
[BackwardDifferentiable]
- static Differential dmul(This a, Differential b)
+ static Differential dmul(U a, Differential b)
{
- return a * b;
+ return __realCast<T, U>(a) * b;
}
}
@@ -1216,12 +1225,13 @@ extension matrix<T, R, C> : IDifferentiable
{
return a + b;
}
-
+
+ __generic<U : __BuiltinRealType>
[__unsafeForceInlineEarly]
[BackwardDifferentiable]
- static Differential dmul(This a, Differential b)
+ static Differential dmul(U a, Differential b)
{
- return a * b;
+ return __realCast<T, U>(a) * b;
}
}
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index f2fd8e3b0..3e381e55d 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -374,12 +374,13 @@ extension Array<T, N> : IDifferentiable
return result;
}
+ __generic<U : __BuiltinRealType>
[__unsafeForceInlineEarly]
- static Differential dmul(This a, Differential b)
+ static Differential dmul(U a, Differential b)
{
Array<T.Differential, N> result;
for (int i = 0; i < N; i++)
- result[i] = T.dmul(a[i], b[i]);
+ result[i] = T.dmul<U>(a, b[i]);
return result;
}
}
@@ -543,8 +544,8 @@ DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair
for (int i = 0; i < N; ++i)
{
result = result + dpx.p[i] * dpy.p[i];
- d_result = T.dadd(d_result, T.dmul(dpx.p[i], __slang_noop_cast<T.Differential>(dpy.d[i])));
- d_result = T.dadd(d_result, T.dmul(dpy.p[i], __slang_noop_cast<T.Differential>(dpx.d[i])));
+ d_result = T.dadd(d_result, __slang_noop_cast<T.Differential>(dpx.p[i] * dpy.d[i]));
+ d_result = T.dadd(d_result, __slang_noop_cast<T.Differential>(dpy.p[i] * dpx.d[i]));
}
return DifferentialPair<T>(result, d_result);
}
@@ -797,7 +798,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
#define UNARY_DERIVATIVE_IMPL(NAME, FWD_DIFF_FUNC, BWD_DIFF_FUNC) \
__generic<T : __BuiltinFloatingPointType> \
- [BackwardDifferentiable][PreferRecompute] \
+ [BackwardDifferentiable] [PreferRecompute] \
[ForwardDerivativeOf(NAME)] \
DifferentialPair<T> __d_##NAME(DifferentialPair<T> dpx) \
{ \
@@ -805,7 +806,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
return DifferentialPair<T>(NAME(dpx.p), FWD_DIFF_FUNC); \
} \
__generic<T : __BuiltinFloatingPointType, let N : int> \
- [BackwardDifferentiable][PreferRecompute] \
+ [BackwardDifferentiable] [PreferRecompute] \
[ForwardDerivativeOf(NAME)] \
DifferentialPair<vector<T, N>> __d_##NAME##_vector(DifferentialPair<vector<T, N>> dpx) \
{ \
@@ -813,21 +814,21 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
return DifferentialPair<ReturnType>(NAME(dpx.p), FWD_DIFF_FUNC); \
} \
__generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
- [BackwardDifferentiable][PreferRecompute] \
+ [BackwardDifferentiable] [PreferRecompute] \
[ForwardDerivativeOf(NAME)] \
DifferentialPair<matrix<T, M, N>> __d_##NAME##_m(DifferentialPair<matrix<T, M, N>> dpm) \
{ \
- typealias ReturnType = vector<T,N>; \
- matrix<T,M,N>.Differential diff; \
+ typealias ReturnType = vector<T, N>; \
+ matrix<T, M, N>.Differential diff; \
[ForceUnroll] for (int i = 0; i < M; i++) \
{ \
var dpx = diffPair(dpm.p[i], dpm.d[i]); \
- diff[i] = FWD_DIFF_FUNC; \
+ diff[i] = __slang_noop_cast<vector<T, N>>(FWD_DIFF_FUNC); \
} \
return diffPair(NAME(dpm.p), diff); \
} \
__generic<T : __BuiltinFloatingPointType> \
- [BackwardDifferentiable][PreferRecompute] \
+ [BackwardDifferentiable] [PreferRecompute] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME(inout DifferentialPair<T> dpx, T.Differential dOut) \
{ \
@@ -835,31 +836,57 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \
} \
__generic<T : __BuiltinFloatingPointType, let N : int> \
- [BackwardDifferentiable][PreferRecompute] \
+ [BackwardDifferentiable] [PreferRecompute] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME##_vector( \
- inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) \
+ inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) \
{ \
typealias ReturnType = vector<T, N>; \
dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \
} \
__generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
- [BackwardDifferentiable][PreferRecompute] \
+ [BackwardDifferentiable] [PreferRecompute] \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME##_matrix( \
- inout DifferentialPair<matrix<T, M, N>> m, matrix<T, M, N>.Differential mdOut) \
+ inout DifferentialPair<matrix<T, M, N>> m, matrix<T, M, N>.Differential mdOut) \
{ \
typealias ReturnType = vector<T, N>; \
matrix<T, M, N>.Differential diff; \
[ForceUnroll] for (int i = 0; i < M; i++) \
{ \
var dpx = diffPair(m.p[i], m.d[i]); \
- var dOut = mdOut[i]; \
+ var dOut = __slang_noop_cast<vector<T, N>>(mdOut[i]); \
diff[i] = BWD_DIFF_FUNC; \
} \
m = diffPair(m.p, diff); \
}
-#define SIMPLE_UNARY_DERIVATIVE_IMPL(NAME, DIFF_FUNC) UNARY_DERIVATIVE_IMPL(NAME, ReturnType.dmul(DIFF_FUNC, dpx.d), ReturnType.dmul(DIFF_FUNC, dOut))
+#define SIMPLE_UNARY_DERIVATIVE_IMPL(NAME, DIFF_FUNC) UNARY_DERIVATIVE_IMPL(NAME, __mul_p_d(DIFF_FUNC, dpx.d), __mul_p_d(DIFF_FUNC, dOut))
+
+/// Element-wise multiply for scalars and vectors for (T, T.Differential)
+__generic<T : __BuiltinFloatingPointType>
+[__unsafeForceInlineEarly]
+[Differentiable]
+T.Differential __mul_p_d(T a, T.Differential b)
+{
+ return __slang_noop_cast<T.Differential>(a * __slang_noop_cast<T>(b));
+}
+
+__generic<T : __BuiltinFloatingPointType>
+[__unsafeForceInlineEarly]
+[Differentiable]
+T __mul_p_d(T a, T b)
+{
+ return (a * b);
+}
+
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[__unsafeForceInlineEarly]
+[Differentiable]
+vector<T, N> __mul_p_d(vector<T, N> a, vector<T, N> b)
+{
+ return a * b;
+}
+
/// Detach and set derivatives to zero.
__generic<T : IDifferentiable>
@@ -871,14 +898,14 @@ T detach(T x);
#define SLANG_SIGN(x) select(((x)>T(0.0)), ReturnType(T(1.0)), select(((x)==T(0.0)), ReturnType(T(0.0)), ReturnType(T(-1.0))))
// Absolute value
-UNARY_DERIVATIVE_IMPL(abs, ReturnType.dmul(SLANG_SIGN(dpx.p), dpx.d), ReturnType.dmul(SLANG_SIGN(dpx.p), dOut))
+UNARY_DERIVATIVE_IMPL(abs, (__mul_p_d(SLANG_SIGN(dpx.p), (dpx.d))), (__mul_p_d(SLANG_SIGN(dpx.p), (dOut))))
// Saturate
UNARY_DERIVATIVE_IMPL(saturate, select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dpx.d), select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dOut))
// frac
UNARY_DERIVATIVE_IMPL(frac, dpx.d, dOut)
// raidans, degrees
-SIMPLE_UNARY_DERIVATIVE_IMPL(radians, T(0.01745329251994329576923690768489))
-SIMPLE_UNARY_DERIVATIVE_IMPL(degrees, T(57.295779513082320876798154814105))
+SIMPLE_UNARY_DERIVATIVE_IMPL(radians, ReturnType(T(0.01745329251994329576923690768489)))
+SIMPLE_UNARY_DERIVATIVE_IMPL(degrees, ReturnType(T(57.295779513082320876798154814105)))
// Exponent
SIMPLE_UNARY_DERIVATIVE_IMPL(exp, exp(dpx.p))
SIMPLE_UNARY_DERIVATIVE_IMPL(exp2, exp2(dpx.p)* T(50.69314718055994530941723212145818))
@@ -915,8 +942,8 @@ __generic<T : __BuiltinFloatingPointType>
[ForwardDerivativeOf(atan2)]
DifferentialPair<T> __d_atan2(DifferentialPair<T> dpy, DifferentialPair<T> dpx)
{
- T.Differential dx = T.dmul(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d);
- T.Differential dy = T.dmul(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d);
+ T.Differential dx = __mul_p_d(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d);
+ T.Differential dy = __mul_p_d(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d);
return DifferentialPair<T>(
atan2(dpy.p, dpx.p),
T.dadd(dx, dy));
@@ -928,8 +955,8 @@ __generic<T : __BuiltinFloatingPointType>
[BackwardDerivativeOf(atan2)]
void __d_atan2(inout DifferentialPair<T> dpy, inout DifferentialPair<T> dpx, T.Differential dOut)
{
- dpx = diffPair(dpx.p, T.dmul(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d));
- dpy = diffPair(dpy.p, T.dmul(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d));
+ dpx = diffPair(dpx.p, __mul_p_d(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d));
+ dpy = diffPair(dpy.p, __mul_p_d(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d));
}
VECTOR_MATRIX_BINARY_DIFF_IMPL(atan2)
@@ -968,8 +995,8 @@ DifferentialPair<T> __d_pow(DifferentialPair<T> dpx, DifferentialPair<T> dpy)
}
T val = pow(dpx.p, dpy.p);
- T.Differential d1 = T.dmul(val * log(dpx.p), dpy.d);
- T.Differential d2 = T.dmul(val * dpy.p / dpx.p, dpx.d);
+ T.Differential d1 = __mul_p_d((val * log(dpx.p)), dpy.d);
+ T.Differential d2 = __mul_p_d((val * dpy.p / dpx.p), dpx.d);
return DifferentialPair<T>(
val,
T.dadd(d1, d2)
@@ -993,10 +1020,10 @@ void __d_pow(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Dif
T val = pow(dpx.p, dpy.p);
dpx = diffPair(
dpx.p,
- T.dmul(val * dpy.p / dpx.p, dOut));
+ (__mul_p_d((val * dpy.p / dpx.p), dOut)));
dpy = diffPair(
dpy.p,
- T.dmul(val * log(dpx.p), dOut));
+ (__mul_p_d((val * log(dpx.p)), dOut)));
}
}
@@ -1061,7 +1088,7 @@ DifferentialPair<T> __d_lerp(DifferentialPair<T> dpx, DifferentialPair<T> dpy, D
{
return DifferentialPair<T>(
lerp(dpx.p, dpy.p, dps.p),
- T.dadd(T.dadd(T.dmul((T(1.0) - dps.p), dpx.d), T.dmul(dps.p, dpy.d)), T.dmul(dpy.p - dpx.p, dps.d))
+ T.dadd(T.dadd(__mul_p_d((T(1.0) - dps.p), dpx.d), __mul_p_d(dps.p, dpy.d)), __mul_p_d((dpy.p - dpx.p), dps.d))
);
}
__generic<T : __BuiltinFloatingPointType>
@@ -1070,9 +1097,9 @@ __generic<T : __BuiltinFloatingPointType>
[BackwardDerivativeOf(lerp)]
void __d_lerp(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, inout DifferentialPair<T> dps, T.Differential dOut)
{
- dpx = diffPair(dpx.p, T.dmul(T(1.0) - dps.p, dOut));
- dpy = diffPair(dpy.p, T.dmul(dps.p, dOut));
- dps = diffPair(dpy.p, T.dmul((dpy.p - dpx.p), dOut));
+ dpx = diffPair(dpx.p, __mul_p_d((T(1.0) - dps.p), dOut));
+ dpy = diffPair(dpy.p, __mul_p_d(dps.p, dOut));
+ dps = diffPair(dpy.p, __mul_p_d((dpy.p - dpx.p), dOut));
}
VECTOR_MATRIX_TERNARY_DIFF_IMPL(lerp)
@@ -1175,7 +1202,7 @@ DifferentialPair<T> __d_mad(DifferentialPair<T> dpx, DifferentialPair<T> dpy, Di
{
return DifferentialPair<T>(
mad(dpx.p, dpy.p, dpz.p),
- T.dadd(T.dadd(T.dmul(dpy.p, dpx.d), T.dmul(dpx.p, dpy.d)), dpz.d));
+ T.dadd(T.dadd(__mul_p_d(dpy.p, dpx.d), __mul_p_d(dpx.p, dpy.d)), dpz.d));
}
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
@@ -1183,8 +1210,8 @@ __generic<T : __BuiltinFloatingPointType>
[PreferRecompute]
void __d_mad(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, inout DifferentialPair<T> dpz, T.Differential dOut)
{
- dpx = diffPair(dpx.p, T.dmul(dpy.p, dOut));
- dpy = diffPair(dpy.p, T.dmul(dpx.p, dOut));
+ dpx = diffPair(dpx.p, __mul_p_d(dpy.p, dOut));
+ dpy = diffPair(dpy.p, __mul_p_d(dpx.p, dOut));
dpz = diffPair(dpz.p, dOut);
}
VECTOR_MATRIX_TERNARY_DIFF_IMPL(mad)
diff --git a/source/slang/slang-ast-synthesis.cpp b/source/slang/slang-ast-synthesis.cpp
index 872088d54..65955e815 100644
--- a/source/slang/slang-ast-synthesis.cpp
+++ b/source/slang/slang-ast-synthesis.cpp
@@ -122,6 +122,14 @@ Expr* ASTSynthesizer::emitInvokeExpr(Expr* callee, List<Expr*>&& args)
return rs;
}
+Expr* ASTSynthesizer::emitGenericAppExpr(Expr* genericExpr, List<Expr*>&& args)
+{
+ auto rs = m_builder->create<GenericAppExpr>();
+ rs->functionExpr = genericExpr;
+ rs->arguments = _Move(args);
+ return rs;
+}
+
Expr* ASTSynthesizer::emitMemberExpr(Type* type, Name* name)
{
auto rs = m_builder->create<StaticMemberExpr>();
diff --git a/source/slang/slang-ast-synthesis.h b/source/slang/slang-ast-synthesis.h
index 6568b4c83..e595afac1 100644
--- a/source/slang/slang-ast-synthesis.h
+++ b/source/slang/slang-ast-synthesis.h
@@ -138,6 +138,8 @@ public:
Expr* emitInvokeExpr(Expr* callee, List<Expr*>&& args);
+ Expr* emitGenericAppExpr(Expr* genericExpr, List<Expr*>&& args);
+
DeclStmt* emitVarDeclStmt(Type* type, Name* name = nullptr, Expr* initVal = nullptr);
ExpressionStmt* emitExprStmt(Expr* expr);
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 4c1e967e3..2d009c28c 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -2574,6 +2574,135 @@ namespace Slang
return false;
}
+ GenericDecl* SemanticsVisitor::synthesizeGenericSignatureForRequirementWitness(
+ ConformanceCheckingContext* context,
+ DeclRef<GenericDecl> requiredMemberDeclRef,
+ List<Expr*>& synArgs,
+ List<Expr*>& synGenericArgs,
+ ThisExpr*& synThis)
+ {
+ auto synGenericDecl = m_astBuilder->create<GenericDecl>();
+
+ // For now our synthesized method will use the name and source
+ // location of the requirement we are trying to satisfy.
+ //
+ // TODO: as it stands right now our syntesized method will
+ // get a mangled name, which we don't actually want. Leaving
+ // out the name here doesn't help matters, because then *all*
+ // snthesized methods on a given type would share the same
+ // mangled name!
+ //
+ synGenericDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc;
+ if (synGenericDecl->nameAndLoc.name)
+ {
+ synGenericDecl->nameAndLoc.name = getSession()->getNameObj("$__syn_" + synGenericDecl->nameAndLoc.name->text);
+ }
+
+ // Dictionary to map from the original type parameters to the synthesized ones.
+ Dictionary<GenericTypeParamDecl*, GenericTypeParamDecl*> mapOrigToSynTypeParams;
+
+ // Our synthesized method will have parameters matching the names
+ // and types of those on the requirement, and it will use expressions
+ // that reference those parametesr as arguments for the call expresison
+ // that makes up the body.
+ //
+ for (auto member : requiredMemberDeclRef.getDecl()->members)
+ {
+ if (auto typeParamDecl = as<GenericTypeParamDecl>(member))
+ {
+ auto synTypeParamDecl = m_astBuilder->create<GenericTypeParamDecl>();
+ synTypeParamDecl->nameAndLoc = typeParamDecl->getNameAndLoc();
+ synTypeParamDecl->initType = typeParamDecl->initType;
+ synTypeParamDecl->parentDecl = synGenericDecl;
+ synGenericDecl->members.add(synTypeParamDecl);
+
+ mapOrigToSynTypeParams.add(typeParamDecl, synTypeParamDecl);
+
+ // Construct a DeclRefExpr from the type parameter.
+ auto synTypeParamDeclRef = makeDeclRef(synTypeParamDecl);
+
+ auto synTypeParamDeclRefExpr = m_astBuilder->create<VarExpr>();
+ synTypeParamDeclRefExpr->declRef = synTypeParamDeclRef;
+ synTypeParamDeclRefExpr->type = getTypeForDeclRef(m_astBuilder, synTypeParamDeclRef, SourceLoc());
+
+ synGenericArgs.add(synTypeParamDeclRefExpr);
+ }
+ }
+
+ for (auto member : requiredMemberDeclRef.getDecl()->members)
+ {
+ if (auto constraintDecl = as<GenericTypeConstraintDecl>(member))
+ {
+ getASTBuilder()->getSpecializedDeclRef(
+ constraintDecl, requiredMemberDeclRef.getSubst());
+
+ auto synConstraintDecl = m_astBuilder->create<GenericTypeConstraintDecl>();
+ synConstraintDecl->nameAndLoc = constraintDecl->getNameAndLoc();
+ synConstraintDecl->parentDecl = synGenericDecl;
+
+ // For constraints of type T : Interface, where T is a simple type parameter,
+ // find the declaration of T
+ //
+ if (auto typeParamDecl = as<DeclRefType>(constraintDecl->sub.type)->declRef.as<GenericTypeParamDecl>().getDecl())
+ {
+ auto synTypeParamDecl = mapOrigToSynTypeParams[typeParamDecl];
+
+ // Construct a DeclRefExpr from the type parameter.
+ auto synTypeParamDeclRef = makeDeclRef(synTypeParamDecl.getValue());
+
+ auto synTypeParamDeclRefExpr = m_astBuilder->create<VarExpr>();
+ synTypeParamDeclRefExpr->declRef = synTypeParamDeclRef;
+ synTypeParamDeclRefExpr->type = getTypeForDeclRef(m_astBuilder, synTypeParamDeclRef, SourceLoc());
+
+ synConstraintDecl->sub = TypeExp(synTypeParamDeclRefExpr);
+ synConstraintDecl->sup = constraintDecl->sup;
+ synGenericDecl->members.add(synConstraintDecl);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Cannot perform synthesis for requirements with complex type constraints.");
+ }
+ }
+ }
+
+ // Get outer substitutions. (This inner-most substition
+ // must be a ThisTypeSubstition)
+ //
+ Substitutions* outer = nullptr;
+ if (auto thisTypeSubst = findThisTypeSubstitution(
+ requiredMemberDeclRef.getSubst(),
+ as<InterfaceDecl>(requiredMemberDeclRef.getParent(m_astBuilder)).getDecl()))
+ {
+ outer = thisTypeSubst;
+ }
+
+ // Override generic pointer to point to the original generic container.
+ // This will create a substitution of the synthesized parameters for the
+ // original parameters.
+ //
+ GenericSubstitution* requiredFuncSubsts = createDefaultSubstitutionsForGeneric(m_astBuilder, this, requiredMemberDeclRef.getDecl(), outer);
+ DeclRef<Decl> requiredFuncDeclRef = m_astBuilder->getSpecializedDeclRef(requiredMemberDeclRef.getDecl()->inner, requiredFuncSubsts);
+
+ GenericSubstitution* substSynParamsForOrigGeneric = m_astBuilder->getOrCreateGenericSubstitution(
+ outer,
+ requiredMemberDeclRef.getDecl(),
+ createDefaultSubstitutionsForGeneric(m_astBuilder, this, synGenericDecl, nullptr)->getArgs());
+
+ // Substitute parameters of the synthesized generic for the parameters of the original generic.
+ requiredFuncDeclRef = substituteDeclRef(substSynParamsForOrigGeneric, m_astBuilder, requiredFuncDeclRef);
+
+ SLANG_ASSERT(requiredFuncDeclRef.as<FuncDecl>());
+
+ synGenericDecl->inner = synthesizeMethodSignatureForRequirementWitness(
+ context,
+ requiredFuncDeclRef.as<FuncDecl>(),
+ synArgs,
+ synThis);
+ synGenericDecl->inner->parentDecl = synGenericDecl;
+
+ return synGenericDecl;
+ }
+
FuncDecl* SemanticsVisitor::synthesizeMethodSignatureForRequirementWitness(
ConformanceCheckingContext* context,
DeclRef<FuncDecl> requiredMemberDeclRef,
@@ -3274,12 +3403,30 @@ namespace Slang
switch (builtinAttr->kind)
{
case BuiltinRequirementKind::DAddFunc:
- case BuiltinRequirementKind::DMulFunc:
case BuiltinRequirementKind::DZeroFunc:
return trySynthesizeDifferentialMethodRequirementWitness(
context,
requiredFuncDeclRef,
- witnessTable);
+ witnessTable,
+ SynthesisPattern::AllInductive);
+ }
+ }
+ return false;
+ }
+
+ // For generic decl, check if we match DMulFunc, and synthesize the method.
+ if (auto requiredGenericDeclRef = requiredMemberDeclRef.as<GenericDecl>())
+ {
+ if (auto builtinAttr = getInner(requiredGenericDeclRef)->findModifier<BuiltinRequirementModifier>())
+ {
+ switch (builtinAttr->kind)
+ {
+ case BuiltinRequirementKind::DMulFunc:
+ return trySynthesizeDifferentialMethodRequirementWitness(
+ context,
+ requiredGenericDeclRef,
+ witnessTable,
+ SynthesisPattern::FixedFirstArg);
}
}
return false;
@@ -3330,7 +3477,15 @@ namespace Slang
return false;
}
- Stmt* _synthesizeMemberAssignMemberHelper(ASTSynthesizer& synth, Name* funcName, Type* leftType, Expr* leftValue, List<Expr*>&& args, int nestingLevel = 0)
+ Stmt* _synthesizeMemberAssignMemberHelper(
+ ASTSynthesizer& synth,
+ Name* funcName,
+ Type* leftType,
+ Expr* leftValue,
+ List<Expr*>&& args,
+ List<Expr*>&& genericArgs,
+ List<bool>&& inductiveArgMask,
+ int nestingLevel = 0)
{
if (nestingLevel > 16)
return nullptr;
@@ -3342,11 +3497,24 @@ namespace Slang
auto forStmt = synth.emitFor(synth.emitIntConst(0), synth.emitGetArrayLengthExpr(leftValue), indexVar);
addModifier(forStmt, synth.getBuilder()->create<ForceUnrollAttribute>());
auto innerLeft = synth.emitIndexExpr(leftValue, synth.emitVarExpr(indexVar));
- for (auto& arg : args)
+
+ for (auto ii = 0; ii < args.getCount(); ++ii)
{
- arg = synth.emitIndexExpr(arg, synth.emitVarExpr(indexVar));
+ auto& arg = args[ii];
+ if (inductiveArgMask[ii])
+ arg = synth.emitIndexExpr(arg, synth.emitVarExpr(indexVar));
}
- auto assignStmt = _synthesizeMemberAssignMemberHelper(synth, funcName, arrayType->getElementType(), innerLeft, _Move(args), nestingLevel + 1);
+
+ auto assignStmt = _synthesizeMemberAssignMemberHelper(
+ synth,
+ funcName,
+ arrayType->getElementType(),
+ innerLeft,
+ _Move(args),
+ _Move(genericArgs),
+ _Move(inductiveArgMask),
+ nestingLevel + 1);
+
synth.popScope();
if (!assignStmt)
return nullptr;
@@ -3354,13 +3522,18 @@ namespace Slang
}
auto callee = synth.emitMemberExpr(leftType, funcName);
+
+ if (genericArgs.getCount() > 0)
+ callee = synth.emitGenericAppExpr(callee, _Move(genericArgs));
+
return synth.emitAssignStmt(leftValue, synth.emitInvokeExpr(callee, _Move(args)));
}
bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness(
ConformanceCheckingContext* context,
DeclRef<Decl> requirementDeclRef,
- RefPtr<WitnessTable> witnessTable)
+ RefPtr<WitnessTable> witnessTable,
+ SynthesisPattern pattern)
{
// We support two cases of synthesis here.
// Case 1 is that there the associated Differential type is defined to be `DifferentialBottom`.
@@ -3371,9 +3544,10 @@ namespace Slang
// ```
// static TResult requiredMethod(TParam1 p0, TParam2 p1, ...)
// ```
- // Where TResult, TParam1, TParam2 is either `This` or `Differential`,
- // We synthesize a memberwise dispatch to compute each field of `TResult`,
- // resulting an implementation of the form:
+ // Where TResult,TParam1, TParam2 is either `This` or `Differential`,
+ // We synthesize a memberwise dispatch to compute each field of `TResult`.
+ // Multiple patterns are supported (see SemanticsVisitor::SynthesisPattern for a full list)
+ // For AllInductive, we synthesize an implementation of the form:
// ```
// [BackwardDifferentiable]
// static TResult requiredMethod(TParam1 p0, TParam2 p1, ...)
@@ -3404,13 +3578,32 @@ namespace Slang
ASTSynthesizer synth(m_astBuilder, getNamePool());
List<Expr*> synArgs;
+ List<Expr*> synGenericArgs;
ThisExpr* synThis = nullptr;
- auto synFunc = synthesizeMethodSignatureForRequirementWitness(
- context, requirementDeclRef.as<FuncDecl>(), synArgs, synThis);
+ FuncDecl* synFunc = nullptr;
+ GenericDecl* synGeneric = nullptr;
+
+ if (auto genericDeclRef = requirementDeclRef.as<GenericDecl>())
+ {
+ synGeneric = synthesizeGenericSignatureForRequirementWitness(
+ context, genericDeclRef, synArgs, synGenericArgs, synThis);
+ synFunc = as<FuncDecl>(synGeneric->inner);
+ }
+ else if (auto funcDeclRef = requirementDeclRef.as<FuncDecl>())
+ {
+ synFunc = synthesizeMethodSignatureForRequirementWitness(
+ context, funcDeclRef, synArgs, synThis);
+ }
+
+ SLANG_ASSERT(synFunc);
addModifier(synFunc, m_astBuilder->create<BackwardDifferentiableAttribute>());
- synFunc->parentDecl = context->parentDecl;
+ if (synGeneric)
+ synGeneric->parentDecl = context->parentDecl;
+ else
+ synFunc->parentDecl = context->parentDecl;
+
synth.pushContainerScope(synFunc);
auto blockStmt = m_astBuilder->create<BlockStmt>();
synFunc->body = blockStmt;
@@ -3438,23 +3631,71 @@ namespace Slang
// Construct reference exprs to the member's corresponding fields in each parameter.
List<Expr*> paramFields;
- int paramIndex = 0;
- for (auto arg : synArgs)
+ List<bool> inductiveArgMask;
+
+ switch (pattern)
{
- auto memberExpr = m_astBuilder->create<MemberExpr>();
- memberExpr->baseExpression = arg;
- // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is
- // Differential type.
- memberExpr->name = varMember->getName();
- paramFields.add(memberExpr);
- paramIndex++;
+ case SynthesisPattern::AllInductive:
+ {
+ int paramIndex = 0;
+ for (auto arg : synArgs)
+ {
+ auto memberExpr = m_astBuilder->create<MemberExpr>();
+ memberExpr->baseExpression = arg;
+ // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is
+ // Differential type.
+ memberExpr->name = varMember->getName();
+ paramFields.add(memberExpr);
+ inductiveArgMask.add(true);
+
+ paramIndex++;
+ }
+ break;
+ }
+ case SynthesisPattern::FixedFirstArg:
+ {
+ int paramIndex = 0;
+ for (auto arg : synArgs)
+ {
+ if (paramIndex == 0)
+ {
+ paramFields.add(arg);
+ inductiveArgMask.add(false);
+
+ paramIndex++;
+ }
+ else
+ {
+ auto memberExpr = m_astBuilder->create<MemberExpr>();
+ memberExpr->baseExpression = arg;
+ // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is
+ // Differential type.
+ memberExpr->name = varMember->getName();
+ paramFields.add(memberExpr);
+ inductiveArgMask.add(true);
+
+ paramIndex++;
+ }
+ }
+ break;
+ }
+ default:
+ SLANG_UNIMPLEMENTED_X("unhandled synthesis pattern");
+ break;
}
// Invoke the method for the field and assign the value to resultVar.
// TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr`
// is Differential type.
auto leftVal = synth.emitMemberExpr(resultVarExpr, varMember->getName());
- if (!_synthesizeMemberAssignMemberHelper(synth, requirementDeclRef.getName(), memberType, leftVal, _Move(paramFields)))
+ if (!_synthesizeMemberAssignMemberHelper(
+ synth,
+ requirementDeclRef.getName(),
+ memberType,
+ leftVal,
+ _Move(paramFields),
+ _Move(synGenericArgs),
+ _Move(inductiveArgMask)))
return false;
}
@@ -3473,11 +3714,11 @@ namespace Slang
// This can be done by obtaining ThisTypeSubstitution from requirementDeclRef to get the
// generic substitution for outer generic parameters, and apply it here.
SubstitutionSet substSet;
- if (auto thisTypeSusbt = findThisTypeSubstitution(
+ if (auto thisTypeSubst = findThisTypeSubstitution(
requirementDeclRef.getSubst(),
as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl)))
{
- if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub))
+ if (auto declRefType = as<DeclRefType>(thisTypeSubst->witness->sub))
{
substSet = declRefType->declRef.getSubst();
}
@@ -3610,7 +3851,9 @@ namespace Slang
// requirement, it may be possible that we can still synthesis the
// implementation if this is one of the known builtin requirements.
// Otherwise, report diagnostic now.
- if (!requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementModifier>())
+ if (!requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementModifier>() &&
+ !(requiredMemberDeclRef.as<GenericDecl>() &&
+ getInner(requiredMemberDeclRef.as<GenericDecl>())->hasModifier<BuiltinRequirementModifier>()))
{
getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef);
getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef);
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 04112743a..575d4aff7 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1467,6 +1467,13 @@ namespace Slang
DeclRef<FuncDecl> requiredMemberDeclRef,
List<Expr*>& synArgs,
ThisExpr*& synThis);
+
+ GenericDecl* synthesizeGenericSignatureForRequirementWitness(
+ ConformanceCheckingContext* context,
+ DeclRef<GenericDecl> requiredMemberDeclRef,
+ List<Expr*>& synArgs,
+ List<Expr*>& synGenericArgs,
+ ThisExpr*& synThis);
void _addMethodWitness(
WitnessTable* witnessTable,
@@ -1503,15 +1510,39 @@ namespace Slang
LookupResult const& lookupResult,
DeclRef<Decl> requiredMemberDeclRef,
RefPtr<WitnessTable> witnessTable);
+
- /// Attempt to synthesize `zero`, `dadd` and `dmul` methods for a type that conforms to
+ enum SynthesisPattern
+ {
+ // Synthesized method inducts over all arguments.
+ // T fn(T x, T y, T z, ...)
+ // {
+ // typeof(T::member0)::fn(x.member0, y.member0, z.member0, ...);
+ // typeof(T::member1)::fn(x.member1, y.member1, z.member1, ...);
+ // ...
+ // }
+ //
+ AllInductive,
+
+ // Synthesized method inducts over all arguments except the first.
+ // T fn(U x, T y, T z)
+ // {
+ // typeof(T::member0)::fn(x, y.member0, z.member0, ...);
+ // typeof(T::member1)::fn(x, y.member1, z.member1, ...);
+ // ...
+ // }
+ FixedFirstArg
+ };
+
+ /// Attempt to synthesize `zero`, `dadd` & `dmul` methods for a type that conforms to
/// `IDifferentiable`.
/// On success, installs the syntethesized functions and returns `true`.
/// Otherwise, returns `false`.
bool trySynthesizeDifferentialMethodRequirementWitness(
ConformanceCheckingContext* context,
DeclRef<Decl> requirementDeclRef,
- RefPtr<WitnessTable> witnessTable);
+ RefPtr<WitnessTable> witnessTable,
+ SynthesisPattern pattern);
/// Attempt to synthesize an associated `Differential` type for a type that conforms to
/// `IDifferentiable`.
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 3207e0729..6171d9a75 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -636,7 +636,7 @@ public:
{
if (auto genericInst = as<IRGeneric>(inst))
{
- if (auto innerFunc = as<IRGlobalValueWithCode>(findGenericReturnVal(genericInst)))
+ if (auto innerFunc = as<IRGlobalValueWithCode>(findInnerMostGenericReturnVal(genericInst)))
processFunc(innerFunc);
}
else if (auto funcInst = as<IRGlobalValueWithCode>(inst))
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index b0af5378c..cab01d585 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -3924,7 +3924,7 @@ namespace Slang
declToModify = genericDecl->inner;
_addModifiers(declToModify, modifiers);
- if (containerDecl)
+ if (containerDecl && !as<GenericDecl>(containerDecl))
{
// Make sure the decl is properly nested inside its lexical parent
AddMember(containerDecl, decl);
diff --git a/tests/autodiff/auto-differential-type.slang b/tests/autodiff/auto-differential-type.slang
index efeebb459..a253a25bb 100644
--- a/tests/autodiff/auto-differential-type.slang
+++ b/tests/autodiff/auto-differential-type.slang
@@ -26,9 +26,10 @@ struct A : IDifferentiable
}
[__unsafeForceInlineEarly]
- static Differential dmul(This a, Differential b)
+ __generic<T : __BuiltinRealType>
+ static Differential dmul(T a, Differential b)
{
- Differential o = {a.x * b.x, 0.0};
+ Differential o = { __realCast<float, T>(a * __realCast<T, float>(b.x)), 0.0};
return o;
}
};
diff --git a/tests/autodiff/custom-intrinsic.slang b/tests/autodiff/custom-intrinsic.slang
index 8048c60ff..dd122a674 100644
--- a/tests/autodiff/custom-intrinsic.slang
+++ b/tests/autodiff/custom-intrinsic.slang
@@ -6,81 +6,81 @@ RWStructuredBuffer<float> outputBuffer;
typedef DifferentialPair<float> dpfloat;
-typealias IDFloat = IFloat & IDifferentiable;
+typealias IDFloat = __BuiltinFloatingPointType & IDifferentiable;
namespace myintrinsiclib
{
__generic<T : IDFloat>
- __target_intrinsic(hlsl)
- __target_intrinsic(glsl)
+ __target_intrinsic(hlsl, "exp($0)")
+ __target_intrinsic(glsl, "exp($0)")
__target_intrinsic(cuda, "$P_exp($0)")
__target_intrinsic(cpp, "$P_exp($0)")
__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0")
- [ForwardDerivative(d_exp<T>)]
- T exp(T x);
+ [ForwardDerivative(d_myexp<T>)]
+ T myexp(T x);
__generic<T : IDFloat>
- DifferentialPair<T> d_exp(DifferentialPair<T> dpx)
+ DifferentialPair<T> d_myexp(DifferentialPair<T> dpx)
{
return DifferentialPair<T>(
- exp(dpx.p),
- T.dmul(exp(dpx.p), dpx.d));
+ myexp(dpx.p),
+ T.dmul(myexp(dpx.p), dpx.d));
}
// Sine
__generic<T : IDFloat>
- __target_intrinsic(hlsl)
- __target_intrinsic(glsl)
+ __target_intrinsic(hlsl, "sin($0)")
+ __target_intrinsic(glsl, "sin($0)")
__target_intrinsic(cuda, "$P_sin($0)")
__target_intrinsic(cpp, "$P_sin($0)")
__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0")
- [ForwardDerivative(d_sin<T>)]
- T sin(T x);
+ [ForwardDerivative(d_mysin<T>)]
+ T mysin(T x);
__generic<T : IDFloat>
- DifferentialPair<T> d_sin(DifferentialPair<T> dpx)
+ DifferentialPair<T> d_mysin(DifferentialPair<T> dpx)
{
return DifferentialPair<T>(
- sin(dpx.p),
- T.dmul(cos(dpx.p), dpx.d));
+ mysin(dpx.p),
+ T.dmul(mycos(dpx.p), dpx.d));
}
// Cosine
__generic<T : IDFloat>
- __target_intrinsic(hlsl)
- __target_intrinsic(glsl)
+ __target_intrinsic(hlsl, "cos($0)")
+ __target_intrinsic(glsl, "cos($0)")
__target_intrinsic(cuda, "$P_cos($0)")
__target_intrinsic(cpp, "$P_cos($0)")
__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0")
- [ForwardDerivative(d_cos<T>)]
- T cos(T x);
+ [ForwardDerivative(d_mycos<T>)]
+ T mycos(T x);
__generic<T : IDFloat>
- DifferentialPair<T> d_cos(DifferentialPair<T> dpx)
+ DifferentialPair<T> d_mycos(DifferentialPair<T> dpx)
{
return DifferentialPair<T>(
- cos(dpx.p),
+ mycos(dpx.p),
T.dmul(-sin(dpx.p), dpx.d));
}
// Sine and cosine
__generic<T : IDFloat>
- __target_intrinsic(hlsl)
+ __target_intrinsic(hlsl, "sincos($0, $1, $2)")
__target_intrinsic(cuda, "$P_sincos($0, $1, $2)")
- [ForwardDerivative(d_sincos<T>)]
- void sincos(T x, out T s, out T c)
+ [ForwardDerivative(d_mysincos<T>)]
+ void mysincos(T x, out T s, out T c)
{
s = sin(x);
c = cos(x);
}
__generic<T : IDFloat>
- void d_sincos(DifferentialPair<T> x, out DifferentialPair<T> s, out DifferentialPair<T> c)
+ void d_mysincos(DifferentialPair<T> x, out DifferentialPair<T> s, out DifferentialPair<T> c)
{
T _s;
T _c;
- sincos(x.p, _s, _c);
+ mysincos(x.p, _s, _c);
s = DifferentialPair<T>(_s, T.dmul(_c, x.d));
c = DifferentialPair<T>(_c, T.dmul(-_s, x.d));
@@ -90,7 +90,7 @@ namespace myintrinsiclib
[ForwardDifferentiable]
float f(float x)
{
- return myintrinsiclib.exp(x);
+ return myintrinsiclib.myexp(x);
}
[ForwardDifferentiable]
@@ -98,7 +98,7 @@ float g(float x)
{
float s;
float t;
- myintrinsiclib.sincos(x, s, t);
+ myintrinsiclib.mysincos(x, s, t);
return s + t;
}
diff --git a/tests/autodiff/differential-method-synthesis.slang b/tests/autodiff/differential-method-synthesis.slang
index 3220976e7..e9385b78c 100644
--- a/tests/autodiff/differential-method-synthesis.slang
+++ b/tests/autodiff/differential-method-synthesis.slang
@@ -41,8 +41,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
A a = {1.0, 2.0};
A.Differential b = {0.2};
dpA dpa = dpA(a, b);
- outputBuffer[0] = __fwd_diff(f)(dpa).d.b.x; // Expect: 0
- outputBuffer[1] = A.dadd(b, b).b.x; // Expect: 0.4
- outputBuffer[2] = A.dmul(a, b).b.x; // Expect: 0.2
+ outputBuffer[0] = __fwd_diff(f)(dpa).d.b.x; // Expect: 0
+ outputBuffer[1] = A.dadd(b, b).b.x; // Expect: 0.4
+ outputBuffer[2] = A.dmul<float>(2.0, b).b.x; // Expect: 0.4
}
-}
+}
diff --git a/tests/autodiff/differential-method-synthesis.slang.expected.txt b/tests/autodiff/differential-method-synthesis.slang.expected.txt
index 5fbff9752..353c35ec8 100644
--- a/tests/autodiff/differential-method-synthesis.slang.expected.txt
+++ b/tests/autodiff/differential-method-synthesis.slang.expected.txt
@@ -1,6 +1,6 @@
type: float
0.000000
0.400000
-0.200000
+0.400000
0.000000
0.000000
diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang
index 98adc4a7c..674d5c5ca 100644
--- a/tests/autodiff/generic-impl-jvp.slang
+++ b/tests/autodiff/generic-impl-jvp.slang
@@ -6,7 +6,7 @@ RWStructuredBuffer<float> outputBuffer;
typedef float Real;
-typealias IDFloat = IFloat & IDifferentiable;
+typealias IDFloat = __BuiltinRealType & IDifferentiable;
__generic<T : IDifferentiable, let N : int>
struct dvector : IDifferentiable
@@ -44,13 +44,13 @@ struct myvector : IDifferentiable
}
- static Differential dmul(This a, Differential b)
+ static Differential dmul<U: __BuiltinRealType>(U a, Differential b)
{
Differential output;
for (int i = 0; i < N; i++)
{
- output.values[i] = T.dmul(a.values[i], b.values[i]);
+ output.values[i] = T.dmul<U>(a, b.values[i]);
}
return output;
@@ -112,7 +112,7 @@ __generic<T : IDFloat, let N : int>
[ForwardDerivative(dot_jvp)]
T dot(myvector<T, N> a, myvector<T, N> b)
{
- T curr = (T)0.0;
+ T curr = __realCast<T, float>(0.f);
[ForceUnroll]
for (int i = 0; i < N; i++)
{
@@ -129,7 +129,7 @@ __generic<T : IDFloat, let N : int>
DifferentialPair<T> dot_jvp(dpvector<T, N> a, dpvector<T, N> b)
{
T.Differential curr_d = (T.dzero());
- T curr_p = (T)0.0;
+ T curr_p = __realCast<T, float>(0.f);
[ForceUnroll]
for (int i = 0; i < N; i++)
{
@@ -137,8 +137,8 @@ DifferentialPair<T> dot_jvp(dpvector<T, N> a, dpvector<T, N> b)
curr_d = T.dadd(
curr_d,
T.dadd(
- T.dmul(a.p.values[i], b.d.values[i]),
- T.dmul(b.p.values[i], a.d.values[i])));
+ T.dmul<T>(a.p.values[i], b.d.values[i]),
+ T.dmul<T>(b.p.values[i], a.d.values[i])));
}
return DifferentialPair<T>(curr_p, curr_d);
@@ -203,9 +203,9 @@ struct linearvector : MyLinearArithmeticType, IDifferentiable
return { myvector<Real, N>.dadd(a.val, b.val) };
}
- static Differential dmul(This a, Differential b)
+ static Differential dmul<T: __BuiltinRealType>(T a, Differential b)
{
- return { myvector<Real, N>.dmul(a.val, b.val) };
+ return { myvector<Real, N>.dmul<T>(a, b.val) };
}
[ForwardDifferentiable]
diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang
index 2be0045d4..7e5625477 100644
--- a/tests/autodiff/generic-jvp.slang
+++ b/tests/autodiff/generic-jvp.slang
@@ -113,9 +113,9 @@ extension myfloat3 : IDifferentiable
}
[ForwardDifferentiable]
- static Differential dmul(Differential a, Differential b)
+ static Differential dmul<T : __BuiltinRealType>(T a, Differential b)
{
- return a * b;
+ return { __realCast<Real, T>(a) * b.val };
}
};
@@ -139,9 +139,9 @@ extension myfloat4 : IDifferentiable
}
[ForwardDifferentiable]
- static Differential dmul(Differential a, Differential b)
+ static Differential dmul<T: __BuiltinRealType>(T a, Differential b)
{
- return a * b;
+ return { __realCast<Real, T>(a) * b.val };
}
};
diff --git a/tests/autodiff/getter-setter-multi.slang b/tests/autodiff/getter-setter-multi.slang
index 9055e860a..9f03ac4eb 100644
--- a/tests/autodiff/getter-setter-multi.slang
+++ b/tests/autodiff/getter-setter-multi.slang
@@ -34,9 +34,9 @@ struct A : IDifferentiable
}
[__unsafeForceInlineEarly]
- static Differential dmul(This a, Differential b)
+ static Differential dmul<T: __BuiltinRealType>(T a, Differential b)
{
- B o = {a.x * b.z};
+ B o = {__realCast<float, T>(a) * b.z};
return o;
}
};
diff --git a/tests/autodiff/getter-setter.slang b/tests/autodiff/getter-setter.slang
index 06caadce8..bc7343f27 100644
--- a/tests/autodiff/getter-setter.slang
+++ b/tests/autodiff/getter-setter.slang
@@ -32,9 +32,9 @@ struct A : IDifferentiable
}
[__unsafeForceInlineEarly]
- static Differential dmul(This a, Differential b)
+ static Differential dmul<T : __BuiltinRealType>(T a, Differential b)
{
- B o = {a.x * b.z};
+ B o = {__realCast<float, T>(a) * b.z};
return o;
}
};