diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ast-decl.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-ast-expr.h | 9 | ||||
| -rw-r--r-- | source/slang/slang-ast-iterator.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-ast-support-types.h | 22 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 137 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 75 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 38 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 421 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 345 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.h | 153 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 20 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-error-handling.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-language-server-ast-lookup.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-serialize-misc-type-info.h | 6 |
20 files changed, 1045 insertions, 251 deletions
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index cbd3f0f0c..4da832d11 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -391,6 +391,12 @@ class ModuleDecl : public NamespaceDeclBase // Module* module = nullptr; + /// Map a decl to the list of its associated decls. + /// + /// This mapping is filled in during semantic checking, as the decl declarations get checked or generated. + /// + OrderedDictionary<Decl*, RefPtr<DeclAssociationList>> mapDeclToAssociatedDecls; + SLANG_UNREFLECTED /// Map a type to the list of extensions of that type (if any) declared in this module @@ -398,6 +404,7 @@ class ModuleDecl : public NamespaceDeclBase /// This mapping is filled in during semantic checking, as `ExtensionDecl`s get checked. /// Dictionary<AggTypeDecl*, RefPtr<CandidateExtensionList>> mapTypeToCandidateExtensions; + }; /// A declaration that brings members of another declaration or namespace into scope diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index 3a99ac15f..a0268cce9 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -465,6 +465,15 @@ class BackwardDifferentiateExpr: public DifferentiateExpr SLANG_AST_CLASS(BackwardDifferentiateExpr) }; + /// An express to mark its inner expression as an intended non-differential call. +class TreatAsDifferentiableExpr : public Expr +{ + SLANG_AST_CLASS(TreatAsDifferentiableExpr) + + Expr* innerExpr; + Scope* scope; +}; + /// A type expression of the form `__TaggedUnion(A, ...)`. /// /// An expression of this form will resolve to a `TaggedUnionType` diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index ed396139e..79aade1ee 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -268,6 +268,11 @@ struct ASTIterator iterator->maybeDispatchCallback(expr); dispatchIfNotNull(expr->baseFunction); } + + void visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr) + { + dispatchIfNotNull(expr->innerExpr); + } }; struct ASTIteratorStmtVisitor : public StmtVisitor<ASTIteratorStmtVisitor> diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 7c954987e..21611bcb1 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -1496,6 +1496,28 @@ namespace Slang List<ExtensionDecl*> candidateExtensions; }; + + enum class DeclAssociationKind + { + ForwardDerivativeFunc, BackwardDerivativeFunc, + }; + + struct DeclAssociation + { + SLANG_VALUE_CLASS(DeclAssociation) + DeclAssociationKind kind; + Decl* decl; + }; + + /// A reference-counted object to hold a list of associated decls for a decl. + /// + struct DeclAssociationList : SerialRefObject + { + SLANG_OBJ_CLASS(DeclAssociationList) + + List<DeclAssociation> associations; + }; + /// Represents the "direction" that a parameter is being passed (e.g., `in` or `out` enum ParameterDirection { diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index ffbc5a841..009d0a987 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -4635,6 +4635,7 @@ namespace Slang checkDerivativeAttribute(as<FunctionDeclBase>(calleeDeclRef->declRef.getDecl()), fwdDerivativeAttr); attr->backDeclRef = fwdDerivativeAttr->funcExpr; fwdDerivativeAttr->funcExpr = nullptr; + getShared()->registerAssociatedDecl(calleeDeclRef->declRef.getDecl(), DeclAssociationKind::ForwardDerivativeFunc, funcDecl); return; } } @@ -4684,6 +4685,22 @@ namespace Slang if (auto derivativeAttr = decl->findModifier<ForwardDerivativeAttribute>()) checkDerivativeAttribute(decl, derivativeAttr); + if (newContext.getParentDifferentiableAttribute()) + { + // Register additional types outside the function body first. + auto oldAttr = m_parentDifferentiableAttr; + m_parentDifferentiableAttr = newContext.getParentDifferentiableAttribute(); + for (auto param : decl->getParameters()) + maybeRegisterDifferentiableType(m_astBuilder, param->type.type); + maybeRegisterDifferentiableType(m_astBuilder, decl->returnType.type); + if (as<ConstructorDecl>(decl) || !isEffectivelyStatic(decl)) + { + auto thisType = calcThisType(makeDeclRef(decl)); + maybeRegisterDifferentiableType(m_astBuilder, thisType); + } + m_parentDifferentiableAttr = oldAttr; + } + if (auto body = decl->body) { checkStmt(decl->body, newContext); @@ -6379,6 +6396,126 @@ namespace Slang } } + /// Get a reference to the associated decl list for `decl` in the given dictionary + /// + /// Note: this function creates an empty list of candidates for the given type if + /// a matching entry doesn't exist already. + /// + static List<DeclAssociation>& _getDeclAssociationList( + Decl* decl, + OrderedDictionary<Decl*, RefPtr<DeclAssociationList>>& mapDeclToDeclarations) + { + RefPtr<DeclAssociationList> entry; + if (!mapDeclToDeclarations.TryGetValue(decl, entry)) + { + entry = new DeclAssociationList(); + mapDeclToDeclarations.Add(decl, entry); + } + return entry->associations; + } + + void SharedSemanticsContext::_addDeclAssociationsFromModule(ModuleDecl* moduleDecl) + { + for (auto& entry : moduleDecl->mapDeclToAssociatedDecls) + { + auto& list = _getDeclAssociationList(entry.Key, m_mapDeclToAssociatedDecls); + list.addRange(entry.Value->associations); + } + } + + void SharedSemanticsContext::registerAssociatedDecl(Decl* original, DeclAssociationKind kind, Decl* associated) + { + auto moduleDecl = getModuleDecl(associated); + DeclAssociation assoc = {kind, associated}; + _getDeclAssociationList(original, moduleDecl->mapDeclToAssociatedDecls).add(assoc); + + m_associatedDeclListsBuilt = false; + m_mapDeclToAssociatedDecls.Clear(); + } + + List<DeclAssociation> const& SharedSemanticsContext::getAssociatedDeclsForDecl(Decl* decl) + { + // This duplicates the exact same logic from `getCandidateExtensionsForTypeDecl`. + // Consider refactoring them into the same framework. + if (!m_associatedDeclListsBuilt) + { + m_associatedDeclListsBuilt = true; + + for (auto module : getSession()->stdlibModules) + { + _addDeclAssociationsFromModule(module->getModuleDecl()); + } + + if (m_module) + { + _addDeclAssociationsFromModule(m_module->getModuleDecl()); + for (auto moduleDecl : this->importedModulesList) + { + _addDeclAssociationsFromModule(moduleDecl); + } + } + else + { + for (auto module : m_linkage->loadedModulesList) + { + _addDeclAssociationsFromModule(module->getModuleDecl()); + } + } + } + return _getDeclAssociationList(decl, m_mapDeclToAssociatedDecls); + } + + bool SharedSemanticsContext::isDifferentiableFunc(FunctionDeclBase* func) + { + // A function is differentiable if it is marked as differentiable, or it + // has an associated derivative function. + if (func->findModifier<DifferentiableAttribute>()) + return true; + for (auto assocDecl : getAssociatedDeclsForDecl(func)) + { + switch (assocDecl.kind) + { + case DeclAssociationKind::ForwardDerivativeFunc: + case DeclAssociationKind::BackwardDerivativeFunc: + return true; + default: + break; + } + } + return false; + } + + bool SharedSemanticsContext::isBackwardDifferentiableFunc(FunctionDeclBase* func) + { + // A function is differentiable if it is marked as differentiable, or it + // has an associated derivative function. + if (func->findModifier<BackwardDifferentiableAttribute>()) + return true; + for (auto assocDecl : getAssociatedDeclsForDecl(func)) + { + switch (assocDecl.kind) + { + case DeclAssociationKind::BackwardDerivativeFunc: + return true; + default: + break; + } + } + if (auto builtinReq = func->findModifier<BuiltinRequirementModifier>()) + { + switch (builtinReq->kind) + { + case BuiltinRequirementKind::DAddFunc: + case BuiltinRequirementKind::DMulFunc: + case BuiltinRequirementKind::DZeroFunc: + return true; + default: + break; + } + } + return false; + } + List<ExtensionDecl*> const& getCandidateExtensions( DeclRef<AggTypeDecl> const& declRef, SemanticsVisitor* semantics) diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index b43a03150..2c6899269 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1893,6 +1893,34 @@ namespace Slang } } } + + if (auto higherOrderInvoke = as<DifferentiateExpr>(invoke->functionExpr)) + { + if (auto funcDeclExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(higherOrderInvoke))) + { + auto funcDecl = as<FunctionDeclBase>(funcDeclExpr->declRef.getDecl()); + if (funcDecl) + { + DifferentiateExpr* forwardDiff = nullptr; + DifferentiateExpr* backwardDiff = nullptr; + for (auto node = as<DifferentiateExpr>(invoke->functionExpr); node; node = as<DifferentiateExpr>(node->baseFunction)) + { + if (auto fwd = as<ForwardDifferentiateExpr>(node)) + forwardDiff = fwd; + if (auto bwd = as<BackwardDifferentiateExpr>(node)) + backwardDiff = bwd; + } + if (forwardDiff && !getShared()->isDifferentiableFunc(funcDecl)) + { + getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "forward"); + } + if (backwardDiff && !getShared()->isBackwardDifferentiableFunc(funcDecl)) + { + getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward"); + } + } + } + } } } return rs; @@ -1920,7 +1948,7 @@ namespace Slang auto checkedExpr = CheckInvokeExprWithCheckedOperands(expr); - if (m_parentFunc && m_parentFunc->hasModifier<DifferentiableAttribute>()) + if (m_parentDifferentiableAttr) { if (auto checkedInvokeExpr = as<InvokeExpr>(checkedExpr)) { @@ -1929,6 +1957,30 @@ namespace Slang { maybeRegisterDifferentiableType(m_astBuilder, arg->type.type); } + if (auto calleeExpr = as<DeclRefExpr>(checkedInvokeExpr->functionExpr)) + { + if (auto calleeDecl = as<FunctionDeclBase>(calleeExpr->declRef.getDecl())) + { + if (getShared()->isDifferentiableFunc(calleeDecl)) + { + if (!m_treatAsDifferentiableExpr) + { + auto newFuncExpr = + getASTBuilder()->create<TreatAsDifferentiableExpr>(); + newFuncExpr->type = checkedInvokeExpr->type; + newFuncExpr->innerExpr = checkedInvokeExpr; + newFuncExpr->loc = checkedInvokeExpr->loc; + checkedExpr = newFuncExpr; + } + else + { + getSink()->diagnose( + m_treatAsDifferentiableExpr, + Diagnostics::useOfNoDiffOnDifferentiableFunc); + } + } + } + } } maybeRegisterDifferentiableType(m_astBuilder, checkedExpr->type.type); } @@ -2227,6 +2279,27 @@ namespace Slang return _checkDifferentiateExpr(this, expr, &actions); } + Expr* SemanticsExprVisitor::visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr) + { + auto subContext = withTreatAsDifferentiable(expr); + expr->innerExpr = dispatchExpr(expr->innerExpr, subContext); + expr->type = expr->innerExpr->type; + auto innerExpr = expr->innerExpr; + while (auto parenExpr = as<ParenExpr>(innerExpr)) + { + innerExpr = parenExpr->base; + } + if (!as<InvokeExpr>(innerExpr)) + { + getSink()->diagnose(expr, Diagnostics::invalidUseOfNoDiff); + } + else if (!m_parentDifferentiableAttr) + { + getSink()->diagnose(expr, Diagnostics::cannotUseNoDiffInNonDifferentiableFunc); + } + return expr; + } + Expr* SemanticsExprVisitor::visitGetArrayLengthExpr(GetArrayLengthExpr* expr) { expr->arrayExpr = CheckTerm(expr->arrayExpr); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index c4c32a681..1c2f698bd 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -284,6 +284,13 @@ namespace Slang /// Register a candidate extension `extDecl` for `typeDecl` encountered during checking. void registerCandidateExtension(AggTypeDecl* typeDecl, ExtensionDecl* extDecl); + void registerAssociatedDecl(Decl* original, DeclAssociationKind assoc, Decl* declaration); + + List<DeclAssociation> const& getAssociatedDeclsForDecl(Decl* decl); + + bool isDifferentiableFunc(FunctionDeclBase* func); + bool isBackwardDifferentiableFunc(FunctionDeclBase* func); + private: /// Mapping from type declarations to the known extensiosn that apply to them Dictionary<AggTypeDecl*, RefPtr<CandidateExtensionList>> m_mapTypeDeclToCandidateExtensions; @@ -293,6 +300,17 @@ namespace Slang /// Add candidate extensions declared in `moduleDecl` to `m_mapTypeDeclToCandidateExtensions` void _addCandidateExtensionsFromModule(ModuleDecl* moduleDecl); + + /// Mapping from a decl to additional declarations of the same decl. + /// The additional declarations provide a location to hold extra decorations. + OrderedDictionary<Decl*, RefPtr<DeclAssociationList>> m_mapDeclToAssociatedDecls; + + /// Is the `m_mapDeclToAssociatedDecls` dictionary valid and up to date? + bool m_associatedDeclListsBuilt = false; + + /// Add associated decls declared in `moduleDecl` to `m_mapDeclToAssociatedDecls` + void _addDeclAssociationsFromModule(ModuleDecl* moduleDecl); + }; /// Local/scoped state of the semantic-checking system @@ -411,6 +429,13 @@ namespace Slang return result; } + SemanticsContext withTreatAsDifferentiable(TreatAsDifferentiableExpr* expr) + { + SemanticsContext result(*this); + result.m_treatAsDifferentiableExpr = expr; + return result; + } + SemanticsContext allowStaticReferenceToNonStaticMember() { SemanticsContext result(*this); @@ -444,6 +469,10 @@ namespace Slang /// is considered valid in the current context. bool m_allowStaticReferenceToNonStaticMember = false; + /// Whether or not we are in a `no_diff` environment (and therefore should treat the call to + /// a non-differentiable function as differentiable and not issue a diagnostic). + TreatAsDifferentiableExpr* m_treatAsDifferentiableExpr = nullptr; + ASTBuilder* m_astBuilder = nullptr; }; @@ -724,9 +753,6 @@ namespace Slang /// void registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness); - // Check and register a type if it is differentiable. - void maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type); - // Construct the differential for 'type', if it exists. Type* getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc); Type* tryGetDifferentialType(ASTBuilder* builder, Type* type); @@ -1061,6 +1087,9 @@ namespace Slang /// Gather differentiable members from decl. List<DifferentiableMemberInfo> collectDifferentiableMemberInfo(ContainerDecl* decl); + // Check and register a type if it is differentiable. + void maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type); + // Find the appropriate member of a declared type to // satisfy a requirement of an interface the type // claims to conform to. @@ -1266,8 +1295,6 @@ namespace Slang /// Given an immutable `expr` used as an l-value emit a special diagnostic if it was derived from `this`. void maybeDiagnoseThisNotLValue(Expr* expr); - void registerExtension(ExtensionDecl* decl); - // Figure out what type an initializer/constructor declaration // is supposed to return. In most cases this is just the type // declaration that its declaration is nested inside. @@ -1913,6 +1940,7 @@ namespace Slang Expr* visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr); Expr* visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr); + Expr* visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr); Expr* visitGetArrayLengthExpr(GetArrayLengthExpr* expr); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index c69a1e9e6..d293626ae 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -304,6 +304,8 @@ DIAGNOSTIC(30094, Error, mustUseTryClauseToCallAThrowFunc, "the callee may throw DIAGNOSTIC(30095, Error, errorTypeOfCalleeIncompatibleWithCaller, "the error type `$1` of callee `$0` is not compatible with the caller's error type `$2`.") DIAGNOSTIC(30096, Error, differentialTypeShouldServeAsItsOwnDifferentialType, "type '$0' is used as a `Differential` type, therefore it must serve as its own `Differential` type.") +DIAGNOSTIC(30097, Error, functionNotMarkedAsDifferentiable, "function '$0' is not marked as $1-differentiable.") + DIAGNOSTIC(-1, Note, noteSeeUseOfDifferentialType, "see use of '$0' as Differential of '$1'.") // Attributes @@ -494,7 +496,9 @@ DIAGNOSTIC(38026, Error, globalTypeArgumentDoesNotConformToInterface, "type argu DIAGNOSTIC(38027, Error, mismatchExistentialSlotArgCount, "expected $0 existential slot arguments ($1 provided)") DIAGNOSTIC(38029, Error, typeArgumentDoesNotConformToInterface, "type argument '$0' does not conform to the required interface '$1'") -DIAGNOSTIC(30830, Error, functionNotMarkedAsDifferentiable, "function '$0' is not marked as $1-differentiable.") +DIAGNOSTIC(38031, Error, invalidUseOfNoDiff, "'no_diff' can only be used to decorate a call.") +DIAGNOSTIC(38032, Error, useOfNoDiffOnDifferentiableFunc, "use 'no_diff' on a call to a differentiable function has no meaning.") +DIAGNOSTIC(38033, Error, cannotUseNoDiffInNonDifferentiableFunc, "cannot use 'no_diff' in a non-differentiable function.") DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself") DIAGNOSTIC(39999, Error, errorInImportedModule, "import of module '$0' failed because of a compilation error") @@ -568,6 +572,9 @@ DIAGNOSTIC(41010, Warning, missingReturn, "control flow may reach end of non-'vo DIAGNOSTIC(41011, Error, typeDoesNotFitAnyValueSize, "type '$0' does not fit in the size required by its conforming interface.") DIAGNOSTIC(41012, Note, typeAndLimit, "sizeof($0) is $1, limit is $2") DIAGNOSTIC(41012, Error, typeCannotBePackedIntoAnyValue, "type '$0' contains fields that cannot be packed into an AnyValue.") +DIAGNOSTIC(41020, Error, lossOfDerivativeDueToCallOfNonDifferentiableFunction, "derivative cannot be propagated through call to non-differentiable function `$0`, use 'no_diff' to clarify intention.") +DIAGNOSTIC(41021, Error, differentiableFuncMustHaveOutput, "a differentiable function must have at least one differentiable output.") +DIAGNOSTIC(41022, Error, differentiableFuncMustHaveInput, "a differentiable function must have at least one differentiable input.") // // 5xxxx - Target code generation. diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp new file mode 100644 index 000000000..44c6324e3 --- /dev/null +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -0,0 +1,421 @@ +#include "slang-ir-check-differentiability.h" + +#include "slang-ir-diff-jvp.h" +#include "slang-ir-inst-pass-base.h" + +namespace Slang +{ + +struct CheckDifferentiabilityPassContext : public InstPassBase +{ +public: + DiagnosticSink* sink; + AutoDiffSharedContext sharedContext; + + HashSet<IRInst*> differentiableFunctions; + + CheckDifferentiabilityPassContext(IRModule* inModule, DiagnosticSink* inSink) + : InstPassBase(inModule), sink(inSink), sharedContext(inModule->getModuleInst()) + {} + + IRInst* getSpecializedVal(IRInst* inst) + { + int loopLimit = 1024; + while (inst && inst->getOp() == kIROp_Specialize) + { + inst = as<IRSpecialize>(inst)->getBase(); + loopLimit--; + if (loopLimit == 0) + return inst; + } + return inst; + } + + IRInst* getLeafFunc(IRInst* func) + { + func = getSpecializedVal(func); + if (!func) + return nullptr; + if (auto genericFunc = as<IRGeneric>(func)) + return findInnerMostGenericReturnVal(genericFunc); + return func; + } + + bool _isFuncMarkedForAutoDiff(IRInst* func) + { + func = getLeafFunc(func); + if (!func) + return false; + for (auto decorations : func->getDecorations()) + { + switch (decorations->getOp()) + { + case kIROp_ForwardDifferentiableDecoration: + case kIROp_BackwardDifferentiableDecoration: + return true; + } + } + return false; + } + + + bool _isDifferentiableFuncImpl(IRInst* func) + { + func = getLeafFunc(func); + if (!func) + return false; + + for (auto decorations : func->getDecorations()) + { + switch (decorations->getOp()) + { + case kIROp_ForwardDerivativeDecoration: + case kIROp_ForwardDifferentiableDecoration: + case kIROp_BackwardDerivativeDecoration: + case kIROp_BackwardDifferentiableDecoration: + return true; + } + } + return false; + } + + bool isDifferentiableFunc(IRInst* func) + { + switch (func->getOp()) + { + case kIROp_ForwardDifferentiate: + case kIROp_BackwardDifferentiate: + return true; + default: + break; + } + + func = getSpecializedVal(func); + if (!func) + return false; + + if (differentiableFunctions.Contains(func)) + return true; + + for (; func; func = func->parent) + { + if (as<IRGeneric>(func)) + { + return differentiableFunctions.Contains(func); + } + } + return false; + } + + bool isBackwardDifferentiableFunc(IRInst* func) + { + for (auto decorations : func->getDecorations()) + { + switch (decorations->getOp()) + { + case kIROp_BackwardDerivativeDecoration: + case kIROp_BackwardDifferentiableDecoration: + return true; + } + } + return false; + } + + bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst) + { + HashSet<IRInst*> processedSet; + while (auto ptrType = as<IRPtrTypeBase>(typeInst)) + { + typeInst = ptrType->getValueType(); + if (!processedSet.Add(typeInst)) + return false; + } + if (!typeInst) + return false; + switch (typeInst->getOp()) + { + case kIROp_FloatType: + case kIROp_DifferentialPairType: + return true; + default: + break; + } + if (context.lookUpConformanceForType(typeInst)) + return true; + // Look for equivalent types. + for (auto type : context.differentiableWitnessDictionary) + { + if (isTypeEqual(type.Key, (IRType*)typeInst)) + { + context.differentiableWitnessDictionary[(IRType*)typeInst] = type.Value; + return true; + } + } + return false; + } + + int getParamIndexInBlock(IRParam* paramInst) + { + auto block = as<IRBlock>(paramInst->getParent()); + if (!block) + return -1; + int paramIndex = 0; + for (auto param : block->getParams()) + { + if (param == paramInst) + return paramIndex; + paramIndex++; + } + return -1; + } + + bool isInstInFunc(IRInst* inst, IRInst* func) + { + while (inst) + { + if (inst == func) + return true; + inst = inst->parent; + } + return false; + } + void processFunc(IRGlobalValueWithCode* funcInst) + { + if (!_isFuncMarkedForAutoDiff(funcInst)) + return; + if (!funcInst->getFirstBlock()) + return; + + DifferentiableTypeConformanceContext diffTypeContext(&sharedContext); + diffTypeContext.setFunc(funcInst); + + HashSet<IRInst*> produceDiffSet; + HashSet<IRInst*> expectDiffSet; + int differentiableInputs = 0; + int differentiableOutputs = 0; + for (auto param : funcInst->getFirstBlock()->getParams()) + { + if (isDifferentiableType(diffTypeContext, param->getFullType())) + { + if (as<IROutTypeBase>(param->getFullType())) + differentiableOutputs++; + if (!as<IROutType>(param->getFullType())) + differentiableInputs++; + produceDiffSet.Add(param); + } + } + if (auto funcType = as<IRFuncType>(funcInst->getDataType())) + { + if (isDifferentiableType(diffTypeContext, funcType->getResultType())) + differentiableOutputs++; + } + + if (differentiableOutputs == 0) + sink->diagnose(funcInst, Diagnostics::differentiableFuncMustHaveOutput); + if (differentiableInputs == 0) + sink->diagnose(funcInst, Diagnostics::differentiableFuncMustHaveInput); + + auto isInstProducingDiff = [&](IRInst* inst) -> bool + { + switch (inst->getOp()) + { + case kIROp_FloatLit: + return true; + case kIROp_Call: + return inst->findDecoration<IRTreatAsDifferentiableCallDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee()); + case kIROp_Load: + // We don't have more knowledge on whether diff is available at the destination address. + // Just assume it is producing diff. + //TODO: propagate the info if this is a load of a temporary variable intended to receive result from an `out` parameter. + return isDifferentiableType(diffTypeContext, inst->getDataType()); + default: + // default case is to assume the inst produces a diff value if any + // of its operands produces a diff value. + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (produceDiffSet.Contains(inst->getOperand(i))) + { + return true; + } + } + return false; + } + }; + + List<IRInst*> expectDiffInstWorkList; + OrderedHashSet<IRInst*> expectDiffInstWorkListSet; + auto addToExpectDiffWorkList = [&](IRInst* inst) + { + if (isInstInFunc(inst, funcInst)) + { + if (expectDiffInstWorkListSet.Add(inst)) + expectDiffInstWorkList.add(inst); + } + }; + // Run data flow analysis and generate `produceDiffSet` and an intial `expectDiffSet`. + Index lastProduceDiffCount = 0; + do + { + lastProduceDiffCount = produceDiffSet.Count(); + for (auto block : funcInst->getBlocks()) + { + if (block != funcInst->getFirstBlock()) + { + UInt paramIndex = 0; + for (auto param : block->getParams()) + { + for (auto p : block->getPredecessors()) + { + // A Phi Node is producing diff if any of its candidate values are producing diff. + if (auto branch = as<IRUnconditionalBranch>(p->getTerminator())) + { + if (branch->getArgCount() > paramIndex) + { + auto arg = branch->getArg(paramIndex); + if (produceDiffSet.Contains(arg)) + { + produceDiffSet.Add(param); + break; + } + } + } + } + paramIndex++; + } + } + for (auto inst : block->getChildren()) + { + if (isInstProducingDiff(inst)) + produceDiffSet.Add(inst); + switch (inst->getOp()) + { + case kIROp_Call: + if (isDifferentiableFunc(as<IRCall>(inst)->getCallee())) + { + addToExpectDiffWorkList(inst); + } + break; + case kIROp_Store: + { + auto storeInst = as<IRStore>(inst); + if (isDifferentiableType(diffTypeContext, as<IRStore>(inst)->getPtr()->getDataType())) + { + addToExpectDiffWorkList(storeInst->getVal()); + } + } + break; + case kIROp_Return: + if (auto returnVal = as<IRReturn>(inst)->getVal()) + { + if (isDifferentiableType(diffTypeContext, returnVal->getDataType())) + { + addToExpectDiffWorkList(inst); + } + } + break; + default: + break; + } + } + } + } while (produceDiffSet.Count() != lastProduceDiffCount); + + // Reverse propagate `expectDiffSet`. + for (int i = 0; i < expectDiffInstWorkList.getCount(); i++) + { + auto inst = expectDiffInstWorkList[i]; + // Is inst in produceDiffSet? + if (!produceDiffSet.Contains(inst)) + { + if (auto call = as<IRCall>(inst)) + { + sink->diagnose(inst, Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, getLeafFunc(call->getCallee())); + } + } + switch (inst->getOp()) + { + case kIROp_Param: + { + auto block = as<IRBlock>(inst->getParent()); + if (block != funcInst->getFirstBlock()) + { + auto paramIndex = getParamIndexInBlock(as<IRParam>(inst)); + if (paramIndex != -1) + { + for (auto p : block->getPredecessors()) + { + // A Phi Node is producing diff if any of its candidate values are producing diff. + if (auto branch = as<IRUnconditionalBranch>(p->getTerminator())) + { + if (branch->getArgCount() > (UInt)paramIndex) + { + auto arg = branch->getArg(paramIndex); + addToExpectDiffWorkList(arg); + } + } + } + } + } + break; + } + default: + // Default behavior is to request all differentiable operands to provide differential. + for (UInt opIndex = 0; opIndex < inst->getOperandCount(); opIndex++) + { + auto operand = inst->getOperand(opIndex); + if (isDifferentiableType(diffTypeContext, operand->getFullType())) + { + addToExpectDiffWorkList(operand); + } + } + } + } + } + + void processModule() + { + // Collect set of differentiable functions. + HashSet<UnownedStringSlice> differentiableSymbolNames; + for (auto inst : module->getGlobalInsts()) + { + if (_isDifferentiableFuncImpl(inst)) + { + if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>()) + differentiableSymbolNames.Add(linkageDecor->getMangledName()); + differentiableFunctions.Add(inst); + } + } + for (auto inst : module->getGlobalInsts()) + { + if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>()) + { + if (differentiableSymbolNames.Contains(linkageDecor->getMangledName())) + differentiableFunctions.Add(inst); + } + } + + if (!sharedContext.isInterfaceAvailable) + return; + + for (auto inst : module->getGlobalInsts()) + { + if (auto genericInst = as<IRGeneric>(inst)) + { + if (auto innerFunc = as<IRGlobalValueWithCode>(findGenericReturnVal(genericInst))) + processFunc(innerFunc); + } + else if (auto funcInst = as<IRGlobalValueWithCode>(inst)) + { + processFunc(funcInst); + } + } + } +}; + +void checkAutoDiffUsages(IRModule* module, DiagnosticSink* sink) +{ + CheckDifferentiabilityPassContext context(module, sink); + context.processModule(); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-check-differentiability.h b/source/slang/slang-ir-check-differentiability.h new file mode 100644 index 000000000..735a918c9 --- /dev/null +++ b/source/slang/slang-ir-check-differentiability.h @@ -0,0 +1,14 @@ +// slang-ir-check-differentiability.h +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ +struct IRModule; +class DiagnosticSink; + +// Check all auto diff usages are valid. +void checkAutoDiffUsages(IRModule* module, DiagnosticSink* sink); + +} // namespace Slang diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 152601dbd..4ee16aafc 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -1,8 +1,6 @@ // slang-ir-diff-jvp.cpp #include "slang-ir-diff-jvp.h" -#include "slang-ir.h" -#include "slang-ir-insts.h" #include "slang-ir-clone.h" #include "slang-ir-dce.h" #include "slang-ir-eliminate-phis.h" @@ -16,154 +14,6 @@ namespace Slang { -template<typename P, typename D> -struct Pair -{ - P primal; - D differential; - Pair() = default; - Pair(P primal, D differential) : primal(primal), differential(differential) - {} - HashCode getHashCode() const - { - Hasher hasher; - hasher << primal << differential; - return hasher.getResult(); - } - bool operator ==(const Pair& other) const - { - return primal == other.primal && differential == other.differential; - } -}; - -typedef Pair<IRInst*, IRInst*> InstPair; - -struct AutoDiffSharedContext -{ - IRModuleInst* moduleInst = nullptr; - - SharedIRBuilder* sharedBuilder = nullptr; - - // A reference to the builtin IDifferentiable interface type. - // We use this to look up all the other types (and type exprs) - // that conform to a base type. - // - IRInterfaceType* differentiableInterfaceType = nullptr; - - // The struct key for the 'Differential' associated type - // defined inside IDifferential. We use this to lookup the differential - // type in the conformance table associated with the concrete type. - // - IRStructKey* differentialAssocTypeStructKey = nullptr; - - // The struct key for the witness that `Differential` associated type conforms to - // `IDifferential`. - IRStructKey* differentialAssocTypeWitnessStructKey = nullptr; - - - // The struct key for the 'zero()' associated type - // defined inside IDifferential. We use this to lookup the - // implementation of zero() for a given type. - // - IRStructKey* zeroMethodStructKey = nullptr; - - // The struct key for the 'add()' associated type - // defined inside IDifferential. We use this to lookup the - // implementation of add() for a given type. - // - IRStructKey* addMethodStructKey = nullptr; - - IRStructKey* mulMethodStructKey = nullptr; - - - // Modules that don't use differentiable types - // won't have the IDifferentiable interface type available. - // Set to false to indicate that we are uninitialized. - // - bool isInterfaceAvailable = false; - - - AutoDiffSharedContext(IRModuleInst* inModuleInst) - : moduleInst(inModuleInst) - { - differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface()); - if (differentiableInterfaceType) - { - differentialAssocTypeStructKey = findDifferentialTypeStructKey(); - differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey(); - zeroMethodStructKey = findZeroMethodStructKey(); - addMethodStructKey = findAddMethodStructKey(); - mulMethodStructKey = findMulMethodStructKey(); - - if (differentialAssocTypeStructKey) - isInterfaceAvailable = true; - } - } - - private: - - IRInst* findDifferentiableInterface() - { - if (auto module = as<IRModuleInst>(moduleInst)) - { - for (auto globalInst : module->getGlobalInsts()) - { - // TODO: This seems like a particularly dangerous way to look for an interface. - // See if we can lower IDifferentiable to a separate IR inst. - // - if (globalInst->getOp() == kIROp_InterfaceType && - as<IRInterfaceType>(globalInst)->findDecoration<IRNameHintDecoration>()->getName() == "IDifferentiable") - { - return globalInst; - } - } - } - return nullptr; - } - - IRStructKey* findDifferentialTypeStructKey() - { - return getIDifferentiableStructKeyAtIndex(0); - } - - IRStructKey* findDifferentialTypeWitnessStructKey() - { - return getIDifferentiableStructKeyAtIndex(1); - } - - IRStructKey* findZeroMethodStructKey() - { - return getIDifferentiableStructKeyAtIndex(2); - } - - IRStructKey* findAddMethodStructKey() - { - return getIDifferentiableStructKeyAtIndex(3); - } - - IRStructKey* findMulMethodStructKey() - { - return getIDifferentiableStructKeyAtIndex(4); - } - - IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index) - { - if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType) - { - // Assume for now that IDifferentiable has exactly five fields. - SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5); - if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index))) - return as<IRStructKey>(entry->getRequirementKey()); - else - { - SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type"); - } - } - - return nullptr; - } -}; - namespace { @@ -189,97 +39,6 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK } -struct DifferentiableTypeConformanceContext -{ - AutoDiffSharedContext* sharedContext; - - IRGlobalValueWithCode* parentFunc = nullptr; - Dictionary<IRType*, IRInst*> differentiableWitnessDictionary; - - DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared) - : sharedContext(shared) - {} - - void setFunc(IRGlobalValueWithCode* func) - { - parentFunc = func; - - auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>(); - SLANG_RELEASE_ASSERT(decor); - - // Build lookup dictionary for type witnesses. - for (auto child = decor->getFirstChild(); child; child = child->next) - { - if (auto item = as<IRDifferentiableTypeDictionaryItem>(child)) - { - auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType()); - if (existingItem) - { - if (auto witness = as<IRWitnessTable>(item->getWitness())) - { - if (witness->getConcreteType()->getOp() == kIROp_DifferentialBottomType) - continue; - } - *existingItem = item->getWitness(); - } - else - { - differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness()); - } - } - } - } - - - // Lookup a witness table for the concreteType. One should exist if concreteType - // inherits (successfully) from IDifferentiable. - // - IRInst* lookUpConformanceForType(IRInst* type) - { - IRInst* foundResult = nullptr; - differentiableWitnessDictionary.TryGetValue(type, foundResult); - return foundResult; - } - - IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) - { - if (auto conformance = lookUpConformanceForType(origType)) - { - return _lookupWitness(builder, conformance, key); - } - return nullptr; - } - - // Lookup and return the 'Differential' type declared in the concrete type - // in order to conform to the IDifferentiable interface. - // Note that inside a generic block, this will be a witness table lookup instruction - // that gets resolved during the specialization pass. - // - IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType) - { - switch (origType->getOp()) - { - case kIROp_FloatType: - case kIROp_HalfType: - case kIROp_DoubleType: - case kIROp_VectorType: - return origType; - } - return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey); - } - - IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType) - { - return lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey); - } - - IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType) - { - return lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey); - } - -}; - struct DifferentialPairTypeBuilder { @@ -3275,4 +3034,108 @@ void stripAutoDiffDecorations(IRModule* module) stripAutoDiffDecorationsFromChildren(module->getModuleInst()); } +AutoDiffSharedContext::AutoDiffSharedContext(IRModuleInst* inModuleInst) + : moduleInst(inModuleInst) +{ + differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface()); + if (differentiableInterfaceType) + { + differentialAssocTypeStructKey = findDifferentialTypeStructKey(); + differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey(); + zeroMethodStructKey = findZeroMethodStructKey(); + addMethodStructKey = findAddMethodStructKey(); + mulMethodStructKey = findMulMethodStructKey(); + + if (differentialAssocTypeStructKey) + isInterfaceAvailable = true; + } +} + +IRInst* AutoDiffSharedContext::findDifferentiableInterface() +{ + if (auto module = as<IRModuleInst>(moduleInst)) + { + for (auto globalInst : module->getGlobalInsts()) + { + // TODO: This seems like a particularly dangerous way to look for an interface. + // See if we can lower IDifferentiable to a separate IR inst. + // + if (globalInst->getOp() == kIROp_InterfaceType && + as<IRInterfaceType>(globalInst)->findDecoration<IRNameHintDecoration>()->getName() == "IDifferentiable") + { + return globalInst; + } + } + } + return nullptr; +} + +IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt index) +{ + if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType) + { + // Assume for now that IDifferentiable has exactly five fields. + SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5); + if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index))) + return as<IRStructKey>(entry->getRequirementKey()); + else + { + SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type"); + } + } + + return nullptr; +} + +void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) +{ + parentFunc = func; + + auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>(); + SLANG_RELEASE_ASSERT(decor); + + // Build lookup dictionary for type witnesses. + for (auto child = decor->getFirstChild(); child; child = child->next) + { + if (auto item = as<IRDifferentiableTypeDictionaryItem>(child)) + { + auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType()); + if (existingItem) + { + if (auto witness = as<IRWitnessTable>(item->getWitness())) + { + if (witness->getConcreteType()->getOp() == kIROp_DifferentialBottomType) + continue; + } + *existingItem = item->getWitness(); + } + else + { + differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness()); + } + } + } +} + + +// Lookup a witness table for the concreteType. One should exist if concreteType +// inherits (successfully) from IDifferentiable. +// + +IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type) +{ + IRInst* foundResult = nullptr; + differentiableWitnessDictionary.TryGetValue(type, foundResult); + return foundResult; +} + +IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) +{ + if (auto conformance = lookUpConformanceForType(origType)) + { + return _lookupWitness(builder, conformance, key); + } + return nullptr; +} + } diff --git a/source/slang/slang-ir-diff-jvp.h b/source/slang/slang-ir-diff-jvp.h index a866a3db3..9e0f9cfcc 100644 --- a/source/slang/slang-ir-diff-jvp.h +++ b/source/slang/slang-ir-diff-jvp.h @@ -2,11 +2,162 @@ #pragma once #include "slang-ir.h" +#include "slang-ir-insts.h" #include "slang-compiler.h" namespace Slang { - struct IRModule; + template<typename P, typename D> + struct DiffInstPair + { + P primal; + D differential; + DiffInstPair() = default; + DiffInstPair(P primal, D differential) : primal(primal), differential(differential) + {} + HashCode getHashCode() const + { + Hasher hasher; + hasher << primal << differential; + return hasher.getResult(); + } + bool operator ==(const DiffInstPair& other) const + { + return primal == other.primal && differential == other.differential; + } + }; + + typedef DiffInstPair<IRInst*, IRInst*> InstPair; + + struct AutoDiffSharedContext + { + IRModuleInst* moduleInst = nullptr; + + SharedIRBuilder* sharedBuilder = nullptr; + + // A reference to the builtin IDifferentiable interface type. + // We use this to look up all the other types (and type exprs) + // that conform to a base type. + // + IRInterfaceType* differentiableInterfaceType = nullptr; + + // The struct key for the 'Differential' associated type + // defined inside IDifferential. We use this to lookup the differential + // type in the conformance table associated with the concrete type. + // + IRStructKey* differentialAssocTypeStructKey = nullptr; + + // The struct key for the witness that `Differential` associated type conforms to + // `IDifferential`. + IRStructKey* differentialAssocTypeWitnessStructKey = nullptr; + + + // The struct key for the 'zero()' associated type + // defined inside IDifferential. We use this to lookup the + // implementation of zero() for a given type. + // + IRStructKey* zeroMethodStructKey = nullptr; + + // The struct key for the 'add()' associated type + // defined inside IDifferential. We use this to lookup the + // implementation of add() for a given type. + // + IRStructKey* addMethodStructKey = nullptr; + + IRStructKey* mulMethodStructKey = nullptr; + + + // Modules that don't use differentiable types + // won't have the IDifferentiable interface type available. + // Set to false to indicate that we are uninitialized. + // + bool isInterfaceAvailable = false; + + + AutoDiffSharedContext(IRModuleInst* inModuleInst); + + private: + + IRInst* findDifferentiableInterface(); + + IRStructKey* findDifferentialTypeStructKey() + { + return getIDifferentiableStructKeyAtIndex(0); + } + + IRStructKey* findDifferentialTypeWitnessStructKey() + { + return getIDifferentiableStructKeyAtIndex(1); + } + + IRStructKey* findZeroMethodStructKey() + { + return getIDifferentiableStructKeyAtIndex(2); + } + + IRStructKey* findAddMethodStructKey() + { + return getIDifferentiableStructKeyAtIndex(3); + } + + IRStructKey* findMulMethodStructKey() + { + return getIDifferentiableStructKeyAtIndex(4); + } + + IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index); + }; + + struct DifferentiableTypeConformanceContext + { + AutoDiffSharedContext* sharedContext; + + IRGlobalValueWithCode* parentFunc = nullptr; + OrderedDictionary<IRType*, IRInst*> differentiableWitnessDictionary; + + DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared) + : sharedContext(shared) + {} + + void setFunc(IRGlobalValueWithCode* func); + + + // Lookup a witness table for the concreteType. One should exist if concreteType + // inherits (successfully) from IDifferentiable. + // + IRInst* lookUpConformanceForType(IRInst* type); + + IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key); + + // Lookup and return the 'Differential' type declared in the concrete type + // in order to conform to the IDifferentiable interface. + // Note that inside a generic block, this will be a witness table lookup instruction + // that gets resolved during the specialization pass. + // + IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType) + { + switch (origType->getOp()) + { + case kIROp_FloatType: + case kIROp_HalfType: + case kIROp_DoubleType: + case kIROp_VectorType: + return origType; + } + return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey); + } + + IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType) + { + return lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey); + } + + IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType) + { + return lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey); + } + + }; struct IRJVPDerivativePassOptions { diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index bf73a31d8..e11f98dcd 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -736,6 +736,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// differential member of a type in its associated differential type. INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0) + /// Treat the IRCall as a call to a differentiable function. + INST(TreatAsDifferentiableCallDecoration, treatAsDifferentiableCallDecoration, 0, 0) + /// Marks a class type as a COM interface implementation, which enables /// the witness table to be easily picked up by emit. INST(COMWitnessDecoration, COMWitnessDecoration, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index d85e56d7e..fcdeed17a 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -572,6 +572,18 @@ struct IRForwardDerivativeDecoration : IRDecoration IRInst* getForwardDerivativeFunc() { return getOperand(0); } }; + +struct IRBackwardDerivativeDecoration : IRDecoration +{ + enum + { + kOp = kIROp_BackwardDerivativeDecoration + }; + IR_LEAF_ISA(BackwardDerivativeDecoration) + + IRInst* getBackwardDerivativeFunc() { return getOperand(0); } +}; + struct IRBackwardDifferentiableDecoration : IRDecoration { enum @@ -581,6 +593,14 @@ struct IRBackwardDifferentiableDecoration : IRDecoration IR_LEAF_ISA(BackwardDifferentiableDecoration) }; +struct IRTreatAsDifferentiableCallDecoration : IRDecoration +{ + enum + { + kOp = kIROp_TreatAsDifferentiableCallDecoration + }; + IR_LEAF_ISA(TreatAsDifferentiableCallDecoration) +}; struct IRDerivativeMemberDecoration : IRDecoration { diff --git a/source/slang/slang-ir-lower-error-handling.cpp b/source/slang/slang-ir-lower-error-handling.cpp index 387ab45b4..e9747e3b6 100644 --- a/source/slang/slang-ir-lower-error-handling.cpp +++ b/source/slang/slang-ir-lower-error-handling.cpp @@ -90,6 +90,8 @@ struct ErrorHandlingLoweringContext args.add(tryCall->getArg(i)); } auto call = builder.emitCallInst(resultType, tryCall->getCallee(), args); + tryCall->transferDecorationsTo(call); + auto isFail = builder.emitIsResultError(call); auto failBlock = tryCall->getFailureBlock(); auto successBlock = tryCall->getSuccessBlock(); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 261f64130..de86a6a52 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -5848,7 +5848,10 @@ namespace Slang return static_cast<IRConstant*>(a)->isValueEqual(static_cast<IRConstant*>(b)) && isTypeEqual(a->getFullType(), b->getFullType()); } - + if (IRSpecialize::isaImpl(opA) || opA == kIROp_lookup_interface_method) + { + return _areTypeOperandsEqual(a, b); + } SLANG_ASSERT(!"Unhandled comparison"); // We can't equate any other type.. diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index 4b6c7f33d..2f05dc1db 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -426,6 +426,10 @@ public: } return dispatchIfNotNull(expr->baseFunction); } + bool visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr) + { + return dispatchIfNotNull(expr->innerExpr); + } }; struct ASTLookupStmtVisitor : public StmtVisitor<ASTLookupStmtVisitor, bool> diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 4c3f4d646..61b6fcb76 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -11,6 +11,7 @@ #include "slang-ir-diff-jvp.h" #include "slang-ir-inline.h" #include "slang-ir-insts.h" +#include "slang-ir-check-differentiability.h" #include "slang-ir-missing-return.h" #include "slang-ir-sccp.h" #include "slang-ir-ssa.h" @@ -3113,6 +3114,14 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> baseVal.val)); } + LoweredValInfo visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr) + { + auto baseVal = lowerSubExpr(expr->innerExpr); + SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); + getBuilder()->addDecoration(baseVal.val, kIROp_TreatAsDifferentiableCallDecoration); + return baseVal; + } + // Emit IR to denote the forward-mode derivative // of the inner func-expr. This will be resolved // to a concrete function during the derivative @@ -8998,6 +9007,9 @@ RefPtr<IRModule> generateIRForTranslationUnit( checkForMissingReturns(module, compileRequest->getSink()); + // Check for invalid differentiable function body. + checkAutoDiffUsages(module, compileRequest->getSink()); + // The "mandatory" optimization passes may make use of the // `IRHighLevelDeclDecoration` type to relate IR instructions // back to AST-level code in order to improve the quality diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 7cd0cdce9..b513217c8 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -5086,6 +5086,14 @@ namespace Slang return tryExpr; } + static NodeBase* parseTreatAsDifferentiableExpr(Parser* parser, void* /*userData*/) + { + auto noDiffExpr = parser->astBuilder->create<TreatAsDifferentiableExpr>(); + noDiffExpr->innerExpr = parser->ParseLeafExpression(); + noDiffExpr->scope = parser->currentScope; + return noDiffExpr; + } + static bool _isFinite(double value) { // Lets type pun double to uint64_t, so we can detect special double values @@ -6670,6 +6678,7 @@ namespace Slang _makeParseExpr("nullptr", parseNullPtrExpr), _makeParseExpr("none", parseNoneExpr), _makeParseExpr("try", parseTryExpr), + _makeParseExpr("no_diff", parseTreatAsDifferentiableExpr), _makeParseExpr("__TaggedUnion", parseTaggedUnionType), _makeParseExpr("__fwd_diff", parseForwardDifferentiate), _makeParseExpr("__bwd_diff", parseBackwardDifferentiate) diff --git a/source/slang/slang-serialize-misc-type-info.h b/source/slang/slang-serialize-misc-type-info.h index 191514785..d3d83e1d0 100644 --- a/source/slang/slang-serialize-misc-type-info.h +++ b/source/slang/slang-serialize-misc-type-info.h @@ -188,7 +188,11 @@ struct SerialTypeInfo<const DiagnosticInfo*> } }; - +// DeclAssociation +template <> +struct SerialTypeInfo<DeclAssociation> : SerialIdentityTypeInfo<DeclAssociation> {}; +template <> +struct SerialTypeInfo<DeclAssociationKind> : public SerialConvertTypeInfo<DeclAssociationKind, uint8_t> {}; } // namespace Slang |
