diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-07-26 17:15:21 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-07-26 17:15:21 -0400 |
| commit | ba89fc84267bfd09f1c8abf10a5b85d09bbc79de (patch) | |
| tree | 2c79fc5dafb89a030d22fa86cd6fa3d69a89a785 /source/slang/slang-check-decl.cpp | |
| parent | b8ade05df10a2774d3da5ef1fb2c7479ff48989a (diff) | |
Refactor `dmul(This, Differential)` to `dmul<T:Real>(T, Differential)` (#3029)
* Refactor `dmul(This, Differential)` to `dmul<T:Real>(T, Differential)`
- Add AST synthesis support for generic containers
- Refactor relevant tests
* Merge dmul synthesis with dadd and dzero, and disambiguate using an enum
* Fix trailing spaces
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 295 |
1 files changed, 269 insertions, 26 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 4c1e967e3..2d009c28c 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2574,6 +2574,135 @@ namespace Slang return false; } + GenericDecl* SemanticsVisitor::synthesizeGenericSignatureForRequirementWitness( + ConformanceCheckingContext* context, + DeclRef<GenericDecl> requiredMemberDeclRef, + List<Expr*>& synArgs, + List<Expr*>& synGenericArgs, + ThisExpr*& synThis) + { + auto synGenericDecl = m_astBuilder->create<GenericDecl>(); + + // For now our synthesized method will use the name and source + // location of the requirement we are trying to satisfy. + // + // TODO: as it stands right now our syntesized method will + // get a mangled name, which we don't actually want. Leaving + // out the name here doesn't help matters, because then *all* + // snthesized methods on a given type would share the same + // mangled name! + // + synGenericDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc; + if (synGenericDecl->nameAndLoc.name) + { + synGenericDecl->nameAndLoc.name = getSession()->getNameObj("$__syn_" + synGenericDecl->nameAndLoc.name->text); + } + + // Dictionary to map from the original type parameters to the synthesized ones. + Dictionary<GenericTypeParamDecl*, GenericTypeParamDecl*> mapOrigToSynTypeParams; + + // Our synthesized method will have parameters matching the names + // and types of those on the requirement, and it will use expressions + // that reference those parametesr as arguments for the call expresison + // that makes up the body. + // + for (auto member : requiredMemberDeclRef.getDecl()->members) + { + if (auto typeParamDecl = as<GenericTypeParamDecl>(member)) + { + auto synTypeParamDecl = m_astBuilder->create<GenericTypeParamDecl>(); + synTypeParamDecl->nameAndLoc = typeParamDecl->getNameAndLoc(); + synTypeParamDecl->initType = typeParamDecl->initType; + synTypeParamDecl->parentDecl = synGenericDecl; + synGenericDecl->members.add(synTypeParamDecl); + + mapOrigToSynTypeParams.add(typeParamDecl, synTypeParamDecl); + + // Construct a DeclRefExpr from the type parameter. + auto synTypeParamDeclRef = makeDeclRef(synTypeParamDecl); + + auto synTypeParamDeclRefExpr = m_astBuilder->create<VarExpr>(); + synTypeParamDeclRefExpr->declRef = synTypeParamDeclRef; + synTypeParamDeclRefExpr->type = getTypeForDeclRef(m_astBuilder, synTypeParamDeclRef, SourceLoc()); + + synGenericArgs.add(synTypeParamDeclRefExpr); + } + } + + for (auto member : requiredMemberDeclRef.getDecl()->members) + { + if (auto constraintDecl = as<GenericTypeConstraintDecl>(member)) + { + getASTBuilder()->getSpecializedDeclRef( + constraintDecl, requiredMemberDeclRef.getSubst()); + + auto synConstraintDecl = m_astBuilder->create<GenericTypeConstraintDecl>(); + synConstraintDecl->nameAndLoc = constraintDecl->getNameAndLoc(); + synConstraintDecl->parentDecl = synGenericDecl; + + // For constraints of type T : Interface, where T is a simple type parameter, + // find the declaration of T + // + if (auto typeParamDecl = as<DeclRefType>(constraintDecl->sub.type)->declRef.as<GenericTypeParamDecl>().getDecl()) + { + auto synTypeParamDecl = mapOrigToSynTypeParams[typeParamDecl]; + + // Construct a DeclRefExpr from the type parameter. + auto synTypeParamDeclRef = makeDeclRef(synTypeParamDecl.getValue()); + + auto synTypeParamDeclRefExpr = m_astBuilder->create<VarExpr>(); + synTypeParamDeclRefExpr->declRef = synTypeParamDeclRef; + synTypeParamDeclRefExpr->type = getTypeForDeclRef(m_astBuilder, synTypeParamDeclRef, SourceLoc()); + + synConstraintDecl->sub = TypeExp(synTypeParamDeclRefExpr); + synConstraintDecl->sup = constraintDecl->sup; + synGenericDecl->members.add(synConstraintDecl); + } + else + { + SLANG_UNEXPECTED("Cannot perform synthesis for requirements with complex type constraints."); + } + } + } + + // Get outer substitutions. (This inner-most substition + // must be a ThisTypeSubstition) + // + Substitutions* outer = nullptr; + if (auto thisTypeSubst = findThisTypeSubstitution( + requiredMemberDeclRef.getSubst(), + as<InterfaceDecl>(requiredMemberDeclRef.getParent(m_astBuilder)).getDecl())) + { + outer = thisTypeSubst; + } + + // Override generic pointer to point to the original generic container. + // This will create a substitution of the synthesized parameters for the + // original parameters. + // + GenericSubstitution* requiredFuncSubsts = createDefaultSubstitutionsForGeneric(m_astBuilder, this, requiredMemberDeclRef.getDecl(), outer); + DeclRef<Decl> requiredFuncDeclRef = m_astBuilder->getSpecializedDeclRef(requiredMemberDeclRef.getDecl()->inner, requiredFuncSubsts); + + GenericSubstitution* substSynParamsForOrigGeneric = m_astBuilder->getOrCreateGenericSubstitution( + outer, + requiredMemberDeclRef.getDecl(), + createDefaultSubstitutionsForGeneric(m_astBuilder, this, synGenericDecl, nullptr)->getArgs()); + + // Substitute parameters of the synthesized generic for the parameters of the original generic. + requiredFuncDeclRef = substituteDeclRef(substSynParamsForOrigGeneric, m_astBuilder, requiredFuncDeclRef); + + SLANG_ASSERT(requiredFuncDeclRef.as<FuncDecl>()); + + synGenericDecl->inner = synthesizeMethodSignatureForRequirementWitness( + context, + requiredFuncDeclRef.as<FuncDecl>(), + synArgs, + synThis); + synGenericDecl->inner->parentDecl = synGenericDecl; + + return synGenericDecl; + } + FuncDecl* SemanticsVisitor::synthesizeMethodSignatureForRequirementWitness( ConformanceCheckingContext* context, DeclRef<FuncDecl> requiredMemberDeclRef, @@ -3274,12 +3403,30 @@ namespace Slang switch (builtinAttr->kind) { case BuiltinRequirementKind::DAddFunc: - case BuiltinRequirementKind::DMulFunc: case BuiltinRequirementKind::DZeroFunc: return trySynthesizeDifferentialMethodRequirementWitness( context, requiredFuncDeclRef, - witnessTable); + witnessTable, + SynthesisPattern::AllInductive); + } + } + return false; + } + + // For generic decl, check if we match DMulFunc, and synthesize the method. + if (auto requiredGenericDeclRef = requiredMemberDeclRef.as<GenericDecl>()) + { + if (auto builtinAttr = getInner(requiredGenericDeclRef)->findModifier<BuiltinRequirementModifier>()) + { + switch (builtinAttr->kind) + { + case BuiltinRequirementKind::DMulFunc: + return trySynthesizeDifferentialMethodRequirementWitness( + context, + requiredGenericDeclRef, + witnessTable, + SynthesisPattern::FixedFirstArg); } } return false; @@ -3330,7 +3477,15 @@ namespace Slang return false; } - Stmt* _synthesizeMemberAssignMemberHelper(ASTSynthesizer& synth, Name* funcName, Type* leftType, Expr* leftValue, List<Expr*>&& args, int nestingLevel = 0) + Stmt* _synthesizeMemberAssignMemberHelper( + ASTSynthesizer& synth, + Name* funcName, + Type* leftType, + Expr* leftValue, + List<Expr*>&& args, + List<Expr*>&& genericArgs, + List<bool>&& inductiveArgMask, + int nestingLevel = 0) { if (nestingLevel > 16) return nullptr; @@ -3342,11 +3497,24 @@ namespace Slang auto forStmt = synth.emitFor(synth.emitIntConst(0), synth.emitGetArrayLengthExpr(leftValue), indexVar); addModifier(forStmt, synth.getBuilder()->create<ForceUnrollAttribute>()); auto innerLeft = synth.emitIndexExpr(leftValue, synth.emitVarExpr(indexVar)); - for (auto& arg : args) + + for (auto ii = 0; ii < args.getCount(); ++ii) { - arg = synth.emitIndexExpr(arg, synth.emitVarExpr(indexVar)); + auto& arg = args[ii]; + if (inductiveArgMask[ii]) + arg = synth.emitIndexExpr(arg, synth.emitVarExpr(indexVar)); } - auto assignStmt = _synthesizeMemberAssignMemberHelper(synth, funcName, arrayType->getElementType(), innerLeft, _Move(args), nestingLevel + 1); + + auto assignStmt = _synthesizeMemberAssignMemberHelper( + synth, + funcName, + arrayType->getElementType(), + innerLeft, + _Move(args), + _Move(genericArgs), + _Move(inductiveArgMask), + nestingLevel + 1); + synth.popScope(); if (!assignStmt) return nullptr; @@ -3354,13 +3522,18 @@ namespace Slang } auto callee = synth.emitMemberExpr(leftType, funcName); + + if (genericArgs.getCount() > 0) + callee = synth.emitGenericAppExpr(callee, _Move(genericArgs)); + return synth.emitAssignStmt(leftValue, synth.emitInvokeExpr(callee, _Move(args))); } bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness( ConformanceCheckingContext* context, DeclRef<Decl> requirementDeclRef, - RefPtr<WitnessTable> witnessTable) + RefPtr<WitnessTable> witnessTable, + SynthesisPattern pattern) { // We support two cases of synthesis here. // Case 1 is that there the associated Differential type is defined to be `DifferentialBottom`. @@ -3371,9 +3544,10 @@ namespace Slang // ``` // static TResult requiredMethod(TParam1 p0, TParam2 p1, ...) // ``` - // Where TResult, TParam1, TParam2 is either `This` or `Differential`, - // We synthesize a memberwise dispatch to compute each field of `TResult`, - // resulting an implementation of the form: + // Where TResult,TParam1, TParam2 is either `This` or `Differential`, + // We synthesize a memberwise dispatch to compute each field of `TResult`. + // Multiple patterns are supported (see SemanticsVisitor::SynthesisPattern for a full list) + // For AllInductive, we synthesize an implementation of the form: // ``` // [BackwardDifferentiable] // static TResult requiredMethod(TParam1 p0, TParam2 p1, ...) @@ -3404,13 +3578,32 @@ namespace Slang ASTSynthesizer synth(m_astBuilder, getNamePool()); List<Expr*> synArgs; + List<Expr*> synGenericArgs; ThisExpr* synThis = nullptr; - auto synFunc = synthesizeMethodSignatureForRequirementWitness( - context, requirementDeclRef.as<FuncDecl>(), synArgs, synThis); + FuncDecl* synFunc = nullptr; + GenericDecl* synGeneric = nullptr; + + if (auto genericDeclRef = requirementDeclRef.as<GenericDecl>()) + { + synGeneric = synthesizeGenericSignatureForRequirementWitness( + context, genericDeclRef, synArgs, synGenericArgs, synThis); + synFunc = as<FuncDecl>(synGeneric->inner); + } + else if (auto funcDeclRef = requirementDeclRef.as<FuncDecl>()) + { + synFunc = synthesizeMethodSignatureForRequirementWitness( + context, funcDeclRef, synArgs, synThis); + } + + SLANG_ASSERT(synFunc); addModifier(synFunc, m_astBuilder->create<BackwardDifferentiableAttribute>()); - synFunc->parentDecl = context->parentDecl; + if (synGeneric) + synGeneric->parentDecl = context->parentDecl; + else + synFunc->parentDecl = context->parentDecl; + synth.pushContainerScope(synFunc); auto blockStmt = m_astBuilder->create<BlockStmt>(); synFunc->body = blockStmt; @@ -3438,23 +3631,71 @@ namespace Slang // Construct reference exprs to the member's corresponding fields in each parameter. List<Expr*> paramFields; - int paramIndex = 0; - for (auto arg : synArgs) + List<bool> inductiveArgMask; + + switch (pattern) { - auto memberExpr = m_astBuilder->create<MemberExpr>(); - memberExpr->baseExpression = arg; - // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is - // Differential type. - memberExpr->name = varMember->getName(); - paramFields.add(memberExpr); - paramIndex++; + case SynthesisPattern::AllInductive: + { + int paramIndex = 0; + for (auto arg : synArgs) + { + auto memberExpr = m_astBuilder->create<MemberExpr>(); + memberExpr->baseExpression = arg; + // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is + // Differential type. + memberExpr->name = varMember->getName(); + paramFields.add(memberExpr); + inductiveArgMask.add(true); + + paramIndex++; + } + break; + } + case SynthesisPattern::FixedFirstArg: + { + int paramIndex = 0; + for (auto arg : synArgs) + { + if (paramIndex == 0) + { + paramFields.add(arg); + inductiveArgMask.add(false); + + paramIndex++; + } + else + { + auto memberExpr = m_astBuilder->create<MemberExpr>(); + memberExpr->baseExpression = arg; + // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is + // Differential type. + memberExpr->name = varMember->getName(); + paramFields.add(memberExpr); + inductiveArgMask.add(true); + + paramIndex++; + } + } + break; + } + default: + SLANG_UNIMPLEMENTED_X("unhandled synthesis pattern"); + break; } // Invoke the method for the field and assign the value to resultVar. // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr` // is Differential type. auto leftVal = synth.emitMemberExpr(resultVarExpr, varMember->getName()); - if (!_synthesizeMemberAssignMemberHelper(synth, requirementDeclRef.getName(), memberType, leftVal, _Move(paramFields))) + if (!_synthesizeMemberAssignMemberHelper( + synth, + requirementDeclRef.getName(), + memberType, + leftVal, + _Move(paramFields), + _Move(synGenericArgs), + _Move(inductiveArgMask))) return false; } @@ -3473,11 +3714,11 @@ namespace Slang // This can be done by obtaining ThisTypeSubstitution from requirementDeclRef to get the // generic substitution for outer generic parameters, and apply it here. SubstitutionSet substSet; - if (auto thisTypeSusbt = findThisTypeSubstitution( + if (auto thisTypeSubst = findThisTypeSubstitution( requirementDeclRef.getSubst(), as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl))) { - if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub)) + if (auto declRefType = as<DeclRefType>(thisTypeSubst->witness->sub)) { substSet = declRefType->declRef.getSubst(); } @@ -3610,7 +3851,9 @@ namespace Slang // requirement, it may be possible that we can still synthesis the // implementation if this is one of the known builtin requirements. // Otherwise, report diagnostic now. - if (!requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementModifier>()) + if (!requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementModifier>() && + !(requiredMemberDeclRef.as<GenericDecl>() && + getInner(requiredMemberDeclRef.as<GenericDecl>())->hasModifier<BuiltinRequirementModifier>())) { getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef); getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef); |
