summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-12-01 18:55:43 -0800
committerGitHub <noreply@github.com>2022-12-01 18:55:43 -0800
commite7df8538eb8f0ed06f0838d946bec8e9e0fe0985 (patch)
tree3c08e646600ab82ffda260f2b6deb96dd2085776 /source/slang/slang-check-decl.cpp
parentf51f69d045d9e0b83d9ab1f4623d4319ce1867be (diff)
Allow `no_diff` on `this` parameter. (#2543)
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp132
1 files changed, 103 insertions, 29 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index d36e6286d..d8968e33a 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -340,6 +340,16 @@ namespace Slang
return isEffectivelyStatic(decl, parentDecl);
}
+ bool isGlobalDecl(Decl* decl)
+ {
+ if (!decl)
+ return false;
+ auto parentDecl = decl->parentDecl;
+ if (auto genericDecl = as<GenericDecl>(parentDecl))
+ parentDecl = genericDecl->parentDecl;
+ return as<NamespaceDeclBase>(parentDecl) != nullptr;
+ }
+
/// Is `decl` a global shader parameter declaration?
bool isGlobalShaderParameter(VarDeclBase* decl)
{
@@ -1920,37 +1930,21 @@ namespace Slang
if(!requiredResultType->equals(satisfyingResultType))
return false;
- witnessTable->add(
- requiredMemberDeclRef.getDecl(),
- RequirementWitness(satisfyingMemberDeclRef));
-
if (hasForwardDerivative || hasBackwardDerivative)
{
- int fwdReqFound = 0;
- int bwdReqFound = 0;
- for (auto reqRefDecl : requiredMemberDeclRef.getDecl()->getMembersOfType<DerivativeRequirementReferenceDecl>())
+ auto parentInterfaceDecl = as<InterfaceDecl>(getParentDecl(requiredMemberDeclRef.getDecl()));
+ if (parentInterfaceDecl)
{
- if (auto fwdReq = as<ForwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl))
- {
- ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>();
- val->func = satisfyingMemberDeclRef;
- witnessTable->add(fwdReq, RequirementWitness(val));
- fwdReqFound++;
- }
- else if (auto bwdReq = as<BackwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl))
- {
- BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>();
- val->func = satisfyingMemberDeclRef;
- witnessTable->add(bwdReq, RequirementWitness(val));
- bwdReqFound++;
- }
+ auto idiffType = DeclRefType::create(m_astBuilder, m_astBuilder->getDifferentiableInterface());
+ bool noDiffThisSatisfying = !isDeclaredSubtype(witnessTable->witnessedType, idiffType);
+ bool noDiffThisRequirement = (requiredMemberDeclRef.getDecl()->findModifier<NoDiffThisAttribute>() != nullptr);
+ if (noDiffThisRequirement != noDiffThisSatisfying)
+ return false;
}
-
- SLANG_RELEASE_ASSERT(
- fwdReqFound == (hasForwardDerivative ? 1 : 0) &&
- bwdReqFound == (hasBackwardDerivative ? 1 : 0));
}
+ _addMethodWitness(witnessTable, requiredMemberDeclRef, satisfyingMemberDeclRef);
+
return true;
}
@@ -2543,7 +2537,10 @@ namespace Slang
// mangled name!
//
synFuncDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc;
-
+ if (synFuncDecl->nameAndLoc.name)
+ {
+ synFuncDecl->nameAndLoc.name = getSession()->getNameObj("$__syn_" + synFuncDecl->nameAndLoc.name->text);
+ }
// The result type of our synthesized method will be the expected
// result type from the interface requirement.
//
@@ -2592,6 +2589,13 @@ namespace Slang
synArg->declRef = makeDeclRef(synParamDecl);
synArg->type = paramType;
synArgs.add(synArg);
+
+ if (paramDeclRef.getDecl()->findModifier<NoDiffModifier>())
+ {
+ auto noDiffModifier = m_astBuilder->create<NoDiffModifier>();
+ noDiffModifier->keywordName = getSession()->getNameObj("no_diff");
+ addModifier(synParamDecl, noDiffModifier);
+ }
}
@@ -2625,13 +2629,52 @@ namespace Slang
synThis->type.isLeftValue = true;
auto synMutatingAttr = m_astBuilder->create<MutatingAttribute>();
- synFuncDecl->modifiers.first = synMutatingAttr;
+ addModifier(synFuncDecl, synMutatingAttr);
+ }
+
+ if (requiredMemberDeclRef.getDecl()->hasModifier<NoDiffThisAttribute>())
+ {
+ auto noDiffThisAttr = m_astBuilder->create<NoDiffThisAttribute>();
+ addModifier(synFuncDecl, noDiffThisAttr);
+ }
+ if (requiredMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>())
+ {
+ auto attr = m_astBuilder->create<ForwardDifferentiableAttribute>();
+ addModifier(synFuncDecl, attr);
+ }
+ if (requiredMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>())
+ {
+ auto attr = m_astBuilder->create<BackwardDifferentiableAttribute>();
+ addModifier(synFuncDecl, attr);
}
}
return synFuncDecl;
}
+ void SemanticsVisitor::_addMethodWitness(
+ WitnessTable* witnessTable,
+ DeclRef<CallableDecl> requiredMemberDeclRef,
+ DeclRef<CallableDecl> satisfyingMemberDeclRef)
+ {
+ for (auto reqRefDecl : requiredMemberDeclRef.getDecl()->getMembersOfType<DerivativeRequirementReferenceDecl>())
+ {
+ if (auto fwdReq = as<ForwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl))
+ {
+ ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>();
+ val->func = satisfyingMemberDeclRef;
+ witnessTable->add(fwdReq, RequirementWitness(val));
+ }
+ else if (auto bwdReq = as<BackwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl))
+ {
+ BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>();
+ val->func = satisfyingMemberDeclRef;
+ witnessTable->add(bwdReq, RequirementWitness(val));
+ }
+ }
+ witnessTable->add(requiredMemberDeclRef, RequirementWitness(satisfyingMemberDeclRef));
+ }
+
bool SemanticsVisitor::trySynthesizeMethodRequirementWitness(
ConformanceCheckingContext* context,
LookupResult const& lookupResult,
@@ -2806,8 +2849,7 @@ namespace Slang
// difference between our synthetic method and a hand-written
// one with the same behavior.
//
- witnessTable->add(requiredMemberDeclRef,
- RequirementWitness(makeDeclRef(synFuncDecl)));
+ _addMethodWitness(witnessTable, requiredMemberDeclRef, makeDeclRef(synFuncDecl));
return true;
}
@@ -5593,6 +5635,7 @@ namespace Slang
if (auto interfaceDecl = findParentInterfaceDecl(decl))
{
+ bool isDiffFunc = false;
if (decl->hasModifier<ForwardDifferentiableAttribute>())
{
auto reqDecl = m_astBuilder->create<ForwardDerivativeRequirementDecl>();
@@ -5607,6 +5650,7 @@ namespace Slang
reqRef->referencedDecl = reqDecl;
reqRef->parentDecl = decl;
decl->members.add(reqRef);
+ isDiffFunc = true;
}
if (decl->hasModifier<BackwardDifferentiableAttribute>())
{
@@ -5622,6 +5666,36 @@ namespace Slang
reqRef->referencedDecl = reqDecl;
reqRef->parentDecl = decl;
decl->members.add(reqRef);
+ isDiffFunc = true;
+ }
+ if (isDiffFunc)
+ {
+ auto interfaceDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(interfaceDecl));
+ auto interfaceType = DeclRefType::create(m_astBuilder, interfaceDeclRef);
+ bool noDiffThisRequirement = !isTypeDifferentiable(interfaceType);
+ if (noDiffThisRequirement)
+ {
+ auto noDiffThisModifier = m_astBuilder->create<NoDiffThisAttribute>();
+ addModifier(decl, noDiffThisModifier);
+ }
+ }
+ }
+ if (decl->findModifier<DifferentiableAttribute>())
+ {
+ // Add `no_diff` modifiers to parameters.
+ // This is necessary to preserve no-diff-ness for generic function before and after
+ // specialization.
+ for (auto paramDecl : decl->getParameters())
+ {
+ if (paramDecl->type.type && !isTypeDifferentiable(paramDecl->type.type))
+ {
+ if (!paramDecl->hasModifier<NoDiffModifier>())
+ {
+ auto noDiffModifier = m_astBuilder->create<NoDiffModifier>();
+ noDiffModifier->keywordName = getSession()->getNameObj("no_diff");
+ addModifier(paramDecl, noDiffModifier);
+ }
+ }
}
}
}