diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-07-26 17:15:21 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-07-26 17:15:21 -0400 |
| commit | ba89fc84267bfd09f1c8abf10a5b85d09bbc79de (patch) | |
| tree | 2c79fc5dafb89a030d22fa86cd6fa3d69a89a785 /source/slang | |
| parent | b8ade05df10a2774d3da5ef1fb2c7479ff48989a (diff) | |
Refactor `dmul(This, Differential)` to `dmul<T:Real>(T, Differential)` (#3029)
* Refactor `dmul(This, Differential)` to `dmul<T:Real>(T, Differential)`
- Add AST synthesis support for generic containers
- Refactor relevant tests
* Merge dmul synthesis with dadd and dzero, and disambiguate using an enum
* Fix trailing spaces
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/core.meta.slang | 44 | ||||
| -rw-r--r-- | source/slang/diff.meta.slang | 97 | ||||
| -rw-r--r-- | source/slang/slang-ast-synthesis.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-ast-synthesis.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 295 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 35 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 2 |
8 files changed, 403 insertions, 82 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 77b9405ba..ae70a83f4 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -112,6 +112,12 @@ interface __BuiltinSignedArithmeticType : __BuiltinArithmeticType {} interface __BuiltinIntegerType : __BuiltinArithmeticType {} +/// A type that can represent non-integers +[sealed] +[builtin] +interface __BuiltinRealType : __BuiltinSignedArithmeticType {} + + __attributeTarget(AggTypeDecl) attribute_syntax [__NonCopyableType] : NonCopyableTypeAttribute; @@ -144,7 +150,7 @@ interface IDifferentiable // Note: the compiler implementation requires the `Differential` associated type to be defined // before anything else. - __builtin_requirement($( (int) BuiltinRequirementKind::DifferentialType) ) + __builtin_requirement($( (int)BuiltinRequirementKind::DifferentialType) ) associatedtype Differential : IDifferentiable; __builtin_requirement($( (int)BuiltinRequirementKind::DZeroFunc) ) @@ -154,7 +160,8 @@ interface IDifferentiable static Differential dadd(Differential, Differential); __builtin_requirement($( (int)BuiltinRequirementKind::DMulFunc) ) - static Differential dmul(This, Differential); + __generic<T : __BuiltinRealType> + static Differential dmul(T, Differential); }; @@ -219,19 +226,16 @@ struct DifferentialPair : IDifferentiable T.Differential.dadd(a.d, b.d)); } + __generic<U : __BuiltinRealType> [__unsafeForceInlineEarly] - static Differential dmul(This a, Differential b) + static Differential dmul(U a, Differential b) { return Differential( - T.dmul(a.p, b.p), - T.Differential.dmul(a.d, b.d)); + T.dmul<U>(a, b.p), + T.Differential.dmul<U>(a, b.d)); } }; -/// A type that can represent non-integers -[sealed] -[builtin] -interface __BuiltinRealType : __BuiltinSignedArithmeticType {} /// A type that uses a floating-point representation [sealed] @@ -339,6 +343,9 @@ __generic<T, let N : int> __intrinsic_op(select) vector<T,N> operator?:(vector<b __generic<T> __intrinsic_op(select) T select(bool condition, T ifTrue, T ifFalse); __generic<T, let N : int> __intrinsic_op(select) vector<T,N> select(vector<bool,N> condition, vector<T,N> ifTrue, vector<T,N> ifFalse); +// Allow real-number types to be cast into each other +__intrinsic_op($(kIROp_FloatCast)) + T __realCast<T : __BuiltinRealType, U : __BuiltinRealType>(U val); ${{{{ // We are going to use code generation to produce the @@ -483,12 +490,13 @@ ${{{{ { return a + b; } - + + __generic<U : __BuiltinRealType> [__unsafeForceInlineEarly] [BackwardDifferentiable] - static Differential dmul(Differential a, Differential b) + static Differential dmul(U a, Differential b) { - return a * b; + return __realCast<Differential, U>(a) * b; } ${{{{ break; @@ -1190,11 +1198,12 @@ extension vector<T, N> : IDifferentiable return a + b; } + __generic<U : __BuiltinRealType> [__unsafeForceInlineEarly] [BackwardDifferentiable] - static Differential dmul(This a, Differential b) + static Differential dmul(U a, Differential b) { - return a * b; + return __realCast<T, U>(a) * b; } } @@ -1216,12 +1225,13 @@ extension matrix<T, R, C> : IDifferentiable { return a + b; } - + + __generic<U : __BuiltinRealType> [__unsafeForceInlineEarly] [BackwardDifferentiable] - static Differential dmul(This a, Differential b) + static Differential dmul(U a, Differential b) { - return a * b; + return __realCast<T, U>(a) * b; } } diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index f2fd8e3b0..3e381e55d 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -374,12 +374,13 @@ extension Array<T, N> : IDifferentiable return result; } + __generic<U : __BuiltinRealType> [__unsafeForceInlineEarly] - static Differential dmul(This a, Differential b) + static Differential dmul(U a, Differential b) { Array<T.Differential, N> result; for (int i = 0; i < N; i++) - result[i] = T.dmul(a[i], b[i]); + result[i] = T.dmul<U>(a, b[i]); return result; } } @@ -543,8 +544,8 @@ DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair for (int i = 0; i < N; ++i) { result = result + dpx.p[i] * dpy.p[i]; - d_result = T.dadd(d_result, T.dmul(dpx.p[i], __slang_noop_cast<T.Differential>(dpy.d[i]))); - d_result = T.dadd(d_result, T.dmul(dpy.p[i], __slang_noop_cast<T.Differential>(dpx.d[i]))); + d_result = T.dadd(d_result, __slang_noop_cast<T.Differential>(dpx.p[i] * dpy.d[i])); + d_result = T.dadd(d_result, __slang_noop_cast<T.Differential>(dpy.p[i] * dpx.d[i])); } return DifferentialPair<T>(result, d_result); } @@ -797,7 +798,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve #define UNARY_DERIVATIVE_IMPL(NAME, FWD_DIFF_FUNC, BWD_DIFF_FUNC) \ __generic<T : __BuiltinFloatingPointType> \ - [BackwardDifferentiable][PreferRecompute] \ + [BackwardDifferentiable] [PreferRecompute] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair<T> __d_##NAME(DifferentialPair<T> dpx) \ { \ @@ -805,7 +806,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve return DifferentialPair<T>(NAME(dpx.p), FWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let N : int> \ - [BackwardDifferentiable][PreferRecompute] \ + [BackwardDifferentiable] [PreferRecompute] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair<vector<T, N>> __d_##NAME##_vector(DifferentialPair<vector<T, N>> dpx) \ { \ @@ -813,21 +814,21 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve return DifferentialPair<ReturnType>(NAME(dpx.p), FWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ - [BackwardDifferentiable][PreferRecompute] \ + [BackwardDifferentiable] [PreferRecompute] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair<matrix<T, M, N>> __d_##NAME##_m(DifferentialPair<matrix<T, M, N>> dpm) \ { \ - typealias ReturnType = vector<T,N>; \ - matrix<T,M,N>.Differential diff; \ + typealias ReturnType = vector<T, N>; \ + matrix<T, M, N>.Differential diff; \ [ForceUnroll] for (int i = 0; i < M; i++) \ { \ var dpx = diffPair(dpm.p[i], dpm.d[i]); \ - diff[i] = FWD_DIFF_FUNC; \ + diff[i] = __slang_noop_cast<vector<T, N>>(FWD_DIFF_FUNC); \ } \ return diffPair(NAME(dpm.p), diff); \ } \ __generic<T : __BuiltinFloatingPointType> \ - [BackwardDifferentiable][PreferRecompute] \ + [BackwardDifferentiable] [PreferRecompute] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME(inout DifferentialPair<T> dpx, T.Differential dOut) \ { \ @@ -835,31 +836,57 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let N : int> \ - [BackwardDifferentiable][PreferRecompute] \ + [BackwardDifferentiable] [PreferRecompute] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_vector( \ - inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) \ + inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) \ { \ typealias ReturnType = vector<T, N>; \ dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ - [BackwardDifferentiable][PreferRecompute] \ + [BackwardDifferentiable] [PreferRecompute] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_matrix( \ - inout DifferentialPair<matrix<T, M, N>> m, matrix<T, M, N>.Differential mdOut) \ + inout DifferentialPair<matrix<T, M, N>> m, matrix<T, M, N>.Differential mdOut) \ { \ typealias ReturnType = vector<T, N>; \ matrix<T, M, N>.Differential diff; \ [ForceUnroll] for (int i = 0; i < M; i++) \ { \ var dpx = diffPair(m.p[i], m.d[i]); \ - var dOut = mdOut[i]; \ + var dOut = __slang_noop_cast<vector<T, N>>(mdOut[i]); \ diff[i] = BWD_DIFF_FUNC; \ } \ m = diffPair(m.p, diff); \ } -#define SIMPLE_UNARY_DERIVATIVE_IMPL(NAME, DIFF_FUNC) UNARY_DERIVATIVE_IMPL(NAME, ReturnType.dmul(DIFF_FUNC, dpx.d), ReturnType.dmul(DIFF_FUNC, dOut)) +#define SIMPLE_UNARY_DERIVATIVE_IMPL(NAME, DIFF_FUNC) UNARY_DERIVATIVE_IMPL(NAME, __mul_p_d(DIFF_FUNC, dpx.d), __mul_p_d(DIFF_FUNC, dOut)) + +/// Element-wise multiply for scalars and vectors for (T, T.Differential) +__generic<T : __BuiltinFloatingPointType> +[__unsafeForceInlineEarly] +[Differentiable] +T.Differential __mul_p_d(T a, T.Differential b) +{ + return __slang_noop_cast<T.Differential>(a * __slang_noop_cast<T>(b)); +} + +__generic<T : __BuiltinFloatingPointType> +[__unsafeForceInlineEarly] +[Differentiable] +T __mul_p_d(T a, T b) +{ + return (a * b); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[__unsafeForceInlineEarly] +[Differentiable] +vector<T, N> __mul_p_d(vector<T, N> a, vector<T, N> b) +{ + return a * b; +} + /// Detach and set derivatives to zero. __generic<T : IDifferentiable> @@ -871,14 +898,14 @@ T detach(T x); #define SLANG_SIGN(x) select(((x)>T(0.0)), ReturnType(T(1.0)), select(((x)==T(0.0)), ReturnType(T(0.0)), ReturnType(T(-1.0)))) // Absolute value -UNARY_DERIVATIVE_IMPL(abs, ReturnType.dmul(SLANG_SIGN(dpx.p), dpx.d), ReturnType.dmul(SLANG_SIGN(dpx.p), dOut)) +UNARY_DERIVATIVE_IMPL(abs, (__mul_p_d(SLANG_SIGN(dpx.p), (dpx.d))), (__mul_p_d(SLANG_SIGN(dpx.p), (dOut)))) // Saturate UNARY_DERIVATIVE_IMPL(saturate, select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dpx.d), select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dOut)) // frac UNARY_DERIVATIVE_IMPL(frac, dpx.d, dOut) // raidans, degrees -SIMPLE_UNARY_DERIVATIVE_IMPL(radians, T(0.01745329251994329576923690768489)) -SIMPLE_UNARY_DERIVATIVE_IMPL(degrees, T(57.295779513082320876798154814105)) +SIMPLE_UNARY_DERIVATIVE_IMPL(radians, ReturnType(T(0.01745329251994329576923690768489))) +SIMPLE_UNARY_DERIVATIVE_IMPL(degrees, ReturnType(T(57.295779513082320876798154814105))) // Exponent SIMPLE_UNARY_DERIVATIVE_IMPL(exp, exp(dpx.p)) SIMPLE_UNARY_DERIVATIVE_IMPL(exp2, exp2(dpx.p)* T(50.69314718055994530941723212145818)) @@ -915,8 +942,8 @@ __generic<T : __BuiltinFloatingPointType> [ForwardDerivativeOf(atan2)] DifferentialPair<T> __d_atan2(DifferentialPair<T> dpy, DifferentialPair<T> dpx) { - T.Differential dx = T.dmul(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d); - T.Differential dy = T.dmul(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d); + T.Differential dx = __mul_p_d(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d); + T.Differential dy = __mul_p_d(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d); return DifferentialPair<T>( atan2(dpy.p, dpx.p), T.dadd(dx, dy)); @@ -928,8 +955,8 @@ __generic<T : __BuiltinFloatingPointType> [BackwardDerivativeOf(atan2)] void __d_atan2(inout DifferentialPair<T> dpy, inout DifferentialPair<T> dpx, T.Differential dOut) { - dpx = diffPair(dpx.p, T.dmul(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d)); - dpy = diffPair(dpy.p, T.dmul(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d)); + dpx = diffPair(dpx.p, __mul_p_d(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d)); + dpy = diffPair(dpy.p, __mul_p_d(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d)); } VECTOR_MATRIX_BINARY_DIFF_IMPL(atan2) @@ -968,8 +995,8 @@ DifferentialPair<T> __d_pow(DifferentialPair<T> dpx, DifferentialPair<T> dpy) } T val = pow(dpx.p, dpy.p); - T.Differential d1 = T.dmul(val * log(dpx.p), dpy.d); - T.Differential d2 = T.dmul(val * dpy.p / dpx.p, dpx.d); + T.Differential d1 = __mul_p_d((val * log(dpx.p)), dpy.d); + T.Differential d2 = __mul_p_d((val * dpy.p / dpx.p), dpx.d); return DifferentialPair<T>( val, T.dadd(d1, d2) @@ -993,10 +1020,10 @@ void __d_pow(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, T.Dif T val = pow(dpx.p, dpy.p); dpx = diffPair( dpx.p, - T.dmul(val * dpy.p / dpx.p, dOut)); + (__mul_p_d((val * dpy.p / dpx.p), dOut))); dpy = diffPair( dpy.p, - T.dmul(val * log(dpx.p), dOut)); + (__mul_p_d((val * log(dpx.p)), dOut))); } } @@ -1061,7 +1088,7 @@ DifferentialPair<T> __d_lerp(DifferentialPair<T> dpx, DifferentialPair<T> dpy, D { return DifferentialPair<T>( lerp(dpx.p, dpy.p, dps.p), - T.dadd(T.dadd(T.dmul((T(1.0) - dps.p), dpx.d), T.dmul(dps.p, dpy.d)), T.dmul(dpy.p - dpx.p, dps.d)) + T.dadd(T.dadd(__mul_p_d((T(1.0) - dps.p), dpx.d), __mul_p_d(dps.p, dpy.d)), __mul_p_d((dpy.p - dpx.p), dps.d)) ); } __generic<T : __BuiltinFloatingPointType> @@ -1070,9 +1097,9 @@ __generic<T : __BuiltinFloatingPointType> [BackwardDerivativeOf(lerp)] void __d_lerp(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, inout DifferentialPair<T> dps, T.Differential dOut) { - dpx = diffPair(dpx.p, T.dmul(T(1.0) - dps.p, dOut)); - dpy = diffPair(dpy.p, T.dmul(dps.p, dOut)); - dps = diffPair(dpy.p, T.dmul((dpy.p - dpx.p), dOut)); + dpx = diffPair(dpx.p, __mul_p_d((T(1.0) - dps.p), dOut)); + dpy = diffPair(dpy.p, __mul_p_d(dps.p, dOut)); + dps = diffPair(dpy.p, __mul_p_d((dpy.p - dpx.p), dOut)); } VECTOR_MATRIX_TERNARY_DIFF_IMPL(lerp) @@ -1175,7 +1202,7 @@ DifferentialPair<T> __d_mad(DifferentialPair<T> dpx, DifferentialPair<T> dpy, Di { return DifferentialPair<T>( mad(dpx.p, dpy.p, dpz.p), - T.dadd(T.dadd(T.dmul(dpy.p, dpx.d), T.dmul(dpx.p, dpy.d)), dpz.d)); + T.dadd(T.dadd(__mul_p_d(dpy.p, dpx.d), __mul_p_d(dpx.p, dpy.d)), dpz.d)); } __generic<T : __BuiltinFloatingPointType> [BackwardDifferentiable] @@ -1183,8 +1210,8 @@ __generic<T : __BuiltinFloatingPointType> [PreferRecompute] void __d_mad(inout DifferentialPair<T> dpx, inout DifferentialPair<T> dpy, inout DifferentialPair<T> dpz, T.Differential dOut) { - dpx = diffPair(dpx.p, T.dmul(dpy.p, dOut)); - dpy = diffPair(dpy.p, T.dmul(dpx.p, dOut)); + dpx = diffPair(dpx.p, __mul_p_d(dpy.p, dOut)); + dpy = diffPair(dpy.p, __mul_p_d(dpx.p, dOut)); dpz = diffPair(dpz.p, dOut); } VECTOR_MATRIX_TERNARY_DIFF_IMPL(mad) diff --git a/source/slang/slang-ast-synthesis.cpp b/source/slang/slang-ast-synthesis.cpp index 872088d54..65955e815 100644 --- a/source/slang/slang-ast-synthesis.cpp +++ b/source/slang/slang-ast-synthesis.cpp @@ -122,6 +122,14 @@ Expr* ASTSynthesizer::emitInvokeExpr(Expr* callee, List<Expr*>&& args) return rs; } +Expr* ASTSynthesizer::emitGenericAppExpr(Expr* genericExpr, List<Expr*>&& args) +{ + auto rs = m_builder->create<GenericAppExpr>(); + rs->functionExpr = genericExpr; + rs->arguments = _Move(args); + return rs; +} + Expr* ASTSynthesizer::emitMemberExpr(Type* type, Name* name) { auto rs = m_builder->create<StaticMemberExpr>(); diff --git a/source/slang/slang-ast-synthesis.h b/source/slang/slang-ast-synthesis.h index 6568b4c83..e595afac1 100644 --- a/source/slang/slang-ast-synthesis.h +++ b/source/slang/slang-ast-synthesis.h @@ -138,6 +138,8 @@ public: Expr* emitInvokeExpr(Expr* callee, List<Expr*>&& args); + Expr* emitGenericAppExpr(Expr* genericExpr, List<Expr*>&& args); + DeclStmt* emitVarDeclStmt(Type* type, Name* name = nullptr, Expr* initVal = nullptr); ExpressionStmt* emitExprStmt(Expr* expr); diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 4c1e967e3..2d009c28c 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2574,6 +2574,135 @@ namespace Slang return false; } + GenericDecl* SemanticsVisitor::synthesizeGenericSignatureForRequirementWitness( + ConformanceCheckingContext* context, + DeclRef<GenericDecl> requiredMemberDeclRef, + List<Expr*>& synArgs, + List<Expr*>& synGenericArgs, + ThisExpr*& synThis) + { + auto synGenericDecl = m_astBuilder->create<GenericDecl>(); + + // For now our synthesized method will use the name and source + // location of the requirement we are trying to satisfy. + // + // TODO: as it stands right now our syntesized method will + // get a mangled name, which we don't actually want. Leaving + // out the name here doesn't help matters, because then *all* + // snthesized methods on a given type would share the same + // mangled name! + // + synGenericDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc; + if (synGenericDecl->nameAndLoc.name) + { + synGenericDecl->nameAndLoc.name = getSession()->getNameObj("$__syn_" + synGenericDecl->nameAndLoc.name->text); + } + + // Dictionary to map from the original type parameters to the synthesized ones. + Dictionary<GenericTypeParamDecl*, GenericTypeParamDecl*> mapOrigToSynTypeParams; + + // Our synthesized method will have parameters matching the names + // and types of those on the requirement, and it will use expressions + // that reference those parametesr as arguments for the call expresison + // that makes up the body. + // + for (auto member : requiredMemberDeclRef.getDecl()->members) + { + if (auto typeParamDecl = as<GenericTypeParamDecl>(member)) + { + auto synTypeParamDecl = m_astBuilder->create<GenericTypeParamDecl>(); + synTypeParamDecl->nameAndLoc = typeParamDecl->getNameAndLoc(); + synTypeParamDecl->initType = typeParamDecl->initType; + synTypeParamDecl->parentDecl = synGenericDecl; + synGenericDecl->members.add(synTypeParamDecl); + + mapOrigToSynTypeParams.add(typeParamDecl, synTypeParamDecl); + + // Construct a DeclRefExpr from the type parameter. + auto synTypeParamDeclRef = makeDeclRef(synTypeParamDecl); + + auto synTypeParamDeclRefExpr = m_astBuilder->create<VarExpr>(); + synTypeParamDeclRefExpr->declRef = synTypeParamDeclRef; + synTypeParamDeclRefExpr->type = getTypeForDeclRef(m_astBuilder, synTypeParamDeclRef, SourceLoc()); + + synGenericArgs.add(synTypeParamDeclRefExpr); + } + } + + for (auto member : requiredMemberDeclRef.getDecl()->members) + { + if (auto constraintDecl = as<GenericTypeConstraintDecl>(member)) + { + getASTBuilder()->getSpecializedDeclRef( + constraintDecl, requiredMemberDeclRef.getSubst()); + + auto synConstraintDecl = m_astBuilder->create<GenericTypeConstraintDecl>(); + synConstraintDecl->nameAndLoc = constraintDecl->getNameAndLoc(); + synConstraintDecl->parentDecl = synGenericDecl; + + // For constraints of type T : Interface, where T is a simple type parameter, + // find the declaration of T + // + if (auto typeParamDecl = as<DeclRefType>(constraintDecl->sub.type)->declRef.as<GenericTypeParamDecl>().getDecl()) + { + auto synTypeParamDecl = mapOrigToSynTypeParams[typeParamDecl]; + + // Construct a DeclRefExpr from the type parameter. + auto synTypeParamDeclRef = makeDeclRef(synTypeParamDecl.getValue()); + + auto synTypeParamDeclRefExpr = m_astBuilder->create<VarExpr>(); + synTypeParamDeclRefExpr->declRef = synTypeParamDeclRef; + synTypeParamDeclRefExpr->type = getTypeForDeclRef(m_astBuilder, synTypeParamDeclRef, SourceLoc()); + + synConstraintDecl->sub = TypeExp(synTypeParamDeclRefExpr); + synConstraintDecl->sup = constraintDecl->sup; + synGenericDecl->members.add(synConstraintDecl); + } + else + { + SLANG_UNEXPECTED("Cannot perform synthesis for requirements with complex type constraints."); + } + } + } + + // Get outer substitutions. (This inner-most substition + // must be a ThisTypeSubstition) + // + Substitutions* outer = nullptr; + if (auto thisTypeSubst = findThisTypeSubstitution( + requiredMemberDeclRef.getSubst(), + as<InterfaceDecl>(requiredMemberDeclRef.getParent(m_astBuilder)).getDecl())) + { + outer = thisTypeSubst; + } + + // Override generic pointer to point to the original generic container. + // This will create a substitution of the synthesized parameters for the + // original parameters. + // + GenericSubstitution* requiredFuncSubsts = createDefaultSubstitutionsForGeneric(m_astBuilder, this, requiredMemberDeclRef.getDecl(), outer); + DeclRef<Decl> requiredFuncDeclRef = m_astBuilder->getSpecializedDeclRef(requiredMemberDeclRef.getDecl()->inner, requiredFuncSubsts); + + GenericSubstitution* substSynParamsForOrigGeneric = m_astBuilder->getOrCreateGenericSubstitution( + outer, + requiredMemberDeclRef.getDecl(), + createDefaultSubstitutionsForGeneric(m_astBuilder, this, synGenericDecl, nullptr)->getArgs()); + + // Substitute parameters of the synthesized generic for the parameters of the original generic. + requiredFuncDeclRef = substituteDeclRef(substSynParamsForOrigGeneric, m_astBuilder, requiredFuncDeclRef); + + SLANG_ASSERT(requiredFuncDeclRef.as<FuncDecl>()); + + synGenericDecl->inner = synthesizeMethodSignatureForRequirementWitness( + context, + requiredFuncDeclRef.as<FuncDecl>(), + synArgs, + synThis); + synGenericDecl->inner->parentDecl = synGenericDecl; + + return synGenericDecl; + } + FuncDecl* SemanticsVisitor::synthesizeMethodSignatureForRequirementWitness( ConformanceCheckingContext* context, DeclRef<FuncDecl> requiredMemberDeclRef, @@ -3274,12 +3403,30 @@ namespace Slang switch (builtinAttr->kind) { case BuiltinRequirementKind::DAddFunc: - case BuiltinRequirementKind::DMulFunc: case BuiltinRequirementKind::DZeroFunc: return trySynthesizeDifferentialMethodRequirementWitness( context, requiredFuncDeclRef, - witnessTable); + witnessTable, + SynthesisPattern::AllInductive); + } + } + return false; + } + + // For generic decl, check if we match DMulFunc, and synthesize the method. + if (auto requiredGenericDeclRef = requiredMemberDeclRef.as<GenericDecl>()) + { + if (auto builtinAttr = getInner(requiredGenericDeclRef)->findModifier<BuiltinRequirementModifier>()) + { + switch (builtinAttr->kind) + { + case BuiltinRequirementKind::DMulFunc: + return trySynthesizeDifferentialMethodRequirementWitness( + context, + requiredGenericDeclRef, + witnessTable, + SynthesisPattern::FixedFirstArg); } } return false; @@ -3330,7 +3477,15 @@ namespace Slang return false; } - Stmt* _synthesizeMemberAssignMemberHelper(ASTSynthesizer& synth, Name* funcName, Type* leftType, Expr* leftValue, List<Expr*>&& args, int nestingLevel = 0) + Stmt* _synthesizeMemberAssignMemberHelper( + ASTSynthesizer& synth, + Name* funcName, + Type* leftType, + Expr* leftValue, + List<Expr*>&& args, + List<Expr*>&& genericArgs, + List<bool>&& inductiveArgMask, + int nestingLevel = 0) { if (nestingLevel > 16) return nullptr; @@ -3342,11 +3497,24 @@ namespace Slang auto forStmt = synth.emitFor(synth.emitIntConst(0), synth.emitGetArrayLengthExpr(leftValue), indexVar); addModifier(forStmt, synth.getBuilder()->create<ForceUnrollAttribute>()); auto innerLeft = synth.emitIndexExpr(leftValue, synth.emitVarExpr(indexVar)); - for (auto& arg : args) + + for (auto ii = 0; ii < args.getCount(); ++ii) { - arg = synth.emitIndexExpr(arg, synth.emitVarExpr(indexVar)); + auto& arg = args[ii]; + if (inductiveArgMask[ii]) + arg = synth.emitIndexExpr(arg, synth.emitVarExpr(indexVar)); } - auto assignStmt = _synthesizeMemberAssignMemberHelper(synth, funcName, arrayType->getElementType(), innerLeft, _Move(args), nestingLevel + 1); + + auto assignStmt = _synthesizeMemberAssignMemberHelper( + synth, + funcName, + arrayType->getElementType(), + innerLeft, + _Move(args), + _Move(genericArgs), + _Move(inductiveArgMask), + nestingLevel + 1); + synth.popScope(); if (!assignStmt) return nullptr; @@ -3354,13 +3522,18 @@ namespace Slang } auto callee = synth.emitMemberExpr(leftType, funcName); + + if (genericArgs.getCount() > 0) + callee = synth.emitGenericAppExpr(callee, _Move(genericArgs)); + return synth.emitAssignStmt(leftValue, synth.emitInvokeExpr(callee, _Move(args))); } bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness( ConformanceCheckingContext* context, DeclRef<Decl> requirementDeclRef, - RefPtr<WitnessTable> witnessTable) + RefPtr<WitnessTable> witnessTable, + SynthesisPattern pattern) { // We support two cases of synthesis here. // Case 1 is that there the associated Differential type is defined to be `DifferentialBottom`. @@ -3371,9 +3544,10 @@ namespace Slang // ``` // static TResult requiredMethod(TParam1 p0, TParam2 p1, ...) // ``` - // Where TResult, TParam1, TParam2 is either `This` or `Differential`, - // We synthesize a memberwise dispatch to compute each field of `TResult`, - // resulting an implementation of the form: + // Where TResult,TParam1, TParam2 is either `This` or `Differential`, + // We synthesize a memberwise dispatch to compute each field of `TResult`. + // Multiple patterns are supported (see SemanticsVisitor::SynthesisPattern for a full list) + // For AllInductive, we synthesize an implementation of the form: // ``` // [BackwardDifferentiable] // static TResult requiredMethod(TParam1 p0, TParam2 p1, ...) @@ -3404,13 +3578,32 @@ namespace Slang ASTSynthesizer synth(m_astBuilder, getNamePool()); List<Expr*> synArgs; + List<Expr*> synGenericArgs; ThisExpr* synThis = nullptr; - auto synFunc = synthesizeMethodSignatureForRequirementWitness( - context, requirementDeclRef.as<FuncDecl>(), synArgs, synThis); + FuncDecl* synFunc = nullptr; + GenericDecl* synGeneric = nullptr; + + if (auto genericDeclRef = requirementDeclRef.as<GenericDecl>()) + { + synGeneric = synthesizeGenericSignatureForRequirementWitness( + context, genericDeclRef, synArgs, synGenericArgs, synThis); + synFunc = as<FuncDecl>(synGeneric->inner); + } + else if (auto funcDeclRef = requirementDeclRef.as<FuncDecl>()) + { + synFunc = synthesizeMethodSignatureForRequirementWitness( + context, funcDeclRef, synArgs, synThis); + } + + SLANG_ASSERT(synFunc); addModifier(synFunc, m_astBuilder->create<BackwardDifferentiableAttribute>()); - synFunc->parentDecl = context->parentDecl; + if (synGeneric) + synGeneric->parentDecl = context->parentDecl; + else + synFunc->parentDecl = context->parentDecl; + synth.pushContainerScope(synFunc); auto blockStmt = m_astBuilder->create<BlockStmt>(); synFunc->body = blockStmt; @@ -3438,23 +3631,71 @@ namespace Slang // Construct reference exprs to the member's corresponding fields in each parameter. List<Expr*> paramFields; - int paramIndex = 0; - for (auto arg : synArgs) + List<bool> inductiveArgMask; + + switch (pattern) { - auto memberExpr = m_astBuilder->create<MemberExpr>(); - memberExpr->baseExpression = arg; - // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is - // Differential type. - memberExpr->name = varMember->getName(); - paramFields.add(memberExpr); - paramIndex++; + case SynthesisPattern::AllInductive: + { + int paramIndex = 0; + for (auto arg : synArgs) + { + auto memberExpr = m_astBuilder->create<MemberExpr>(); + memberExpr->baseExpression = arg; + // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is + // Differential type. + memberExpr->name = varMember->getName(); + paramFields.add(memberExpr); + inductiveArgMask.add(true); + + paramIndex++; + } + break; + } + case SynthesisPattern::FixedFirstArg: + { + int paramIndex = 0; + for (auto arg : synArgs) + { + if (paramIndex == 0) + { + paramFields.add(arg); + inductiveArgMask.add(false); + + paramIndex++; + } + else + { + auto memberExpr = m_astBuilder->create<MemberExpr>(); + memberExpr->baseExpression = arg; + // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is + // Differential type. + memberExpr->name = varMember->getName(); + paramFields.add(memberExpr); + inductiveArgMask.add(true); + + paramIndex++; + } + } + break; + } + default: + SLANG_UNIMPLEMENTED_X("unhandled synthesis pattern"); + break; } // Invoke the method for the field and assign the value to resultVar. // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr` // is Differential type. auto leftVal = synth.emitMemberExpr(resultVarExpr, varMember->getName()); - if (!_synthesizeMemberAssignMemberHelper(synth, requirementDeclRef.getName(), memberType, leftVal, _Move(paramFields))) + if (!_synthesizeMemberAssignMemberHelper( + synth, + requirementDeclRef.getName(), + memberType, + leftVal, + _Move(paramFields), + _Move(synGenericArgs), + _Move(inductiveArgMask))) return false; } @@ -3473,11 +3714,11 @@ namespace Slang // This can be done by obtaining ThisTypeSubstitution from requirementDeclRef to get the // generic substitution for outer generic parameters, and apply it here. SubstitutionSet substSet; - if (auto thisTypeSusbt = findThisTypeSubstitution( + if (auto thisTypeSubst = findThisTypeSubstitution( requirementDeclRef.getSubst(), as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl))) { - if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub)) + if (auto declRefType = as<DeclRefType>(thisTypeSubst->witness->sub)) { substSet = declRefType->declRef.getSubst(); } @@ -3610,7 +3851,9 @@ namespace Slang // requirement, it may be possible that we can still synthesis the // implementation if this is one of the known builtin requirements. // Otherwise, report diagnostic now. - if (!requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementModifier>()) + if (!requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementModifier>() && + !(requiredMemberDeclRef.as<GenericDecl>() && + getInner(requiredMemberDeclRef.as<GenericDecl>())->hasModifier<BuiltinRequirementModifier>())) { getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef); getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 04112743a..575d4aff7 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1467,6 +1467,13 @@ namespace Slang DeclRef<FuncDecl> requiredMemberDeclRef, List<Expr*>& synArgs, ThisExpr*& synThis); + + GenericDecl* synthesizeGenericSignatureForRequirementWitness( + ConformanceCheckingContext* context, + DeclRef<GenericDecl> requiredMemberDeclRef, + List<Expr*>& synArgs, + List<Expr*>& synGenericArgs, + ThisExpr*& synThis); void _addMethodWitness( WitnessTable* witnessTable, @@ -1503,15 +1510,39 @@ namespace Slang LookupResult const& lookupResult, DeclRef<Decl> requiredMemberDeclRef, RefPtr<WitnessTable> witnessTable); + - /// Attempt to synthesize `zero`, `dadd` and `dmul` methods for a type that conforms to + enum SynthesisPattern + { + // Synthesized method inducts over all arguments. + // T fn(T x, T y, T z, ...) + // { + // typeof(T::member0)::fn(x.member0, y.member0, z.member0, ...); + // typeof(T::member1)::fn(x.member1, y.member1, z.member1, ...); + // ... + // } + // + AllInductive, + + // Synthesized method inducts over all arguments except the first. + // T fn(U x, T y, T z) + // { + // typeof(T::member0)::fn(x, y.member0, z.member0, ...); + // typeof(T::member1)::fn(x, y.member1, z.member1, ...); + // ... + // } + FixedFirstArg + }; + + /// Attempt to synthesize `zero`, `dadd` & `dmul` methods for a type that conforms to /// `IDifferentiable`. /// On success, installs the syntethesized functions and returns `true`. /// Otherwise, returns `false`. bool trySynthesizeDifferentialMethodRequirementWitness( ConformanceCheckingContext* context, DeclRef<Decl> requirementDeclRef, - RefPtr<WitnessTable> witnessTable); + RefPtr<WitnessTable> witnessTable, + SynthesisPattern pattern); /// Attempt to synthesize an associated `Differential` type for a type that conforms to /// `IDifferentiable`. diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 3207e0729..6171d9a75 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -636,7 +636,7 @@ public: { if (auto genericInst = as<IRGeneric>(inst)) { - if (auto innerFunc = as<IRGlobalValueWithCode>(findGenericReturnVal(genericInst))) + if (auto innerFunc = as<IRGlobalValueWithCode>(findInnerMostGenericReturnVal(genericInst))) processFunc(innerFunc); } else if (auto funcInst = as<IRGlobalValueWithCode>(inst)) diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index b0af5378c..cab01d585 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -3924,7 +3924,7 @@ namespace Slang declToModify = genericDecl->inner; _addModifiers(declToModify, modifiers); - if (containerDecl) + if (containerDecl && !as<GenericDecl>(containerDecl)) { // Make sure the decl is properly nested inside its lexical parent AddMember(containerDecl, decl); |
