diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-23 16:02:56 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-23 16:02:56 -0800 |
| commit | 4ad0470025da4e808c46023f9a2525febcf973a2 (patch) | |
| tree | 8fcb1c84121ddf40c50ca58b5de867da0da435ee /source/slang/slang-lower-to-ir.cpp | |
| parent | 97cb4851eed7a43f10196971b08d3d311386ce9f (diff) | |
Fix issues around dynamic generic function and autodiff. (#2528)
* Fix issues around dynamic generic function and autodiff.
* Fix return type issue.
* Fix type unification for generic `inout` parameter.
* Fix.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 37 |
1 files changed, 3 insertions, 34 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index a0becdafa..09dacc20d 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -6863,14 +6863,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { operandCount += associatedTypeDecl->getMembersOfType<TypeConstraintDecl>().getCount(); } - else if (auto callableDecl = as<CallableDecl>(requirementDecl)) - { - // Differentiable functions has additional requirements for the derivatives. - if (callableDecl->getMembersOfType<ForwardDerivativeRequirementDecl>().getCount()) - operandCount++; - if (callableDecl->getMembersOfType<BackwardDerivativeRequirementDecl>().getCount()) - operandCount++; - } } // Allocate an IRInterfaceType with the `operandCount` operands. @@ -6957,33 +6949,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> if (auto callableDecl = as<CallableDecl>(requirementDecl)) { // Differentiable functions has additional requirements for the derivatives. - for (auto diffDecl : callableDecl->getMembersOfType<DerivativeRequirementDecl>()) + for (auto diffDecl : callableDecl->getMembersOfType<DerivativeRequirementReferenceDecl>()) { - auto diffKey = getInterfaceRequirementKey(diffDecl); - IRInst* diffVal = ensureDecl(subContext, diffDecl).val; - auto diffEntry = subBuilder->createInterfaceRequirementEntry(diffKey, diffVal); - if (diffVal) - { - switch (diffVal->getOp()) - { - case kIROp_Func: - case kIROp_Generic: - { - // Remove lowered `IRFunc`s since we only care about - // function types. - auto reqType = diffVal->getFullType(); - diffEntry->setRequirementVal(reqType); - break; - } - default: - break; - } - } - irInterface->setOperand(entryIndex, diffEntry); - entryIndex++; - - setValue(context, diffDecl, LoweredValInfo::simple(diffEntry)); - insertRequirementKeyAssociation(irInterface, diffDecl, requirementKey, diffKey); + auto diffKey = getInterfaceRequirementKey(diffDecl->referencedDecl); + insertRequirementKeyAssociation(irInterface, diffDecl->referencedDecl, requirementKey, diffKey); } } // Add lowered requirement entry to current decl mapping to prevent |
