summaryrefslogtreecommitdiff
path: root/source/slang/slang-lower-to-ir.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2025-03-14 17:15:36 -0700
committerGitHub <noreply@github.com>2025-03-15 00:15:36 +0000
commit78517dc392f0d2ebba25f0ac3f4d4e004b0f0ab0 (patch)
tree104b48da3fc54e43cd7c5ce51cc66b4e2dc26d55 /source/slang/slang-lower-to-ir.cpp
parentc8c9e424e91e72e718529ed76df14f7586624cd6 (diff)
Fix lowering of associated types in generic interfaces (#6600)
* Fix lowering of associated types in generic interfaces. * Update diff-assoctype-generic-interface.slang * Fix-up lowering of differentiable witnesses for implicit ops * Update slang-ir-autodiff-transcriber-base.cpp * Fix issue with differentiating type-packs
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
-rw-r--r--source/slang/slang-lower-to-ir.cpp41
1 files changed, 30 insertions, 11 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 775986a9a..decfe4a91 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1919,6 +1919,28 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
return LoweredValInfo::simple(getBuilder()->getIntValue(type, val->getValue()));
}
+ IRType* visitDifferentialPairType(DifferentialPairType* pairType)
+ {
+ IRType* primalType = lowerType(context, pairType->getPrimalType());
+ if (as<IRAssociatedType>(primalType) || as<IRThisType>(primalType))
+ {
+ List<IRInst*> operands;
+ SubstitutionSet(pairType->getDeclRef())
+ .forEachSubstitutionArg(
+ [&](Val* arg)
+ {
+ auto argVal = lowerVal(context, arg).val;
+ SLANG_ASSERT(argVal);
+ operands.add(argVal);
+ });
+
+ auto undefined = getBuilder()->emitUndefined(operands[1]->getFullType());
+ return getBuilder()->getDifferentialPairUserCodeType(primalType, undefined);
+ }
+ else
+ return lowerSimpleIntrinsicType(pairType);
+ }
+
IRFuncType* visitFuncType(FuncType* type)
{
IRType* resultType = lowerType(context, type->getResultType());
@@ -10195,15 +10217,17 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// If our function is differentiable, register a callback so the derivative
// annotations for types can be lowered.
//
- if (auto diffAttr = decl->findModifier<DifferentiableAttribute>())
+ if (decl->findModifier<DifferentiableAttribute>() && !isInterfaceRequirement(decl))
{
+ auto diffAttr = decl->findModifier<DifferentiableAttribute>();
+
auto diffTypeWitnessMap = diffAttr->getMapTypeToIDifferentiableWitness();
- OrderedDictionary<DeclRefBase*, SubtypeWitness*> resolveddiffTypeWitnessMap;
+ OrderedDictionary<Type*, SubtypeWitness*> resolveddiffTypeWitnessMap;
// Go through each entry in the map and resolve the key.
for (auto& entry : diffTypeWitnessMap)
{
- auto resolvedKey = as<DeclRefBase>(entry.key->resolve());
+ auto resolvedKey = as<Type>(entry.key->resolve());
resolveddiffTypeWitnessMap[resolvedKey] =
as<SubtypeWitness>(as<Val>(entry.value)->resolve());
}
@@ -10211,14 +10235,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
subContext->registerTypeCallback(
[=](IRGenContext* context, Type* type, IRType* irType)
{
- if (!as<DeclRefType>(type))
- return irType;
-
- DeclRefBase* declRefBase = as<DeclRefType>(type)->getDeclRefBase();
- if (resolveddiffTypeWitnessMap.containsKey(declRefBase))
+ if (resolveddiffTypeWitnessMap.containsKey(type))
{
- auto irWitness =
- lowerVal(subContext, resolveddiffTypeWitnessMap[declRefBase]).val;
+ auto irWitness = lowerVal(subContext, resolveddiffTypeWitnessMap[type]).val;
if (irWitness)
{
IRInst* args[] = {irType, irWitness};
@@ -11328,7 +11347,7 @@ LoweredValInfo emitDeclRef(IRGenContext* context, Decl* decl, DeclRefBase* subst
// interface definitions.
return emitDeclRef(
context,
- createDefaultSpecializedDeclRef(context, nullptr, decl),
+ decl->getDefaultDeclRef(),
context->irBuilder->getTypeKind());
}