summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2025-03-14 17:15:36 -0700
committerGitHub <noreply@github.com>2025-03-15 00:15:36 +0000
commit78517dc392f0d2ebba25f0ac3f4d4e004b0f0ab0 (patch)
tree104b48da3fc54e43cd7c5ce51cc66b4e2dc26d55 /source
parentc8c9e424e91e72e718529ed76df14f7586624cd6 (diff)
Fix lowering of associated types in generic interfaces (#6600)
* Fix lowering of associated types in generic interfaces. * Update diff-assoctype-generic-interface.slang * Fix-up lowering of differentiable witnesses for implicit ops * Update slang-ir-autodiff-transcriber-base.cpp * Fix issue with differentiating type-packs
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-dump.cpp2
-rw-r--r--source/slang/slang-ast-modifier.cpp2
-rw-r--r--source/slang/slang-ast-modifier.h10
-rw-r--r--source/slang/slang-check-expr.cpp28
-rw-r--r--source/slang/slang-check-impl.h2
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp36
-rw-r--r--source/slang/slang-ir-autodiff.cpp16
-rw-r--r--source/slang/slang-ir-autodiff.h9
-rw-r--r--source/slang/slang-lower-to-ir.cpp41
9 files changed, 77 insertions, 69 deletions
diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp
index f6cdf50d8..bd366be19 100644
--- a/source/slang/slang-ast-dump.cpp
+++ b/source/slang/slang-ast-dump.cpp
@@ -647,7 +647,7 @@ struct ASTDumpContext
void dump(SourceLanguage language) { m_writer->emit((int)language); }
- void dump(KeyValuePair<DeclRefBase*, SubtypeWitness*> pair)
+ void dump(KeyValuePair<Type*, SubtypeWitness*> pair)
{
m_writer->emit("(");
dump(pair.key);
diff --git a/source/slang/slang-ast-modifier.cpp b/source/slang/slang-ast-modifier.cpp
index 2a245130e..383389c39 100644
--- a/source/slang/slang-ast-modifier.cpp
+++ b/source/slang/slang-ast-modifier.cpp
@@ -5,7 +5,7 @@
namespace Slang
{
-const OrderedDictionary<DeclRefBase*, SubtypeWitness*>& DifferentiableAttribute::
+const OrderedDictionary<Type*, SubtypeWitness*>& DifferentiableAttribute::
getMapTypeToIDifferentiableWitness()
{
for (Index i = m_mapToIDifferentiableWitness.getCount();
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index e4d5ccd09..5f9ccb5bb 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1391,25 +1391,25 @@ class DifferentiableAttribute : public Attribute
{
SLANG_AST_CLASS(DifferentiableAttribute)
- List<KeyValuePair<DeclRefBase*, SubtypeWitness*>> m_typeToIDifferentiableWitnessMappings;
+ List<KeyValuePair<Type*, SubtypeWitness*>> m_typeToIDifferentiableWitnessMappings;
- void addType(DeclRefBase* declRef, SubtypeWitness* witness)
+ void addType(Type* declRef, SubtypeWitness* witness)
{
getMapTypeToIDifferentiableWitness();
if (m_mapToIDifferentiableWitness.addIfNotExists(declRef, witness))
{
m_typeToIDifferentiableWitnessMappings.add(
- KeyValuePair<DeclRefBase*, SubtypeWitness*>(declRef, witness));
+ KeyValuePair<Type*, SubtypeWitness*>(declRef, witness));
}
}
/// Mapping from types to subtype witnesses for conformance to IDifferentiable.
- const OrderedDictionary<DeclRefBase*, SubtypeWitness*>& getMapTypeToIDifferentiableWitness();
+ const OrderedDictionary<Type*, SubtypeWitness*>& getMapTypeToIDifferentiableWitness();
SLANG_UNREFLECTED ValSet m_typeRegistrationWorkingSet;
private:
- OrderedDictionary<DeclRefBase*, SubtypeWitness*> m_mapToIDifferentiableWitness;
+ OrderedDictionary<Type*, SubtypeWitness*> m_mapToIDifferentiableWitness;
};
class DllImportAttribute : public Attribute
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index b730069b6..2f91a6a77 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1405,14 +1405,12 @@ Type* SemanticsVisitor::getDifferentialType(ASTBuilder* builder, Type* type, Sou
return result;
}
-void SemanticsVisitor::addDifferentiableTypeToDiffTypeRegistry(
- DeclRefType* type,
- SubtypeWitness* witness)
+void SemanticsVisitor::addDifferentiableTypeToDiffTypeRegistry(Type* type, SubtypeWitness* witness)
{
SLANG_RELEASE_ASSERT(m_parentDifferentiableAttr);
if (witness)
{
- m_parentDifferentiableAttr->addType(type->getDeclRef(), witness);
+ m_parentDifferentiableAttr->addType(type, witness);
}
}
@@ -1468,14 +1466,14 @@ void SemanticsVisitor::maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder*
type,
getASTBuilder()->getDifferentiableInterfaceType())))
{
- addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness);
+ addDifferentiableTypeToDiffTypeRegistry(type, subtypeWitness);
}
if (auto subtypeWitness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(
type,
getASTBuilder()->getDifferentiableRefInterfaceType())))
{
- addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness);
+ addDifferentiableTypeToDiffTypeRegistry(type, subtypeWitness);
}
if (auto aggTypeDeclRef = declRefType->getDeclRef().as<AggTypeDecl>())
@@ -1515,6 +1513,15 @@ void SemanticsVisitor::maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder*
maybeRegisterDifferentiableTypeImplRecursive(builder, typePack->getElementType(i));
return;
}
+
+ // General check for types that may not be decl-ref-type, but still have some conformance to
+ // IDifferentiable/IDifferentiablePtrType
+ if (auto subtypeWitness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(
+ type,
+ getASTBuilder()->getDifferentiableInterfaceType())))
+ {
+ addDifferentiableTypeToDiffTypeRegistry(type, subtypeWitness);
+ }
}
@@ -4846,7 +4853,14 @@ Expr* SemanticsVisitor::checkBaseForMemberExpr(
auto baseExpr = inBaseExpr;
baseExpr = CheckTerm(baseExpr);
- return maybeInsertImplicitOpForMemberBase(baseExpr, checkBaseContext, outNeedDeref);
+ auto resultBaseExpr =
+ maybeInsertImplicitOpForMemberBase(baseExpr, checkBaseContext, outNeedDeref);
+
+ // We might want to register differentiability on any implicit ops that we add in.
+ if (this->m_parentFunc && this->m_parentFunc->findModifier<DifferentiableAttribute>())
+ maybeRegisterDifferentiableType(getASTBuilder(), resultBaseExpr->type.type);
+
+ return resultBaseExpr;
}
Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* baseType)
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index ba3792af7..95716744c 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1512,7 +1512,7 @@ public:
/// Registers a type as conforming to IDifferentiable, along with a witness
/// describing the relationship.
///
- void addDifferentiableTypeToDiffTypeRegistry(DeclRefType* type, SubtypeWitness* witness);
+ void addDifferentiableTypeToDiffTypeRegistry(Type* type, SubtypeWitness* witness);
void maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder* builder, Type* type);
// Construct the differential for 'type', if it exists.
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 8356e5f81..d67d75997 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -720,9 +720,7 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I
if (auto diffType = differentiateType(builder, originalType))
{
- IRInst* diffWitnessTable = nullptr;
- IRType* diffOuterType = nullptr;
- if (isExistentialType(diffType))
+ if (isExistentialType(diffType) && !as<IRLookupWitnessMethod>(diffType))
{
// Emit null differential & pack it into an IDifferentiable existential.
@@ -789,25 +787,8 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I
return result;
}
- // Since primalType has a corresponding differential type, we can lookup the
- // definition for zero().
- IRInst* zeroMethod = nullptr;
- if (auto lookupInterface = as<IRLookupWitnessMethod>(diffType))
- {
- // if the differential type itself comes from a witness lookup, we can just lookup the
- // zero method from the same witness table.
- auto wt = lookupInterface->getWitnessTable();
- zeroMethod = builder->emitLookupInterfaceMethodInst(
- builder->getFuncType(List<IRType*>(), diffType),
- wt,
- autoDiffSharedContext->zeroMethodStructKey);
- builder->markInstAsPrimal(zeroMethod);
- }
- else
- {
- zeroMethod =
- differentiableTypeConformanceContext.getZeroMethodForType(builder, originalType);
- }
+ auto zeroMethod =
+ differentiableTypeConformanceContext.getZeroMethodForType(builder, originalType);
SLANG_RELEASE_ASSERT(zeroMethod);
auto emptyArgList = List<IRInst*>();
@@ -815,16 +796,7 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I
auto callInst = builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList);
builder->markInstAsDifferential(callInst, primalType);
- if (diffOuterType && isExistentialType(diffOuterType))
- {
- // Need to wrap the result back into an existential.
- auto existentialZero =
- builder->emitMakeExistential(diffOuterType, callInst, diffWitnessTable);
- builder->markInstAsDifferential(existentialZero, primalType);
- return existentialZero;
- }
- else
- return callInst;
+ return callInst;
}
else
{
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index df657476a..f3f32add2 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -1362,9 +1362,10 @@ IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(
IRBuilder* builder,
IRType* origType,
IRStructKey* key,
- IRType* resultType)
+ IRType* resultType,
+ DiffConformanceKind kind)
{
- if (auto conformance = tryGetDifferentiableWitness(builder, origType, DiffConformanceKind::Any))
+ if (auto conformance = tryGetDifferentiableWitness(builder, origType, kind))
return _lookupWitness(builder, conformance, key, resultType);
return nullptr;
}
@@ -2097,8 +2098,6 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
IRWitnessTable* table = nullptr;
if (target == DiffConformanceKind::Value)
{
- SLANG_ASSERT(isDifferentiableValueType((IRType*)inTupleType));
-
auto addMethod = builder->createFunc();
auto zeroMethod = builder->createFunc();
@@ -2138,6 +2137,8 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
&b,
(IRType*)elementType,
DiffConformanceKind::Value);
+
+ SLANG_ASSERT(isDifferentiableValueType((IRType*)elementType));
IRInst* elementResult = nullptr;
if (!innerWitness)
{
@@ -2171,9 +2172,9 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
{
// Zero method.
IRBuilder b = *builder;
- b.setInsertInto(addMethod);
- b.addBackwardDifferentiableDecoration(addMethod);
- addMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType));
+ b.setInsertInto(zeroMethod);
+ b.addBackwardDifferentiableDecoration(zeroMethod);
+ zeroMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType));
b.emitBlock();
List<IRInst*> results;
for (UInt i = 0; i < inTupleType->getOperandCount(); i++)
@@ -2214,7 +2215,6 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
else if (target == DiffConformanceKind::Ptr)
{
SLANG_ASSERT(isDifferentiablePtrType((IRType*)inTupleType));
-
table = builder->createWitnessTable(
sharedContext->differentiablePtrInterfaceType,
(IRType*)inTupleType);
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index 4698408e3..2cd08eb28 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -252,7 +252,8 @@ struct DifferentiableTypeConformanceContext
IRBuilder* builder,
IRType* origType,
IRStructKey* key,
- IRType* resultType = nullptr);
+ IRType* resultType = nullptr,
+ DiffConformanceKind kind = DiffConformanceKind::Any);
IRType* differentiateType(IRBuilder* builder, IRInst* primalType);
@@ -411,7 +412,8 @@ struct DifferentiableTypeConformanceContext
builder,
origType,
sharedContext->zeroMethodStructKey,
- sharedContext->zeroMethodType);
+ sharedContext->zeroMethodType,
+ DiffConformanceKind::Value);
return result;
}
@@ -421,7 +423,8 @@ struct DifferentiableTypeConformanceContext
builder,
origType,
sharedContext->addMethodStructKey,
- sharedContext->addMethodType);
+ sharedContext->addMethodType,
+ DiffConformanceKind::Value);
return result;
}
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 775986a9a..decfe4a91 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1919,6 +1919,28 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
return LoweredValInfo::simple(getBuilder()->getIntValue(type, val->getValue()));
}
+ IRType* visitDifferentialPairType(DifferentialPairType* pairType)
+ {
+ IRType* primalType = lowerType(context, pairType->getPrimalType());
+ if (as<IRAssociatedType>(primalType) || as<IRThisType>(primalType))
+ {
+ List<IRInst*> operands;
+ SubstitutionSet(pairType->getDeclRef())
+ .forEachSubstitutionArg(
+ [&](Val* arg)
+ {
+ auto argVal = lowerVal(context, arg).val;
+ SLANG_ASSERT(argVal);
+ operands.add(argVal);
+ });
+
+ auto undefined = getBuilder()->emitUndefined(operands[1]->getFullType());
+ return getBuilder()->getDifferentialPairUserCodeType(primalType, undefined);
+ }
+ else
+ return lowerSimpleIntrinsicType(pairType);
+ }
+
IRFuncType* visitFuncType(FuncType* type)
{
IRType* resultType = lowerType(context, type->getResultType());
@@ -10195,15 +10217,17 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// If our function is differentiable, register a callback so the derivative
// annotations for types can be lowered.
//
- if (auto diffAttr = decl->findModifier<DifferentiableAttribute>())
+ if (decl->findModifier<DifferentiableAttribute>() && !isInterfaceRequirement(decl))
{
+ auto diffAttr = decl->findModifier<DifferentiableAttribute>();
+
auto diffTypeWitnessMap = diffAttr->getMapTypeToIDifferentiableWitness();
- OrderedDictionary<DeclRefBase*, SubtypeWitness*> resolveddiffTypeWitnessMap;
+ OrderedDictionary<Type*, SubtypeWitness*> resolveddiffTypeWitnessMap;
// Go through each entry in the map and resolve the key.
for (auto& entry : diffTypeWitnessMap)
{
- auto resolvedKey = as<DeclRefBase>(entry.key->resolve());
+ auto resolvedKey = as<Type>(entry.key->resolve());
resolveddiffTypeWitnessMap[resolvedKey] =
as<SubtypeWitness>(as<Val>(entry.value)->resolve());
}
@@ -10211,14 +10235,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
subContext->registerTypeCallback(
[=](IRGenContext* context, Type* type, IRType* irType)
{
- if (!as<DeclRefType>(type))
- return irType;
-
- DeclRefBase* declRefBase = as<DeclRefType>(type)->getDeclRefBase();
- if (resolveddiffTypeWitnessMap.containsKey(declRefBase))
+ if (resolveddiffTypeWitnessMap.containsKey(type))
{
- auto irWitness =
- lowerVal(subContext, resolveddiffTypeWitnessMap[declRefBase]).val;
+ auto irWitness = lowerVal(subContext, resolveddiffTypeWitnessMap[type]).val;
if (irWitness)
{
IRInst* args[] = {irType, irWitness};
@@ -11328,7 +11347,7 @@ LoweredValInfo emitDeclRef(IRGenContext* context, Decl* decl, DeclRefBase* subst
// interface definitions.
return emitDeclRef(
context,
- createDefaultSpecializedDeclRef(context, nullptr, decl),
+ decl->getDefaultDeclRef(),
context->irBuilder->getTypeKind());
}