From e7df8538eb8f0ed06f0838d946bec8e9e0fe0985 Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 1 Dec 2022 18:55:43 -0800 Subject: Allow `no_diff` on `this` parameter. (#2543) --- source/slang/slang-check-decl.cpp | 132 +++++++++++++++++++++++++++++--------- 1 file changed, 103 insertions(+), 29 deletions(-) (limited to 'source/slang/slang-check-decl.cpp') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index d36e6286d..d8968e33a 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -340,6 +340,16 @@ namespace Slang return isEffectivelyStatic(decl, parentDecl); } + bool isGlobalDecl(Decl* decl) + { + if (!decl) + return false; + auto parentDecl = decl->parentDecl; + if (auto genericDecl = as(parentDecl)) + parentDecl = genericDecl->parentDecl; + return as(parentDecl) != nullptr; + } + /// Is `decl` a global shader parameter declaration? bool isGlobalShaderParameter(VarDeclBase* decl) { @@ -1920,37 +1930,21 @@ namespace Slang if(!requiredResultType->equals(satisfyingResultType)) return false; - witnessTable->add( - requiredMemberDeclRef.getDecl(), - RequirementWitness(satisfyingMemberDeclRef)); - if (hasForwardDerivative || hasBackwardDerivative) { - int fwdReqFound = 0; - int bwdReqFound = 0; - for (auto reqRefDecl : requiredMemberDeclRef.getDecl()->getMembersOfType()) + auto parentInterfaceDecl = as(getParentDecl(requiredMemberDeclRef.getDecl())); + if (parentInterfaceDecl) { - if (auto fwdReq = as(reqRefDecl->referencedDecl)) - { - ForwardDifferentiateVal* val = m_astBuilder->create(); - val->func = satisfyingMemberDeclRef; - witnessTable->add(fwdReq, RequirementWitness(val)); - fwdReqFound++; - } - else if (auto bwdReq = as(reqRefDecl->referencedDecl)) - { - BackwardDifferentiateVal* val = m_astBuilder->create(); - val->func = satisfyingMemberDeclRef; - witnessTable->add(bwdReq, RequirementWitness(val)); - bwdReqFound++; - } + auto idiffType = DeclRefType::create(m_astBuilder, m_astBuilder->getDifferentiableInterface()); + bool noDiffThisSatisfying = !isDeclaredSubtype(witnessTable->witnessedType, idiffType); + bool noDiffThisRequirement = (requiredMemberDeclRef.getDecl()->findModifier() != nullptr); + if (noDiffThisRequirement != noDiffThisSatisfying) + return false; } - - SLANG_RELEASE_ASSERT( - fwdReqFound == (hasForwardDerivative ? 1 : 0) && - bwdReqFound == (hasBackwardDerivative ? 1 : 0)); } + _addMethodWitness(witnessTable, requiredMemberDeclRef, satisfyingMemberDeclRef); + return true; } @@ -2543,7 +2537,10 @@ namespace Slang // mangled name! // synFuncDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc; - + if (synFuncDecl->nameAndLoc.name) + { + synFuncDecl->nameAndLoc.name = getSession()->getNameObj("$__syn_" + synFuncDecl->nameAndLoc.name->text); + } // The result type of our synthesized method will be the expected // result type from the interface requirement. // @@ -2592,6 +2589,13 @@ namespace Slang synArg->declRef = makeDeclRef(synParamDecl); synArg->type = paramType; synArgs.add(synArg); + + if (paramDeclRef.getDecl()->findModifier()) + { + auto noDiffModifier = m_astBuilder->create(); + noDiffModifier->keywordName = getSession()->getNameObj("no_diff"); + addModifier(synParamDecl, noDiffModifier); + } } @@ -2625,13 +2629,52 @@ namespace Slang synThis->type.isLeftValue = true; auto synMutatingAttr = m_astBuilder->create(); - synFuncDecl->modifiers.first = synMutatingAttr; + addModifier(synFuncDecl, synMutatingAttr); + } + + if (requiredMemberDeclRef.getDecl()->hasModifier()) + { + auto noDiffThisAttr = m_astBuilder->create(); + addModifier(synFuncDecl, noDiffThisAttr); + } + if (requiredMemberDeclRef.getDecl()->hasModifier()) + { + auto attr = m_astBuilder->create(); + addModifier(synFuncDecl, attr); + } + if (requiredMemberDeclRef.getDecl()->hasModifier()) + { + auto attr = m_astBuilder->create(); + addModifier(synFuncDecl, attr); } } return synFuncDecl; } + void SemanticsVisitor::_addMethodWitness( + WitnessTable* witnessTable, + DeclRef requiredMemberDeclRef, + DeclRef satisfyingMemberDeclRef) + { + for (auto reqRefDecl : requiredMemberDeclRef.getDecl()->getMembersOfType()) + { + if (auto fwdReq = as(reqRefDecl->referencedDecl)) + { + ForwardDifferentiateVal* val = m_astBuilder->create(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(fwdReq, RequirementWitness(val)); + } + else if (auto bwdReq = as(reqRefDecl->referencedDecl)) + { + BackwardDifferentiateVal* val = m_astBuilder->create(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(bwdReq, RequirementWitness(val)); + } + } + witnessTable->add(requiredMemberDeclRef, RequirementWitness(satisfyingMemberDeclRef)); + } + bool SemanticsVisitor::trySynthesizeMethodRequirementWitness( ConformanceCheckingContext* context, LookupResult const& lookupResult, @@ -2806,8 +2849,7 @@ namespace Slang // difference between our synthetic method and a hand-written // one with the same behavior. // - witnessTable->add(requiredMemberDeclRef, - RequirementWitness(makeDeclRef(synFuncDecl))); + _addMethodWitness(witnessTable, requiredMemberDeclRef, makeDeclRef(synFuncDecl)); return true; } @@ -5593,6 +5635,7 @@ namespace Slang if (auto interfaceDecl = findParentInterfaceDecl(decl)) { + bool isDiffFunc = false; if (decl->hasModifier()) { auto reqDecl = m_astBuilder->create(); @@ -5607,6 +5650,7 @@ namespace Slang reqRef->referencedDecl = reqDecl; reqRef->parentDecl = decl; decl->members.add(reqRef); + isDiffFunc = true; } if (decl->hasModifier()) { @@ -5622,6 +5666,36 @@ namespace Slang reqRef->referencedDecl = reqDecl; reqRef->parentDecl = decl; decl->members.add(reqRef); + isDiffFunc = true; + } + if (isDiffFunc) + { + auto interfaceDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(interfaceDecl)); + auto interfaceType = DeclRefType::create(m_astBuilder, interfaceDeclRef); + bool noDiffThisRequirement = !isTypeDifferentiable(interfaceType); + if (noDiffThisRequirement) + { + auto noDiffThisModifier = m_astBuilder->create(); + addModifier(decl, noDiffThisModifier); + } + } + } + if (decl->findModifier()) + { + // Add `no_diff` modifiers to parameters. + // This is necessary to preserve no-diff-ness for generic function before and after + // specialization. + for (auto paramDecl : decl->getParameters()) + { + if (paramDecl->type.type && !isTypeDifferentiable(paramDecl->type.type)) + { + if (!paramDecl->hasModifier()) + { + auto noDiffModifier = m_astBuilder->create(); + noDiffModifier->keywordName = getSession()->getNameObj("no_diff"); + addModifier(paramDecl, noDiffModifier); + } + } } } } -- cgit v1.2.3