summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-12-01 18:55:43 -0800
committerGitHub <noreply@github.com>2022-12-01 18:55:43 -0800
commite7df8538eb8f0ed06f0838d946bec8e9e0fe0985 (patch)
tree3c08e646600ab82ffda260f2b6deb96dd2085776 /source
parentf51f69d045d9e0b83d9ab1f4623d4319ce1867be (diff)
Allow `no_diff` on `this` parameter. (#2543)
Diffstat (limited to 'source')
-rw-r--r--source/slang/diff.meta.slang5
-rw-r--r--source/slang/slang-ast-builder.cpp10
-rw-r--r--source/slang/slang-ast-builder.h6
-rw-r--r--source/slang/slang-ast-modifier.h7
-rw-r--r--source/slang/slang-check-conformance.cpp5
-rw-r--r--source/slang/slang-check-decl.cpp132
-rw-r--r--source/slang/slang-check-expr.cpp10
-rw-r--r--source/slang/slang-check-impl.h9
-rw-r--r--source/slang/slang-diagnostic-defs.h1
-rw-r--r--source/slang/slang-emit.cpp4
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp187
-rw-r--r--source/slang/slang-ir-autodiff-fwd.h2
-rw-r--r--source/slang/slang-ir-autodiff-pairs.cpp38
-rw-r--r--source/slang/slang-ir-autodiff.cpp59
-rw-r--r--source/slang/slang-ir-autodiff.h22
-rw-r--r--source/slang/slang-ir-hoist-local-types.cpp4
-rw-r--r--source/slang/slang-ir-hoist-local-types.h2
-rw-r--r--source/slang/slang-lower-to-ir.cpp5
-rw-r--r--source/slang/slang-syntax.cpp8
-rw-r--r--source/slang/slang-syntax.h5
20 files changed, 386 insertions, 135 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 69ced9156..033c173ab 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -11,13 +11,16 @@ attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute;
-
__attributeTarget(FunctionDeclBase)
attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute;
__attributeTarget(DeclBase)
attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute;
+// Exclude "this" parameter from differentiation.
+__attributeTarget(FunctionDeclBase)
+attribute_syntax [NoDiffThis] : NoDiffThisAttribute;
+
/// Pair type that serves to wrap the primal and
/// differential types of an arbitrary type T.
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index 623a9161b..ab161065d 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -141,6 +141,16 @@ Type* SharedASTBuilder::getNoneType()
return m_noneType;
}
+Type* SharedASTBuilder::getDiffInterfaceType()
+{
+ if (!m_diffInterfaceType)
+ {
+ auto decl = findMagicDecl("DifferentiableType");
+ m_diffInterfaceType = DeclRefType::create(m_astBuilder, makeDeclRef<Decl>(decl));
+ }
+ return m_diffInterfaceType;
+}
+
SharedASTBuilder::~SharedASTBuilder()
{
// Release built in types..
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index bdc03dda5..72d8ec50a 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -36,6 +36,8 @@ public:
Type* getNullPtrType();
/// Get the NullPtr type
Type* getNoneType();
+ /// Get the `IDifferentiable` type
+ Type* getDiffInterfaceType();
const ReflectClassInfo* findClassInfo(Name* name);
SyntaxClass<NodeBase> findSyntaxClass(Name* name);
@@ -85,7 +87,7 @@ protected:
Type* m_dynamicType = nullptr;
Type* m_nullPtrType = nullptr;
Type* m_noneType = nullptr;
- Type* m_diffBottomType = nullptr;
+ Type* m_diffInterfaceType = nullptr;
Type* m_builtinTypes[Index(BaseType::CountOf)];
Dictionary<String, Decl*> m_magicDecls;
@@ -308,7 +310,7 @@ public:
Type* getNullPtrType() { return m_sharedASTBuilder->getNullPtrType(); }
Type* getNoneType() { return m_sharedASTBuilder->getNoneType(); }
Type* getEnumTypeType() { return m_sharedASTBuilder->getEnumTypeType(); }
-
+ Type* getDiffInterfaceType() { return m_sharedASTBuilder->getDiffInterfaceType(); }
// Construct the type `Ptr<valueType>`, where `Ptr`
// is looked up as a builtin type.
PtrType* getPtrType(Type* valueType);
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 2adbcf6c6..c85464061 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1113,6 +1113,13 @@ class BackwardDerivativeOfAttribute : public DifferentiableAttribute
Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction.
};
+ /// The `[NoDiffThis]` attribute is used to specify that the `this` parameter should not be
+ /// included for differentiation.
+class NoDiffThisAttribute : public Attribute
+{
+ SLANG_AST_CLASS(NoDiffThisAttribute)
+};
+
/// Indicates that the modified declaration is one of the "magic" declarations
/// that NVAPI uses to communicate extended operations. When NVAPI is being included
/// via the prelude for downstream compilation, declarations with this modifier
diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp
index 4d983b746..3a50897de 100644
--- a/source/slang/slang-check-conformance.cpp
+++ b/source/slang/slang-check-conformance.cpp
@@ -527,6 +527,11 @@ namespace Slang
return false;
}
+ bool SemanticsVisitor::isTypeDifferentiable(Type* type)
+ {
+ return isDeclaredSubtype(type, m_astBuilder->getDiffInterfaceType());
+ }
+
Val* SemanticsVisitor::tryGetSubtypeWitness(
Type* subType,
DeclRef<AggTypeDecl> superTypeDeclRef)
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index d36e6286d..d8968e33a 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -340,6 +340,16 @@ namespace Slang
return isEffectivelyStatic(decl, parentDecl);
}
+ bool isGlobalDecl(Decl* decl)
+ {
+ if (!decl)
+ return false;
+ auto parentDecl = decl->parentDecl;
+ if (auto genericDecl = as<GenericDecl>(parentDecl))
+ parentDecl = genericDecl->parentDecl;
+ return as<NamespaceDeclBase>(parentDecl) != nullptr;
+ }
+
/// Is `decl` a global shader parameter declaration?
bool isGlobalShaderParameter(VarDeclBase* decl)
{
@@ -1920,37 +1930,21 @@ namespace Slang
if(!requiredResultType->equals(satisfyingResultType))
return false;
- witnessTable->add(
- requiredMemberDeclRef.getDecl(),
- RequirementWitness(satisfyingMemberDeclRef));
-
if (hasForwardDerivative || hasBackwardDerivative)
{
- int fwdReqFound = 0;
- int bwdReqFound = 0;
- for (auto reqRefDecl : requiredMemberDeclRef.getDecl()->getMembersOfType<DerivativeRequirementReferenceDecl>())
+ auto parentInterfaceDecl = as<InterfaceDecl>(getParentDecl(requiredMemberDeclRef.getDecl()));
+ if (parentInterfaceDecl)
{
- if (auto fwdReq = as<ForwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl))
- {
- ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>();
- val->func = satisfyingMemberDeclRef;
- witnessTable->add(fwdReq, RequirementWitness(val));
- fwdReqFound++;
- }
- else if (auto bwdReq = as<BackwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl))
- {
- BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>();
- val->func = satisfyingMemberDeclRef;
- witnessTable->add(bwdReq, RequirementWitness(val));
- bwdReqFound++;
- }
+ auto idiffType = DeclRefType::create(m_astBuilder, m_astBuilder->getDifferentiableInterface());
+ bool noDiffThisSatisfying = !isDeclaredSubtype(witnessTable->witnessedType, idiffType);
+ bool noDiffThisRequirement = (requiredMemberDeclRef.getDecl()->findModifier<NoDiffThisAttribute>() != nullptr);
+ if (noDiffThisRequirement != noDiffThisSatisfying)
+ return false;
}
-
- SLANG_RELEASE_ASSERT(
- fwdReqFound == (hasForwardDerivative ? 1 : 0) &&
- bwdReqFound == (hasBackwardDerivative ? 1 : 0));
}
+ _addMethodWitness(witnessTable, requiredMemberDeclRef, satisfyingMemberDeclRef);
+
return true;
}
@@ -2543,7 +2537,10 @@ namespace Slang
// mangled name!
//
synFuncDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc;
-
+ if (synFuncDecl->nameAndLoc.name)
+ {
+ synFuncDecl->nameAndLoc.name = getSession()->getNameObj("$__syn_" + synFuncDecl->nameAndLoc.name->text);
+ }
// The result type of our synthesized method will be the expected
// result type from the interface requirement.
//
@@ -2592,6 +2589,13 @@ namespace Slang
synArg->declRef = makeDeclRef(synParamDecl);
synArg->type = paramType;
synArgs.add(synArg);
+
+ if (paramDeclRef.getDecl()->findModifier<NoDiffModifier>())
+ {
+ auto noDiffModifier = m_astBuilder->create<NoDiffModifier>();
+ noDiffModifier->keywordName = getSession()->getNameObj("no_diff");
+ addModifier(synParamDecl, noDiffModifier);
+ }
}
@@ -2625,13 +2629,52 @@ namespace Slang
synThis->type.isLeftValue = true;
auto synMutatingAttr = m_astBuilder->create<MutatingAttribute>();
- synFuncDecl->modifiers.first = synMutatingAttr;
+ addModifier(synFuncDecl, synMutatingAttr);
+ }
+
+ if (requiredMemberDeclRef.getDecl()->hasModifier<NoDiffThisAttribute>())
+ {
+ auto noDiffThisAttr = m_astBuilder->create<NoDiffThisAttribute>();
+ addModifier(synFuncDecl, noDiffThisAttr);
+ }
+ if (requiredMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>())
+ {
+ auto attr = m_astBuilder->create<ForwardDifferentiableAttribute>();
+ addModifier(synFuncDecl, attr);
+ }
+ if (requiredMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>())
+ {
+ auto attr = m_astBuilder->create<BackwardDifferentiableAttribute>();
+ addModifier(synFuncDecl, attr);
}
}
return synFuncDecl;
}
+ void SemanticsVisitor::_addMethodWitness(
+ WitnessTable* witnessTable,
+ DeclRef<CallableDecl> requiredMemberDeclRef,
+ DeclRef<CallableDecl> satisfyingMemberDeclRef)
+ {
+ for (auto reqRefDecl : requiredMemberDeclRef.getDecl()->getMembersOfType<DerivativeRequirementReferenceDecl>())
+ {
+ if (auto fwdReq = as<ForwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl))
+ {
+ ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>();
+ val->func = satisfyingMemberDeclRef;
+ witnessTable->add(fwdReq, RequirementWitness(val));
+ }
+ else if (auto bwdReq = as<BackwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl))
+ {
+ BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>();
+ val->func = satisfyingMemberDeclRef;
+ witnessTable->add(bwdReq, RequirementWitness(val));
+ }
+ }
+ witnessTable->add(requiredMemberDeclRef, RequirementWitness(satisfyingMemberDeclRef));
+ }
+
bool SemanticsVisitor::trySynthesizeMethodRequirementWitness(
ConformanceCheckingContext* context,
LookupResult const& lookupResult,
@@ -2806,8 +2849,7 @@ namespace Slang
// difference between our synthetic method and a hand-written
// one with the same behavior.
//
- witnessTable->add(requiredMemberDeclRef,
- RequirementWitness(makeDeclRef(synFuncDecl)));
+ _addMethodWitness(witnessTable, requiredMemberDeclRef, makeDeclRef(synFuncDecl));
return true;
}
@@ -5593,6 +5635,7 @@ namespace Slang
if (auto interfaceDecl = findParentInterfaceDecl(decl))
{
+ bool isDiffFunc = false;
if (decl->hasModifier<ForwardDifferentiableAttribute>())
{
auto reqDecl = m_astBuilder->create<ForwardDerivativeRequirementDecl>();
@@ -5607,6 +5650,7 @@ namespace Slang
reqRef->referencedDecl = reqDecl;
reqRef->parentDecl = decl;
decl->members.add(reqRef);
+ isDiffFunc = true;
}
if (decl->hasModifier<BackwardDifferentiableAttribute>())
{
@@ -5622,6 +5666,36 @@ namespace Slang
reqRef->referencedDecl = reqDecl;
reqRef->parentDecl = decl;
decl->members.add(reqRef);
+ isDiffFunc = true;
+ }
+ if (isDiffFunc)
+ {
+ auto interfaceDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(interfaceDecl));
+ auto interfaceType = DeclRefType::create(m_astBuilder, interfaceDeclRef);
+ bool noDiffThisRequirement = !isTypeDifferentiable(interfaceType);
+ if (noDiffThisRequirement)
+ {
+ auto noDiffThisModifier = m_astBuilder->create<NoDiffThisAttribute>();
+ addModifier(decl, noDiffThisModifier);
+ }
+ }
+ }
+ if (decl->findModifier<DifferentiableAttribute>())
+ {
+ // Add `no_diff` modifiers to parameters.
+ // This is necessary to preserve no-diff-ness for generic function before and after
+ // specialization.
+ for (auto paramDecl : decl->getParameters())
+ {
+ if (paramDecl->type.type && !isTypeDifferentiable(paramDecl->type.type))
+ {
+ if (!paramDecl->hasModifier<NoDiffModifier>())
+ {
+ auto noDiffModifier = m_astBuilder->create<NoDiffModifier>();
+ noDiffModifier->keywordName = getSession()->getNameObj("no_diff");
+ addModifier(paramDecl, noDiffModifier);
+ }
+ }
}
}
}
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 7297ca282..336682bf4 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -503,7 +503,11 @@ namespace Slang
auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>();
addModifier(synthesizedDecl, toBeSynthesized);
- return ConstructDeclRefExpr(makeDeclRef(synthesizedDecl), nullptr, originalExpr->loc, originalExpr);
+ return ConstructDeclRefExpr(
+ makeDeclRef(synthesizedDecl),
+ nullptr,
+ originalExpr ? originalExpr->loc : SourceLoc(),
+ originalExpr);
}
Expr* SemanticsVisitor::ConstructLookupResultExpr(
@@ -1927,6 +1931,10 @@ namespace Slang
{
getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward");
}
+ if (!isEffectivelyStatic(funcDecl) && !isGlobalDecl(funcDecl))
+ {
+ getSink()->diagnose(forwardDiff, Diagnostics::nonStaticMemberFunctionNotAllowedAsDiffOperand, funcDecl);
+ }
}
}
}
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 1c2f698bd..fb47a38c1 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -16,6 +16,8 @@ namespace Slang
bool isEffectivelyStatic(
Decl* decl);
+ bool isGlobalDecl(Decl* decl);
+
Type* checkProperType(
Linkage* linkage,
TypeExp typeExp,
@@ -1026,6 +1028,11 @@ namespace Slang
List<Expr*>& synArgs,
ThisExpr*& synThis);
+ void _addMethodWitness(
+ WitnessTable* witnessTable,
+ DeclRef<CallableDecl> requirement,
+ DeclRef<CallableDecl> method);
+
/// Attempt to synthesize a method that can satisfy `requiredMemberDeclRef` using `lookupResult`.
///
/// On success, installs the syntethesized method in `witnessTable` and returns `true`.
@@ -1431,6 +1438,8 @@ namespace Slang
bool isInterfaceType(Type* type);
+ bool isTypeDifferentiable(Type* type);
+
/// Check whether `subType` is a sub-type of `superTypeDeclRef`,
/// and return a witness to the sub-type relationship if it holds
/// (return null otherwise).
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index d293626ae..fc92241e1 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -305,6 +305,7 @@ DIAGNOSTIC(30095, Error, errorTypeOfCalleeIncompatibleWithCaller, "the error typ
DIAGNOSTIC(30096, Error, differentialTypeShouldServeAsItsOwnDifferentialType, "type '$0' is used as a `Differential` type, therefore it must serve as its own `Differential` type.")
DIAGNOSTIC(30097, Error, functionNotMarkedAsDifferentiable, "function '$0' is not marked as $1-differentiable.")
+DIAGNOSTIC(30098, Error, nonStaticMemberFunctionNotAllowedAsDiffOperand, "non-static function reference '$0' is not allowed here.")
DIAGNOSTIC(-1, Note, noteSeeUseOfDifferentialType, "see use of '$0' as Differential of '$1'.")
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 508402736..2476f79e5 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -382,7 +382,9 @@ Result linkAndOptimizeIR(
if (!changed)
break;
}
-
+
+ finalizeAutoDiffPass(irModule);
+
lowerReinterpret(targetRequest, irModule, sink);
validateIRModuleIfEnabled(codeGenContext, irModule);
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index d45dd0c10..c94342736 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -11,6 +11,12 @@
namespace Slang
{
+static IRInst* _unwrapAttributedType(IRInst* type)
+{
+ while (auto attrType = as<IRAttributedType>(type))
+ type = attrType->getBaseType();
+ return type;
+}
DiagnosticSink* ForwardDerivativeTranscriber::getSink()
{
@@ -183,8 +189,12 @@ IRType* ForwardDerivativeTranscriber::getOrCreateDiffPairType(IRInst* primalType
IRType* ForwardDerivativeTranscriber::getOrCreateDiffPairType(IRInst* primalType)
{
IRBuilder builder(sharedBuilder);
- builder.setInsertInto(primalType->parent);
- auto witness = as<IRWitnessTable>(
+ if (!primalType->next)
+ builder.setInsertInto(primalType->parent);
+ else
+ builder.setInsertBefore(primalType->next);
+
+ IRInst* witness = as<IRWitnessTable>(
differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType));
if (!witness)
@@ -193,6 +203,10 @@ IRType* ForwardDerivativeTranscriber::getOrCreateDiffPairType(IRInst* primalType
{
witness = getDifferentialPairWitness(primalPairType);
}
+ else if (auto extractExistential = as<IRExtractExistentialType>(primalType))
+ {
+ differentiateExtractExistentialType(&builder, extractExistential, witness);
+ }
}
return builder.getDifferentialPairType(
@@ -271,6 +285,12 @@ IRType* ForwardDerivativeTranscriber::_differentiateTypeImpl(IRBuilder* builder,
else
return nullptr;
+ case kIROp_ExtractExistentialType:
+ {
+ IRInst* wt = nullptr;
+ return differentiateExtractExistentialType(builder, as<IRExtractExistentialType>(primalType), wt);
+ }
+
case kIROp_TupleType:
{
auto tupleType = as<IRTupleType>(primalType);
@@ -288,6 +308,75 @@ IRType* ForwardDerivativeTranscriber::_differentiateTypeImpl(IRBuilder* builder,
}
}
+ // Given an interface type, return the lookup path from a witness table of `type` to a witness table of `IDifferentiable`.
+bool _findDifferentiableInterfaceLookupPathImpl(
+ HashSet<IRInst*>& processedTypes,
+ IRInterfaceType* idiffType,
+ IRInterfaceType* type,
+ List<IRInterfaceRequirementEntry*>& currentPath)
+{
+ if (processedTypes.Contains(type))
+ return false;
+ processedTypes.Add(type);
+
+ List<IRInterfaceRequirementEntry*> lookupKeyPath;
+ for (UInt i = 0; i < type->getOperandCount(); i++)
+ {
+ auto entry = as<IRInterfaceRequirementEntry>(type->getOperand(i));
+ if (!entry) continue;
+ if (auto wt = as<IRWitnessTableTypeBase>(entry->getRequirementVal()))
+ {
+ currentPath.add(entry);
+ if (wt->getConformanceType() == idiffType)
+ {
+ return true;
+ }
+ else if (auto subInterfaceType = as<IRInterfaceType>(wt->getConformanceType()))
+ {
+ if (_findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, subInterfaceType, currentPath))
+ return true;
+ }
+ currentPath.removeLast();
+ }
+ }
+ return false;
+}
+
+List<IRInterfaceRequirementEntry*> _findDifferentiableInterfaceLookupPath(
+ IRInterfaceType* idiffType,
+ IRInterfaceType* type)
+{
+ List<IRInterfaceRequirementEntry*> currentPath;
+ HashSet<IRInst*> processedTypes;
+ _findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, type, currentPath);
+ return currentPath;
+}
+
+IRType* ForwardDerivativeTranscriber::differentiateExtractExistentialType(IRBuilder* builder, IRExtractExistentialType* origType, IRInst*& witnessTable)
+{
+ witnessTable = nullptr;
+
+ // Search for IDifferentiable conformance.
+ auto interfaceType = as<IRInterfaceType>(_unwrapAttributedType(origType->getOperand(0)->getDataType()));
+ if (!interfaceType)
+ return nullptr;
+ List<IRInterfaceRequirementEntry*> lookupKeyPath = _findDifferentiableInterfaceLookupPath(
+ autoDiffSharedContext->differentiableInterfaceType, interfaceType);
+
+ if (lookupKeyPath.getCount())
+ {
+ // `interfaceType` does conform to `IDifferentiable`.
+ witnessTable = builder->emitExtractExistentialWitnessTable(origType->getOperand(0));
+ for (auto node : lookupKeyPath)
+ {
+ witnessTable = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), witnessTable, node->getRequirementKey());
+ }
+ auto diffType = builder->emitLookupInterfaceMethodInst(builder->getTypeType(), witnessTable, autoDiffSharedContext->differentialAssocTypeStructKey);
+ return (IRType*)diffType;
+ }
+ return nullptr;
+}
+
IRType* ForwardDerivativeTranscriber::tryGetDiffPairType(IRBuilder* builder, IRType* primalType)
{
// If this is a PtrType (out, inout, etc..), then create diff pair from
@@ -699,6 +788,10 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall
return InstPair(primalCall, nullptr);
}
+ auto calleeType = as<IRFuncType>(diffCallee->getDataType());
+ SLANG_ASSERT(calleeType);
+ SLANG_RELEASE_ASSERT(calleeType->getParamCount() == origCall->getArgCount());
+
List<IRInst*> args;
// Go over the parameter list and create pairs for each input (if required)
for (UIndex ii = 0; ii < origCall->getArgCount(); ii++)
@@ -707,7 +800,15 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall
auto primalArg = findOrTranscribePrimalInst(builder, origArg);
SLANG_ASSERT(primalArg);
- auto primalType = primalArg->getDataType();
+ auto primalType = primalArg->getDataType();
+ auto paramType = calleeType->getParamType(ii);
+ if (!isNoDiffType(paramType))
+ {
+ if (isNoDiffType(primalType))
+ {
+ while (auto attrType = as<IRAttributedType>(primalType))
+ primalType = attrType->getBaseType();
+ }
if (auto pairType = tryGetDiffPairType(builder, primalType))
{
auto diffArg = findOrTranscribeDiffInst(builder, origArg);
@@ -718,16 +819,16 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall
SLANG_RELEASE_ASSERT(diffArg);
auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg);
args.add(diffPair);
- }
- else
- {
- // Add original/primal argument.
- args.add(primalArg);
+ continue;
}
}
+ // Argument is not differentiable.
+ // Add original/primal argument.
+ args.add(primalArg);
+ }
- IRType* diffReturnType = nullptr;
- diffReturnType = tryGetDiffPairType(builder, origCall->getFullType());
+ IRType* diffReturnType = nullptr;
+ diffReturnType = tryGetDiffPairType(builder, origCall->getFullType());
if (!diffReturnType)
{
@@ -942,37 +1043,37 @@ InstPair ForwardDerivativeTranscriber::transcribeSpecialize(IRBuilder* builder,
// Make sure this isn't itself a specialize .
SLANG_RELEASE_ASSERT(!as<IRSpecialize>(jvpFunc));
- return InstPair(primalSpecialize, jvpFunc);
- }
- else if (auto derivativeDecoration = genericInnerVal->findDecoration<IRForwardDerivativeDecoration>())
- {
- diffBase = derivativeDecoration->getForwardDerivativeFunc();
- List<IRInst*> args;
- for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
- {
- args.add(primalSpecialize->getArg(i));
- }
- auto diffSpecialize = builder->emitSpecializeInst(
- builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
- return InstPair(primalSpecialize, diffSpecialize);
- }
- else if (auto diffDecor = genericInnerVal->findDecoration<IRForwardDifferentiableDecoration>())
+ return InstPair(primalSpecialize, jvpFunc);
+ }
+ else if (auto derivativeDecoration = genericInnerVal->findDecoration<IRForwardDerivativeDecoration>())
+ {
+ diffBase = derivativeDecoration->getForwardDerivativeFunc();
+ List<IRInst*> args;
+ for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
{
- List<IRInst*> args;
- for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
- {
- args.add(primalSpecialize->getArg(i));
- }
- diffBase = findOrTranscribeDiffInst(builder, origSpecialize->getBase());
- auto diffSpecialize = builder->emitSpecializeInst(
- builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
- return InstPair(primalSpecialize, diffSpecialize);
+ args.add(primalSpecialize->getArg(i));
}
- else
+ auto diffSpecialize = builder->emitSpecializeInst(
+ builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
+ return InstPair(primalSpecialize, diffSpecialize);
+ }
+ else if (auto diffDecor = genericInnerVal->findDecoration<IRForwardDifferentiableDecoration>())
+ {
+ List<IRInst*> args;
+ for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
{
- return InstPair(primalSpecialize, nullptr);
+ args.add(primalSpecialize->getArg(i));
}
+ diffBase = findOrTranscribeDiffInst(builder, origSpecialize->getBase());
+ auto diffSpecialize = builder->emitSpecializeInst(
+ builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
+ return InstPair(primalSpecialize, diffSpecialize);
+ }
+ else
+ {
+ return InstPair(primalSpecialize, nullptr);
}
+}
InstPair ForwardDerivativeTranscriber::transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst)
{
@@ -981,7 +1082,7 @@ InstPair ForwardDerivativeTranscriber::transcribeLookupInterfaceMethod(IRBuilder
auto primalType = findOrTranscribePrimalInst(builder, lookupInst->getFullType());
auto primal = (IRSpecialize*)builder->emitLookupInterfaceMethodInst((IRType*)primalType, primalWt, primalKey);
- auto interfaceType = as<IRInterfaceType>(as<IRWitnessTableTypeBase>(lookupInst->getWitnessTable()->getDataType())->getConformanceType());
+ auto interfaceType = as<IRInterfaceType>(_unwrapAttributedType(as<IRWitnessTableTypeBase>(lookupInst->getWitnessTable()->getDataType())->getConformanceType()));
if (!interfaceType)
{
return InstPair(primal, nullptr);
@@ -1031,7 +1132,17 @@ IRInst* ForwardDerivativeTranscriber::getDifferentialZeroOfType(IRBuilder* build
// Since primalType has a corresponding differential type, we can lookup the
// definition for zero().
auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
- SLANG_ASSERT(zeroMethod);
+ if (!zeroMethod)
+ {
+ // if the differential type itself comes from a witness lookup, we can just lookup the
+ // zero method from the same witness table.
+ if (auto lookupInterface = as<IRLookupWitnessMethod>(diffType))
+ {
+ auto wt = lookupInterface->getWitnessTable();
+ zeroMethod = builder->emitLookupInterfaceMethodInst(builder->getFuncType(List<IRType*>(), diffType), wt, autoDiffSharedContext->zeroMethodStructKey);
+ }
+ }
+ SLANG_RELEASE_ASSERT(zeroMethod);
auto emptyArgList = List<IRInst*>();
return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList);
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
index ab5d753d6..678677625 100644
--- a/source/slang/slang-ir-autodiff-fwd.h
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -106,6 +106,8 @@ struct ForwardDerivativeTranscriber
IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType);
+ IRType* differentiateExtractExistentialType(IRBuilder* builder, IRExtractExistentialType* origType, IRInst*& witnessTable);
+
IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType);
InstPair transcribeParam(IRBuilder* builder, IRParam* origParam);
diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp
index b9b4a8b66..dc72ed44a 100644
--- a/source/slang/slang-ir-autodiff-pairs.cpp
+++ b/source/slang/slang-ir-autodiff-pairs.cpp
@@ -1,4 +1,5 @@
#include "slang-ir-autodiff-pairs.h"
+#include "slang-ir-hoist-local-types.h"
namespace Slang
{
@@ -13,25 +14,22 @@ struct DiffPairLoweringPass : InstPassBase
pairBuilder = &pairBuilderStorage;
}
- IRInst* lowerPairType(IRBuilder* builder, IRType* pairType, bool* isTrivial = nullptr)
+ IRInst* lowerPairType(IRBuilder* builder, IRType* pairType)
{
builder->setInsertBefore(pairType);
- auto loweredPairTypeInfo = pairBuilder->lowerDiffPairType(
+ auto loweredPairType = pairBuilder->lowerDiffPairType(
builder,
pairType);
- if (isTrivial)
- *isTrivial = loweredPairTypeInfo.isTrivial;
- return loweredPairTypeInfo.loweredType;
+ return loweredPairType;
}
IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst)
{
-
if (auto makePairInst = as<IRMakeDifferentialPair>(inst))
{
bool isTrivial = false;
auto pairType = as<IRDifferentialPairType>(makePairInst->getDataType());
- if (auto loweredPairType = lowerPairType(builder, pairType, &isTrivial))
+ if (auto loweredPairType = lowerPairType(builder, pairType))
{
builder->setInsertBefore(makePairInst);
IRInst* result = nullptr;
@@ -63,7 +61,7 @@ struct DiffPairLoweringPass : InstPassBase
pairType = pairPtrType->getValueType();
}
- if (lowerPairType(builder, pairType, nullptr))
+ if (lowerPairType(builder, pairType))
{
builder->setInsertBefore(getDiffInst);
IRInst* diffFieldExtract = nullptr;
@@ -81,7 +79,7 @@ struct DiffPairLoweringPass : InstPassBase
pairType = pairPtrType->getValueType();
}
- if (lowerPairType(builder, pairType, nullptr))
+ if (lowerPairType(builder, pairType))
{
builder->setInsertBefore(getPrimalInst);
@@ -99,27 +97,9 @@ struct DiffPairLoweringPass : InstPassBase
bool processInstWithChildren(IRBuilder* builder, IRInst* instWithChildren)
{
bool modified = false;
+
// Hoist all pair types to global scope when possible.
- auto moduleInst = module->getModuleInst();
- processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRInst* originalPairType)
- {
- if (originalPairType->parent != moduleInst)
- {
- originalPairType->removeFromParent();
- ShortList<IRInst*> operands;
- for (UInt i = 0; i < originalPairType->getOperandCount(); i++)
- {
- operands.add(originalPairType->getOperand(i));
- }
- auto newPairType = builder->findOrEmitHoistableInst(
- originalPairType->getFullType(),
- originalPairType->getOp(),
- originalPairType->getOperandCount(),
- operands.getArrayView().getBuffer());
- originalPairType->replaceUsesWith(newPairType);
- originalPairType->removeAndDeallocate();
- }
- });
+ hoistLocalTypes(module);
autodiffContext->sharedBuilder->deduplicateAndRebuildGlobalNumberingMap();
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 5b5832073..86429f9ba 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -27,6 +27,20 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK
return nullptr;
}
+bool isNoDiffType(IRType* paramType)
+{
+ while (auto ptrType = as<IRPtrTypeBase>(paramType))
+ paramType = ptrType->getValueType();
+ while (auto attrType = as<IRAttributedType>(paramType))
+ {
+ if (attrType->findAttr<IRNoDiffAttr>())
+ {
+ return true;
+ }
+ }
+ return false;
+}
+
IRStructField* DifferentialPairTypeBuilder::findField(IRInst* type, IRStructKey* key)
{
if (auto irStructType = as<IRStructType>(type))
@@ -80,19 +94,14 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(IRBuilder* builder, IRIns
IRInst* pairType = nullptr;
if (auto basePtrType = as<IRPtrTypeBase>(baseInst->getDataType()))
{
- auto baseTypeInfo = lowerDiffPairType(builder, basePtrType->getValueType());
+ auto loweredType = lowerDiffPairType(builder, basePtrType->getValueType());
- // TODO(sai): Not sure at the moment how to handle diff-bottom pointer types,
- // especially since we probably don't need diff bottom anymore.
- //
- SLANG_ASSERT(!baseTypeInfo.isTrivial);
-
- pairType = builder->getPtrType(kIROp_PtrType, (IRType*)baseTypeInfo.loweredType);
+ pairType = builder->getPtrType(kIROp_PtrType, (IRType*)loweredType);
}
else
{
- auto baseTypeInfo = lowerDiffPairType(builder, baseInst->getDataType());
- pairType = baseTypeInfo.loweredType;
+ auto loweredType = lowerDiffPairType(builder, baseInst->getDataType());
+ pairType = loweredType;
}
if (auto basePairStructType = as<IRStructType>(pairType))
@@ -240,33 +249,29 @@ IRInst* DifferentialPairTypeBuilder::getDiffTypeWitnessFromPairType(IRBuilder* b
return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey);
}
-DifferentialPairTypeBuilder::LoweredPairTypeInfo DifferentialPairTypeBuilder::lowerDiffPairType(
+IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(
IRBuilder* builder, IRType* originalPairType)
{
- LoweredPairTypeInfo result = {};
-
+ IRInst* result = nullptr;
if (pairTypeCache.TryGetValue(originalPairType, result))
return result;
auto pairType = as<IRDifferentialPairType>(originalPairType);
if (!pairType)
{
- result.isTrivial = true;
- result.loweredType = originalPairType;
+ result = originalPairType;
return result;
}
auto primalType = pairType->getValueType();
if (as<IRParam>(primalType))
{
- result.isTrivial = false;
- result.loweredType = nullptr;
+ result = nullptr;
return result;
}
auto diffType = getDiffTypeFromPairType(builder, pairType);
if (!diffType)
return result;
- result.loweredType = _createDiffPairType(pairType->getValueType(), (IRType*)diffType);
- result.isTrivial = false;
+ result = _createDiffPairType(pairType->getValueType(), (IRType*)diffType);
pairTypeCache.Add(originalPairType, result);
return result;
@@ -469,6 +474,22 @@ bool processAutodiffCalls(
// Process reverse derivative calls.
modified |= processReverseDerivativeCalls(&autodiffContext, sink);
+ return modified;
+}
+
+bool finalizeAutoDiffPass(IRModule* module)
+{
+ bool modified = false;
+
+ // Create shared context for all auto-diff related passes
+ AutoDiffSharedContext autodiffContext(module->getModuleInst());
+
+ SharedIRBuilder sharedBuilder;
+ sharedBuilder.init(module);
+ sharedBuilder.deduplicateAndRebuildGlobalNumberingMap();
+
+ autodiffContext.sharedBuilder = &sharedBuilder;
+
// Replaces IRDifferentialPairType with an auto-generated struct,
// IRDifferentialPairGetDifferential with 'differential' field access,
// IRDifferentialPairGetPrimal with 'primal' field access, and
@@ -481,7 +502,7 @@ bool processAutodiffCalls(
// Remove auto-diff related decorations.
stripAutoDiffDecorations(module);
- return modified;
+ return false;
}
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index e470044a4..25cbe16f4 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -147,12 +147,6 @@ struct DifferentiableTypeConformanceContext
struct DifferentialPairTypeBuilder
{
- struct LoweredPairTypeInfo
- {
- IRInst* loweredType;
- bool isTrivial;
- };
-
DifferentialPairTypeBuilder() = default;
DifferentialPairTypeBuilder(AutoDiffSharedContext* sharedContext) : sharedContext(sharedContext) {}
@@ -177,10 +171,16 @@ struct DifferentialPairTypeBuilder
IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type);
- LoweredPairTypeInfo lowerDiffPairType(IRBuilder* builder, IRType* originalPairType);
+ IRInst* lowerDiffPairType(IRBuilder* builder, IRType* originalPairType);
+ struct PairStructKey
+ {
+ IRInst* originalType;
+ IRInst* diffType;
+ };
- Dictionary<IRInst*, LoweredPairTypeInfo> pairTypeCache;
+ // Cache from `IRDifferentialPairType` to materialized struct type.
+ Dictionary<IRInst*, IRInst*> pairTypeCache;
IRStructKey* globalPrimalKey = nullptr;
@@ -197,6 +197,8 @@ void stripAutoDiffDecorations(IRModule* module);
IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey);
+bool isNoDiffType(IRType* paramType);
+
struct IRAutodiffPassOptions
{
// Nothing for now...
@@ -207,4 +209,6 @@ bool processAutodiffCalls(
DiagnosticSink* sink,
IRAutodiffPassOptions const& options = IRAutodiffPassOptions());
-}; \ No newline at end of file
+bool finalizeAutoDiffPass(IRModule* module);
+
+};
diff --git a/source/slang/slang-ir-hoist-local-types.cpp b/source/slang/slang-ir-hoist-local-types.cpp
index 756a25c49..cf091f701 100644
--- a/source/slang/slang-ir-hoist-local-types.cpp
+++ b/source/slang/slang-ir-hoist-local-types.cpp
@@ -8,7 +8,6 @@ namespace Slang
struct HoistLocalTypesContext
{
IRModule* module;
- DiagnosticSink* sink;
SharedIRBuilder sharedBuilderStorage;
@@ -98,11 +97,10 @@ struct HoistLocalTypesContext
}
};
-void hoistLocalTypes(IRModule* module, DiagnosticSink* sink)
+void hoistLocalTypes(IRModule* module)
{
HoistLocalTypesContext context;
context.module = module;
- context.sink = sink;
context.processModule();
}
diff --git a/source/slang/slang-ir-hoist-local-types.h b/source/slang/slang-ir-hoist-local-types.h
index 6b742746f..55e62ce57 100644
--- a/source/slang/slang-ir-hoist-local-types.h
+++ b/source/slang/slang-ir-hoist-local-types.h
@@ -13,6 +13,6 @@ class DiagnosticSink;
/// can be hoisted to global scope. This pass examines all local type defintions
// and try to hoist them to global scope if the definition is no longer dependent on
// the local context.
-void hoistLocalTypes(IRModule* module, DiagnosticSink* sink);
+void hoistLocalTypes(IRModule* module);
} // namespace Slang
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 28639ae53..f836824f7 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -2884,6 +2884,11 @@ void collectParameterLists(
auto thisType = getThisParamTypeForContainer(context, parentDeclRef);
if(thisType)
{
+ if (declRef.getDecl()->findModifier<NoDiffThisAttribute>())
+ {
+ auto noDiffAttr = context->astBuilder->getNoDiffModifierVal();
+ thisType = context->astBuilder->getModifiedType(thisType, 1, &noDiffAttr);
+ }
addThisParameter(innerThisParamDirection, thisType, ioParameterLists);
}
}
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index a79c48227..4e5db17c0 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -1368,6 +1368,14 @@ Module* getModule(Decl* decl)
return moduleDecl->module;
}
+Decl* getParentDecl(Decl* decl)
+{
+ decl = decl->parentDecl;
+ while (as<GenericDecl>(decl))
+ decl = decl->parentDecl;
+ return decl;
+}
+
static const ImageFormatInfo kImageFormatInfos[] =
{
#define SLANG_IMAGE_FORMAT_INFO(TYPE, COUNT, SIZE) SLANG_SCALAR_TYPE_##TYPE, uint8_t(COUNT), uint8_t(SIZE)
diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h
index 441dcb8e7..e36ee944c 100644
--- a/source/slang/slang-syntax.h
+++ b/source/slang/slang-syntax.h
@@ -329,10 +329,11 @@ namespace Slang
/// Get the module dclaration that a declaration is associated with, if any.
ModuleDecl* getModuleDecl(Decl* decl);
- /// Get the module that a declaration is associated with, if any.
+ /// Get the module that a declaration is associated with, if any.
Module* getModule(Decl* decl);
-
+ /// Get the parent decl, skipping any generic decls in between.
+ Decl* getParentDecl(Decl* decl);
} // namespace Slang