summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
-rw-r--r--source/slang/slang-ir-autodiff.cpp16
1 files changed, 8 insertions, 8 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index df657476a..f3f32add2 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -1362,9 +1362,10 @@ IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(
IRBuilder* builder,
IRType* origType,
IRStructKey* key,
- IRType* resultType)
+ IRType* resultType,
+ DiffConformanceKind kind)
{
- if (auto conformance = tryGetDifferentiableWitness(builder, origType, DiffConformanceKind::Any))
+ if (auto conformance = tryGetDifferentiableWitness(builder, origType, kind))
return _lookupWitness(builder, conformance, key, resultType);
return nullptr;
}
@@ -2097,8 +2098,6 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
IRWitnessTable* table = nullptr;
if (target == DiffConformanceKind::Value)
{
- SLANG_ASSERT(isDifferentiableValueType((IRType*)inTupleType));
-
auto addMethod = builder->createFunc();
auto zeroMethod = builder->createFunc();
@@ -2138,6 +2137,8 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
&b,
(IRType*)elementType,
DiffConformanceKind::Value);
+
+ SLANG_ASSERT(isDifferentiableValueType((IRType*)elementType));
IRInst* elementResult = nullptr;
if (!innerWitness)
{
@@ -2171,9 +2172,9 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
{
// Zero method.
IRBuilder b = *builder;
- b.setInsertInto(addMethod);
- b.addBackwardDifferentiableDecoration(addMethod);
- addMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType));
+ b.setInsertInto(zeroMethod);
+ b.addBackwardDifferentiableDecoration(zeroMethod);
+ zeroMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType));
b.emitBlock();
List<IRInst*> results;
for (UInt i = 0; i < inTupleType->getOperandCount(); i++)
@@ -2214,7 +2215,6 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
else if (target == DiffConformanceKind::Ptr)
{
SLANG_ASSERT(isDifferentiablePtrType((IRType*)inTupleType));
-
table = builder->createWitnessTable(
sharedContext->differentiablePtrInterfaceType,
(IRType*)inTupleType);