diff options
| author | Yong He <yonghe@outlook.com> | 2022-10-27 11:06:14 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-27 11:06:14 -0700 |
| commit | 8e11063bfcec528e70f5e80e5db9fca7d4016737 (patch) | |
| tree | 60a4a492e732479bb1f72e76675b2b494c140b6b | |
| parent | f7f0dcadd3b2aca4c0bcd03a96e11c617cf69fc2 (diff) | |
Auto synthesis of IDifferntial interface methods. (#2469)
* Auto synthesis of IDifferntial interface methods.
* Add comments.
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/core.meta.slang | 9 | ||||
| -rw-r--r-- | source/slang/diff.meta.slang | 7 | ||||
| -rw-r--r-- | source/slang/slang-ast-dump.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-ast-support-types.h | 9 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 302 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 41 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 19 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 4 | ||||
| -rw-r--r-- | tests/autodiff/differential-method-synthesis.slang | 45 | ||||
| -rw-r--r-- | tests/autodiff/differential-method-synthesis.slang.expected.txt | 6 |
11 files changed, 345 insertions, 105 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index a25ce03bd..75bc65562 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2742,12 +2742,15 @@ attribute_syntax [Differentiable] : DifferentiableAttribute; __attributeTarget(DeclBase) attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; -enum _BuiltinAssociatedTypeRequirementKind +enum _BuiltinRequirementKind { - Differential = $( (int) BuiltinAssociatedTypeRequirementKind::Differential), + DifferentialType = $( (int) BuiltinRequirementKind::DifferentialType), + DZeroFunc = $( (int) BuiltinRequirementKind::DZeroFunc), + DAddFunc = $( (int) BuiltinRequirementKind::DAddFunc), + DMulFunc = $( (int) BuiltinRequirementKind::DMulFunc), }; __attributeTarget(DeclBase) -attribute_syntax [__BuiltinAssociatedTypeRequirementAttribute(kind: _BuiltinAssociatedTypeRequirementKind)] : BuiltinAssociatedTypeRequirementAttribute; +attribute_syntax [__BuiltinRequirement(kind: _BuiltinRequirementKind)] : BuiltinRequirementAttribute; __attributeTarget(DeclBase) attribute_syntax [builtin] : BuiltinAttribute; diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index ea204c839..38d7270e4 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -21,13 +21,16 @@ interface IDifferentiable // Note: the compiler implementation requires the `Differential` associated type to be defined // before anything else. - [__BuiltinAssociatedTypeRequirementAttribute(_BuiltinAssociatedTypeRequirementKind.Differential)] + [__BuiltinRequirement(_BuiltinRequirementKind.DifferentialType)] associatedtype Differential; - + + [__BuiltinRequirement(_BuiltinRequirementKind.DZeroFunc)] static Differential zero(); + [__BuiltinRequirement(_BuiltinRequirementKind.DAddFunc)] static Differential dadd(Differential, Differential); + [__BuiltinRequirement(_BuiltinRequirementKind.DMulFunc)] static Differential dmul(This, Differential); }; diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp index 32f9dd16f..455a9db74 100644 --- a/source/slang/slang-ast-dump.cpp +++ b/source/slang/slang-ast-dump.cpp @@ -345,7 +345,7 @@ struct ASTDumpContext { m_writer->emit(getTryClauseTypeName(clauseType)); } - void dump(BuiltinAssociatedTypeRequirementKind kind) + void dump(BuiltinRequirementKind kind) { m_writer->emit((int)kind); } diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index c439c7437..6220fcb95 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -591,11 +591,11 @@ class Attribute : public AttributeBase }; // A modifier that indicates a built-in associated type requirement (e.g., `Differential`) -class BuiltinAssociatedTypeRequirementAttribute : public Attribute +class BuiltinRequirementAttribute : public Attribute { - SLANG_AST_CLASS(BuiltinAssociatedTypeRequirementAttribute); + SLANG_AST_CLASS(BuiltinRequirementAttribute); - BuiltinAssociatedTypeRequirementKind kind; + BuiltinRequirementKind kind; }; class UserDefinedAttribute : public Attribute diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 9a32d816c..61580ca9e 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -1490,10 +1490,13 @@ namespace Slang kParameterDirection_Ref, ///< By-reference }; - /// The type of a builtin associated type requirement. - enum class BuiltinAssociatedTypeRequirementKind + /// The kind of a builtin interface requirement that can be automatically synthesized. + enum class BuiltinRequirementKind { - Differential + DifferentialType, ///< The `IDifferentiable.Differential` associated type requirement + DZeroFunc, ///< The `IDifferentiable.dzero` function requirement + DAddFunc, ///< The `IDifferentiable.dadd` function requirement + DMulFunc, ///< The `IDifferentiable.dmul` function requirement }; } // namespace Slang diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index fa05dde11..f28f46deb 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -81,8 +81,6 @@ namespace Slang void checkCallableDeclCommon(CallableDecl* decl); - void maybeCheckDifferentiableAccessorSignature(FuncDecl* funcDecl); - void visitFuncDecl(FuncDecl* funcDecl); void visitParamDecl(ParamDecl* paramDecl); @@ -889,6 +887,7 @@ namespace Slang auto diffTypeDict = subDiffTypeContext->makeDifferentiableTypeDictionaryNode(visitor->getASTBuilder()); diffTypeDict->parentDecl = containerDecl; containerDecl->members.add(diffTypeDict); + containerDecl->invalidateMemberDictionary(); } } } @@ -1072,8 +1071,11 @@ namespace Slang auto initExpr = varDecl->initExpr; if(!initExpr) { - getSink()->diagnose(varDecl, Diagnostics::varWithoutTypeMustHaveInitializer); - varDecl->type.type = m_astBuilder->getErrorType(); + if (!varDecl->type.type) + { + getSink()->diagnose(varDecl, Diagnostics::varWithoutTypeMustHaveInitializer); + varDecl->type.type = m_astBuilder->getErrorType(); + } } else { @@ -1390,7 +1392,7 @@ namespace Slang context->parentDecl->members.add((aggTypeDecl)); aggTypeDecl->nameAndLoc.name = requirementDeclRef.getName(); aggTypeDecl->loc = context->parentDecl->nameAndLoc.loc; - context->parentDecl->getMemberDictionary().Add(aggTypeDecl->getName(), aggTypeDecl); + context->parentDecl->invalidateMemberDictionary(); } // TODO: if we want to make the synthesized type itself to be differentiable, @@ -1409,6 +1411,7 @@ namespace Slang diffField->checkState = DeclCheckState::SignatureChecked; diffField->parentDecl = aggTypeDecl; aggTypeDecl->members.add(diffField); + aggTypeDecl->invalidateMemberDictionary(); // Inject a `DerivativeMember` modifier on the original decl. auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>(); @@ -2412,54 +2415,12 @@ namespace Slang return false; } - bool SemanticsVisitor::trySynthesizeMethodRequirementWitness( + FuncDecl* SemanticsVisitor::synthesizeMethodSignatureForRequirementWitness( ConformanceCheckingContext* context, - LookupResult const& lookupResult, - DeclRef<FuncDecl> requiredMemberDeclRef, - RefPtr<WitnessTable> witnessTable) + DeclRef<FuncDecl> requiredMemberDeclRef, + List<Expr*>& synArgs, + ThisExpr*& synThis) { - // The situation here is that the context of an inheritance - // declaration didn't provide an exact match for a required - // method. E.g.: - // - // interface ICounter { [mutating] int increment(); } - // struct MyCounter : ICounter - // { - // [murtating] int increment(int val = 1) { ... } - // } - // - // It is clear in this case that the `MyCounter` type *can* - // satisfy the signature required by `ICounter`, but it has - // no explicit method declaration that is a perfect match. - // - // The approach in this function will be to construct a - // synthesized method along the lines of: - // - // struct MyCounter ... - // { - // ... - // [murtating] int synthesized() - // { - // return this.increment(); - // } - // } - // - // That is, we construct a method with the exact signature - // of the requirement (same parameter and result types), - // and then provide it with a body that simple `return`s - // the result of applying the desired requirement name - // (`increment` in this case) to those parameters. - // - // If the synthesized method type-checks, then we can say - // that the type must satisfy the requirement structurally, - // even if there isn't an exact signature match. More - // importantly, the method we just synthesized can be - // used as a witness to the fact that the requirement is - // satisfied. - - // With the big picture spelled out, we can settle into - // the work of constructing our synthesized method. - // auto synFuncDecl = m_astBuilder->create<FuncDecl>(); // For now our synthesized method will use the name and source @@ -2497,8 +2458,7 @@ namespace Slang // that reference those parametesr as arguments for the call expresison // that makes up the body. // - List<Expr*> synArgs; - for( auto paramDeclRef : getParameters(requiredMemberDeclRef) ) + for (auto paramDeclRef : getParameters(requiredMemberDeclRef)) { auto paramType = getType(m_astBuilder, paramDeclRef); @@ -2524,13 +2484,13 @@ namespace Slang synArgs.add(synArg); } + // Required interface methods can be `static` or non-`static`, // and non-`static` methods can be `[mutating]` or non-`[mutating]`. // All of these details affect how we introduce our `this` parameter, // if any. // - ThisExpr* synThis = nullptr; - if( requiredMemberDeclRef.getDecl()->hasModifier<HLSLStaticModifier>() ) + if (requiredMemberDeclRef.getDecl()->hasModifier<HLSLStaticModifier>()) { auto synStaticModifier = m_astBuilder->create<HLSLStaticModifier>(); synFuncDecl->modifiers.first = synStaticModifier; @@ -2546,7 +2506,7 @@ namespace Slang // synThis->type.type = context->conformingType; - if( requiredMemberDeclRef.getDecl()->hasModifier<MutatingAttribute>() ) + if (requiredMemberDeclRef.getDecl()->hasModifier<MutatingAttribute>()) { // If the interface requirement is `[mutating]` then our // synthesized method should be too, and also the `this` @@ -2559,6 +2519,64 @@ namespace Slang } } + return synFuncDecl; + } + + bool SemanticsVisitor::trySynthesizeMethodRequirementWitness( + ConformanceCheckingContext* context, + LookupResult const& lookupResult, + DeclRef<FuncDecl> requiredMemberDeclRef, + RefPtr<WitnessTable> witnessTable) + { + // The situation here is that the context of an inheritance + // declaration didn't provide an exact match for a required + // method. E.g.: + // + // interface ICounter { [mutating] int increment(); } + // struct MyCounter : ICounter + // { + // [murtating] int increment(int val = 1) { ... } + // } + // + // It is clear in this case that the `MyCounter` type *can* + // satisfy the signature required by `ICounter`, but it has + // no explicit method declaration that is a perfect match. + // + // The approach in this function will be to construct a + // synthesized method along the lines of: + // + // struct MyCounter ... + // { + // ... + // [murtating] int synthesized() + // { + // return this.increment(); + // } + // } + // + // That is, we construct a method with the exact signature + // of the requirement (same parameter and result types), + // and then provide it with a body that simple `return`s + // the result of applying the desired requirement name + // (`increment` in this case) to those parameters. + // + // If the synthesized method type-checks, then we can say + // that the type must satisfy the requirement structurally, + // even if there isn't an exact signature match. More + // importantly, the method we just synthesized can be + // used as a witness to the fact that the requirement is + // satisfied. + + // With the big picture spelled out, we can settle into + // the work of constructing our synthesized method. + // + ThisExpr* synThis = nullptr; + List<Expr*> synArgs; + auto synFuncDecl = synthesizeMethodSignatureForRequirementWitness( + context, requiredMemberDeclRef, synArgs, synThis); + + auto resultType = synFuncDecl->returnType.type; + // The body of our synthesized method is going to try to // make a call using the name of the method requirement (e.g., // the name `increment` in our example at the top of this function). @@ -3036,11 +3054,27 @@ namespace Slang if (auto requiredFuncDeclRef = requiredMemberDeclRef.as<FuncDecl>()) { // Check signature match. - return trySynthesizeMethodRequirementWitness( + if (trySynthesizeMethodRequirementWitness( context, lookupResult, requiredFuncDeclRef, - witnessTable); + witnessTable)) + return true; + + if (auto builtinAttr = requiredFuncDeclRef.getDecl()->findModifier<BuiltinRequirementAttribute>()) + { + switch (builtinAttr->kind) + { + case BuiltinRequirementKind::DAddFunc: + case BuiltinRequirementKind::DMulFunc: + case BuiltinRequirementKind::DZeroFunc: + return trySynthesizeDifferentialMethodRequirementWitness( + context, + requiredFuncDeclRef, + witnessTable); + } + } + return false; } if( auto requiredPropertyDeclRef = requiredMemberDeclRef.as<PropertyDecl>() ) @@ -3054,11 +3088,11 @@ namespace Slang if (auto requiredAssocTypeDeclRef = requiredMemberDeclRef.as<AssocTypeDecl>()) { - if (auto builtinAttr = requiredAssocTypeDeclRef.getDecl()->findModifier<BuiltinAssociatedTypeRequirementAttribute>()) + if (auto builtinAttr = requiredAssocTypeDeclRef.getDecl()->findModifier<BuiltinRequirementAttribute>()) { switch (builtinAttr->kind) { - case BuiltinAssociatedTypeRequirementKind::Differential: + case BuiltinRequirementKind::DifferentialType: return trySynthesizeDifferentialAssociatedTypeRequirementWitness( context, requiredAssocTypeDeclRef, @@ -3088,6 +3122,130 @@ namespace Slang return false; } + bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness( + ConformanceCheckingContext* context, + DeclRef<Decl> requirementDeclRef, + RefPtr<WitnessTable> witnessTable) + { + // This method implements a general code synthesis pattern. + // For requirement of the form: + // ``` + // static TResult requiredMethod(TParam1 p0, TParam2 p1, ...) + // ``` + // Where TResult, TParam1, TParam2 is either `This` or `Differential`, + // We synthesize a memberwise dispatch to compute each field of `TResult`, + // resulting an implementation of the form: + // ``` + // static TResult requiredMethod(TParam1 p0, TParam2 p1, ...) + // { + // TResult result; + // result.member0 = decltype(result.member0).requiredMethod(p0.member0, p1.member0); + // result.member1 = decltype(result.member1).requiredMethod(p0.member1, p1.member1); + // ... + // return result; + // } + // ``` + List<Expr*> synArgs; + ThisExpr* synThis = nullptr; + auto synFunc = synthesizeMethodSignatureForRequirementWitness( + context, requirementDeclRef.as<FuncDecl>(), synArgs, synThis); + + auto blockStmt = m_astBuilder->create<BlockStmt>(); + synFunc->body = blockStmt; + auto seqStmt = m_astBuilder->create<SeqStmt>(); + blockStmt->body = seqStmt; + + // Create a variable for return value. + auto scopeDecl = m_astBuilder->create<ScopeDecl>(); + synFunc->members.add(scopeDecl); + scopeDecl->parentDecl = synFunc; + auto varStmt = m_astBuilder->create<DeclStmt>(); + seqStmt->stmts.add(varStmt); + + auto returnVar = m_astBuilder->create<VarDecl>(); + returnVar->parentDecl = scopeDecl; + scopeDecl->members.add(returnVar); + + returnVar->type.type = synFunc->returnType.type; + returnVar->nameAndLoc.name = getName("result"); + varStmt->decl = returnVar; + auto resultVarExpr = m_astBuilder->create<VarExpr>(); + resultVarExpr->declRef = makeDeclRef(returnVar); + resultVarExpr->type.type = synFunc->returnType.type; + resultVarExpr->type.isLeftValue = true; + + for (auto member : context->parentDecl->members) + { + auto derivativeAttr = member->findModifier<DerivativeMemberAttribute>(); + if (!derivativeAttr) + continue; + auto varMember = as<VarDeclBase>(member); + if (!varMember) + continue; + ensureDecl(varMember, DeclCheckState::ReadyForReference); + auto memberType = varMember->getType(); + auto diffMemberType = tryGetDifferentialType(m_astBuilder, memberType); + if (!diffMemberType) + continue; + + // Construct reference exprs to the member's corresponding fields in each parameter. + List<Expr*> paramFields; + int paramIndex = 0; + for (auto arg : synArgs) + { + auto memberExpr = m_astBuilder->create<MemberExpr>(); + memberExpr->baseExpression = arg; + // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is + // Differential type. + memberExpr->name = varMember->getName(); + paramFields.add(memberExpr); + paramIndex++; + } + + // Invoke the method for the field. + auto callee = m_astBuilder->create<StaticMemberExpr>(); + auto baseSharedType = m_astBuilder->create<SharedTypeExpr>(); + auto baseSharedTypeType = m_astBuilder->create<TypeType>(); + baseSharedTypeType->type = memberType; + baseSharedType->type = baseSharedTypeType; + baseSharedType->base.type = memberType; + callee->baseExpression = baseSharedType; + callee->name = requirementDeclRef.getName(); + callee->loc = synFunc->loc; + auto invokeExpr = m_astBuilder->create<InvokeExpr>(); + invokeExpr->functionExpr = callee; + invokeExpr->arguments = _Move(paramFields); + + // Assign the value to resultVar. + auto leftVal = m_astBuilder->create<MemberExpr>(); + leftVal->baseExpression = resultVarExpr; + // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr` + // is Differential type. + leftVal->name = varMember->getName(); + + auto assignExpr = m_astBuilder->create<AssignExpr>(); + assignExpr->left = leftVal; + assignExpr->right = invokeExpr; + auto assignStmt = m_astBuilder->create<ExpressionStmt>(); + assignStmt->expression = assignExpr; + seqStmt->stmts.add(assignStmt); + } + + // TODO: synthesize assignments for inherited members here. + + auto synReturn = m_astBuilder->create<ReturnStmt>(); + synReturn->expression = resultVarExpr; + seqStmt->stmts.add(synReturn); + + synFunc->parentDecl = context->parentDecl; + context->parentDecl->members.add(synFunc); + context->parentDecl->invalidateMemberDictionary(); + addModifier(synFunc, m_astBuilder->create<SynthesizedModifier>()); + + witnessTable->add(requirementDeclRef, RequirementWitness(makeDeclRef(synFunc))); + return true; + } + bool SemanticsVisitor::findWitnessForInterfaceRequirement( ConformanceCheckingContext* context, Type* subType, @@ -3210,18 +3368,16 @@ namespace Slang if(!lookupResult.isValid()) { - // If we failed to even look up a member with the name of the - // requirement, then we can be certain that the type doesn't - // satisfy the requirement. - // - // TODO: If we ever allowed certain kinds of requirements to - // be inferred (e.g., inferring associated types from the - // signatures of methods, as is done for Swift), we'd - // need to revisit this step. - // - getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef); - getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef); - return false; + // If we failed to look up a member with the name of the + // requirement, it may be possible that we can still synthesis the + // implementation if this is one of the known builtin requirements. + // Otherwise, report diagnostic now. + if (!requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementAttribute>()) + { + getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef); + getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef); + return false; + } } // Iterate over the members and look for one that matches @@ -5103,7 +5259,7 @@ namespace Slang { resultType = CheckProperType(resultType); } - else + else if (!funcDecl->returnType.type) { resultType = TypeExp(m_astBuilder->getVoidType()); } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index d69cd39ed..d1e737720 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -405,14 +405,17 @@ namespace Slang switch (item.declRef.getDecl()->astNodeType) { case ASTNodeType::AssocTypeDecl: - return maybeUseSynthesizedTypeDeclForLookupResult(item, originalExpr); + break; + case ASTNodeType::FuncDecl: + // We don't need to intercept lookup results with synthesized decls for methods, + // because function lookups will only take place when we are checking the decl bodies. + // At that point conformance check and synthesis is already done so they will always resolve + // to the synthesized method. + return nullptr; default: return nullptr; } - } - Expr* SemanticsVisitor::maybeUseSynthesizedTypeDeclForLookupResult(LookupResultItem const& item, Expr* originalExpr) - { // We need to check if the lookup should resolve to a definition in an implementation type // if it existed. // This will be the case when the lookup is initiated from the concrete implementation type instead of @@ -425,7 +428,7 @@ namespace Slang // We will only ever need to synthesis a type to satisfy an associatedtype requirement. // In this case the lookup should have resolved to a known associatedtype decl. - auto builtinAssocTypeAttr = item.declRef.getDecl()->findModifier<BuiltinAssociatedTypeRequirementAttribute>(); + auto builtinAssocTypeAttr = item.declRef.getDecl()->findModifier<BuiltinRequirementAttribute>(); if (!builtinAssocTypeAttr) return nullptr; @@ -465,22 +468,32 @@ namespace Slang if (!parent) return nullptr; - // If we reach here, we are expecting a synthesized associated type defined in `subType`. - // Instead of returning a DeclRefExpr to the requirement decl, we synthesize a placeholder type + // If we reach here, we are expecting a synthesized decl defined in `subType`. + // Instead of returning a DeclRefExpr to the requirement decl, we synthesize a placeholder decl // in `subType` and return a DeclRefExpr to the synthesized decl. - auto assocType = m_astBuilder->create<StructDecl>(); - assocType->parentDecl = parent; - assocType->nameAndLoc.name = item.declRef.getName(); - assocType->loc = parent->loc; - parent->members.add(assocType); + + Decl* synthesizedDecl = nullptr; + switch (builtinAssocTypeAttr->kind) + { + case BuiltinRequirementKind::DifferentialType: + synthesizedDecl = m_astBuilder->create<StructDecl>(); + break; + default: + break; + } + synthesizedDecl = m_astBuilder->create<StructDecl>(); + synthesizedDecl->parentDecl = parent; + synthesizedDecl->nameAndLoc.name = item.declRef.getName(); + synthesizedDecl->loc = parent->loc; + parent->members.add(synthesizedDecl); parent->invalidateMemberDictionary(); // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it // from user-provided definitions, and proceed to fill in its definition. auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>(); - addModifier(assocType, toBeSynthesized); + addModifier(synthesizedDecl, toBeSynthesized); - return ConstructDeclRefExpr(makeDeclRef(assocType), nullptr, originalExpr->loc, originalExpr); + return ConstructDeclRefExpr(makeDeclRef(synthesizedDecl), nullptr, originalExpr->loc, originalExpr); } Expr* SemanticsVisitor::ConstructLookupResultExpr( diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 2dc08262e..31075c3e8 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -630,10 +630,6 @@ namespace Slang Expr* base, SourceLoc loc); - Expr* maybeUseSynthesizedTypeDeclForLookupResult( - LookupResultItem const& item, - Expr* orignalExpr); - Expr* maybeUseSynthesizedDeclForLookupResult( LookupResultItem const& item, Expr* orignalExpr); @@ -1073,6 +1069,12 @@ namespace Slang Dictionary<DeclRef<InterfaceDecl>, RefPtr<WitnessTable>> mapInterfaceToWitnessTable; }; + FuncDecl* synthesizeMethodSignatureForRequirementWitness( + ConformanceCheckingContext* context, + DeclRef<FuncDecl> requiredMemberDeclRef, + List<Expr*>& synArgs, + ThisExpr*& synThis); + /// Attempt to synthesize a method that can satisfy `requiredMemberDeclRef` using `lookupResult`. /// /// On success, installs the syntethesized method in `witnessTable` and returns `true`. @@ -1104,6 +1106,15 @@ namespace Slang DeclRef<Decl> requiredMemberDeclRef, RefPtr<WitnessTable> witnessTable); + /// Attempt to synthesize `zero`, `dadd` and `dmul` methods for a type that conforms to + /// `IDifferentiable`. + /// On success, installs the syntethesized functions and returns `true`. + /// Otherwise, returns `false`. + bool trySynthesizeDifferentialMethodRequirementWitness( + ConformanceCheckingContext* context, + DeclRef<Decl> requirementDeclRef, + RefPtr<WitnessTable> witnessTable); + /// Attempt to synthesize an associated `Differential` type for a type that conforms to /// `IDifferentiable`. /// diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 20e5d5378..e189b9114 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -484,14 +484,14 @@ namespace Slang return false; } } - else if (auto builtinAssocTypeAttr = as<BuiltinAssociatedTypeRequirementAttribute>(attr)) + else if (auto builtinAssocTypeAttr = as<BuiltinRequirementAttribute>(attr)) { if (attr->args.getCount() == 1) { //IntVal* outIntVal; if (auto cInt = checkConstantEnumVal(attr->args[0])) { - builtinAssocTypeAttr->kind = (BuiltinAssociatedTypeRequirementKind)(cInt->value); + builtinAssocTypeAttr->kind = (BuiltinRequirementKind)(cInt->value); } else { diff --git a/tests/autodiff/differential-method-synthesis.slang b/tests/autodiff/differential-method-synthesis.slang new file mode 100644 index 000000000..73afc4411 --- /dev/null +++ b/tests/autodiff/differential-method-synthesis.slang @@ -0,0 +1,45 @@ +// Tests automatic synthesis of Differential type and method requirements. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +struct B : IDifferentiable +{ + float x; +} + +struct A : IDifferentiable +{ + B b; + float y; +}; + +typedef __DifferentialPair<A> dpA; + +A nonDiff(A a) +{ + return a; +} + +__differentiate_jvp A f(A a) +{ + A aout; + aout.y = 2 * a.b.x; + aout.b.x = 5 * a.b.x; + + return nonDiff(aout); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + A a = {1.0, 2.0}; + A.Differential b = {0.2}; + dpA dpa = dpA(a, b); + outputBuffer[0] = __jvp(f)(dpa).d().b.x; // Expect: 0 + } +} diff --git a/tests/autodiff/differential-method-synthesis.slang.expected.txt b/tests/autodiff/differential-method-synthesis.slang.expected.txt new file mode 100644 index 000000000..e070cf84d --- /dev/null +++ b/tests/autodiff/differential-method-synthesis.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +0.000000 +0.000000 +0.000000 +0.000000 +0.000000 |
