diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2025-03-14 17:15:36 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-03-15 00:15:36 +0000 |
| commit | 78517dc392f0d2ebba25f0ac3f4d4e004b0f0ab0 (patch) | |
| tree | 104b48da3fc54e43cd7c5ce51cc66b4e2dc26d55 /source/slang/slang-lower-to-ir.cpp | |
| parent | c8c9e424e91e72e718529ed76df14f7586624cd6 (diff) | |
Fix lowering of associated types in generic interfaces (#6600)
* Fix lowering of associated types in generic interfaces.
* Update diff-assoctype-generic-interface.slang
* Fix-up lowering of differentiable witnesses for implicit ops
* Update slang-ir-autodiff-transcriber-base.cpp
* Fix issue with differentiating type-packs
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 41 |
1 files changed, 30 insertions, 11 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 775986a9a..decfe4a91 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1919,6 +1919,28 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower return LoweredValInfo::simple(getBuilder()->getIntValue(type, val->getValue())); } + IRType* visitDifferentialPairType(DifferentialPairType* pairType) + { + IRType* primalType = lowerType(context, pairType->getPrimalType()); + if (as<IRAssociatedType>(primalType) || as<IRThisType>(primalType)) + { + List<IRInst*> operands; + SubstitutionSet(pairType->getDeclRef()) + .forEachSubstitutionArg( + [&](Val* arg) + { + auto argVal = lowerVal(context, arg).val; + SLANG_ASSERT(argVal); + operands.add(argVal); + }); + + auto undefined = getBuilder()->emitUndefined(operands[1]->getFullType()); + return getBuilder()->getDifferentialPairUserCodeType(primalType, undefined); + } + else + return lowerSimpleIntrinsicType(pairType); + } + IRFuncType* visitFuncType(FuncType* type) { IRType* resultType = lowerType(context, type->getResultType()); @@ -10195,15 +10217,17 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // If our function is differentiable, register a callback so the derivative // annotations for types can be lowered. // - if (auto diffAttr = decl->findModifier<DifferentiableAttribute>()) + if (decl->findModifier<DifferentiableAttribute>() && !isInterfaceRequirement(decl)) { + auto diffAttr = decl->findModifier<DifferentiableAttribute>(); + auto diffTypeWitnessMap = diffAttr->getMapTypeToIDifferentiableWitness(); - OrderedDictionary<DeclRefBase*, SubtypeWitness*> resolveddiffTypeWitnessMap; + OrderedDictionary<Type*, SubtypeWitness*> resolveddiffTypeWitnessMap; // Go through each entry in the map and resolve the key. for (auto& entry : diffTypeWitnessMap) { - auto resolvedKey = as<DeclRefBase>(entry.key->resolve()); + auto resolvedKey = as<Type>(entry.key->resolve()); resolveddiffTypeWitnessMap[resolvedKey] = as<SubtypeWitness>(as<Val>(entry.value)->resolve()); } @@ -10211,14 +10235,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> subContext->registerTypeCallback( [=](IRGenContext* context, Type* type, IRType* irType) { - if (!as<DeclRefType>(type)) - return irType; - - DeclRefBase* declRefBase = as<DeclRefType>(type)->getDeclRefBase(); - if (resolveddiffTypeWitnessMap.containsKey(declRefBase)) + if (resolveddiffTypeWitnessMap.containsKey(type)) { - auto irWitness = - lowerVal(subContext, resolveddiffTypeWitnessMap[declRefBase]).val; + auto irWitness = lowerVal(subContext, resolveddiffTypeWitnessMap[type]).val; if (irWitness) { IRInst* args[] = {irType, irWitness}; @@ -11328,7 +11347,7 @@ LoweredValInfo emitDeclRef(IRGenContext* context, Decl* decl, DeclRefBase* subst // interface definitions. return emitDeclRef( context, - createDefaultSpecializedDeclRef(context, nullptr, decl), + decl->getDefaultDeclRef(), context->irBuilder->getTypeKind()); } |
