diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-23 16:02:56 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-23 16:02:56 -0800 |
| commit | 4ad0470025da4e808c46023f9a2525febcf973a2 (patch) | |
| tree | 8fcb1c84121ddf40c50ca58b5de867da0da435ee /source/slang | |
| parent | 97cb4851eed7a43f10196971b08d3d311386ce9f (diff) | |
Fix issues around dynamic generic function and autodiff. (#2528)
* Fix issues around dynamic generic function and autodiff.
* Fix return type issue.
* Fix type unification for generic `inout` parameter.
* Fix.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-ast-decl.cpp | 13 | ||||
| -rw-r--r-- | source/slang/slang-ast-decl.h | 12 | ||||
| -rw-r--r-- | source/slang/slang-ast-type.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-ast-type.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 61 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 40 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 37 |
8 files changed, 106 insertions, 70 deletions
diff --git a/source/slang/slang-ast-decl.cpp b/source/slang/slang-ast-decl.cpp index b2802e304..9931bbcaf 100644 --- a/source/slang/slang-ast-decl.cpp +++ b/source/slang/slang-ast-decl.cpp @@ -18,6 +18,19 @@ const TypeExp& TypeConstraintDecl::_getSupOverride() const //return TypeExp::empty; } +InterfaceDecl* findParentInterfaceDecl(Decl* decl) +{ + auto ancestor = decl->parentDecl; + for (; ancestor; ancestor = ancestor->parentDecl) + { + if (auto interfaceDecl = as<InterfaceDecl>(ancestor)) + return interfaceDecl; + + if (as<ExtensionDecl>(ancestor)) + return nullptr; + } + return nullptr; +} bool isInterfaceRequirement(Decl* decl) { diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index e7dc73a85..ccbac0286 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -518,7 +518,8 @@ class AttributeDecl : public ContainerDecl SyntaxClass<NodeBase> syntaxClass; }; -// A synthesized decl used as a placeholder for a differentiable function requirement. +// A synthesized decl used as a placeholder for a differentiable function requirement. This decl will +// be a child of interface decl. // This allows us to form an interface requirement key for the derivative of an interface function. // The synthesized `DerivativeRequirementDecl` will be a child of the original function requirement // decl after an interface type is checked. @@ -527,6 +528,14 @@ class DerivativeRequirementDecl : public FunctionDeclBase SLANG_AST_CLASS(DerivativeRequirementDecl) }; +// A reference to a synthesized decl representing a differentiable function requirement, this decl will +// be a child in the orignal function. +class DerivativeRequirementReferenceDecl : public FunctionDeclBase +{ + SLANG_AST_CLASS(DerivativeRequirementReferenceDecl) + DerivativeRequirementDecl* referencedDecl; +}; + class ForwardDerivativeRequirementDecl : public DerivativeRequirementDecl { SLANG_AST_CLASS(ForwardDerivativeRequirementDecl) @@ -538,5 +547,6 @@ class BackwardDerivativeRequirementDecl : public DerivativeRequirementDecl }; bool isInterfaceRequirement(Decl* decl); +InterfaceDecl* findParentInterfaceDecl(Decl* decl); } // namespace Slang diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 76623d01c..3fcc762ec 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -1178,5 +1178,14 @@ Val* ModifiedType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionS return substType; } +Type* removeParamDirType(Type* type) +{ + for (auto paramDirType = as<ParamDirectionType>(type); paramDirType;) + { + type = paramDirType->getValueType(); + paramDirType = as<ParamDirectionType>(type); + } + return type; +} } // namespace Slang diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index d85391d58..8953f0b10 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -872,4 +872,6 @@ class ModifiedType : public Type Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; +Type* removeParamDirType(Type* type); + } // namespace Slang diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 36a1061c9..4d2839b8d 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1926,23 +1926,33 @@ namespace Slang requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingMemberDeclRef)); - if (hasForwardDerivative) + if (hasForwardDerivative || hasBackwardDerivative) { - auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType<ForwardDerivativeRequirementDecl>(); - SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty()); - ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>(); - val->func = satisfyingMemberDeclRef; - witnessTable->add(reqDecl.getFirst(), RequirementWitness(val)); - } + int fwdReqFound = 0; + int bwdReqFound = 0; + for (auto reqRefDecl : requiredMemberDeclRef.getDecl()->getMembersOfType<DerivativeRequirementReferenceDecl>()) + { + if (auto fwdReq = as<ForwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl)) + { + ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(fwdReq, RequirementWitness(val)); + fwdReqFound++; + } + else if (auto bwdReq = as<BackwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl)) + { + BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(bwdReq, RequirementWitness(val)); + bwdReqFound++; + } + } - if (hasBackwardDerivative) - { - auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType<BackwardDerivativeRequirementDecl>(); - SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty()); - BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>(); - val->func = satisfyingMemberDeclRef; - witnessTable->add(reqDecl.getFirst(), RequirementWitness(val)); + SLANG_RELEASE_ASSERT( + fwdReqFound == (hasForwardDerivative ? 1 : 0) && + bwdReqFound == (hasBackwardDerivative ? 1 : 0)); } + return true; } @@ -3706,7 +3716,8 @@ namespace Slang { if(isAssociatedTypeDecl(requiredMemberDeclRef)) continue; - + if (requiredMemberDeclRef.as<DerivativeRequirementDecl>()) + continue; auto requirementSatisfied = findWitnessForInterfaceRequirement( context, subType, @@ -5617,7 +5628,7 @@ namespace Slang } decl->errorType = errorType; - if (isInterfaceRequirement(decl)) + if (auto interfaceDecl = findParentInterfaceDecl(decl)) { if (decl->hasModifier<ForwardDifferentiableAttribute>()) { @@ -5626,8 +5637,13 @@ namespace Slang auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef)); setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType)); - decl->members.add(reqDecl); - reqDecl->parentDecl = decl; + interfaceDecl->members.add(reqDecl); + reqDecl->parentDecl = interfaceDecl; + + auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); + reqRef->referencedDecl = reqDecl; + reqRef->parentDecl = decl; + decl->members.add(reqRef); } if (decl->hasModifier<BackwardDifferentiableAttribute>()) { @@ -5636,8 +5652,13 @@ namespace Slang auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); auto diffFuncType = getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef)); setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType)); - decl->members.add(reqDecl); - reqDecl->parentDecl = decl; + interfaceDecl->members.add(reqDecl); + reqDecl->parentDecl = interfaceDecl; + + auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); + reqRef->referencedDecl = reqDecl; + reqRef->parentDecl = decl; + decl->members.add(reqRef); } } } diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 83774303b..3867dda03 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1580,7 +1580,7 @@ namespace Slang List<Type*> paramTypes; for (UIndex ii = 0; ii < diffFuncType->getParamCount(); ii++) - paramTypes.add(diffFuncType->getParamType(ii)); + paramTypes.add(removeParamDirType(diffFuncType->getParamType(ii))); // Try to infer generic arguments, based on the updated context. DeclRef<Decl> innerRef = inferGenericArguments( diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 04a898ea9..c93522565 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -111,6 +111,8 @@ struct JVPTranscriber IRInst* lookupPrimalInst(IRInst* origInst) { + if (!origInst) + return nullptr; if (shouldUseOriginalAsPrimal(origInst)) return origInst; return cloneEnv.mapOldValToNew[origInst]; @@ -118,11 +120,15 @@ struct JVPTranscriber IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst) { + if (!origInst) + return nullptr; return (hasPrimalInst(origInst)) ? lookupPrimalInst(origInst) : defaultInst; } bool hasPrimalInst(IRInst* origInst) { + if (!origInst) + return true; if (shouldUseOriginalAsPrimal(origInst)) return true; return cloneEnv.mapOldValToNew.ContainsKey(origInst); @@ -175,7 +181,7 @@ struct JVPTranscriber if (auto returnPairType = tryGetDiffPairType(builder, origResultType)) diffReturnType = returnPairType; else - diffReturnType = builder->getVoidType(); + diffReturnType = origResultType; return builder->getFuncType(newParameterTypes, diffReturnType); } @@ -735,13 +741,12 @@ struct JVPTranscriber SLANG_ASSERT(primalArg); auto primalType = primalArg->getDataType(); - auto diffArg = findOrTranscribeDiffInst(builder, origArg); - - if (!diffArg) - diffArg = getDifferentialZeroOfType(builder, primalType); - if (auto pairType = tryGetDiffPairType(builder, primalType)) { + auto diffArg = findOrTranscribeDiffInst(builder, origArg); + if (!diffArg) + diffArg = getDifferentialZeroOfType(builder, primalType); + // If a pair type can be formed, this must be non-null. SLANG_RELEASE_ASSERT(diffArg); auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg); @@ -984,6 +989,18 @@ struct JVPTranscriber builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); return InstPair(primalSpecialize, diffSpecialize); } + else if (auto diffDecor = genericInnerVal->findDecoration<IRForwardDifferentiableDecoration>()) + { + List<IRInst*> args; + for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) + { + args.add(primalSpecialize->getArg(i)); + } + diffBase = findOrTranscribeDiffInst(builder, origSpecialize->getBase()); + auto diffSpecialize = builder->emitSpecializeInst( + builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); + return InstPair(primalSpecialize, diffSpecialize); + } else { return InstPair(primalSpecialize, nullptr); @@ -1365,15 +1382,14 @@ struct JVPTranscriber { differentiableTypeConformanceContext.setFunc(innerFunc); } + else if (auto funcType = as<IRFuncType>(innerVal)) + { + } else { return InstPair(origGeneric, nullptr); } - // For now, we assume there's only one generic layer. So this inst must be top level - bool isTopLevel = (as<IRModuleInst>(origGeneric->getParent()) != nullptr); - SLANG_RELEASE_ASSERT(isTopLevel); - IRGeneric* primalGeneric = origGeneric; IRBuilder builder(inBuilder->getSharedBuilder()); @@ -1395,10 +1411,6 @@ struct JVPTranscriber diffGeneric->setFullType(diffType); - // TODO(sai): Replace naming scheme - // if (auto jvpName = this->getJVPFuncName(builder, primalFn)) - // builder->addNameHintDecoration(diffFunc, jvpName); - // Transcribe children from origFunc into diffFunc. builder.setInsertInto(diffGeneric); for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock()) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index a0becdafa..09dacc20d 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -6863,14 +6863,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { operandCount += associatedTypeDecl->getMembersOfType<TypeConstraintDecl>().getCount(); } - else if (auto callableDecl = as<CallableDecl>(requirementDecl)) - { - // Differentiable functions has additional requirements for the derivatives. - if (callableDecl->getMembersOfType<ForwardDerivativeRequirementDecl>().getCount()) - operandCount++; - if (callableDecl->getMembersOfType<BackwardDerivativeRequirementDecl>().getCount()) - operandCount++; - } } // Allocate an IRInterfaceType with the `operandCount` operands. @@ -6957,33 +6949,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> if (auto callableDecl = as<CallableDecl>(requirementDecl)) { // Differentiable functions has additional requirements for the derivatives. - for (auto diffDecl : callableDecl->getMembersOfType<DerivativeRequirementDecl>()) + for (auto diffDecl : callableDecl->getMembersOfType<DerivativeRequirementReferenceDecl>()) { - auto diffKey = getInterfaceRequirementKey(diffDecl); - IRInst* diffVal = ensureDecl(subContext, diffDecl).val; - auto diffEntry = subBuilder->createInterfaceRequirementEntry(diffKey, diffVal); - if (diffVal) - { - switch (diffVal->getOp()) - { - case kIROp_Func: - case kIROp_Generic: - { - // Remove lowered `IRFunc`s since we only care about - // function types. - auto reqType = diffVal->getFullType(); - diffEntry->setRequirementVal(reqType); - break; - } - default: - break; - } - } - irInterface->setOperand(entryIndex, diffEntry); - entryIndex++; - - setValue(context, diffDecl, LoweredValInfo::simple(diffEntry)); - insertRequirementKeyAssociation(irInterface, diffDecl, requirementKey, diffKey); + auto diffKey = getInterfaceRequirementKey(diffDecl->referencedDecl); + insertRequirementKeyAssociation(irInterface, diffDecl->referencedDecl, requirementKey, diffKey); } } // Add lowered requirement entry to current decl mapping to prevent |
