From ba89fc84267bfd09f1c8abf10a5b85d09bbc79de Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 26 Jul 2023 17:15:21 -0400 Subject: Refactor `dmul(This, Differential)` to `dmul(T, Differential)` (#3029) * Refactor `dmul(This, Differential)` to `dmul(T, Differential)` - Add AST synthesis support for generic containers - Refactor relevant tests * Merge dmul synthesis with dadd and dzero, and disambiguate using an enum * Fix trailing spaces --- source/slang/core.meta.slang | 44 ++-- source/slang/diff.meta.slang | 97 ++++--- source/slang/slang-ast-synthesis.cpp | 8 + source/slang/slang-ast-synthesis.h | 2 + source/slang/slang-check-decl.cpp | 295 ++++++++++++++++++++-- source/slang/slang-check-impl.h | 35 ++- source/slang/slang-ir-check-differentiability.cpp | 2 +- source/slang/slang-parser.cpp | 2 +- 8 files changed, 403 insertions(+), 82 deletions(-) (limited to 'source') diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 77b9405ba..ae70a83f4 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -112,6 +112,12 @@ interface __BuiltinSignedArithmeticType : __BuiltinArithmeticType {} interface __BuiltinIntegerType : __BuiltinArithmeticType {} +/// A type that can represent non-integers +[sealed] +[builtin] +interface __BuiltinRealType : __BuiltinSignedArithmeticType {} + + __attributeTarget(AggTypeDecl) attribute_syntax [__NonCopyableType] : NonCopyableTypeAttribute; @@ -144,7 +150,7 @@ interface IDifferentiable // Note: the compiler implementation requires the `Differential` associated type to be defined // before anything else. - __builtin_requirement($( (int) BuiltinRequirementKind::DifferentialType) ) + __builtin_requirement($( (int)BuiltinRequirementKind::DifferentialType) ) associatedtype Differential : IDifferentiable; __builtin_requirement($( (int)BuiltinRequirementKind::DZeroFunc) ) @@ -154,7 +160,8 @@ interface IDifferentiable static Differential dadd(Differential, Differential); __builtin_requirement($( (int)BuiltinRequirementKind::DMulFunc) ) - static Differential dmul(This, Differential); + __generic + static Differential dmul(T, Differential); }; @@ -219,19 +226,16 @@ struct DifferentialPair : IDifferentiable T.Differential.dadd(a.d, b.d)); } + __generic [__unsafeForceInlineEarly] - static Differential dmul(This a, Differential b) + static Differential dmul(U a, Differential b) { return Differential( - T.dmul(a.p, b.p), - T.Differential.dmul(a.d, b.d)); + T.dmul(a, b.p), + T.Differential.dmul(a, b.d)); } }; -/// A type that can represent non-integers -[sealed] -[builtin] -interface __BuiltinRealType : __BuiltinSignedArithmeticType {} /// A type that uses a floating-point representation [sealed] @@ -339,6 +343,9 @@ __generic __intrinsic_op(select) vector operator?:(vector __intrinsic_op(select) T select(bool condition, T ifTrue, T ifFalse); __generic __intrinsic_op(select) vector select(vector condition, vector ifTrue, vector ifFalse); +// Allow real-number types to be cast into each other +__intrinsic_op($(kIROp_FloatCast)) + T __realCast(U val); ${{{{ // We are going to use code generation to produce the @@ -483,12 +490,13 @@ ${{{{ { return a + b; } - + + __generic [__unsafeForceInlineEarly] [BackwardDifferentiable] - static Differential dmul(Differential a, Differential b) + static Differential dmul(U a, Differential b) { - return a * b; + return __realCast(a) * b; } ${{{{ break; @@ -1190,11 +1198,12 @@ extension vector : IDifferentiable return a + b; } + __generic [__unsafeForceInlineEarly] [BackwardDifferentiable] - static Differential dmul(This a, Differential b) + static Differential dmul(U a, Differential b) { - return a * b; + return __realCast(a) * b; } } @@ -1216,12 +1225,13 @@ extension matrix : IDifferentiable { return a + b; } - + + __generic [__unsafeForceInlineEarly] [BackwardDifferentiable] - static Differential dmul(This a, Differential b) + static Differential dmul(U a, Differential b) { - return a * b; + return __realCast(a) * b; } } diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index f2fd8e3b0..3e381e55d 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -374,12 +374,13 @@ extension Array : IDifferentiable return result; } + __generic [__unsafeForceInlineEarly] - static Differential dmul(This a, Differential b) + static Differential dmul(U a, Differential b) { Array result; for (int i = 0; i < N; i++) - result[i] = T.dmul(a[i], b[i]); + result[i] = T.dmul(a, b[i]); return result; } } @@ -543,8 +544,8 @@ DifferentialPair __d_dot(DifferentialPair> dpx, DifferentialPair for (int i = 0; i < N; ++i) { result = result + dpx.p[i] * dpy.p[i]; - d_result = T.dadd(d_result, T.dmul(dpx.p[i], __slang_noop_cast(dpy.d[i]))); - d_result = T.dadd(d_result, T.dmul(dpy.p[i], __slang_noop_cast(dpx.d[i]))); + d_result = T.dadd(d_result, __slang_noop_cast(dpx.p[i] * dpy.d[i])); + d_result = T.dadd(d_result, __slang_noop_cast(dpy.p[i] * dpx.d[i])); } return DifferentialPair(result, d_result); } @@ -797,7 +798,7 @@ void __d_cross(inout DifferentialPair> a, inout DifferentialPair \ - [BackwardDifferentiable][PreferRecompute] \ + [BackwardDifferentiable] [PreferRecompute] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair __d_##NAME(DifferentialPair dpx) \ { \ @@ -805,7 +806,7 @@ void __d_cross(inout DifferentialPair> a, inout DifferentialPair(NAME(dpx.p), FWD_DIFF_FUNC); \ } \ __generic \ - [BackwardDifferentiable][PreferRecompute] \ + [BackwardDifferentiable] [PreferRecompute] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair> __d_##NAME##_vector(DifferentialPair> dpx) \ { \ @@ -813,21 +814,21 @@ void __d_cross(inout DifferentialPair> a, inout DifferentialPair(NAME(dpx.p), FWD_DIFF_FUNC); \ } \ __generic \ - [BackwardDifferentiable][PreferRecompute] \ + [BackwardDifferentiable] [PreferRecompute] \ [ForwardDerivativeOf(NAME)] \ DifferentialPair> __d_##NAME##_m(DifferentialPair> dpm) \ { \ - typealias ReturnType = vector; \ - matrix.Differential diff; \ + typealias ReturnType = vector; \ + matrix.Differential diff; \ [ForceUnroll] for (int i = 0; i < M; i++) \ { \ var dpx = diffPair(dpm.p[i], dpm.d[i]); \ - diff[i] = FWD_DIFF_FUNC; \ + diff[i] = __slang_noop_cast>(FWD_DIFF_FUNC); \ } \ return diffPair(NAME(dpm.p), diff); \ } \ __generic \ - [BackwardDifferentiable][PreferRecompute] \ + [BackwardDifferentiable] [PreferRecompute] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME(inout DifferentialPair dpx, T.Differential dOut) \ { \ @@ -835,31 +836,57 @@ void __d_cross(inout DifferentialPair> a, inout DifferentialPair \ - [BackwardDifferentiable][PreferRecompute] \ + [BackwardDifferentiable] [PreferRecompute] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_vector( \ - inout DifferentialPair> dpx, vector.Differential dOut) \ + inout DifferentialPair> dpx, vector.Differential dOut) \ { \ typealias ReturnType = vector; \ dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \ } \ __generic \ - [BackwardDifferentiable][PreferRecompute] \ + [BackwardDifferentiable] [PreferRecompute] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_matrix( \ - inout DifferentialPair> m, matrix.Differential mdOut) \ + inout DifferentialPair> m, matrix.Differential mdOut) \ { \ typealias ReturnType = vector; \ matrix.Differential diff; \ [ForceUnroll] for (int i = 0; i < M; i++) \ { \ var dpx = diffPair(m.p[i], m.d[i]); \ - var dOut = mdOut[i]; \ + var dOut = __slang_noop_cast>(mdOut[i]); \ diff[i] = BWD_DIFF_FUNC; \ } \ m = diffPair(m.p, diff); \ } -#define SIMPLE_UNARY_DERIVATIVE_IMPL(NAME, DIFF_FUNC) UNARY_DERIVATIVE_IMPL(NAME, ReturnType.dmul(DIFF_FUNC, dpx.d), ReturnType.dmul(DIFF_FUNC, dOut)) +#define SIMPLE_UNARY_DERIVATIVE_IMPL(NAME, DIFF_FUNC) UNARY_DERIVATIVE_IMPL(NAME, __mul_p_d(DIFF_FUNC, dpx.d), __mul_p_d(DIFF_FUNC, dOut)) + +/// Element-wise multiply for scalars and vectors for (T, T.Differential) +__generic +[__unsafeForceInlineEarly] +[Differentiable] +T.Differential __mul_p_d(T a, T.Differential b) +{ + return __slang_noop_cast(a * __slang_noop_cast(b)); +} + +__generic +[__unsafeForceInlineEarly] +[Differentiable] +T __mul_p_d(T a, T b) +{ + return (a * b); +} + +__generic +[__unsafeForceInlineEarly] +[Differentiable] +vector __mul_p_d(vector a, vector b) +{ + return a * b; +} + /// Detach and set derivatives to zero. __generic @@ -871,14 +898,14 @@ T detach(T x); #define SLANG_SIGN(x) select(((x)>T(0.0)), ReturnType(T(1.0)), select(((x)==T(0.0)), ReturnType(T(0.0)), ReturnType(T(-1.0)))) // Absolute value -UNARY_DERIVATIVE_IMPL(abs, ReturnType.dmul(SLANG_SIGN(dpx.p), dpx.d), ReturnType.dmul(SLANG_SIGN(dpx.p), dOut)) +UNARY_DERIVATIVE_IMPL(abs, (__mul_p_d(SLANG_SIGN(dpx.p), (dpx.d))), (__mul_p_d(SLANG_SIGN(dpx.p), (dOut)))) // Saturate UNARY_DERIVATIVE_IMPL(saturate, select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dpx.d), select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dOut)) // frac UNARY_DERIVATIVE_IMPL(frac, dpx.d, dOut) // raidans, degrees -SIMPLE_UNARY_DERIVATIVE_IMPL(radians, T(0.01745329251994329576923690768489)) -SIMPLE_UNARY_DERIVATIVE_IMPL(degrees, T(57.295779513082320876798154814105)) +SIMPLE_UNARY_DERIVATIVE_IMPL(radians, ReturnType(T(0.01745329251994329576923690768489))) +SIMPLE_UNARY_DERIVATIVE_IMPL(degrees, ReturnType(T(57.295779513082320876798154814105))) // Exponent SIMPLE_UNARY_DERIVATIVE_IMPL(exp, exp(dpx.p)) SIMPLE_UNARY_DERIVATIVE_IMPL(exp2, exp2(dpx.p)* T(50.69314718055994530941723212145818)) @@ -915,8 +942,8 @@ __generic [ForwardDerivativeOf(atan2)] DifferentialPair __d_atan2(DifferentialPair dpy, DifferentialPair dpx) { - T.Differential dx = T.dmul(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d); - T.Differential dy = T.dmul(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d); + T.Differential dx = __mul_p_d(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d); + T.Differential dy = __mul_p_d(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d); return DifferentialPair( atan2(dpy.p, dpx.p), T.dadd(dx, dy)); @@ -928,8 +955,8 @@ __generic [BackwardDerivativeOf(atan2)] void __d_atan2(inout DifferentialPair dpy, inout DifferentialPair dpx, T.Differential dOut) { - dpx = diffPair(dpx.p, T.dmul(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d)); - dpy = diffPair(dpy.p, T.dmul(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d)); + dpx = diffPair(dpx.p, __mul_p_d(-dpy.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpx.d)); + dpy = diffPair(dpy.p, __mul_p_d(-dpx.p / (dpx.p * dpx.p + dpy.p * dpy.p), dpy.d)); } VECTOR_MATRIX_BINARY_DIFF_IMPL(atan2) @@ -968,8 +995,8 @@ DifferentialPair __d_pow(DifferentialPair dpx, DifferentialPair dpy) } T val = pow(dpx.p, dpy.p); - T.Differential d1 = T.dmul(val * log(dpx.p), dpy.d); - T.Differential d2 = T.dmul(val * dpy.p / dpx.p, dpx.d); + T.Differential d1 = __mul_p_d((val * log(dpx.p)), dpy.d); + T.Differential d2 = __mul_p_d((val * dpy.p / dpx.p), dpx.d); return DifferentialPair( val, T.dadd(d1, d2) @@ -993,10 +1020,10 @@ void __d_pow(inout DifferentialPair dpx, inout DifferentialPair dpy, T.Dif T val = pow(dpx.p, dpy.p); dpx = diffPair( dpx.p, - T.dmul(val * dpy.p / dpx.p, dOut)); + (__mul_p_d((val * dpy.p / dpx.p), dOut))); dpy = diffPair( dpy.p, - T.dmul(val * log(dpx.p), dOut)); + (__mul_p_d((val * log(dpx.p)), dOut))); } } @@ -1061,7 +1088,7 @@ DifferentialPair __d_lerp(DifferentialPair dpx, DifferentialPair dpy, D { return DifferentialPair( lerp(dpx.p, dpy.p, dps.p), - T.dadd(T.dadd(T.dmul((T(1.0) - dps.p), dpx.d), T.dmul(dps.p, dpy.d)), T.dmul(dpy.p - dpx.p, dps.d)) + T.dadd(T.dadd(__mul_p_d((T(1.0) - dps.p), dpx.d), __mul_p_d(dps.p, dpy.d)), __mul_p_d((dpy.p - dpx.p), dps.d)) ); } __generic @@ -1070,9 +1097,9 @@ __generic [BackwardDerivativeOf(lerp)] void __d_lerp(inout DifferentialPair dpx, inout DifferentialPair dpy, inout DifferentialPair dps, T.Differential dOut) { - dpx = diffPair(dpx.p, T.dmul(T(1.0) - dps.p, dOut)); - dpy = diffPair(dpy.p, T.dmul(dps.p, dOut)); - dps = diffPair(dpy.p, T.dmul((dpy.p - dpx.p), dOut)); + dpx = diffPair(dpx.p, __mul_p_d((T(1.0) - dps.p), dOut)); + dpy = diffPair(dpy.p, __mul_p_d(dps.p, dOut)); + dps = diffPair(dpy.p, __mul_p_d((dpy.p - dpx.p), dOut)); } VECTOR_MATRIX_TERNARY_DIFF_IMPL(lerp) @@ -1175,7 +1202,7 @@ DifferentialPair __d_mad(DifferentialPair dpx, DifferentialPair dpy, Di { return DifferentialPair( mad(dpx.p, dpy.p, dpz.p), - T.dadd(T.dadd(T.dmul(dpy.p, dpx.d), T.dmul(dpx.p, dpy.d)), dpz.d)); + T.dadd(T.dadd(__mul_p_d(dpy.p, dpx.d), __mul_p_d(dpx.p, dpy.d)), dpz.d)); } __generic [BackwardDifferentiable] @@ -1183,8 +1210,8 @@ __generic [PreferRecompute] void __d_mad(inout DifferentialPair dpx, inout DifferentialPair dpy, inout DifferentialPair dpz, T.Differential dOut) { - dpx = diffPair(dpx.p, T.dmul(dpy.p, dOut)); - dpy = diffPair(dpy.p, T.dmul(dpx.p, dOut)); + dpx = diffPair(dpx.p, __mul_p_d(dpy.p, dOut)); + dpy = diffPair(dpy.p, __mul_p_d(dpx.p, dOut)); dpz = diffPair(dpz.p, dOut); } VECTOR_MATRIX_TERNARY_DIFF_IMPL(mad) diff --git a/source/slang/slang-ast-synthesis.cpp b/source/slang/slang-ast-synthesis.cpp index 872088d54..65955e815 100644 --- a/source/slang/slang-ast-synthesis.cpp +++ b/source/slang/slang-ast-synthesis.cpp @@ -122,6 +122,14 @@ Expr* ASTSynthesizer::emitInvokeExpr(Expr* callee, List&& args) return rs; } +Expr* ASTSynthesizer::emitGenericAppExpr(Expr* genericExpr, List&& args) +{ + auto rs = m_builder->create(); + rs->functionExpr = genericExpr; + rs->arguments = _Move(args); + return rs; +} + Expr* ASTSynthesizer::emitMemberExpr(Type* type, Name* name) { auto rs = m_builder->create(); diff --git a/source/slang/slang-ast-synthesis.h b/source/slang/slang-ast-synthesis.h index 6568b4c83..e595afac1 100644 --- a/source/slang/slang-ast-synthesis.h +++ b/source/slang/slang-ast-synthesis.h @@ -138,6 +138,8 @@ public: Expr* emitInvokeExpr(Expr* callee, List&& args); + Expr* emitGenericAppExpr(Expr* genericExpr, List&& args); + DeclStmt* emitVarDeclStmt(Type* type, Name* name = nullptr, Expr* initVal = nullptr); ExpressionStmt* emitExprStmt(Expr* expr); diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 4c1e967e3..2d009c28c 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2574,6 +2574,135 @@ namespace Slang return false; } + GenericDecl* SemanticsVisitor::synthesizeGenericSignatureForRequirementWitness( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + List& synArgs, + List& synGenericArgs, + ThisExpr*& synThis) + { + auto synGenericDecl = m_astBuilder->create(); + + // For now our synthesized method will use the name and source + // location of the requirement we are trying to satisfy. + // + // TODO: as it stands right now our syntesized method will + // get a mangled name, which we don't actually want. Leaving + // out the name here doesn't help matters, because then *all* + // snthesized methods on a given type would share the same + // mangled name! + // + synGenericDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc; + if (synGenericDecl->nameAndLoc.name) + { + synGenericDecl->nameAndLoc.name = getSession()->getNameObj("$__syn_" + synGenericDecl->nameAndLoc.name->text); + } + + // Dictionary to map from the original type parameters to the synthesized ones. + Dictionary mapOrigToSynTypeParams; + + // Our synthesized method will have parameters matching the names + // and types of those on the requirement, and it will use expressions + // that reference those parametesr as arguments for the call expresison + // that makes up the body. + // + for (auto member : requiredMemberDeclRef.getDecl()->members) + { + if (auto typeParamDecl = as(member)) + { + auto synTypeParamDecl = m_astBuilder->create(); + synTypeParamDecl->nameAndLoc = typeParamDecl->getNameAndLoc(); + synTypeParamDecl->initType = typeParamDecl->initType; + synTypeParamDecl->parentDecl = synGenericDecl; + synGenericDecl->members.add(synTypeParamDecl); + + mapOrigToSynTypeParams.add(typeParamDecl, synTypeParamDecl); + + // Construct a DeclRefExpr from the type parameter. + auto synTypeParamDeclRef = makeDeclRef(synTypeParamDecl); + + auto synTypeParamDeclRefExpr = m_astBuilder->create(); + synTypeParamDeclRefExpr->declRef = synTypeParamDeclRef; + synTypeParamDeclRefExpr->type = getTypeForDeclRef(m_astBuilder, synTypeParamDeclRef, SourceLoc()); + + synGenericArgs.add(synTypeParamDeclRefExpr); + } + } + + for (auto member : requiredMemberDeclRef.getDecl()->members) + { + if (auto constraintDecl = as(member)) + { + getASTBuilder()->getSpecializedDeclRef( + constraintDecl, requiredMemberDeclRef.getSubst()); + + auto synConstraintDecl = m_astBuilder->create(); + synConstraintDecl->nameAndLoc = constraintDecl->getNameAndLoc(); + synConstraintDecl->parentDecl = synGenericDecl; + + // For constraints of type T : Interface, where T is a simple type parameter, + // find the declaration of T + // + if (auto typeParamDecl = as(constraintDecl->sub.type)->declRef.as().getDecl()) + { + auto synTypeParamDecl = mapOrigToSynTypeParams[typeParamDecl]; + + // Construct a DeclRefExpr from the type parameter. + auto synTypeParamDeclRef = makeDeclRef(synTypeParamDecl.getValue()); + + auto synTypeParamDeclRefExpr = m_astBuilder->create(); + synTypeParamDeclRefExpr->declRef = synTypeParamDeclRef; + synTypeParamDeclRefExpr->type = getTypeForDeclRef(m_astBuilder, synTypeParamDeclRef, SourceLoc()); + + synConstraintDecl->sub = TypeExp(synTypeParamDeclRefExpr); + synConstraintDecl->sup = constraintDecl->sup; + synGenericDecl->members.add(synConstraintDecl); + } + else + { + SLANG_UNEXPECTED("Cannot perform synthesis for requirements with complex type constraints."); + } + } + } + + // Get outer substitutions. (This inner-most substition + // must be a ThisTypeSubstition) + // + Substitutions* outer = nullptr; + if (auto thisTypeSubst = findThisTypeSubstitution( + requiredMemberDeclRef.getSubst(), + as(requiredMemberDeclRef.getParent(m_astBuilder)).getDecl())) + { + outer = thisTypeSubst; + } + + // Override generic pointer to point to the original generic container. + // This will create a substitution of the synthesized parameters for the + // original parameters. + // + GenericSubstitution* requiredFuncSubsts = createDefaultSubstitutionsForGeneric(m_astBuilder, this, requiredMemberDeclRef.getDecl(), outer); + DeclRef requiredFuncDeclRef = m_astBuilder->getSpecializedDeclRef(requiredMemberDeclRef.getDecl()->inner, requiredFuncSubsts); + + GenericSubstitution* substSynParamsForOrigGeneric = m_astBuilder->getOrCreateGenericSubstitution( + outer, + requiredMemberDeclRef.getDecl(), + createDefaultSubstitutionsForGeneric(m_astBuilder, this, synGenericDecl, nullptr)->getArgs()); + + // Substitute parameters of the synthesized generic for the parameters of the original generic. + requiredFuncDeclRef = substituteDeclRef(substSynParamsForOrigGeneric, m_astBuilder, requiredFuncDeclRef); + + SLANG_ASSERT(requiredFuncDeclRef.as()); + + synGenericDecl->inner = synthesizeMethodSignatureForRequirementWitness( + context, + requiredFuncDeclRef.as(), + synArgs, + synThis); + synGenericDecl->inner->parentDecl = synGenericDecl; + + return synGenericDecl; + } + FuncDecl* SemanticsVisitor::synthesizeMethodSignatureForRequirementWitness( ConformanceCheckingContext* context, DeclRef requiredMemberDeclRef, @@ -3274,12 +3403,30 @@ namespace Slang switch (builtinAttr->kind) { case BuiltinRequirementKind::DAddFunc: - case BuiltinRequirementKind::DMulFunc: case BuiltinRequirementKind::DZeroFunc: return trySynthesizeDifferentialMethodRequirementWitness( context, requiredFuncDeclRef, - witnessTable); + witnessTable, + SynthesisPattern::AllInductive); + } + } + return false; + } + + // For generic decl, check if we match DMulFunc, and synthesize the method. + if (auto requiredGenericDeclRef = requiredMemberDeclRef.as()) + { + if (auto builtinAttr = getInner(requiredGenericDeclRef)->findModifier()) + { + switch (builtinAttr->kind) + { + case BuiltinRequirementKind::DMulFunc: + return trySynthesizeDifferentialMethodRequirementWitness( + context, + requiredGenericDeclRef, + witnessTable, + SynthesisPattern::FixedFirstArg); } } return false; @@ -3330,7 +3477,15 @@ namespace Slang return false; } - Stmt* _synthesizeMemberAssignMemberHelper(ASTSynthesizer& synth, Name* funcName, Type* leftType, Expr* leftValue, List&& args, int nestingLevel = 0) + Stmt* _synthesizeMemberAssignMemberHelper( + ASTSynthesizer& synth, + Name* funcName, + Type* leftType, + Expr* leftValue, + List&& args, + List&& genericArgs, + List&& inductiveArgMask, + int nestingLevel = 0) { if (nestingLevel > 16) return nullptr; @@ -3342,11 +3497,24 @@ namespace Slang auto forStmt = synth.emitFor(synth.emitIntConst(0), synth.emitGetArrayLengthExpr(leftValue), indexVar); addModifier(forStmt, synth.getBuilder()->create()); auto innerLeft = synth.emitIndexExpr(leftValue, synth.emitVarExpr(indexVar)); - for (auto& arg : args) + + for (auto ii = 0; ii < args.getCount(); ++ii) { - arg = synth.emitIndexExpr(arg, synth.emitVarExpr(indexVar)); + auto& arg = args[ii]; + if (inductiveArgMask[ii]) + arg = synth.emitIndexExpr(arg, synth.emitVarExpr(indexVar)); } - auto assignStmt = _synthesizeMemberAssignMemberHelper(synth, funcName, arrayType->getElementType(), innerLeft, _Move(args), nestingLevel + 1); + + auto assignStmt = _synthesizeMemberAssignMemberHelper( + synth, + funcName, + arrayType->getElementType(), + innerLeft, + _Move(args), + _Move(genericArgs), + _Move(inductiveArgMask), + nestingLevel + 1); + synth.popScope(); if (!assignStmt) return nullptr; @@ -3354,13 +3522,18 @@ namespace Slang } auto callee = synth.emitMemberExpr(leftType, funcName); + + if (genericArgs.getCount() > 0) + callee = synth.emitGenericAppExpr(callee, _Move(genericArgs)); + return synth.emitAssignStmt(leftValue, synth.emitInvokeExpr(callee, _Move(args))); } bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness( ConformanceCheckingContext* context, DeclRef requirementDeclRef, - RefPtr witnessTable) + RefPtr witnessTable, + SynthesisPattern pattern) { // We support two cases of synthesis here. // Case 1 is that there the associated Differential type is defined to be `DifferentialBottom`. @@ -3371,9 +3544,10 @@ namespace Slang // ``` // static TResult requiredMethod(TParam1 p0, TParam2 p1, ...) // ``` - // Where TResult, TParam1, TParam2 is either `This` or `Differential`, - // We synthesize a memberwise dispatch to compute each field of `TResult`, - // resulting an implementation of the form: + // Where TResult,TParam1, TParam2 is either `This` or `Differential`, + // We synthesize a memberwise dispatch to compute each field of `TResult`. + // Multiple patterns are supported (see SemanticsVisitor::SynthesisPattern for a full list) + // For AllInductive, we synthesize an implementation of the form: // ``` // [BackwardDifferentiable] // static TResult requiredMethod(TParam1 p0, TParam2 p1, ...) @@ -3404,13 +3578,32 @@ namespace Slang ASTSynthesizer synth(m_astBuilder, getNamePool()); List synArgs; + List synGenericArgs; ThisExpr* synThis = nullptr; - auto synFunc = synthesizeMethodSignatureForRequirementWitness( - context, requirementDeclRef.as(), synArgs, synThis); + FuncDecl* synFunc = nullptr; + GenericDecl* synGeneric = nullptr; + + if (auto genericDeclRef = requirementDeclRef.as()) + { + synGeneric = synthesizeGenericSignatureForRequirementWitness( + context, genericDeclRef, synArgs, synGenericArgs, synThis); + synFunc = as(synGeneric->inner); + } + else if (auto funcDeclRef = requirementDeclRef.as()) + { + synFunc = synthesizeMethodSignatureForRequirementWitness( + context, funcDeclRef, synArgs, synThis); + } + + SLANG_ASSERT(synFunc); addModifier(synFunc, m_astBuilder->create()); - synFunc->parentDecl = context->parentDecl; + if (synGeneric) + synGeneric->parentDecl = context->parentDecl; + else + synFunc->parentDecl = context->parentDecl; + synth.pushContainerScope(synFunc); auto blockStmt = m_astBuilder->create(); synFunc->body = blockStmt; @@ -3438,23 +3631,71 @@ namespace Slang // Construct reference exprs to the member's corresponding fields in each parameter. List paramFields; - int paramIndex = 0; - for (auto arg : synArgs) + List inductiveArgMask; + + switch (pattern) { - auto memberExpr = m_astBuilder->create(); - memberExpr->baseExpression = arg; - // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is - // Differential type. - memberExpr->name = varMember->getName(); - paramFields.add(memberExpr); - paramIndex++; + case SynthesisPattern::AllInductive: + { + int paramIndex = 0; + for (auto arg : synArgs) + { + auto memberExpr = m_astBuilder->create(); + memberExpr->baseExpression = arg; + // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is + // Differential type. + memberExpr->name = varMember->getName(); + paramFields.add(memberExpr); + inductiveArgMask.add(true); + + paramIndex++; + } + break; + } + case SynthesisPattern::FixedFirstArg: + { + int paramIndex = 0; + for (auto arg : synArgs) + { + if (paramIndex == 0) + { + paramFields.add(arg); + inductiveArgMask.add(false); + + paramIndex++; + } + else + { + auto memberExpr = m_astBuilder->create(); + memberExpr->baseExpression = arg; + // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is + // Differential type. + memberExpr->name = varMember->getName(); + paramFields.add(memberExpr); + inductiveArgMask.add(true); + + paramIndex++; + } + } + break; + } + default: + SLANG_UNIMPLEMENTED_X("unhandled synthesis pattern"); + break; } // Invoke the method for the field and assign the value to resultVar. // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr` // is Differential type. auto leftVal = synth.emitMemberExpr(resultVarExpr, varMember->getName()); - if (!_synthesizeMemberAssignMemberHelper(synth, requirementDeclRef.getName(), memberType, leftVal, _Move(paramFields))) + if (!_synthesizeMemberAssignMemberHelper( + synth, + requirementDeclRef.getName(), + memberType, + leftVal, + _Move(paramFields), + _Move(synGenericArgs), + _Move(inductiveArgMask))) return false; } @@ -3473,11 +3714,11 @@ namespace Slang // This can be done by obtaining ThisTypeSubstitution from requirementDeclRef to get the // generic substitution for outer generic parameters, and apply it here. SubstitutionSet substSet; - if (auto thisTypeSusbt = findThisTypeSubstitution( + if (auto thisTypeSubst = findThisTypeSubstitution( requirementDeclRef.getSubst(), as(requirementDeclRef.getDecl()->parentDecl))) { - if (auto declRefType = as(thisTypeSusbt->witness->sub)) + if (auto declRefType = as(thisTypeSubst->witness->sub)) { substSet = declRefType->declRef.getSubst(); } @@ -3610,7 +3851,9 @@ namespace Slang // requirement, it may be possible that we can still synthesis the // implementation if this is one of the known builtin requirements. // Otherwise, report diagnostic now. - if (!requiredMemberDeclRef.getDecl()->hasModifier()) + if (!requiredMemberDeclRef.getDecl()->hasModifier() && + !(requiredMemberDeclRef.as() && + getInner(requiredMemberDeclRef.as())->hasModifier())) { getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef); getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 04112743a..575d4aff7 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1467,6 +1467,13 @@ namespace Slang DeclRef requiredMemberDeclRef, List& synArgs, ThisExpr*& synThis); + + GenericDecl* synthesizeGenericSignatureForRequirementWitness( + ConformanceCheckingContext* context, + DeclRef requiredMemberDeclRef, + List& synArgs, + List& synGenericArgs, + ThisExpr*& synThis); void _addMethodWitness( WitnessTable* witnessTable, @@ -1503,15 +1510,39 @@ namespace Slang LookupResult const& lookupResult, DeclRef requiredMemberDeclRef, RefPtr witnessTable); + - /// Attempt to synthesize `zero`, `dadd` and `dmul` methods for a type that conforms to + enum SynthesisPattern + { + // Synthesized method inducts over all arguments. + // T fn(T x, T y, T z, ...) + // { + // typeof(T::member0)::fn(x.member0, y.member0, z.member0, ...); + // typeof(T::member1)::fn(x.member1, y.member1, z.member1, ...); + // ... + // } + // + AllInductive, + + // Synthesized method inducts over all arguments except the first. + // T fn(U x, T y, T z) + // { + // typeof(T::member0)::fn(x, y.member0, z.member0, ...); + // typeof(T::member1)::fn(x, y.member1, z.member1, ...); + // ... + // } + FixedFirstArg + }; + + /// Attempt to synthesize `zero`, `dadd` & `dmul` methods for a type that conforms to /// `IDifferentiable`. /// On success, installs the syntethesized functions and returns `true`. /// Otherwise, returns `false`. bool trySynthesizeDifferentialMethodRequirementWitness( ConformanceCheckingContext* context, DeclRef requirementDeclRef, - RefPtr witnessTable); + RefPtr witnessTable, + SynthesisPattern pattern); /// Attempt to synthesize an associated `Differential` type for a type that conforms to /// `IDifferentiable`. diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 3207e0729..6171d9a75 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -636,7 +636,7 @@ public: { if (auto genericInst = as(inst)) { - if (auto innerFunc = as(findGenericReturnVal(genericInst))) + if (auto innerFunc = as(findInnerMostGenericReturnVal(genericInst))) processFunc(innerFunc); } else if (auto funcInst = as(inst)) diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index b0af5378c..cab01d585 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -3924,7 +3924,7 @@ namespace Slang declToModify = genericDecl->inner; _addModifiers(declToModify, modifiers); - if (containerDecl) + if (containerDecl && !as(containerDecl)) { // Make sure the decl is properly nested inside its lexical parent AddMember(containerDecl, decl); -- cgit v1.2.3