summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-lower-generic-function.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-11 15:33:28 -0800
committerGitHub <noreply@github.com>2023-01-11 15:33:28 -0800
commita3ac6e71cbc922b7c941c45f23ee18a9fc274d1f (patch)
treeacf8c18601f124e9290494f8b379d2420369fc35 /source/slang/slang-ir-lower-generic-function.cpp
parent20262684bcbb707d16669b2670039df870b65ca8 (diff)
Make backward differentiation work with generics. (#2586)
* Make backward differentiation work with generics. * Fix. * Another fix. * More fix. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-lower-generic-function.cpp')
-rw-r--r--source/slang/slang-ir-lower-generic-function.cpp14
1 files changed, 10 insertions, 4 deletions
diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp
index 806ea8826..6f412d579 100644
--- a/source/slang/slang-ir-lower-generic-function.cpp
+++ b/source/slang/slang-ir-lower-generic-function.cpp
@@ -48,9 +48,12 @@ namespace Slang
IRCloneEnv cloneEnv;
IRBuilder builder(sharedContext->sharedBuilderStorage);
builder.setInsertBefore(genericParent);
+ // Do not clone func type (which would break IR def-use rules if we do it here)
+ // This is OK since we will lower the type immediately after the clone.
+ cloneEnv.mapOldValToNew[func->getFullType()] = builder.getTypeKind();
auto loweredFunc = cast<IRFunc>(cloneInstAndOperands(&cloneEnv, &builder, func));
auto loweredGenericType =
- lowerGenericFuncType(&builder, cast<IRGeneric>(genericParent->getFullType()));
+ lowerGenericFuncType(&builder, genericParent, cast<IRFuncType>(func->getFullType()));
SLANG_ASSERT(loweredGenericType);
loweredFunc->setFullType(loweredGenericType);
List<IRInst*> clonedParams;
@@ -90,7 +93,7 @@ namespace Slang
return loweredFunc;
}
- IRType* lowerGenericFuncType(IRBuilder* builder, IRGeneric* genericVal)
+ IRType* lowerGenericFuncType(IRBuilder* builder, IRGeneric* genericVal, IRFuncType* funcType)
{
ShortList<IRInst*> genericParamTypes;
Dictionary<IRInst*, IRInst*> typeMapping;
@@ -107,7 +110,7 @@ namespace Slang
auto innerType = (IRFuncType*)lowerFuncType(
builder,
- cast<IRFuncType>(findGenericReturnVal(genericVal)),
+ funcType,
typeMapping,
genericParamTypes.getArrayView().arrayView);
@@ -182,7 +185,10 @@ namespace Slang
}
else if (auto genericFuncType = as<IRGeneric>(requirementVal))
{
- loweredVal = lowerGenericFuncType(&builder, genericFuncType);
+ loweredVal = lowerGenericFuncType(
+ &builder,
+ genericFuncType,
+ cast<IRFuncType>(findGenericReturnVal(genericFuncType)));
}
else if (requirementVal->getOp() == kIROp_AssociatedType)
{