summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-18 12:37:27 -0800
committerGitHub <noreply@github.com>2022-11-18 12:37:27 -0800
commitd58e08f8237a1888ceaad53402d534679ea83b1a (patch)
treee66838e0dc31fc12ebd7c1acecbb5060e8808366 /source
parent0a050a439fa91b66f2020421d4fec3e60aed4112 (diff)
Data flow validation pass for diagnosing derivative loss. (#2523)
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-decl.h7
-rw-r--r--source/slang/slang-ast-expr.h9
-rw-r--r--source/slang/slang-ast-iterator.h5
-rw-r--r--source/slang/slang-ast-support-types.h22
-rw-r--r--source/slang/slang-check-decl.cpp137
-rw-r--r--source/slang/slang-check-expr.cpp75
-rw-r--r--source/slang/slang-check-impl.h38
-rw-r--r--source/slang/slang-diagnostic-defs.h9
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp421
-rw-r--r--source/slang/slang-ir-check-differentiability.h14
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp345
-rw-r--r--source/slang/slang-ir-diff-jvp.h153
-rw-r--r--source/slang/slang-ir-inst-defs.h3
-rw-r--r--source/slang/slang-ir-insts.h20
-rw-r--r--source/slang/slang-ir-lower-error-handling.cpp2
-rw-r--r--source/slang/slang-ir.cpp5
-rw-r--r--source/slang/slang-language-server-ast-lookup.cpp4
-rw-r--r--source/slang/slang-lower-to-ir.cpp12
-rw-r--r--source/slang/slang-parser.cpp9
-rw-r--r--source/slang/slang-serialize-misc-type-info.h6
20 files changed, 1045 insertions, 251 deletions
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index cbd3f0f0c..4da832d11 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -391,6 +391,12 @@ class ModuleDecl : public NamespaceDeclBase
//
Module* module = nullptr;
+ /// Map a decl to the list of its associated decls.
+ ///
+ /// This mapping is filled in during semantic checking, as the decl declarations get checked or generated.
+ ///
+ OrderedDictionary<Decl*, RefPtr<DeclAssociationList>> mapDeclToAssociatedDecls;
+
SLANG_UNREFLECTED
/// Map a type to the list of extensions of that type (if any) declared in this module
@@ -398,6 +404,7 @@ class ModuleDecl : public NamespaceDeclBase
/// This mapping is filled in during semantic checking, as `ExtensionDecl`s get checked.
///
Dictionary<AggTypeDecl*, RefPtr<CandidateExtensionList>> mapTypeToCandidateExtensions;
+
};
/// A declaration that brings members of another declaration or namespace into scope
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index 3a99ac15f..a0268cce9 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -465,6 +465,15 @@ class BackwardDifferentiateExpr: public DifferentiateExpr
SLANG_AST_CLASS(BackwardDifferentiateExpr)
};
+ /// An express to mark its inner expression as an intended non-differential call.
+class TreatAsDifferentiableExpr : public Expr
+{
+ SLANG_AST_CLASS(TreatAsDifferentiableExpr)
+
+ Expr* innerExpr;
+ Scope* scope;
+};
+
/// A type expression of the form `__TaggedUnion(A, ...)`.
///
/// An expression of this form will resolve to a `TaggedUnionType`
diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h
index ed396139e..79aade1ee 100644
--- a/source/slang/slang-ast-iterator.h
+++ b/source/slang/slang-ast-iterator.h
@@ -268,6 +268,11 @@ struct ASTIterator
iterator->maybeDispatchCallback(expr);
dispatchIfNotNull(expr->baseFunction);
}
+
+ void visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr)
+ {
+ dispatchIfNotNull(expr->innerExpr);
+ }
};
struct ASTIteratorStmtVisitor : public StmtVisitor<ASTIteratorStmtVisitor>
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index 7c954987e..21611bcb1 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -1496,6 +1496,28 @@ namespace Slang
List<ExtensionDecl*> candidateExtensions;
};
+
+ enum class DeclAssociationKind
+ {
+ ForwardDerivativeFunc, BackwardDerivativeFunc,
+ };
+
+ struct DeclAssociation
+ {
+ SLANG_VALUE_CLASS(DeclAssociation)
+ DeclAssociationKind kind;
+ Decl* decl;
+ };
+
+ /// A reference-counted object to hold a list of associated decls for a decl.
+ ///
+ struct DeclAssociationList : SerialRefObject
+ {
+ SLANG_OBJ_CLASS(DeclAssociationList)
+
+ List<DeclAssociation> associations;
+ };
+
/// Represents the "direction" that a parameter is being passed (e.g., `in` or `out`
enum ParameterDirection
{
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index ffbc5a841..009d0a987 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -4635,6 +4635,7 @@ namespace Slang
checkDerivativeAttribute(as<FunctionDeclBase>(calleeDeclRef->declRef.getDecl()), fwdDerivativeAttr);
attr->backDeclRef = fwdDerivativeAttr->funcExpr;
fwdDerivativeAttr->funcExpr = nullptr;
+ getShared()->registerAssociatedDecl(calleeDeclRef->declRef.getDecl(), DeclAssociationKind::ForwardDerivativeFunc, funcDecl);
return;
}
}
@@ -4684,6 +4685,22 @@ namespace Slang
if (auto derivativeAttr = decl->findModifier<ForwardDerivativeAttribute>())
checkDerivativeAttribute(decl, derivativeAttr);
+ if (newContext.getParentDifferentiableAttribute())
+ {
+ // Register additional types outside the function body first.
+ auto oldAttr = m_parentDifferentiableAttr;
+ m_parentDifferentiableAttr = newContext.getParentDifferentiableAttribute();
+ for (auto param : decl->getParameters())
+ maybeRegisterDifferentiableType(m_astBuilder, param->type.type);
+ maybeRegisterDifferentiableType(m_astBuilder, decl->returnType.type);
+ if (as<ConstructorDecl>(decl) || !isEffectivelyStatic(decl))
+ {
+ auto thisType = calcThisType(makeDeclRef(decl));
+ maybeRegisterDifferentiableType(m_astBuilder, thisType);
+ }
+ m_parentDifferentiableAttr = oldAttr;
+ }
+
if (auto body = decl->body)
{
checkStmt(decl->body, newContext);
@@ -6379,6 +6396,126 @@ namespace Slang
}
}
+ /// Get a reference to the associated decl list for `decl` in the given dictionary
+ ///
+ /// Note: this function creates an empty list of candidates for the given type if
+ /// a matching entry doesn't exist already.
+ ///
+ static List<DeclAssociation>& _getDeclAssociationList(
+ Decl* decl,
+ OrderedDictionary<Decl*, RefPtr<DeclAssociationList>>& mapDeclToDeclarations)
+ {
+ RefPtr<DeclAssociationList> entry;
+ if (!mapDeclToDeclarations.TryGetValue(decl, entry))
+ {
+ entry = new DeclAssociationList();
+ mapDeclToDeclarations.Add(decl, entry);
+ }
+ return entry->associations;
+ }
+
+ void SharedSemanticsContext::_addDeclAssociationsFromModule(ModuleDecl* moduleDecl)
+ {
+ for (auto& entry : moduleDecl->mapDeclToAssociatedDecls)
+ {
+ auto& list = _getDeclAssociationList(entry.Key, m_mapDeclToAssociatedDecls);
+ list.addRange(entry.Value->associations);
+ }
+ }
+
+ void SharedSemanticsContext::registerAssociatedDecl(Decl* original, DeclAssociationKind kind, Decl* associated)
+ {
+ auto moduleDecl = getModuleDecl(associated);
+ DeclAssociation assoc = {kind, associated};
+ _getDeclAssociationList(original, moduleDecl->mapDeclToAssociatedDecls).add(assoc);
+
+ m_associatedDeclListsBuilt = false;
+ m_mapDeclToAssociatedDecls.Clear();
+ }
+
+ List<DeclAssociation> const& SharedSemanticsContext::getAssociatedDeclsForDecl(Decl* decl)
+ {
+ // This duplicates the exact same logic from `getCandidateExtensionsForTypeDecl`.
+ // Consider refactoring them into the same framework.
+ if (!m_associatedDeclListsBuilt)
+ {
+ m_associatedDeclListsBuilt = true;
+
+ for (auto module : getSession()->stdlibModules)
+ {
+ _addDeclAssociationsFromModule(module->getModuleDecl());
+ }
+
+ if (m_module)
+ {
+ _addDeclAssociationsFromModule(m_module->getModuleDecl());
+ for (auto moduleDecl : this->importedModulesList)
+ {
+ _addDeclAssociationsFromModule(moduleDecl);
+ }
+ }
+ else
+ {
+ for (auto module : m_linkage->loadedModulesList)
+ {
+ _addDeclAssociationsFromModule(module->getModuleDecl());
+ }
+ }
+ }
+ return _getDeclAssociationList(decl, m_mapDeclToAssociatedDecls);
+ }
+
+ bool SharedSemanticsContext::isDifferentiableFunc(FunctionDeclBase* func)
+ {
+ // A function is differentiable if it is marked as differentiable, or it
+ // has an associated derivative function.
+ if (func->findModifier<DifferentiableAttribute>())
+ return true;
+ for (auto assocDecl : getAssociatedDeclsForDecl(func))
+ {
+ switch (assocDecl.kind)
+ {
+ case DeclAssociationKind::ForwardDerivativeFunc:
+ case DeclAssociationKind::BackwardDerivativeFunc:
+ return true;
+ default:
+ break;
+ }
+ }
+ return false;
+ }
+
+ bool SharedSemanticsContext::isBackwardDifferentiableFunc(FunctionDeclBase* func)
+ {
+ // A function is differentiable if it is marked as differentiable, or it
+ // has an associated derivative function.
+ if (func->findModifier<BackwardDifferentiableAttribute>())
+ return true;
+ for (auto assocDecl : getAssociatedDeclsForDecl(func))
+ {
+ switch (assocDecl.kind)
+ {
+ case DeclAssociationKind::BackwardDerivativeFunc:
+ return true;
+ default:
+ break;
+ }
+ }
+ if (auto builtinReq = func->findModifier<BuiltinRequirementModifier>())
+ {
+ switch (builtinReq->kind)
+ {
+ case BuiltinRequirementKind::DAddFunc:
+ case BuiltinRequirementKind::DMulFunc:
+ case BuiltinRequirementKind::DZeroFunc:
+ return true;
+ default:
+ break;
+ }
+ }
+ return false;
+ }
+
List<ExtensionDecl*> const& getCandidateExtensions(
DeclRef<AggTypeDecl> const& declRef,
SemanticsVisitor* semantics)
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index b43a03150..2c6899269 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1893,6 +1893,34 @@ namespace Slang
}
}
}
+
+ if (auto higherOrderInvoke = as<DifferentiateExpr>(invoke->functionExpr))
+ {
+ if (auto funcDeclExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(higherOrderInvoke)))
+ {
+ auto funcDecl = as<FunctionDeclBase>(funcDeclExpr->declRef.getDecl());
+ if (funcDecl)
+ {
+ DifferentiateExpr* forwardDiff = nullptr;
+ DifferentiateExpr* backwardDiff = nullptr;
+ for (auto node = as<DifferentiateExpr>(invoke->functionExpr); node; node = as<DifferentiateExpr>(node->baseFunction))
+ {
+ if (auto fwd = as<ForwardDifferentiateExpr>(node))
+ forwardDiff = fwd;
+ if (auto bwd = as<BackwardDifferentiateExpr>(node))
+ backwardDiff = bwd;
+ }
+ if (forwardDiff && !getShared()->isDifferentiableFunc(funcDecl))
+ {
+ getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "forward");
+ }
+ if (backwardDiff && !getShared()->isBackwardDifferentiableFunc(funcDecl))
+ {
+ getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward");
+ }
+ }
+ }
+ }
}
}
return rs;
@@ -1920,7 +1948,7 @@ namespace Slang
auto checkedExpr = CheckInvokeExprWithCheckedOperands(expr);
- if (m_parentFunc && m_parentFunc->hasModifier<DifferentiableAttribute>())
+ if (m_parentDifferentiableAttr)
{
if (auto checkedInvokeExpr = as<InvokeExpr>(checkedExpr))
{
@@ -1929,6 +1957,30 @@ namespace Slang
{
maybeRegisterDifferentiableType(m_astBuilder, arg->type.type);
}
+ if (auto calleeExpr = as<DeclRefExpr>(checkedInvokeExpr->functionExpr))
+ {
+ if (auto calleeDecl = as<FunctionDeclBase>(calleeExpr->declRef.getDecl()))
+ {
+ if (getShared()->isDifferentiableFunc(calleeDecl))
+ {
+ if (!m_treatAsDifferentiableExpr)
+ {
+ auto newFuncExpr =
+ getASTBuilder()->create<TreatAsDifferentiableExpr>();
+ newFuncExpr->type = checkedInvokeExpr->type;
+ newFuncExpr->innerExpr = checkedInvokeExpr;
+ newFuncExpr->loc = checkedInvokeExpr->loc;
+ checkedExpr = newFuncExpr;
+ }
+ else
+ {
+ getSink()->diagnose(
+ m_treatAsDifferentiableExpr,
+ Diagnostics::useOfNoDiffOnDifferentiableFunc);
+ }
+ }
+ }
+ }
}
maybeRegisterDifferentiableType(m_astBuilder, checkedExpr->type.type);
}
@@ -2227,6 +2279,27 @@ namespace Slang
return _checkDifferentiateExpr(this, expr, &actions);
}
+ Expr* SemanticsExprVisitor::visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr)
+ {
+ auto subContext = withTreatAsDifferentiable(expr);
+ expr->innerExpr = dispatchExpr(expr->innerExpr, subContext);
+ expr->type = expr->innerExpr->type;
+ auto innerExpr = expr->innerExpr;
+ while (auto parenExpr = as<ParenExpr>(innerExpr))
+ {
+ innerExpr = parenExpr->base;
+ }
+ if (!as<InvokeExpr>(innerExpr))
+ {
+ getSink()->diagnose(expr, Diagnostics::invalidUseOfNoDiff);
+ }
+ else if (!m_parentDifferentiableAttr)
+ {
+ getSink()->diagnose(expr, Diagnostics::cannotUseNoDiffInNonDifferentiableFunc);
+ }
+ return expr;
+ }
+
Expr* SemanticsExprVisitor::visitGetArrayLengthExpr(GetArrayLengthExpr* expr)
{
expr->arrayExpr = CheckTerm(expr->arrayExpr);
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index c4c32a681..1c2f698bd 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -284,6 +284,13 @@ namespace Slang
/// Register a candidate extension `extDecl` for `typeDecl` encountered during checking.
void registerCandidateExtension(AggTypeDecl* typeDecl, ExtensionDecl* extDecl);
+ void registerAssociatedDecl(Decl* original, DeclAssociationKind assoc, Decl* declaration);
+
+ List<DeclAssociation> const& getAssociatedDeclsForDecl(Decl* decl);
+
+ bool isDifferentiableFunc(FunctionDeclBase* func);
+ bool isBackwardDifferentiableFunc(FunctionDeclBase* func);
+
private:
/// Mapping from type declarations to the known extensiosn that apply to them
Dictionary<AggTypeDecl*, RefPtr<CandidateExtensionList>> m_mapTypeDeclToCandidateExtensions;
@@ -293,6 +300,17 @@ namespace Slang
/// Add candidate extensions declared in `moduleDecl` to `m_mapTypeDeclToCandidateExtensions`
void _addCandidateExtensionsFromModule(ModuleDecl* moduleDecl);
+
+ /// Mapping from a decl to additional declarations of the same decl.
+ /// The additional declarations provide a location to hold extra decorations.
+ OrderedDictionary<Decl*, RefPtr<DeclAssociationList>> m_mapDeclToAssociatedDecls;
+
+ /// Is the `m_mapDeclToAssociatedDecls` dictionary valid and up to date?
+ bool m_associatedDeclListsBuilt = false;
+
+ /// Add associated decls declared in `moduleDecl` to `m_mapDeclToAssociatedDecls`
+ void _addDeclAssociationsFromModule(ModuleDecl* moduleDecl);
+
};
/// Local/scoped state of the semantic-checking system
@@ -411,6 +429,13 @@ namespace Slang
return result;
}
+ SemanticsContext withTreatAsDifferentiable(TreatAsDifferentiableExpr* expr)
+ {
+ SemanticsContext result(*this);
+ result.m_treatAsDifferentiableExpr = expr;
+ return result;
+ }
+
SemanticsContext allowStaticReferenceToNonStaticMember()
{
SemanticsContext result(*this);
@@ -444,6 +469,10 @@ namespace Slang
/// is considered valid in the current context.
bool m_allowStaticReferenceToNonStaticMember = false;
+ /// Whether or not we are in a `no_diff` environment (and therefore should treat the call to
+ /// a non-differentiable function as differentiable and not issue a diagnostic).
+ TreatAsDifferentiableExpr* m_treatAsDifferentiableExpr = nullptr;
+
ASTBuilder* m_astBuilder = nullptr;
};
@@ -724,9 +753,6 @@ namespace Slang
///
void registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness);
- // Check and register a type if it is differentiable.
- void maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type);
-
// Construct the differential for 'type', if it exists.
Type* getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc);
Type* tryGetDifferentialType(ASTBuilder* builder, Type* type);
@@ -1061,6 +1087,9 @@ namespace Slang
/// Gather differentiable members from decl.
List<DifferentiableMemberInfo> collectDifferentiableMemberInfo(ContainerDecl* decl);
+ // Check and register a type if it is differentiable.
+ void maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type);
+
// Find the appropriate member of a declared type to
// satisfy a requirement of an interface the type
// claims to conform to.
@@ -1266,8 +1295,6 @@ namespace Slang
/// Given an immutable `expr` used as an l-value emit a special diagnostic if it was derived from `this`.
void maybeDiagnoseThisNotLValue(Expr* expr);
- void registerExtension(ExtensionDecl* decl);
-
// Figure out what type an initializer/constructor declaration
// is supposed to return. In most cases this is just the type
// declaration that its declaration is nested inside.
@@ -1913,6 +1940,7 @@ namespace Slang
Expr* visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr);
Expr* visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr);
+ Expr* visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr);
Expr* visitGetArrayLengthExpr(GetArrayLengthExpr* expr);
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index c69a1e9e6..d293626ae 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -304,6 +304,8 @@ DIAGNOSTIC(30094, Error, mustUseTryClauseToCallAThrowFunc, "the callee may throw
DIAGNOSTIC(30095, Error, errorTypeOfCalleeIncompatibleWithCaller, "the error type `$1` of callee `$0` is not compatible with the caller's error type `$2`.")
DIAGNOSTIC(30096, Error, differentialTypeShouldServeAsItsOwnDifferentialType, "type '$0' is used as a `Differential` type, therefore it must serve as its own `Differential` type.")
+DIAGNOSTIC(30097, Error, functionNotMarkedAsDifferentiable, "function '$0' is not marked as $1-differentiable.")
+
DIAGNOSTIC(-1, Note, noteSeeUseOfDifferentialType, "see use of '$0' as Differential of '$1'.")
// Attributes
@@ -494,7 +496,9 @@ DIAGNOSTIC(38026, Error, globalTypeArgumentDoesNotConformToInterface, "type argu
DIAGNOSTIC(38027, Error, mismatchExistentialSlotArgCount, "expected $0 existential slot arguments ($1 provided)")
DIAGNOSTIC(38029, Error, typeArgumentDoesNotConformToInterface, "type argument '$0' does not conform to the required interface '$1'")
-DIAGNOSTIC(30830, Error, functionNotMarkedAsDifferentiable, "function '$0' is not marked as $1-differentiable.")
+DIAGNOSTIC(38031, Error, invalidUseOfNoDiff, "'no_diff' can only be used to decorate a call.")
+DIAGNOSTIC(38032, Error, useOfNoDiffOnDifferentiableFunc, "use 'no_diff' on a call to a differentiable function has no meaning.")
+DIAGNOSTIC(38033, Error, cannotUseNoDiffInNonDifferentiableFunc, "cannot use 'no_diff' in a non-differentiable function.")
DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself")
DIAGNOSTIC(39999, Error, errorInImportedModule, "import of module '$0' failed because of a compilation error")
@@ -568,6 +572,9 @@ DIAGNOSTIC(41010, Warning, missingReturn, "control flow may reach end of non-'vo
DIAGNOSTIC(41011, Error, typeDoesNotFitAnyValueSize, "type '$0' does not fit in the size required by its conforming interface.")
DIAGNOSTIC(41012, Note, typeAndLimit, "sizeof($0) is $1, limit is $2")
DIAGNOSTIC(41012, Error, typeCannotBePackedIntoAnyValue, "type '$0' contains fields that cannot be packed into an AnyValue.")
+DIAGNOSTIC(41020, Error, lossOfDerivativeDueToCallOfNonDifferentiableFunction, "derivative cannot be propagated through call to non-differentiable function `$0`, use 'no_diff' to clarify intention.")
+DIAGNOSTIC(41021, Error, differentiableFuncMustHaveOutput, "a differentiable function must have at least one differentiable output.")
+DIAGNOSTIC(41022, Error, differentiableFuncMustHaveInput, "a differentiable function must have at least one differentiable input.")
//
// 5xxxx - Target code generation.
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
new file mode 100644
index 000000000..44c6324e3
--- /dev/null
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -0,0 +1,421 @@
+#include "slang-ir-check-differentiability.h"
+
+#include "slang-ir-diff-jvp.h"
+#include "slang-ir-inst-pass-base.h"
+
+namespace Slang
+{
+
+struct CheckDifferentiabilityPassContext : public InstPassBase
+{
+public:
+ DiagnosticSink* sink;
+ AutoDiffSharedContext sharedContext;
+
+ HashSet<IRInst*> differentiableFunctions;
+
+ CheckDifferentiabilityPassContext(IRModule* inModule, DiagnosticSink* inSink)
+ : InstPassBase(inModule), sink(inSink), sharedContext(inModule->getModuleInst())
+ {}
+
+ IRInst* getSpecializedVal(IRInst* inst)
+ {
+ int loopLimit = 1024;
+ while (inst && inst->getOp() == kIROp_Specialize)
+ {
+ inst = as<IRSpecialize>(inst)->getBase();
+ loopLimit--;
+ if (loopLimit == 0)
+ return inst;
+ }
+ return inst;
+ }
+
+ IRInst* getLeafFunc(IRInst* func)
+ {
+ func = getSpecializedVal(func);
+ if (!func)
+ return nullptr;
+ if (auto genericFunc = as<IRGeneric>(func))
+ return findInnerMostGenericReturnVal(genericFunc);
+ return func;
+ }
+
+ bool _isFuncMarkedForAutoDiff(IRInst* func)
+ {
+ func = getLeafFunc(func);
+ if (!func)
+ return false;
+ for (auto decorations : func->getDecorations())
+ {
+ switch (decorations->getOp())
+ {
+ case kIROp_ForwardDifferentiableDecoration:
+ case kIROp_BackwardDifferentiableDecoration:
+ return true;
+ }
+ }
+ return false;
+ }
+
+
+ bool _isDifferentiableFuncImpl(IRInst* func)
+ {
+ func = getLeafFunc(func);
+ if (!func)
+ return false;
+
+ for (auto decorations : func->getDecorations())
+ {
+ switch (decorations->getOp())
+ {
+ case kIROp_ForwardDerivativeDecoration:
+ case kIROp_ForwardDifferentiableDecoration:
+ case kIROp_BackwardDerivativeDecoration:
+ case kIROp_BackwardDifferentiableDecoration:
+ return true;
+ }
+ }
+ return false;
+ }
+
+ bool isDifferentiableFunc(IRInst* func)
+ {
+ switch (func->getOp())
+ {
+ case kIROp_ForwardDifferentiate:
+ case kIROp_BackwardDifferentiate:
+ return true;
+ default:
+ break;
+ }
+
+ func = getSpecializedVal(func);
+ if (!func)
+ return false;
+
+ if (differentiableFunctions.Contains(func))
+ return true;
+
+ for (; func; func = func->parent)
+ {
+ if (as<IRGeneric>(func))
+ {
+ return differentiableFunctions.Contains(func);
+ }
+ }
+ return false;
+ }
+
+ bool isBackwardDifferentiableFunc(IRInst* func)
+ {
+ for (auto decorations : func->getDecorations())
+ {
+ switch (decorations->getOp())
+ {
+ case kIROp_BackwardDerivativeDecoration:
+ case kIROp_BackwardDifferentiableDecoration:
+ return true;
+ }
+ }
+ return false;
+ }
+
+ bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst)
+ {
+ HashSet<IRInst*> processedSet;
+ while (auto ptrType = as<IRPtrTypeBase>(typeInst))
+ {
+ typeInst = ptrType->getValueType();
+ if (!processedSet.Add(typeInst))
+ return false;
+ }
+ if (!typeInst)
+ return false;
+ switch (typeInst->getOp())
+ {
+ case kIROp_FloatType:
+ case kIROp_DifferentialPairType:
+ return true;
+ default:
+ break;
+ }
+ if (context.lookUpConformanceForType(typeInst))
+ return true;
+ // Look for equivalent types.
+ for (auto type : context.differentiableWitnessDictionary)
+ {
+ if (isTypeEqual(type.Key, (IRType*)typeInst))
+ {
+ context.differentiableWitnessDictionary[(IRType*)typeInst] = type.Value;
+ return true;
+ }
+ }
+ return false;
+ }
+
+ int getParamIndexInBlock(IRParam* paramInst)
+ {
+ auto block = as<IRBlock>(paramInst->getParent());
+ if (!block)
+ return -1;
+ int paramIndex = 0;
+ for (auto param : block->getParams())
+ {
+ if (param == paramInst)
+ return paramIndex;
+ paramIndex++;
+ }
+ return -1;
+ }
+
+ bool isInstInFunc(IRInst* inst, IRInst* func)
+ {
+ while (inst)
+ {
+ if (inst == func)
+ return true;
+ inst = inst->parent;
+ }
+ return false;
+ }
+ void processFunc(IRGlobalValueWithCode* funcInst)
+ {
+ if (!_isFuncMarkedForAutoDiff(funcInst))
+ return;
+ if (!funcInst->getFirstBlock())
+ return;
+
+ DifferentiableTypeConformanceContext diffTypeContext(&sharedContext);
+ diffTypeContext.setFunc(funcInst);
+
+ HashSet<IRInst*> produceDiffSet;
+ HashSet<IRInst*> expectDiffSet;
+ int differentiableInputs = 0;
+ int differentiableOutputs = 0;
+ for (auto param : funcInst->getFirstBlock()->getParams())
+ {
+ if (isDifferentiableType(diffTypeContext, param->getFullType()))
+ {
+ if (as<IROutTypeBase>(param->getFullType()))
+ differentiableOutputs++;
+ if (!as<IROutType>(param->getFullType()))
+ differentiableInputs++;
+ produceDiffSet.Add(param);
+ }
+ }
+ if (auto funcType = as<IRFuncType>(funcInst->getDataType()))
+ {
+ if (isDifferentiableType(diffTypeContext, funcType->getResultType()))
+ differentiableOutputs++;
+ }
+
+ if (differentiableOutputs == 0)
+ sink->diagnose(funcInst, Diagnostics::differentiableFuncMustHaveOutput);
+ if (differentiableInputs == 0)
+ sink->diagnose(funcInst, Diagnostics::differentiableFuncMustHaveInput);
+
+ auto isInstProducingDiff = [&](IRInst* inst) -> bool
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_FloatLit:
+ return true;
+ case kIROp_Call:
+ return inst->findDecoration<IRTreatAsDifferentiableCallDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee());
+ case kIROp_Load:
+ // We don't have more knowledge on whether diff is available at the destination address.
+ // Just assume it is producing diff.
+ //TODO: propagate the info if this is a load of a temporary variable intended to receive result from an `out` parameter.
+ return isDifferentiableType(diffTypeContext, inst->getDataType());
+ default:
+ // default case is to assume the inst produces a diff value if any
+ // of its operands produces a diff value.
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ if (produceDiffSet.Contains(inst->getOperand(i)))
+ {
+ return true;
+ }
+ }
+ return false;
+ }
+ };
+
+ List<IRInst*> expectDiffInstWorkList;
+ OrderedHashSet<IRInst*> expectDiffInstWorkListSet;
+ auto addToExpectDiffWorkList = [&](IRInst* inst)
+ {
+ if (isInstInFunc(inst, funcInst))
+ {
+ if (expectDiffInstWorkListSet.Add(inst))
+ expectDiffInstWorkList.add(inst);
+ }
+ };
+ // Run data flow analysis and generate `produceDiffSet` and an intial `expectDiffSet`.
+ Index lastProduceDiffCount = 0;
+ do
+ {
+ lastProduceDiffCount = produceDiffSet.Count();
+ for (auto block : funcInst->getBlocks())
+ {
+ if (block != funcInst->getFirstBlock())
+ {
+ UInt paramIndex = 0;
+ for (auto param : block->getParams())
+ {
+ for (auto p : block->getPredecessors())
+ {
+ // A Phi Node is producing diff if any of its candidate values are producing diff.
+ if (auto branch = as<IRUnconditionalBranch>(p->getTerminator()))
+ {
+ if (branch->getArgCount() > paramIndex)
+ {
+ auto arg = branch->getArg(paramIndex);
+ if (produceDiffSet.Contains(arg))
+ {
+ produceDiffSet.Add(param);
+ break;
+ }
+ }
+ }
+ }
+ paramIndex++;
+ }
+ }
+ for (auto inst : block->getChildren())
+ {
+ if (isInstProducingDiff(inst))
+ produceDiffSet.Add(inst);
+ switch (inst->getOp())
+ {
+ case kIROp_Call:
+ if (isDifferentiableFunc(as<IRCall>(inst)->getCallee()))
+ {
+ addToExpectDiffWorkList(inst);
+ }
+ break;
+ case kIROp_Store:
+ {
+ auto storeInst = as<IRStore>(inst);
+ if (isDifferentiableType(diffTypeContext, as<IRStore>(inst)->getPtr()->getDataType()))
+ {
+ addToExpectDiffWorkList(storeInst->getVal());
+ }
+ }
+ break;
+ case kIROp_Return:
+ if (auto returnVal = as<IRReturn>(inst)->getVal())
+ {
+ if (isDifferentiableType(diffTypeContext, returnVal->getDataType()))
+ {
+ addToExpectDiffWorkList(inst);
+ }
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ } while (produceDiffSet.Count() != lastProduceDiffCount);
+
+ // Reverse propagate `expectDiffSet`.
+ for (int i = 0; i < expectDiffInstWorkList.getCount(); i++)
+ {
+ auto inst = expectDiffInstWorkList[i];
+ // Is inst in produceDiffSet?
+ if (!produceDiffSet.Contains(inst))
+ {
+ if (auto call = as<IRCall>(inst))
+ {
+ sink->diagnose(inst, Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, getLeafFunc(call->getCallee()));
+ }
+ }
+ switch (inst->getOp())
+ {
+ case kIROp_Param:
+ {
+ auto block = as<IRBlock>(inst->getParent());
+ if (block != funcInst->getFirstBlock())
+ {
+ auto paramIndex = getParamIndexInBlock(as<IRParam>(inst));
+ if (paramIndex != -1)
+ {
+ for (auto p : block->getPredecessors())
+ {
+ // A Phi Node is producing diff if any of its candidate values are producing diff.
+ if (auto branch = as<IRUnconditionalBranch>(p->getTerminator()))
+ {
+ if (branch->getArgCount() > (UInt)paramIndex)
+ {
+ auto arg = branch->getArg(paramIndex);
+ addToExpectDiffWorkList(arg);
+ }
+ }
+ }
+ }
+ }
+ break;
+ }
+ default:
+ // Default behavior is to request all differentiable operands to provide differential.
+ for (UInt opIndex = 0; opIndex < inst->getOperandCount(); opIndex++)
+ {
+ auto operand = inst->getOperand(opIndex);
+ if (isDifferentiableType(diffTypeContext, operand->getFullType()))
+ {
+ addToExpectDiffWorkList(operand);
+ }
+ }
+ }
+ }
+ }
+
+ void processModule()
+ {
+ // Collect set of differentiable functions.
+ HashSet<UnownedStringSlice> differentiableSymbolNames;
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (_isDifferentiableFuncImpl(inst))
+ {
+ if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>())
+ differentiableSymbolNames.Add(linkageDecor->getMangledName());
+ differentiableFunctions.Add(inst);
+ }
+ }
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>())
+ {
+ if (differentiableSymbolNames.Contains(linkageDecor->getMangledName()))
+ differentiableFunctions.Add(inst);
+ }
+ }
+
+ if (!sharedContext.isInterfaceAvailable)
+ return;
+
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (auto genericInst = as<IRGeneric>(inst))
+ {
+ if (auto innerFunc = as<IRGlobalValueWithCode>(findGenericReturnVal(genericInst)))
+ processFunc(innerFunc);
+ }
+ else if (auto funcInst = as<IRGlobalValueWithCode>(inst))
+ {
+ processFunc(funcInst);
+ }
+ }
+ }
+};
+
+void checkAutoDiffUsages(IRModule* module, DiagnosticSink* sink)
+{
+ CheckDifferentiabilityPassContext context(module, sink);
+ context.processModule();
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-check-differentiability.h b/source/slang/slang-ir-check-differentiability.h
new file mode 100644
index 000000000..735a918c9
--- /dev/null
+++ b/source/slang/slang-ir-check-differentiability.h
@@ -0,0 +1,14 @@
+// slang-ir-check-differentiability.h
+#pragma once
+
+#include "slang-ir.h"
+
+namespace Slang
+{
+struct IRModule;
+class DiagnosticSink;
+
+// Check all auto diff usages are valid.
+void checkAutoDiffUsages(IRModule* module, DiagnosticSink* sink);
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 152601dbd..4ee16aafc 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -1,8 +1,6 @@
// slang-ir-diff-jvp.cpp
#include "slang-ir-diff-jvp.h"
-#include "slang-ir.h"
-#include "slang-ir-insts.h"
#include "slang-ir-clone.h"
#include "slang-ir-dce.h"
#include "slang-ir-eliminate-phis.h"
@@ -16,154 +14,6 @@
namespace Slang
{
-template<typename P, typename D>
-struct Pair
-{
- P primal;
- D differential;
- Pair() = default;
- Pair(P primal, D differential) : primal(primal), differential(differential)
- {}
- HashCode getHashCode() const
- {
- Hasher hasher;
- hasher << primal << differential;
- return hasher.getResult();
- }
- bool operator ==(const Pair& other) const
- {
- return primal == other.primal && differential == other.differential;
- }
-};
-
-typedef Pair<IRInst*, IRInst*> InstPair;
-
-struct AutoDiffSharedContext
-{
- IRModuleInst* moduleInst = nullptr;
-
- SharedIRBuilder* sharedBuilder = nullptr;
-
- // A reference to the builtin IDifferentiable interface type.
- // We use this to look up all the other types (and type exprs)
- // that conform to a base type.
- //
- IRInterfaceType* differentiableInterfaceType = nullptr;
-
- // The struct key for the 'Differential' associated type
- // defined inside IDifferential. We use this to lookup the differential
- // type in the conformance table associated with the concrete type.
- //
- IRStructKey* differentialAssocTypeStructKey = nullptr;
-
- // The struct key for the witness that `Differential` associated type conforms to
- // `IDifferential`.
- IRStructKey* differentialAssocTypeWitnessStructKey = nullptr;
-
-
- // The struct key for the 'zero()' associated type
- // defined inside IDifferential. We use this to lookup the
- // implementation of zero() for a given type.
- //
- IRStructKey* zeroMethodStructKey = nullptr;
-
- // The struct key for the 'add()' associated type
- // defined inside IDifferential. We use this to lookup the
- // implementation of add() for a given type.
- //
- IRStructKey* addMethodStructKey = nullptr;
-
- IRStructKey* mulMethodStructKey = nullptr;
-
-
- // Modules that don't use differentiable types
- // won't have the IDifferentiable interface type available.
- // Set to false to indicate that we are uninitialized.
- //
- bool isInterfaceAvailable = false;
-
-
- AutoDiffSharedContext(IRModuleInst* inModuleInst)
- : moduleInst(inModuleInst)
- {
- differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface());
- if (differentiableInterfaceType)
- {
- differentialAssocTypeStructKey = findDifferentialTypeStructKey();
- differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey();
- zeroMethodStructKey = findZeroMethodStructKey();
- addMethodStructKey = findAddMethodStructKey();
- mulMethodStructKey = findMulMethodStructKey();
-
- if (differentialAssocTypeStructKey)
- isInterfaceAvailable = true;
- }
- }
-
- private:
-
- IRInst* findDifferentiableInterface()
- {
- if (auto module = as<IRModuleInst>(moduleInst))
- {
- for (auto globalInst : module->getGlobalInsts())
- {
- // TODO: This seems like a particularly dangerous way to look for an interface.
- // See if we can lower IDifferentiable to a separate IR inst.
- //
- if (globalInst->getOp() == kIROp_InterfaceType &&
- as<IRInterfaceType>(globalInst)->findDecoration<IRNameHintDecoration>()->getName() == "IDifferentiable")
- {
- return globalInst;
- }
- }
- }
- return nullptr;
- }
-
- IRStructKey* findDifferentialTypeStructKey()
- {
- return getIDifferentiableStructKeyAtIndex(0);
- }
-
- IRStructKey* findDifferentialTypeWitnessStructKey()
- {
- return getIDifferentiableStructKeyAtIndex(1);
- }
-
- IRStructKey* findZeroMethodStructKey()
- {
- return getIDifferentiableStructKeyAtIndex(2);
- }
-
- IRStructKey* findAddMethodStructKey()
- {
- return getIDifferentiableStructKeyAtIndex(3);
- }
-
- IRStructKey* findMulMethodStructKey()
- {
- return getIDifferentiableStructKeyAtIndex(4);
- }
-
- IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index)
- {
- if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType)
- {
- // Assume for now that IDifferentiable has exactly five fields.
- SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5);
- if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index)))
- return as<IRStructKey>(entry->getRequirementKey());
- else
- {
- SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type");
- }
- }
-
- return nullptr;
- }
-};
-
namespace
{
@@ -189,97 +39,6 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK
}
-struct DifferentiableTypeConformanceContext
-{
- AutoDiffSharedContext* sharedContext;
-
- IRGlobalValueWithCode* parentFunc = nullptr;
- Dictionary<IRType*, IRInst*> differentiableWitnessDictionary;
-
- DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared)
- : sharedContext(shared)
- {}
-
- void setFunc(IRGlobalValueWithCode* func)
- {
- parentFunc = func;
-
- auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>();
- SLANG_RELEASE_ASSERT(decor);
-
- // Build lookup dictionary for type witnesses.
- for (auto child = decor->getFirstChild(); child; child = child->next)
- {
- if (auto item = as<IRDifferentiableTypeDictionaryItem>(child))
- {
- auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType());
- if (existingItem)
- {
- if (auto witness = as<IRWitnessTable>(item->getWitness()))
- {
- if (witness->getConcreteType()->getOp() == kIROp_DifferentialBottomType)
- continue;
- }
- *existingItem = item->getWitness();
- }
- else
- {
- differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness());
- }
- }
- }
- }
-
-
- // Lookup a witness table for the concreteType. One should exist if concreteType
- // inherits (successfully) from IDifferentiable.
- //
- IRInst* lookUpConformanceForType(IRInst* type)
- {
- IRInst* foundResult = nullptr;
- differentiableWitnessDictionary.TryGetValue(type, foundResult);
- return foundResult;
- }
-
- IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key)
- {
- if (auto conformance = lookUpConformanceForType(origType))
- {
- return _lookupWitness(builder, conformance, key);
- }
- return nullptr;
- }
-
- // Lookup and return the 'Differential' type declared in the concrete type
- // in order to conform to the IDifferentiable interface.
- // Note that inside a generic block, this will be a witness table lookup instruction
- // that gets resolved during the specialization pass.
- //
- IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType)
- {
- switch (origType->getOp())
- {
- case kIROp_FloatType:
- case kIROp_HalfType:
- case kIROp_DoubleType:
- case kIROp_VectorType:
- return origType;
- }
- return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey);
- }
-
- IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType)
- {
- return lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey);
- }
-
- IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType)
- {
- return lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey);
- }
-
-};
-
struct DifferentialPairTypeBuilder
{
@@ -3275,4 +3034,108 @@ void stripAutoDiffDecorations(IRModule* module)
stripAutoDiffDecorationsFromChildren(module->getModuleInst());
}
+AutoDiffSharedContext::AutoDiffSharedContext(IRModuleInst* inModuleInst)
+ : moduleInst(inModuleInst)
+{
+ differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface());
+ if (differentiableInterfaceType)
+ {
+ differentialAssocTypeStructKey = findDifferentialTypeStructKey();
+ differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey();
+ zeroMethodStructKey = findZeroMethodStructKey();
+ addMethodStructKey = findAddMethodStructKey();
+ mulMethodStructKey = findMulMethodStructKey();
+
+ if (differentialAssocTypeStructKey)
+ isInterfaceAvailable = true;
+ }
+}
+
+IRInst* AutoDiffSharedContext::findDifferentiableInterface()
+{
+ if (auto module = as<IRModuleInst>(moduleInst))
+ {
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ // TODO: This seems like a particularly dangerous way to look for an interface.
+ // See if we can lower IDifferentiable to a separate IR inst.
+ //
+ if (globalInst->getOp() == kIROp_InterfaceType &&
+ as<IRInterfaceType>(globalInst)->findDecoration<IRNameHintDecoration>()->getName() == "IDifferentiable")
+ {
+ return globalInst;
+ }
+ }
+ }
+ return nullptr;
+}
+
+IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt index)
+{
+ if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType)
+ {
+ // Assume for now that IDifferentiable has exactly five fields.
+ SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5);
+ if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index)))
+ return as<IRStructKey>(entry->getRequirementKey());
+ else
+ {
+ SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type");
+ }
+ }
+
+ return nullptr;
+}
+
+void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
+{
+ parentFunc = func;
+
+ auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>();
+ SLANG_RELEASE_ASSERT(decor);
+
+ // Build lookup dictionary for type witnesses.
+ for (auto child = decor->getFirstChild(); child; child = child->next)
+ {
+ if (auto item = as<IRDifferentiableTypeDictionaryItem>(child))
+ {
+ auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType());
+ if (existingItem)
+ {
+ if (auto witness = as<IRWitnessTable>(item->getWitness()))
+ {
+ if (witness->getConcreteType()->getOp() == kIROp_DifferentialBottomType)
+ continue;
+ }
+ *existingItem = item->getWitness();
+ }
+ else
+ {
+ differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness());
+ }
+ }
+ }
+}
+
+
+// Lookup a witness table for the concreteType. One should exist if concreteType
+// inherits (successfully) from IDifferentiable.
+//
+
+IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type)
+{
+ IRInst* foundResult = nullptr;
+ differentiableWitnessDictionary.TryGetValue(type, foundResult);
+ return foundResult;
+}
+
+IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key)
+{
+ if (auto conformance = lookUpConformanceForType(origType))
+ {
+ return _lookupWitness(builder, conformance, key);
+ }
+ return nullptr;
+}
+
}
diff --git a/source/slang/slang-ir-diff-jvp.h b/source/slang/slang-ir-diff-jvp.h
index a866a3db3..9e0f9cfcc 100644
--- a/source/slang/slang-ir-diff-jvp.h
+++ b/source/slang/slang-ir-diff-jvp.h
@@ -2,11 +2,162 @@
#pragma once
#include "slang-ir.h"
+#include "slang-ir-insts.h"
#include "slang-compiler.h"
namespace Slang
{
- struct IRModule;
+ template<typename P, typename D>
+ struct DiffInstPair
+ {
+ P primal;
+ D differential;
+ DiffInstPair() = default;
+ DiffInstPair(P primal, D differential) : primal(primal), differential(differential)
+ {}
+ HashCode getHashCode() const
+ {
+ Hasher hasher;
+ hasher << primal << differential;
+ return hasher.getResult();
+ }
+ bool operator ==(const DiffInstPair& other) const
+ {
+ return primal == other.primal && differential == other.differential;
+ }
+ };
+
+ typedef DiffInstPair<IRInst*, IRInst*> InstPair;
+
+ struct AutoDiffSharedContext
+ {
+ IRModuleInst* moduleInst = nullptr;
+
+ SharedIRBuilder* sharedBuilder = nullptr;
+
+ // A reference to the builtin IDifferentiable interface type.
+ // We use this to look up all the other types (and type exprs)
+ // that conform to a base type.
+ //
+ IRInterfaceType* differentiableInterfaceType = nullptr;
+
+ // The struct key for the 'Differential' associated type
+ // defined inside IDifferential. We use this to lookup the differential
+ // type in the conformance table associated with the concrete type.
+ //
+ IRStructKey* differentialAssocTypeStructKey = nullptr;
+
+ // The struct key for the witness that `Differential` associated type conforms to
+ // `IDifferential`.
+ IRStructKey* differentialAssocTypeWitnessStructKey = nullptr;
+
+
+ // The struct key for the 'zero()' associated type
+ // defined inside IDifferential. We use this to lookup the
+ // implementation of zero() for a given type.
+ //
+ IRStructKey* zeroMethodStructKey = nullptr;
+
+ // The struct key for the 'add()' associated type
+ // defined inside IDifferential. We use this to lookup the
+ // implementation of add() for a given type.
+ //
+ IRStructKey* addMethodStructKey = nullptr;
+
+ IRStructKey* mulMethodStructKey = nullptr;
+
+
+ // Modules that don't use differentiable types
+ // won't have the IDifferentiable interface type available.
+ // Set to false to indicate that we are uninitialized.
+ //
+ bool isInterfaceAvailable = false;
+
+
+ AutoDiffSharedContext(IRModuleInst* inModuleInst);
+
+ private:
+
+ IRInst* findDifferentiableInterface();
+
+ IRStructKey* findDifferentialTypeStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(0);
+ }
+
+ IRStructKey* findDifferentialTypeWitnessStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(1);
+ }
+
+ IRStructKey* findZeroMethodStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(2);
+ }
+
+ IRStructKey* findAddMethodStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(3);
+ }
+
+ IRStructKey* findMulMethodStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(4);
+ }
+
+ IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index);
+ };
+
+ struct DifferentiableTypeConformanceContext
+ {
+ AutoDiffSharedContext* sharedContext;
+
+ IRGlobalValueWithCode* parentFunc = nullptr;
+ OrderedDictionary<IRType*, IRInst*> differentiableWitnessDictionary;
+
+ DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared)
+ : sharedContext(shared)
+ {}
+
+ void setFunc(IRGlobalValueWithCode* func);
+
+
+ // Lookup a witness table for the concreteType. One should exist if concreteType
+ // inherits (successfully) from IDifferentiable.
+ //
+ IRInst* lookUpConformanceForType(IRInst* type);
+
+ IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key);
+
+ // Lookup and return the 'Differential' type declared in the concrete type
+ // in order to conform to the IDifferentiable interface.
+ // Note that inside a generic block, this will be a witness table lookup instruction
+ // that gets resolved during the specialization pass.
+ //
+ IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType)
+ {
+ switch (origType->getOp())
+ {
+ case kIROp_FloatType:
+ case kIROp_HalfType:
+ case kIROp_DoubleType:
+ case kIROp_VectorType:
+ return origType;
+ }
+ return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey);
+ }
+
+ IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType)
+ {
+ return lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey);
+ }
+
+ IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType)
+ {
+ return lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey);
+ }
+
+ };
struct IRJVPDerivativePassOptions
{
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index bf73a31d8..e11f98dcd 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -736,6 +736,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// differential member of a type in its associated differential type.
INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0)
+ /// Treat the IRCall as a call to a differentiable function.
+ INST(TreatAsDifferentiableCallDecoration, treatAsDifferentiableCallDecoration, 0, 0)
+
/// Marks a class type as a COM interface implementation, which enables
/// the witness table to be easily picked up by emit.
INST(COMWitnessDecoration, COMWitnessDecoration, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index d85e56d7e..fcdeed17a 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -572,6 +572,18 @@ struct IRForwardDerivativeDecoration : IRDecoration
IRInst* getForwardDerivativeFunc() { return getOperand(0); }
};
+
+struct IRBackwardDerivativeDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_BackwardDerivativeDecoration
+ };
+ IR_LEAF_ISA(BackwardDerivativeDecoration)
+
+ IRInst* getBackwardDerivativeFunc() { return getOperand(0); }
+};
+
struct IRBackwardDifferentiableDecoration : IRDecoration
{
enum
@@ -581,6 +593,14 @@ struct IRBackwardDifferentiableDecoration : IRDecoration
IR_LEAF_ISA(BackwardDifferentiableDecoration)
};
+struct IRTreatAsDifferentiableCallDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_TreatAsDifferentiableCallDecoration
+ };
+ IR_LEAF_ISA(TreatAsDifferentiableCallDecoration)
+};
struct IRDerivativeMemberDecoration : IRDecoration
{
diff --git a/source/slang/slang-ir-lower-error-handling.cpp b/source/slang/slang-ir-lower-error-handling.cpp
index 387ab45b4..e9747e3b6 100644
--- a/source/slang/slang-ir-lower-error-handling.cpp
+++ b/source/slang/slang-ir-lower-error-handling.cpp
@@ -90,6 +90,8 @@ struct ErrorHandlingLoweringContext
args.add(tryCall->getArg(i));
}
auto call = builder.emitCallInst(resultType, tryCall->getCallee(), args);
+ tryCall->transferDecorationsTo(call);
+
auto isFail = builder.emitIsResultError(call);
auto failBlock = tryCall->getFailureBlock();
auto successBlock = tryCall->getSuccessBlock();
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 261f64130..de86a6a52 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -5848,7 +5848,10 @@ namespace Slang
return static_cast<IRConstant*>(a)->isValueEqual(static_cast<IRConstant*>(b)) &&
isTypeEqual(a->getFullType(), b->getFullType());
}
-
+ if (IRSpecialize::isaImpl(opA) || opA == kIROp_lookup_interface_method)
+ {
+ return _areTypeOperandsEqual(a, b);
+ }
SLANG_ASSERT(!"Unhandled comparison");
// We can't equate any other type..
diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp
index 4b6c7f33d..2f05dc1db 100644
--- a/source/slang/slang-language-server-ast-lookup.cpp
+++ b/source/slang/slang-language-server-ast-lookup.cpp
@@ -426,6 +426,10 @@ public:
}
return dispatchIfNotNull(expr->baseFunction);
}
+ bool visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr)
+ {
+ return dispatchIfNotNull(expr->innerExpr);
+ }
};
struct ASTLookupStmtVisitor : public StmtVisitor<ASTLookupStmtVisitor, bool>
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 4c3f4d646..61b6fcb76 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -11,6 +11,7 @@
#include "slang-ir-diff-jvp.h"
#include "slang-ir-inline.h"
#include "slang-ir-insts.h"
+#include "slang-ir-check-differentiability.h"
#include "slang-ir-missing-return.h"
#include "slang-ir-sccp.h"
#include "slang-ir-ssa.h"
@@ -3113,6 +3114,14 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
baseVal.val));
}
+ LoweredValInfo visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr)
+ {
+ auto baseVal = lowerSubExpr(expr->innerExpr);
+ SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple);
+ getBuilder()->addDecoration(baseVal.val, kIROp_TreatAsDifferentiableCallDecoration);
+ return baseVal;
+ }
+
// Emit IR to denote the forward-mode derivative
// of the inner func-expr. This will be resolved
// to a concrete function during the derivative
@@ -8998,6 +9007,9 @@ RefPtr<IRModule> generateIRForTranslationUnit(
checkForMissingReturns(module, compileRequest->getSink());
+ // Check for invalid differentiable function body.
+ checkAutoDiffUsages(module, compileRequest->getSink());
+
// The "mandatory" optimization passes may make use of the
// `IRHighLevelDeclDecoration` type to relate IR instructions
// back to AST-level code in order to improve the quality
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 7cd0cdce9..b513217c8 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -5086,6 +5086,14 @@ namespace Slang
return tryExpr;
}
+ static NodeBase* parseTreatAsDifferentiableExpr(Parser* parser, void* /*userData*/)
+ {
+ auto noDiffExpr = parser->astBuilder->create<TreatAsDifferentiableExpr>();
+ noDiffExpr->innerExpr = parser->ParseLeafExpression();
+ noDiffExpr->scope = parser->currentScope;
+ return noDiffExpr;
+ }
+
static bool _isFinite(double value)
{
// Lets type pun double to uint64_t, so we can detect special double values
@@ -6670,6 +6678,7 @@ namespace Slang
_makeParseExpr("nullptr", parseNullPtrExpr),
_makeParseExpr("none", parseNoneExpr),
_makeParseExpr("try", parseTryExpr),
+ _makeParseExpr("no_diff", parseTreatAsDifferentiableExpr),
_makeParseExpr("__TaggedUnion", parseTaggedUnionType),
_makeParseExpr("__fwd_diff", parseForwardDifferentiate),
_makeParseExpr("__bwd_diff", parseBackwardDifferentiate)
diff --git a/source/slang/slang-serialize-misc-type-info.h b/source/slang/slang-serialize-misc-type-info.h
index 191514785..d3d83e1d0 100644
--- a/source/slang/slang-serialize-misc-type-info.h
+++ b/source/slang/slang-serialize-misc-type-info.h
@@ -188,7 +188,11 @@ struct SerialTypeInfo<const DiagnosticInfo*>
}
};
-
+// DeclAssociation
+template <>
+struct SerialTypeInfo<DeclAssociation> : SerialIdentityTypeInfo<DeclAssociation> {};
+template <>
+struct SerialTypeInfo<DeclAssociationKind> : public SerialConvertTypeInfo<DeclAssociationKind, uint8_t> {};
} // namespace Slang