summaryrefslogtreecommitdiff
path: root/source/slang/slang-lower-to-ir.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-23 16:02:56 -0800
committerGitHub <noreply@github.com>2022-11-23 16:02:56 -0800
commit4ad0470025da4e808c46023f9a2525febcf973a2 (patch)
tree8fcb1c84121ddf40c50ca58b5de867da0da435ee /source/slang/slang-lower-to-ir.cpp
parent97cb4851eed7a43f10196971b08d3d311386ce9f (diff)
Fix issues around dynamic generic function and autodiff. (#2528)
* Fix issues around dynamic generic function and autodiff. * Fix return type issue. * Fix type unification for generic `inout` parameter. * Fix. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
-rw-r--r--source/slang/slang-lower-to-ir.cpp37
1 files changed, 3 insertions, 34 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index a0becdafa..09dacc20d 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -6863,14 +6863,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
operandCount += associatedTypeDecl->getMembersOfType<TypeConstraintDecl>().getCount();
}
- else if (auto callableDecl = as<CallableDecl>(requirementDecl))
- {
- // Differentiable functions has additional requirements for the derivatives.
- if (callableDecl->getMembersOfType<ForwardDerivativeRequirementDecl>().getCount())
- operandCount++;
- if (callableDecl->getMembersOfType<BackwardDerivativeRequirementDecl>().getCount())
- operandCount++;
- }
}
// Allocate an IRInterfaceType with the `operandCount` operands.
@@ -6957,33 +6949,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
if (auto callableDecl = as<CallableDecl>(requirementDecl))
{
// Differentiable functions has additional requirements for the derivatives.
- for (auto diffDecl : callableDecl->getMembersOfType<DerivativeRequirementDecl>())
+ for (auto diffDecl : callableDecl->getMembersOfType<DerivativeRequirementReferenceDecl>())
{
- auto diffKey = getInterfaceRequirementKey(diffDecl);
- IRInst* diffVal = ensureDecl(subContext, diffDecl).val;
- auto diffEntry = subBuilder->createInterfaceRequirementEntry(diffKey, diffVal);
- if (diffVal)
- {
- switch (diffVal->getOp())
- {
- case kIROp_Func:
- case kIROp_Generic:
- {
- // Remove lowered `IRFunc`s since we only care about
- // function types.
- auto reqType = diffVal->getFullType();
- diffEntry->setRequirementVal(reqType);
- break;
- }
- default:
- break;
- }
- }
- irInterface->setOperand(entryIndex, diffEntry);
- entryIndex++;
-
- setValue(context, diffDecl, LoweredValInfo::simple(diffEntry));
- insertRequirementKeyAssociation(irInterface, diffDecl, requirementKey, diffKey);
+ auto diffKey = getInterfaceRequirementKey(diffDecl->referencedDecl);
+ insertRequirementKeyAssociation(irInterface, diffDecl->referencedDecl, requirementKey, diffKey);
}
}
// Add lowered requirement entry to current decl mapping to prevent