summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-09-19 03:10:28 -0400
committerGitHub <noreply@github.com>2024-09-19 00:10:28 -0700
commitccc310fa4e8096cda8a6c127aacc1a1fa9d8503a (patch)
tree435e9c462a78fb848ab3b36c23287543d1a859de /source
parent1781c2969eb65fb7ade01d3f0d7d9b8973bcd4d3 (diff)
Support `IDifferentiablePtrType` (#5031)
* initial diff-ref-type interface * Initial support for `IDifferentiablePtrType` * Fix unused vars * More tests + fix switch case fallthrough. * Update slang-ir-autodiff.cpp * Update diff-ptr-type-loop.slang * Add optimization to allow more complex pair types * Update slang-ir-autodiff-primal-hoist.cpp * Update diff-ptr-type-loop.slang * Update slang-ir-autodiff-primal-hoist.cpp * More fixes to address reviews * Update slang-check-expr.cpp * Optimizations + rename `differentiableRefInterfaceType` -> `differentiablePtrInterfaceType` * Move pair logic to ir-builder, unify the type dictionaries. --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang37
-rw-r--r--source/slang/slang-ast-builder.cpp23
-rw-r--r--source/slang/slang-ast-builder.h9
-rw-r--r--source/slang/slang-ast-support-types.h3
-rw-r--r--source/slang/slang-ast-type.h11
-rw-r--r--source/slang/slang-check-conformance.cpp9
-rw-r--r--source/slang/slang-check-decl.cpp8
-rw-r--r--source/slang/slang-check-expr.cpp32
-rw-r--r--source/slang/slang-check-impl.h2
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp80
-rw-r--r--source/slang/slang-ir-autodiff-pairs.cpp3
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp13
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp29
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp173
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.h6
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h3
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp5
-rw-r--r--source/slang/slang-ir-autodiff.cpp793
-rw-r--r--source/slang/slang-ir-autodiff.h167
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp2
-rw-r--r--source/slang/slang-ir-inst-defs.h12
-rw-r--r--source/slang/slang-ir-insts.h28
-rw-r--r--source/slang/slang-ir.cpp141
-rw-r--r--source/slang/slang-ir.h5
24 files changed, 1215 insertions, 379 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index afcff8e65..476279ab8 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -285,6 +285,13 @@ interface IDifferentiable
static Differential dmul(T, Differential);
};
+__magic_type(DifferentiablePtrType)
+interface IDifferentiablePtrType
+{
+ __builtin_requirement($( (int)BuiltinRequirementKind::DifferentialPtrType) )
+ associatedtype Differential : IDifferentiablePtrType;
+};
+
/// Pair type that serves to wrap the primal and
/// differential types of an arbitrary type T.
@@ -357,6 +364,36 @@ struct DifferentialPair : IDifferentiable
}
};
+__generic<T : IDifferentiablePtrType>
+__magic_type(DifferentialPtrPairType)
+__intrinsic_type($(kIROp_DifferentialPtrPairType))
+struct DifferentialPtrPair : IDifferentiablePtrType
+{
+ typedef DifferentialPtrPair<T.Differential> Differential;
+ typedef T.Differential DifferentialElementType;
+
+ __intrinsic_op($(kIROp_MakeDifferentialPtrPair))
+ __init(T _primal, T.Differential _differential);
+
+ property p : T
+ {
+ __intrinsic_op($(kIROp_DifferentialPtrPairGetPrimal))
+ get;
+ }
+
+ property v : T
+ {
+ __intrinsic_op($(kIROp_DifferentialPtrPairGetPrimal))
+ get;
+ }
+
+ property d : T.Differential
+ {
+ __intrinsic_op($(kIROp_DifferentialPtrPairGetDifferential))
+ get;
+ }
+};
+
/// A type that uses a floating-point representation
[sealed]
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index 9879a4187..b66af34fa 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -408,18 +408,32 @@ MatrixExpressionType* ASTBuilder::getMatrixType(Type* elementType, IntVal* rowCo
DifferentialPairType* ASTBuilder::getDifferentialPairType(
Type* valueType,
- Witness* primalIsDifferentialWitness)
+ Witness* diffTypeWitness)
{
- Val* args[] = { valueType, primalIsDifferentialWitness };
+ Val* args[] = { valueType, diffTypeWitness };
return as<DifferentialPairType>(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPairType"));
}
+DifferentialPtrPairType* ASTBuilder::getDifferentialPtrPairType(
+ Type* valueType,
+ Witness* diffRefTypeWitness)
+{
+ Val* args[] = { valueType, diffRefTypeWitness };
+ return as<DifferentialPtrPairType>(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPtrPairType"));
+}
+
DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableInterfaceDecl()
{
DeclRef<InterfaceDecl> declRef = DeclRef<InterfaceDecl>(getBuiltinDeclRef("DifferentiableType", nullptr));
return declRef;
}
+DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableRefInterfaceDecl()
+{
+ DeclRef<InterfaceDecl> declRef = DeclRef<InterfaceDecl>(getBuiltinDeclRef("DifferentiablePtrType", nullptr));
+ return declRef;
+}
+
bool ASTBuilder::isDifferentiableInterfaceAvailable()
{
return (m_sharedASTBuilder->tryFindMagicDecl("DifferentiableType") != nullptr);
@@ -459,6 +473,11 @@ Type* ASTBuilder::getDifferentiableInterfaceType()
return DeclRefType::create(this, getDifferentiableInterfaceDecl());
}
+Type* ASTBuilder::getDifferentiableRefInterfaceType()
+{
+ return DeclRefType::create(this, getDifferentiableRefInterfaceDecl());
+}
+
DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg)
{
auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName);
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index b9b1f7ab8..08951513d 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -489,10 +489,17 @@ public:
DifferentialPairType* getDifferentialPairType(
Type* valueType,
- Witness* primalIsDifferentialWitness);
+ Witness* diffTypeWitness);
+
+ DifferentialPtrPairType* getDifferentialPtrPairType(
+ Type* valueType,
+ Witness* diffRefTypeWitness);
DeclRef<InterfaceDecl> getDifferentiableInterfaceDecl();
+ DeclRef<InterfaceDecl> getDifferentiableRefInterfaceDecl();
+
Type* getDifferentiableInterfaceType();
+ Type* getDifferentiableRefInterfaceType();
bool isDifferentiableInterfaceAvailable();
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index 83a4cf353..56101bb91 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -9,7 +9,7 @@
#include "slang-profile.h"
#include "slang-type-system-shared.h"
-#include "slang.h"
+#include "../../include/slang.h"
#include "../core/slang-semantic-version.h"
@@ -1606,6 +1606,7 @@ namespace Slang
DefaultInitializableConstructor, ///< The `IDefaultInitializable.__init()` method
DifferentialType, ///< The `IDifferentiable.Differential` associated type requirement
+ DifferentialPtrType, ///< The `IDifferentiable.DifferentialPtr` associated type requirement
DZeroFunc, ///< The `IDifferentiable.dzero` function requirement
DAddFunc, ///< The `IDifferentiable.dadd` function requirement
DMulFunc, ///< The `IDifferentiable.dmul` function requirement
diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h
index 401d73e29..46ea3ea55 100644
--- a/source/slang/slang-ast-type.h
+++ b/source/slang/slang-ast-type.h
@@ -462,11 +462,22 @@ class DifferentialPairType : public ArithmeticExpressionType
Type* getPrimalType();
};
+class DifferentialPtrPairType : public ArithmeticExpressionType
+{
+ SLANG_AST_CLASS(DifferentialPtrPairType)
+ Type* getPrimalRefType();
+};
+
class DifferentiableType : public BuiltinType
{
SLANG_AST_CLASS(DifferentiableType)
};
+class DifferentiablePtrType : public BuiltinType
+{
+ SLANG_AST_CLASS(DifferentiablePtrType)
+};
+
class DefaultInitializableType : public BuiltinType
{
SLANG_AST_CLASS(DefaultInitializableType);
diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp
index ffa037996..9d9047e41 100644
--- a/source/slang/slang-check-conformance.cpp
+++ b/source/slang/slang-check-conformance.cpp
@@ -274,9 +274,14 @@ namespace Slang
return isInterfaceType(type);
}
- bool SemanticsVisitor::isTypeDifferentiable(Type* type)
+ SubtypeWitness* SemanticsVisitor::isTypeDifferentiable(Type* type)
{
- return isSubtype(type, m_astBuilder->getDiffInterfaceType(), IsSubTypeOptions::None);
+ if (auto valueWitness = isSubtype(type, m_astBuilder->getDiffInterfaceType(), IsSubTypeOptions::None))
+ return valueWitness;
+ else if (auto ptrWitness = isSubtype(type, m_astBuilder->getDifferentiableRefInterfaceType(), IsSubTypeOptions::None))
+ return ptrWitness;
+
+ return nullptr;
}
bool SemanticsVisitor::doesTypeHaveTag(Type* type, TypeTag tag)
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index deb8c55eb..8e78ff084 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -10204,7 +10204,8 @@ namespace Slang
bool isDiffParam = (!param->findModifier<NoDiffModifier>());
if (isDiffParam)
{
- if (auto pairType = as<DifferentialPairType>(visitor->getDifferentialPairType(param->getType())))
+ auto diffPair = visitor->getDifferentialPairType(param->getType());
+ if (auto pairType = as<DifferentialPairType>(diffPair))
{
arg->type.type = pairType;
arg->type.isLeftValue = true;
@@ -10225,6 +10226,11 @@ namespace Slang
direction = ParameterDirection::kParameterDirection_InOut;
}
}
+ else if (auto refPairType = as<DifferentialPtrPairType>(diffPair))
+ {
+ // no need to change direction of ref-pairs.
+ arg->type.type = refPairType;
+ }
else
{
isDiffParam = false;
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 8f24ec5b0..5233008fd 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1131,7 +1131,8 @@ namespace Slang
{
if (auto builtinRequirement = declRefType->getDeclRef().getDecl()->findModifier<BuiltinRequirementModifier>())
{
- if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType)
+ if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType
+ || builtinRequirement->kind == BuiltinRequirementKind::DifferentialPtrType)
{
// We are trying to get differential type from a differential type.
// The result is itself.
@@ -1139,7 +1140,10 @@ namespace Slang
}
}
type = resolveType(type);
- if (const auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType())))
+ auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType()));
+ if (!witness)
+ witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableRefInterfaceType()));
+ if (witness)
{
auto diffTypeLookupResult = lookUpMember(
getASTBuilder(),
@@ -1367,6 +1371,13 @@ namespace Slang
{
addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness);
}
+
+ if (auto subtypeWitness = as<SubtypeWitness>(
+ tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableRefInterfaceType())))
+ {
+ addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness);
+ }
+
if (auto aggTypeDeclRef = declRefType->getDeclRef().as<AggTypeDecl>())
{
foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member)
@@ -2899,18 +2910,25 @@ namespace Slang
return m_astBuilder->getExpandType(diffPairEachType, makeArrayViewSingle(primalType));
}
}
+
// Get a reference to the builtin 'IDifferentiable' interface
auto differentiableInterface = getASTBuilder()->getDifferentiableInterfaceType();
+ auto differentiableRefInterface = getASTBuilder()->getDifferentiableRefInterfaceType();
- auto conformanceWitness = as<Witness>(isSubtype(primalType, differentiableInterface, IsSubTypeOptions::None));
// Check if the provided type inherits from IDifferentiable.
// If not, return the original type.
- if (conformanceWitness)
+ if (auto conformanceWitness = isTypeDifferentiable(primalType))
{
- return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness);
+ if (conformanceWitness->getSup() == differentiableInterface)
+ {
+ return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness);
+ }
+ else if (conformanceWitness->getSup() == differentiableRefInterface)
+ {
+ return m_astBuilder->getDifferentialPtrPairType(primalType, conformanceWitness);
+ }
}
- else
- return primalType;
+ return primalType;
}
Type* SemanticsVisitor::getForwardDiffFuncType(FuncType* originalType)
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index adb7e81f3..29a57ae35 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2208,7 +2208,7 @@ namespace Slang
bool isValidGenericConstraintType(Type* type);
- bool isTypeDifferentiable(Type* type);
+ SubtypeWitness* isTypeDifferentiable(Type* type);
bool doesTypeHaveTag(Type* type, TypeTag tag);
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index fe7c77ba0..609bcd8a3 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -336,8 +336,8 @@ InstPair ForwardDiffTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* orig
{
auto origPtr = origLoad->getPtr();
auto primalPtr = lookupPrimalInst(builder, origPtr, nullptr);
- auto primalPtrType = as<IRPtrTypeBase>(primalPtr->getFullType());
- if (primalPtrType)
+
+ if (auto primalPtrType = as<IRPtrTypeBase>(primalPtr->getFullType()))
{
if (auto diffPairType = as<IRDifferentialPairType>(primalPtrType->getValueType()))
{
@@ -355,6 +355,18 @@ InstPair ForwardDiffTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* orig
(IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, diffPairType), load);
return InstPair(primalElement, diffElement);
}
+ else if (auto diffPtrPairType = as<IRDifferentialPtrPairType>(primalPtrType->getValueType()))
+ {
+ auto load = builder->emitLoad(primalPtr);
+ builder->markInstAsPrimal(load);
+
+ auto primalElement = builder->emitDifferentialPtrPairGetPrimal(load);
+ auto diffElement = builder->emitDifferentialPtrPairGetDifferential(
+ (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, diffPtrPairType), load);
+ builder->markInstAsPrimal(primalElement);
+ builder->markInstAsPrimal(diffElement);
+ return InstPair(primalElement, diffElement);
+ }
}
auto primalLoad = maybeCloneForPrimalInst(builder, origLoad);
@@ -389,6 +401,16 @@ InstPair ForwardDiffTranscriber::transcribeStore(IRBuilder* builder, IRStore* or
return InstPair(store, nullptr);
}
+ else if (auto diffRefPairType = as<IRDifferentialPtrPairType>(primalLocationPtrType->getValueType()))
+ {
+ auto valToStore = builder->emitMakeDifferentialPtrPair(diffRefPairType, primalStoreVal, diffStoreVal);
+ builder->markInstAsPrimal(valToStore);
+
+ auto store = builder->emitStore(primalStoreLocation, valToStore);
+ builder->markInstAsPrimal(store);
+
+ return InstPair(store, nullptr);
+ }
}
auto primalStore = maybeCloneForPrimalInst(builder, origStore);
@@ -404,7 +426,7 @@ InstPair ForwardDiffTranscriber::transcribeStore(IRBuilder* builder, IRStore* or
// Default case, storing the entire type (and not a member)
diffStore = as<IRStore>(
builder->emitStore(diffStoreLocation, diffStoreVal));
-
+ markDiffTypeInst(builder, diffStore, primalStoreVal->getDataType());
return InstPair(primalStore, diffStore);
}
@@ -696,14 +718,16 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
if (auto pairType = tryGetDiffPairType(&argBuilder, primalType))
{
auto pairPtrType = as<IRPtrTypeBase>(pairType);
- auto pairValType = as<IRDifferentialPairType>(
+
+ auto pairValType = as<IRDifferentialPairTypeBase>(
pairPtrType ? pairPtrType->getValueType() : pairType);
+
auto diffType = differentiableTypeConformanceContext.getDiffTypeFromPairType(&argBuilder, pairValType);
if (auto ptrParamType = as<IRPtrTypeBase>(diffParamType))
{
// Create temp var to pass in/out arguments.
auto srcVar = argBuilder.emitVar(pairValType);
- argBuilder.markInstAsMixedDifferential(srcVar, pairValType->getValueType());
+ markDiffPairTypeInst(&argBuilder, srcVar, pairValType);
auto diffArg = findOrTranscribeDiffInst(&argBuilder, origArg);
if (ptrParamType->getOp() == kIROp_InOutType)
@@ -716,28 +740,28 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
else
{
diffArgVal = argBuilder.emitLoad(diffArg);
- argBuilder.markInstAsDifferential(diffArgVal, pairValType->getValueType());
+ markDiffTypeInst(&argBuilder, diffArgVal, pairValType->getValueType());
}
auto initVal = argBuilder.emitMakeDifferentialPair(pairValType, primalVal, diffArgVal);
- argBuilder.markInstAsMixedDifferential(initVal, primalType);
+ markDiffPairTypeInst(&argBuilder, initVal, pairValType);
auto store = argBuilder.emitStore(srcVar, initVal);
- argBuilder.markInstAsMixedDifferential(store, primalType);
+ markDiffPairTypeInst(&argBuilder, store, pairValType);
}
if (as<IROutTypeBase>(ptrParamType))
{
// Read back new value.
auto newVal = afterBuilder.emitLoad(srcVar);
- afterBuilder.markInstAsMixedDifferential(newVal, pairValType->getValueType());
+ markDiffPairTypeInst(&afterBuilder, newVal, pairValType);
auto newPrimalVal = afterBuilder.emitDifferentialPairGetPrimal(pairValType->getValueType(), newVal);
afterBuilder.emitStore(primalArg, newPrimalVal);
if (diffArg)
{
auto newDiffVal = afterBuilder.emitDifferentialPairGetDifferential((IRType*)diffType, newVal);
- afterBuilder.markInstAsDifferential(newDiffVal, pairValType->getValueType());
+ markDiffTypeInst(&afterBuilder, newDiffVal, pairValType->getValueType());
auto storeInst = afterBuilder.emitStore(diffArg, newDiffVal);
- afterBuilder.markInstAsDifferential(storeInst, pairValType->getValueType());
+ markDiffTypeInst(&afterBuilder, storeInst, pairValType->getValueType());
}
}
args.add(srcVar);
@@ -753,7 +777,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
SLANG_RELEASE_ASSERT(diffArg);
auto diffPair = argBuilder.emitMakeDifferentialPair(pairType, primalArg, diffArg);
- argBuilder.markInstAsMixedDifferential(diffPair, pairType);
+ markDiffPairTypeInst(&argBuilder, diffPair, pairType);
args.add(diffPair);
continue;
@@ -779,12 +803,13 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
diffCallee,
args);
placeholderCall->removeAndDeallocate();
+
argBuilder.markInstAsMixedDifferential(callInst, diffReturnType);
argBuilder.addAutoDiffOriginalValueDecoration(callInst, primalCallee);
*builder = afterBuilder;
- if (diffReturnType->getOp() == kIROp_DifferentialPairType)
+ if (as<IRDifferentialPairType>(diffReturnType) || as<IRDifferentialPtrPairType>(diffReturnType))
{
IRInst* primalResultValue = afterBuilder.emitDifferentialPairGetPrimal(callInst);
auto diffType = differentiateType(&afterBuilder, origCall->getFullType());
@@ -1751,12 +1776,13 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr
IRInst* valToStore = nullptr;
if (writeBack.value.differential)
{
+ auto pairValType = cast<IRPtrTypeBase>(param->getFullType())->getValueType();
auto diffVal = builder.emitLoad(writeBack.value.differential);
- builder.markInstAsDifferential(diffVal, primalVal->getFullType());
+ markDiffTypeInst(&builder, diffVal, primalVal->getFullType());
- valToStore = builder.emitMakeDifferentialPair(cast<IRPtrTypeBase>(param->getFullType())->getValueType(),
- primalVal, diffVal);
- builder.markInstAsMixedDifferential(valToStore, valToStore->getFullType());
+ valToStore = builder.emitMakeDifferentialPair(pairValType, primalVal, diffVal);
+
+ markDiffPairTypeInst(&builder, valToStore, pairValType);
}
else
{
@@ -1767,7 +1793,7 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr
if (writeBack.value.differential)
{
- builder.markInstAsMixedDifferential(storeInst, valToStore->getFullType());
+ markDiffPairTypeInst(&builder, storeInst, valToStore->getFullType());
}
}
}
@@ -2043,24 +2069,25 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam
SLANG_ASSERT(diffPairParam);
- if (auto pairType = as<IRDifferentialPairType>(diffPairType))
+ if (as<IRDifferentialPairType>(diffPairType) || as<IRDifferentialPtrPairType>(diffPairType))
{
return InstPair(
builder->emitDifferentialPairGetPrimal(diffPairParam),
builder->emitDifferentialPairGetDifferential(
- (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, pairType),
+ (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(
+ builder,
+ as<IRDifferentialPairTypeBase>(diffPairType)),
diffPairParam));
}
else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType))
{
- auto ptrInnerPairType = as<IRDifferentialPairType>(pairPtrType->getValueType());
+ auto ptrInnerPairType = as<IRDifferentialPairTypeBase>(pairPtrType->getValueType());
// Make a local copy of the parameter for primal and diff parts.
auto primal = builder->emitVar(ptrInnerPairType->getValueType());
auto diffType = differentiateType(builder, cast<IRPtrTypeBase>(origParam->getDataType())->getValueType());
auto diff = builder->emitVar(diffType);
- builder->markInstAsDifferential(
- diff, builder->getPtrType(ptrInnerPairType->getValueType()));
+ markDiffTypeInst(builder, diff, builder->getPtrType(ptrInnerPairType->getValueType()));
IRInst* primalInitVal = nullptr;
IRInst* diffInitVal = nullptr;
@@ -2072,17 +2099,18 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam
else
{
auto initVal = builder->emitLoad(diffPairParam);
- builder->markInstAsMixedDifferential(initVal, ptrInnerPairType);
+ markDiffPairTypeInst(builder, initVal, ptrInnerPairType);
primalInitVal = builder->emitDifferentialPairGetPrimal(initVal);
diffInitVal = builder->emitDifferentialPairGetDifferential(diffType, initVal);
}
- builder->markInstAsDifferential(diffInitVal, ptrInnerPairType->getValueType());
+ markDiffTypeInst(builder, diffInitVal, ptrInnerPairType->getValueType());
+
builder->emitStore(primal, primalInitVal);
auto diffStore = builder->emitStore(diff, diffInitVal);
- builder->markInstAsDifferential(diffStore, ptrInnerPairType->getValueType());
+ markDiffTypeInst(builder, diffStore, ptrInnerPairType->getValueType());
mapInOutParamToWriteBackValue[diffPairParam] = InstPair(primal, diff);
return InstPair(primal, diff);
diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp
index 7fc8ebbe6..3a6d52bea 100644
--- a/source/slang/slang-ir-autodiff-pairs.cpp
+++ b/source/slang/slang-ir-autodiff-pairs.cpp
@@ -107,10 +107,13 @@ struct DiffPairLoweringPass : InstPassBase
case kIROp_DifferentialPairGetPrimal:
case kIROp_DifferentialPairGetDifferentialUserCode:
case kIROp_DifferentialPairGetPrimalUserCode:
+ case kIROp_DifferentialPtrPairGetDifferential:
+ case kIROp_DifferentialPtrPairGetPrimal:
lowerPairAccess(builder, inst);
break;
case kIROp_MakeDifferentialPairUserCode:
+ case kIROp_MakeDifferentialPtrPair:
lowerMakePair(builder, inst);
break;
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index f51178f0f..2881abe3e 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -892,6 +892,16 @@ void applyToInst(
}
}
SLANG_ASSERT(replacement);
+
+ // If the replacement and inst are not the exact same type, use an int-cast
+ // (e.g. uint vs. int)
+ //
+ if (replacement->getDataType() != inst->getDataType())
+ {
+ setInsertAfterOrdinaryInst(builder, replacement);
+ replacement = builder->emitCast(inst->getDataType(), replacement);
+ }
+
cloneCtx->cloneEnv.mapOldValToNew[inst] = replacement;
cloneCtx->registerClonedInst(builder, inst, replacement);
return;
@@ -1998,6 +2008,7 @@ static bool shouldStoreInst(IRInst* inst)
case kIROp_MakeArrayFromElement:
case kIROp_MakeDifferentialPair:
case kIROp_MakeDifferentialPairUserCode:
+ case kIROp_MakeDifferentialPtrPair:
case kIROp_MakeOptionalNone:
case kIROp_MakeOptionalValue:
case kIROp_MakeExistential:
@@ -2005,6 +2016,8 @@ static bool shouldStoreInst(IRInst* inst)
case kIROp_DifferentialPairGetPrimal:
case kIROp_DifferentialPairGetDifferentialUserCode:
case kIROp_DifferentialPairGetPrimalUserCode:
+ case kIROp_DifferentialPtrPairGetDifferential:
+ case kIROp_DifferentialPtrPairGetPrimal:
case kIROp_ExtractExistentialValue:
case kIROp_ExtractExistentialType:
case kIROp_ExtractExistentialWitnessTable:
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 2fb73c4ac..169dd31ee 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -152,6 +152,16 @@ namespace Slang
builder->emitBlock();
params = _defineFuncParams(builder, as<IRFunc>(existingPrimalFunc));
params.removeLast();
+
+ // Unwrap any ref pairs. We need this special case for trivial funcs.
+ for (Int i = 0; i < params.getCount(); i++)
+ {
+ if (as<IRDifferentialPtrPairType>(params[i]->getDataType()))
+ {
+ params[i] = builder->emitDifferentialPtrPairGetPrimal(params[i]);
+ }
+ }
+
IRInst* originalFuncRefFromPrimalFunc = originalFunc;
if (originalGeneric)
originalFuncRefFromPrimalFunc = maybeSpecializeWithGeneric(*builder, originalGeneric, existingPriamlFuncGeneric);
@@ -266,7 +276,20 @@ namespace Slang
if (auto primalNoDiffType = _getPrimalTypeFromNoDiffType(this, builder, paramType))
return primalNoDiffType;
- return (IRType*)findOrTranscribePrimalInst(builder, paramType);
+ auto primalType = (IRType*)findOrTranscribePrimalInst(builder, paramType);
+
+ // Differentiable pointer types are treated as primal pairs, since they aren't involved in the transposition
+ // process.
+ //
+ if (differentiableTypeConformanceContext.isDifferentiablePtrType(primalType))
+ {
+ auto diffPairType = tryGetDiffPairType(builder, primalType);
+ SLANG_ASSERT(diffPairType);
+
+ return diffPairType;
+ }
+
+ return primalType;
}
IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPropagateFunc(IRBuilder* builder, IRType* paramType)
@@ -292,7 +315,7 @@ namespace Slang
auto diffPairType = tryGetDiffPairType(builder, paramType);
if (diffPairType)
{
- if (!as<IRPtrTypeBase>(diffPairType))
+ if (!as<IRPtrTypeBase>(diffPairType) && !as<IRDifferentialPtrPairType>(diffPairType))
return builder->getInOutType(diffPairType);
return diffPairType;
}
@@ -961,7 +984,7 @@ namespace Slang
// Initialize the var with input diff param at start.
// Note that we insert the store in the primal block so it won't get transposed.
auto storeInst = nextBlockBuilder.emitStore(tempVar, diffParam);
- nextBlockBuilder.markInstAsDifferential(storeInst, diffPairType);
+ nextBlockBuilder.markInstAsDifferential(storeInst, primalType);
// Since this store inst is specific to propagate function, we track it in a
// set so we can remove it when we generate the primal func.
result.propagateFuncSpecificPrimalInsts.add(storeInst);
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 1fa76c730..2141837b5 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -174,45 +174,54 @@ IRInst* AutoDiffTranscriberBase::maybeCloneForPrimalInst(IRBuilder* builder, IRI
IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey);
-IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType)
+IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType, DiffConformanceKind kind)
{
- return differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType);
+ if (kind == DiffConformanceKind::Any)
+ {
+ if (auto valueWitness = differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType, DiffConformanceKind::Value))
+ return valueWitness;
+ if (auto ptrWitness = differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType, DiffConformanceKind::Ptr))
+ return ptrWitness;
+ }
+ else
+ {
+ return differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType, kind);
+ }
+ return nullptr;
}
IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness)
{
- return builder->getDifferentialPairType(
- (IRType*)primalType,
- witness);
+ auto conformanceType = differentiableTypeConformanceContext.getConformanceTypeFromWitness(witness);
+ if (autoDiffSharedContext->isInterfaceAvailable &&
+ conformanceType == autoDiffSharedContext->differentiableInterfaceType)
+ {
+ return builder->getDifferentialPairType((IRType*)primalType, witness);
+ }
+ else if (autoDiffSharedContext->isPtrInterfaceAvailable &&
+ conformanceType == autoDiffSharedContext->differentiablePtrInterfaceType)
+ {
+ return builder->getDifferentialPtrPairType((IRType*)primalType, witness);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Unexpected witness type");
+ }
}
IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* originalType)
{
- auto primalType = lookupPrimalInst(builder, originalType, nullptr);
+ auto primalType = lookupPrimalInst(builder, originalType, originalType);
SLANG_RELEASE_ASSERT(primalType);
IRInst* witness = nullptr;
- if (auto lookup = as<IRLookupWitnessMethod>(primalType))
- {
- if (lookup->getRequirementKey() == autoDiffSharedContext->differentialAssocTypeStructKey)
- {
- witness = builder->emitLookupInterfaceMethodInst(
- lookup->getWitnessTable()->getDataType(),
- lookup->getWitnessTable(),
- autoDiffSharedContext->differentialAssocTypeWitnessStructKey);
- }
- }
-
- // Obtain the witness that primalType conforms to IDifferentiable.
+
+ // Obtain the witness that primalType conforms to IDifferentiable/IDifferentiablePtrType
if (!witness)
- witness = tryGetDifferentiableWitness(builder, originalType);
+ witness = tryGetDifferentiableWitness(builder, primalType, DiffConformanceKind::Any);
SLANG_RELEASE_ASSERT(witness);
- auto pairType = builder->getDifferentialPairType(
- (IRType*)primalType,
- witness);
-
- return pairType;
+ return getOrCreateDiffPairType(builder, primalType, witness);
}
IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* origType)
@@ -223,8 +232,10 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o
// Special-case for differentiable existential types.
if (as<IRInterfaceType>(origType) || as<IRAssociatedType>(origType))
{
- if (differentiableTypeConformanceContext.lookUpConformanceForType(origType))
+ if (differentiableTypeConformanceContext.lookUpConformanceForType(origType, DiffConformanceKind::Value))
return autoDiffSharedContext->differentiableInterfaceType;
+ else if (differentiableTypeConformanceContext.lookUpConformanceForType(origType, DiffConformanceKind::Ptr))
+ return autoDiffSharedContext->differentiablePtrInterfaceType;
else
return nullptr;
}
@@ -278,8 +289,9 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy
}
case kIROp_DifferentialPairType:
+ case kIROp_DifferentialPtrPairType:
{
- auto primalPairType = as<IRDifferentialPairType>(primalType);
+ auto primalPairType = as<IRDifferentialPairTypeBase>(primalType);
return getOrCreateDiffPairType(
builder,
differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, primalPairType),
@@ -445,8 +457,24 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder*
auto interfaceType = as<IRInterfaceType>(unwrapAttributedType(origType->getOperand(0)->getDataType()));
if (!interfaceType)
return nullptr;
- List<IRInterfaceRequirementEntry*> lookupKeyPath = differentiableTypeConformanceContext.findDifferentiableInterfaceLookupPath(
+
+ List<IRInterfaceRequirementEntry*> lookupKeyPath;
+ IRStructKey* diffStructKey = nullptr;
+
+ List<IRInterfaceRequirementEntry*> lookupPathValueType = differentiableTypeConformanceContext.findInterfaceLookupPath(
autoDiffSharedContext->differentiableInterfaceType, interfaceType);
+ if (lookupPathValueType.getCount() > 0)
+ {
+ lookupKeyPath = lookupPathValueType;
+ diffStructKey = autoDiffSharedContext->differentialAssocTypeStructKey;
+ }
+ else
+ {
+ // Try IDifferentiablePtrType
+ lookupKeyPath = differentiableTypeConformanceContext.findInterfaceLookupPath(
+ autoDiffSharedContext->differentiablePtrInterfaceType, interfaceType);
+ diffStructKey = autoDiffSharedContext->differentialAssocRefTypeStructKey;
+ }
if (lookupKeyPath.getCount())
{
@@ -456,7 +484,7 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder*
{
outWitnessTable = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), outWitnessTable, node->getRequirementKey());
}
- auto diffType = builder->emitLookupInterfaceMethodInst(builder->getTypeType(), outWitnessTable, autoDiffSharedContext->differentialAssocTypeStructKey);
+ auto diffType = builder->emitLookupInterfaceMethodInst(builder->getTypeType(), outWitnessTable, diffStructKey);
return (IRType*)diffType;
}
return nullptr;
@@ -561,10 +589,31 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui
return InstPair(primal, diffWitness);
}
+ else if (returnWitnessType->getConformanceType() == autoDiffSharedContext->differentiablePtrInterfaceType)
+ {
+ auto primalDiffType = builder->emitLookupInterfaceMethodInst(
+ builder->getTypeKind(),
+ primal,
+ autoDiffSharedContext->differentialAssocRefTypeStructKey);
+ auto diffWitness = builder->emitLookupInterfaceMethodInst(
+ (IRType*)primalDiffType,
+ primal,
+ autoDiffSharedContext->differentialAssocRefTypeWitnessStructKey);
+
+ // Mark both as primal since we're working with types
+ // (which don't need transposing)
+ //
+ builder->markInstAsPrimal(primalDiffType);
+ builder->markInstAsPrimal(diffWitness);
+
+ return InstPair(primal, diffWitness);
+ }
}
+
auto decor =
lookupInst->getRequirementKey()->findDecorationImpl(
getInterfaceRequirementDerivativeDecorationOp());
+
if (!decor)
{
return InstPair(primal, nullptr);
@@ -589,6 +638,10 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(
{
originalType = (IRType*)unwrapAttributedType(originalType);
auto primalType = (IRType*)lookupPrimalInst(builder, originalType);
+
+ // Can't generate zero for differentiable ptr types. Should never hit this case.
+ SLANG_ASSERT(!differentiableTypeConformanceContext.isDifferentiablePtrType(originalType));
+
if (auto diffType = differentiateType(builder, originalType))
{
IRInst* diffWitnessTable = nullptr;
@@ -985,7 +1038,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
&& !as<IRConstant>(pair.differential))
{
auto primalType = (IRType*)(pair.primal->getDataType());
- builder->markInstAsDifferential(pair.differential, primalType);
+ markDiffTypeInst(builder, pair.differential, primalType);
}
}
else
@@ -997,7 +1050,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
if (as<IRType>(pair.differential))
break;
auto mixedType = (IRType*)(pair.primal->getDataType());
- builder->markInstAsMixedDifferential(pair.primal, mixedType);
+ markDiffPairTypeInst(builder, pair.primal, mixedType);
}
}
@@ -1076,4 +1129,64 @@ InstPair AutoDiffTranscriberBase::transcribeInst(IRBuilder* builder, IRInst* ori
return result;
}
+
+void AutoDiffTranscriberBase::markDiffTypeInst(IRBuilder* builder, IRInst* diffInst, IRType* primalType)
+{
+ // Ignore module-level insts.
+ if (as<IRModuleInst>(diffInst->getParent()))
+ return;
+
+ // Also ignore generic-container-level insts.
+ if (as<IRBlock>(diffInst->getParent()) &&
+ as<IRGeneric>(diffInst->getParent()->getParent()))
+ return;
+
+ // TODO: This logic is a bit of a hack. We need to determine if the type is
+ // relevant to ptr-type computation or not, or more complex applications
+ // that use dynamic dispatch + ptr types will fail.
+ //
+ if (as<IRType>(diffInst))
+ {
+ builder->markInstAsDifferential(diffInst, nullptr);
+ return;
+ }
+
+ SLANG_ASSERT(diffInst);
+ SLANG_ASSERT(primalType);
+
+ if (differentiableTypeConformanceContext.isDifferentiableValueType(primalType))
+ {
+ builder->markInstAsDifferential(diffInst, primalType);
+ }
+ else if (differentiableTypeConformanceContext.isDifferentiablePtrType(primalType))
+ {
+ builder->markInstAsPrimal(diffInst);
+ }
+ else
+ {
+ // Stop-gap solution to go with differential inst for now.
+ builder->markInstAsDifferential(diffInst, primalType);
+ }
+}
+
+void AutoDiffTranscriberBase::markDiffPairTypeInst(IRBuilder* builder, IRInst* diffPairInst, IRType* pairType)
+{
+ SLANG_ASSERT(diffPairInst);
+ SLANG_ASSERT(pairType);
+ SLANG_ASSERT(as<IRDifferentialPairTypeBase>(pairType));
+
+ if (as<IRDifferentialPairType>(pairType))
+ {
+ builder->markInstAsMixedDifferential(diffPairInst, pairType);
+ }
+ else if (as<IRDifferentialPtrPairType>(pairType))
+ {
+ builder->markInstAsPrimal(diffPairInst);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("unexpected differentiable type");
+ }
+}
+
}
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h
index f7f2dd6f2..9f3cfe56f 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.h
+++ b/source/slang/slang-ir-autodiff-transcriber-base.h
@@ -91,7 +91,7 @@ struct AutoDiffTranscriberBase
void maybeMigrateDifferentiableDictionaryFromDerivativeFunc(IRBuilder* builder, IRInst* origFunc);
- IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType);
+ IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType, DiffConformanceKind kind);
IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness);
@@ -152,6 +152,10 @@ struct AutoDiffTranscriberBase
virtual InstPair transcribeInstImpl(IRBuilder* builder, IRInst* origInst) = 0;
virtual IROp getInterfaceRequirementDerivativeDecorationOp() = 0;
+
+ void markDiffTypeInst(IRBuilder* builder, IRInst* inst, IRType* primalType);
+
+ void markDiffPairTypeInst(IRBuilder* builder, IRInst* inst, IRType* primalType);
};
}
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 1f8c3052e..8669df5a4 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -2116,7 +2116,8 @@ struct DiffTransposePass
// If we reach this point, revValue must be a differentiable type.
auto revTypeWitness = diffTypeContext.tryGetDifferentiableWitness(
builder,
- primalType);
+ primalType,
+ DiffConformanceKind::Value);
SLANG_ASSERT(revTypeWitness);
auto baseExistential = fwdInst->getOperand(0);
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 0953c535a..507a2bf92 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -144,7 +144,10 @@ struct ExtractPrimalFuncContext
}
auto structField = genTypeBuilder.createStructField(structType, structKey, (IRType*)fieldType);
- if (auto witness = backwardPrimalTranscriber->tryGetDifferentiableWitness(&genTypeBuilder, (IRType*)fieldType))
+ if (auto witness = backwardPrimalTranscriber->tryGetDifferentiableWitness(
+ &genTypeBuilder,
+ (IRType*)fieldType,
+ DiffConformanceKind::Value))
{
genTypeBuilder.addIntermediateContextFieldDifferentialTypeDecoration(structField, witness);
}
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 07a6a76fb..94a605a68 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -25,7 +25,7 @@ bool isBackwardDifferentiableFunc(IRInst* func)
return false;
}
-IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey)
+IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey, IRType* resultType = nullptr)
{
if (auto witnessTable = as<IRWitnessTable>(witness))
{
@@ -53,15 +53,16 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK
}
else
{
+ SLANG_ASSERT(resultType);
return builder->emitLookupInterfaceMethodInst(
- builder->getTypeKind(),
+ resultType,
witness,
requirementKey);
}
return nullptr;
}
-static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext*sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type)
+static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext* sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type)
{
auto witness = type->getWitness();
SLANG_RELEASE_ASSERT(witness);
@@ -70,16 +71,48 @@ static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext*sharedContext, IRB
if (as<IRInterfaceType>(type->getValueType()) || as<IRAssociatedType>(type->getValueType()))
{
// The differential type is the IDifferentiable interface type.
- return sharedContext->differentiableInterfaceType;
+ if (as<IRDifferentialPairType>(type) || as<IRDifferentialPairUserCodeType>(type))
+ return sharedContext->differentiableInterfaceType;
+ else if (as<IRDifferentialPtrPairType>(type))
+ return sharedContext->differentiablePtrInterfaceType;
+ else
+ SLANG_UNEXPECTED("Unexpected differential pair type");
}
- return _lookupWitness(builder, witness, sharedContext->differentialAssocTypeStructKey);
+ if (as<IRDifferentialPairType>(type) || as<IRDifferentialPairUserCodeType>(type))
+ return _lookupWitness(
+ builder,
+ witness,
+ sharedContext->differentialAssocTypeStructKey,
+ builder->getTypeKind());
+ else if (as<IRDifferentialPtrPairType>(type))
+ return _lookupWitness(
+ builder,
+ witness,
+ sharedContext->differentialAssocRefTypeStructKey,
+ builder->getTypeKind());
+ else
+ SLANG_UNEXPECTED("Unexpected differential pair type");
}
static IRInst* _getDiffTypeWitnessFromPairType(AutoDiffSharedContext* sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type)
{
auto witnessTable = type->getWitness();
- return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey);
+
+ if (as<IRDifferentialPairType>(type) || as<IRDifferentialPairUserCodeType>(type))
+ return _lookupWitness(
+ builder,
+ witnessTable,
+ sharedContext->differentialAssocTypeWitnessStructKey,
+ sharedContext->differentialAssocTypeWitnessTableType);
+ else if (as<IRDifferentialPtrPairType>(type))
+ return _lookupWitness(
+ builder,
+ witnessTable,
+ sharedContext->differentialAssocRefTypeWitnessStructKey,
+ sharedContext->differentialAssocRefTypeWitnessTableType);
+ else
+ SLANG_UNEXPECTED("Unexpected differential pair type");
}
bool isNoDiffType(IRType* paramType)
@@ -320,6 +353,24 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(
return result;
}
+IRInterfaceType* findDifferentiableRefInterface(IRModuleInst* moduleInst)
+{
+ for (auto inst : moduleInst->getGlobalInsts())
+ {
+ if (auto interfaceType = as<IRInterfaceType>(inst))
+ {
+ if (auto decor = interfaceType->findDecoration<IRNameHintDecoration>())
+ {
+ if (decor->getName() == "IDifferentiablePtrType")
+ {
+ return interfaceType;
+ }
+ }
+ }
+ }
+ return nullptr;
+}
+
AutoDiffSharedContext::AutoDiffSharedContext(TargetProgram* target, IRModuleInst* inModuleInst)
: moduleInst(inModuleInst), targetProgram(target)
{
@@ -328,14 +379,27 @@ AutoDiffSharedContext::AutoDiffSharedContext(TargetProgram* target, IRModuleInst
{
differentialAssocTypeStructKey = findDifferentialTypeStructKey();
differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey();
+ differentialAssocTypeWitnessTableType = findDifferentialTypeWitnessTableType();
zeroMethodStructKey = findZeroMethodStructKey();
+ zeroMethodType = cast<IRFuncType>(getInterfaceEntryAtIndex(differentiableInterfaceType, 2)->getRequirementVal());
addMethodStructKey = findAddMethodStructKey();
+ addMethodType = cast<IRFuncType>(getInterfaceEntryAtIndex(differentiableInterfaceType, 3)->getRequirementVal());
mulMethodStructKey = findMulMethodStructKey();
nullDifferentialStructType = findNullDifferentialStructType();
nullDifferentialWitness = findNullDifferentialWitness();
- if (differentialAssocTypeStructKey)
- isInterfaceAvailable = true;
+ isInterfaceAvailable = true;
+ }
+
+ differentiablePtrInterfaceType = as<IRInterfaceType>(findDifferentiableRefInterface(inModuleInst));
+
+ if (differentiablePtrInterfaceType)
+ {
+ differentialAssocRefTypeStructKey = findDifferentialPtrTypeStructKey();
+ differentialAssocRefTypeWitnessStructKey = findDifferentialPtrTypeWitnessStructKey();
+ differentialAssocRefTypeWitnessTableType = findDifferentialPtrTypeWitnessTableType();
+
+ isPtrInterfaceAvailable = true;
}
}
@@ -404,14 +468,14 @@ IRInst* AutoDiffSharedContext::findNullDifferentialWitness()
}
-IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt index)
+IRInterfaceRequirementEntry* AutoDiffSharedContext::getInterfaceEntryAtIndex(IRInterfaceType* interface, UInt index)
{
- if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType)
+ if (as<IRModuleInst>(moduleInst) && interface)
{
// Assume for now that IDifferentiable has exactly five fields.
- SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5);
- if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index)))
- return as<IRStructKey>(entry->getRequirementKey());
+ // SLANG_ASSERT(interface->getOperandCount() == 5);
+ if (auto entry = as<IRInterfaceRequirementEntry>(interface->getOperand(index)))
+ return entry;
else
{
SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type");
@@ -421,6 +485,50 @@ IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt inde
return nullptr;
}
+// Extracts conformance interface from a witness inst while accounting for some
+// quirks in the type system around interfaces that conform to other interfaces.
+//
+IRInterfaceType* DifferentiableTypeConformanceContext::getConformanceTypeFromWitness(IRInst* witness)
+{
+ IRInterfaceType* diffInterfaceType = nullptr;
+ if (auto witnessTableType = as<IRWitnessTableType>(witness->getDataType()))
+ {
+ diffInterfaceType = cast<IRInterfaceType>(witnessTableType->getConformanceType());
+ }
+ else if (auto structKey = as<IRStructKey>(witness))
+ {
+ // We currently assume that a struct key is used uniquely for a single interface-requirement-entry.
+ // Find that entry
+ for (IRUse* use = structKey->firstUse; use; use = use->nextUse)
+ {
+ if (auto entry = as<IRInterfaceRequirementEntry>(use->getUser()))
+ {
+ auto innerWitnessTableType = cast<IRWitnessTableType>(entry->getRequirementVal());
+ diffInterfaceType = cast<IRInterfaceType>(innerWitnessTableType->getConformanceType());
+ break;
+ }
+ }
+ }
+ else if (auto interfaceRequirementEntry = as<IRInterfaceRequirementEntry>(witness))
+ {
+ auto innerWitnessTableType = cast<IRWitnessTableType>(interfaceRequirementEntry->getRequirementVal());
+ diffInterfaceType = cast<IRInterfaceType>(innerWitnessTableType->getConformanceType());
+ }
+ else if (auto tupleType = as<IRTupleType>(witness->getDataType()))
+ {
+ SLANG_ASSERT(tupleType->getOperandCount() >= 1);
+ auto operand = tupleType->getOperand(0);
+ auto innerWitnessTableType = cast<IRWitnessTableType>(operand);
+ return cast<IRInterfaceType>(innerWitnessTableType->getConformanceType());
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Unexpected witness type");
+ }
+
+ return diffInterfaceType;
+}
+
void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
{
parentFunc = func;
@@ -434,7 +542,13 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
{
if (auto item = as<IRDifferentiableTypeDictionaryItem>(child))
{
- auto existingItem = differentiableWitnessDictionary.tryGetValue(item->getConcreteType());
+ IRInterfaceType* diffInterfaceType = getConformanceTypeFromWitness(item->getWitness());
+
+ SLANG_ASSERT(
+ diffInterfaceType == sharedContext->differentiableInterfaceType
+ || diffInterfaceType == sharedContext->differentiablePtrInterfaceType);
+
+ auto existingItem = differentiableTypeWitnessDictionary.tryGetValue(item->getConcreteType());
if (existingItem)
{
*existingItem = item->getWitness();
@@ -458,20 +572,26 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
{
auto element = concreteType->getOperand(i);
auto elementWitness = witnessPack->getOperand(i);
- differentiableWitnessDictionary.addIfNotExists(
- (IRType*)element,
- _lookupWitness(&subBuilder, elementWitness, sharedContext->differentialAssocTypeStructKey));
+
+ if (diffInterfaceType == sharedContext->differentiableInterfaceType)
+ addTypeToDictionary(
+ (IRType*)element,
+ elementWitness);
+ else if (diffInterfaceType == sharedContext->differentiablePtrInterfaceType)
+ addTypeToDictionary(
+ (IRType*)element,
+ elementWitness);
}
return;
}
}
- differentiableWitnessDictionary.add((IRType*)item->getConcreteType(), item->getWitness());
+ addTypeToDictionary((IRType*)item->getConcreteType(), item->getWitness());
if (!as<IRInterfaceType>(item->getConcreteType()))
{
- differentiableWitnessDictionary.addIfNotExists(
- (IRType*)_lookupWitness(&subBuilder, item->getWitness(), sharedContext->differentialAssocTypeStructKey),
+ addTypeToDictionary(
+ (IRType*)_lookupWitness(&subBuilder, item->getWitness(), sharedContext->differentialAssocTypeStructKey, subBuilder.getTypeKind()),
item->getWitness());
}
@@ -480,29 +600,55 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
// For differential pair types, register the differential type as well.
IRBuilder builder(diffPairType);
builder.setInsertAfter(diffPairType->getWitness());
- auto diffType = _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeStructKey);
- auto diffWitness = _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeWitnessStructKey);
- if (diffType && diffWitness)
- {
- differentiableWitnessDictionary.addIfNotExists((IRType*)diffType, diffWitness);
- }
+
+ // TODO(sai): lot of this logic is duplicated. need to refactor.
+ auto diffType = (diffInterfaceType == sharedContext->differentiableInterfaceType) ?
+ _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeStructKey, builder.getTypeKind()) :
+ _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocRefTypeStructKey, builder.getTypeKind());
+ auto diffWitness = (diffInterfaceType == sharedContext->differentiableInterfaceType) ?
+ _lookupWitness(
+ &builder,
+ diffPairType->getWitness(),
+ sharedContext->differentialAssocTypeWitnessStructKey,
+ sharedContext->differentialAssocTypeWitnessTableType) :
+ _lookupWitness(
+ &builder,
+ diffPairType->getWitness(),
+ sharedContext->differentialAssocRefTypeWitnessStructKey,
+ sharedContext->differentialAssocRefTypeWitnessTableType);
+
+ addTypeToDictionary((IRType*)diffType, diffWitness);
}
}
}
}
}
-IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type)
+IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type, DiffConformanceKind kind)
{
IRInst* foundResult = nullptr;
- differentiableWitnessDictionary.tryGetValue(type, foundResult);
- return foundResult;
+ differentiableTypeWitnessDictionary.tryGetValue(type, foundResult);
+ if (!foundResult)
+ return nullptr;
+
+ if (kind == DiffConformanceKind::Any)
+ return foundResult;
+
+ if (auto baseType = getConformanceTypeFromWitness(foundResult))
+ {
+ if (baseType == sharedContext->differentiableInterfaceType && kind == DiffConformanceKind::Value)
+ return foundResult;
+ else if (baseType == sharedContext->differentiablePtrInterfaceType && kind == DiffConformanceKind::Ptr)
+ return foundResult;
+ }
+
+ return nullptr;
}
-IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key)
+IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key, IRType* resultType)
{
- if (auto conformance = tryGetDifferentiableWitness(builder, origType))
- return _lookupWitness(builder, conformance, key);
+ if (auto conformance = tryGetDifferentiableWitness(builder, origType, DiffConformanceKind::Any))
+ return _lookupWitness(builder, conformance, key, resultType);
return nullptr;
}
@@ -514,7 +660,7 @@ IRInst* DifferentiableTypeConformanceContext::getDifferentialTypeFromDiffPairTyp
IRInst* DifferentiableTypeConformanceContext::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type)
{
- return _getDiffTypeFromPairType(sharedContext, builder, type);
+ return this->differentiateType(builder, type->getValueType());
}
IRInst* DifferentiableTypeConformanceContext::getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type)
@@ -525,20 +671,34 @@ IRInst* DifferentiableTypeConformanceContext::getDiffTypeWitnessFromPairType(IRB
IRInst* DifferentiableTypeConformanceContext::getDiffZeroMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type)
{
auto witnessTable = type->getWitness();
- return _lookupWitness(builder, witnessTable, sharedContext->zeroMethodStructKey);
+ return _lookupWitness(builder, witnessTable, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType);
}
IRInst* DifferentiableTypeConformanceContext::getDiffAddMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type)
{
auto witnessTable = type->getWitness();
- return _lookupWitness(builder, witnessTable, sharedContext->addMethodStructKey);
+ return _lookupWitness(builder, witnessTable, sharedContext->addMethodStructKey, sharedContext->addMethodType);
+}
+
+void DifferentiableTypeConformanceContext::addTypeToDictionary(IRType* type, IRInst* witness)
+{
+ auto conformanceType = getConformanceTypeFromWitness(witness);
+
+ if (!sharedContext->isInterfaceAvailable && !sharedContext->isPtrInterfaceAvailable)
+ return;
+
+ SLANG_ASSERT(
+ conformanceType == sharedContext->differentiableInterfaceType ||
+ conformanceType == sharedContext->differentiablePtrInterfaceType);
+
+ differentiableTypeWitnessDictionary.addIfNotExists(type, witness);
}
IRInst *DifferentiableTypeConformanceContext::tryExtractConformanceFromInterfaceType(IRBuilder *builder, IRInterfaceType *interfaceType, IRWitnessTable *witnessTable)
{
SLANG_RELEASE_ASSERT(interfaceType);
- List<IRInterfaceRequirementEntry*> lookupKeyPath = findDifferentiableInterfaceLookupPath(
+ List<IRInterfaceRequirementEntry*> lookupKeyPath = findInterfaceLookupPath(
sharedContext->differentiableInterfaceType, interfaceType);
IRInst* differentialTypeWitness = witnessTable;
@@ -549,6 +709,7 @@ IRInst *DifferentiableTypeConformanceContext::tryExtractConformanceFromInterface
{
differentialTypeWitness = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), differentialTypeWitness, node->getRequirementKey());
// Lookup insts are always primal values.
+
builder->markInstAsPrimal(differentialTypeWitness);
}
return differentialTypeWitness;
@@ -557,10 +718,10 @@ IRInst *DifferentiableTypeConformanceContext::tryExtractConformanceFromInterface
return nullptr;
}
-// Given an interface type, return the lookup path from a witness table of `type` to a witness table of `IDifferentiable`.
-static bool _findDifferentiableInterfaceLookupPathImpl(
+// Given an interface type, return the lookup path from a witness table of `type` to a witness table of `supType`.
+static bool _findInterfaceLookupPathImpl(
HashSet<IRInst*>& processedTypes,
- IRInterfaceType* idiffType,
+ IRInterfaceType* supType,
IRInterfaceType* type,
List<IRInterfaceRequirementEntry*>& currentPath)
{
@@ -576,13 +737,13 @@ static bool _findDifferentiableInterfaceLookupPathImpl(
if (auto wt = as<IRWitnessTableTypeBase>(entry->getRequirementVal()))
{
currentPath.add(entry);
- if (wt->getConformanceType() == idiffType)
+ if (wt->getConformanceType() == supType)
{
return true;
}
else if (auto subInterfaceType = as<IRInterfaceType>(wt->getConformanceType()))
{
- if (_findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, subInterfaceType, currentPath))
+ if (_findInterfaceLookupPathImpl(processedTypes, supType, subInterfaceType, currentPath))
return true;
}
currentPath.removeLast();
@@ -591,11 +752,11 @@ static bool _findDifferentiableInterfaceLookupPathImpl(
return false;
}
-List<IRInterfaceRequirementEntry *> DifferentiableTypeConformanceContext::findDifferentiableInterfaceLookupPath(IRInterfaceType *idiffType, IRInterfaceType *type)
+List<IRInterfaceRequirementEntry *> DifferentiableTypeConformanceContext::findInterfaceLookupPath(IRInterfaceType *supType, IRInterfaceType *type)
{
List<IRInterfaceRequirementEntry*> currentPath;
HashSet<IRInst*> processedTypes;
- _findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, type, currentPath);
+ _findInterfaceLookupPathImpl(processedTypes, supType, type, currentPath);
return currentPath;
}
@@ -722,7 +883,7 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary()
{
if (auto pairType = as<IRDifferentialPairTypeBase>(globalInst))
{
- differentiableWitnessDictionary.addIfNotExists(pairType->getValueType(), pairType->getWitness());
+ addTypeToDictionary(pairType->getValueType(), pairType->getWitness());
}
}
}
@@ -762,9 +923,8 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build
case kIROp_DifferentialPairType:
{
auto primalPairType = as<IRDifferentialPairType>(primalType);
- return getOrCreateDiffPairType(
- builder,
- getDiffTypeFromPairType(builder, primalPairType),
+ return builder->getDifferentialPairType(
+ (IRType*)getDiffTypeFromPairType(builder, primalPairType),
getDiffTypeWitnessFromPairType(builder, primalPairType));
}
@@ -776,6 +936,14 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build
getDiffTypeWitnessFromPairType(builder, primalPairType));
}
+ case kIROp_DifferentialPtrPairType:
+ {
+ auto primalPairType = as<IRDifferentialPtrPairType>(primalType);
+ return builder->getDifferentialPtrPairType(
+ (IRType*)getDiffTypeFromPairType(builder, primalPairType),
+ getDiffTypeWitnessFromPairType(builder, primalPairType));
+ }
+
case kIROp_FuncType:
{
SLANG_UNIMPLEMENTED_X("Impl");
@@ -817,12 +985,12 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build
}
}
-IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* primalType)
+IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* primalType, DiffConformanceKind kind)
{
if (isNoDiffType((IRType*)primalType))
return nullptr;
-
- IRInst* witness = lookUpConformanceForType((IRType*)primalType);
+
+ IRInst* witness = lookUpConformanceForType((IRType*)primalType, kind);
if (witness)
{
SLANG_RELEASE_ASSERT(witness || as<IRArrayType>(primalType));
@@ -834,31 +1002,60 @@ IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuil
witness = nullptr;
}
- if (!witness)
+ if (witness)
+ return witness;
+
+ // If a witness is not already mapped, build one if possible.
+ SLANG_RELEASE_ASSERT(primalType);
+ if (auto primalPairType = as<IRDifferentialPairTypeBase>(primalType))
{
- SLANG_RELEASE_ASSERT(primalType);
- if (auto primalPairType = as<IRDifferentialPairTypeBase>(primalType))
- {
- witness = getOrCreateDifferentiablePairWitness(builder, primalPairType);
- }
- else if (auto arrayType = as<IRArrayType>(primalType))
- {
- witness = getArrayWitness(builder, arrayType);
- }
- else if (auto extractExistential = as<IRExtractExistentialType>(primalType))
- {
- witness = getExtractExistensialTypeWitness(builder, extractExistential);
- }
- else if (auto typePack = as<IRTypePack>(primalType))
+ witness = buildDifferentiablePairWitness(builder, primalPairType, kind);
+ }
+ else if (auto arrayType = as<IRArrayType>(primalType))
+ {
+ witness = buildArrayWitness(builder, arrayType, kind);
+ }
+ else if (auto extractExistential = as<IRExtractExistentialType>(primalType))
+ {
+ witness = buildExtractExistensialTypeWitness(builder, extractExistential, kind);
+ }
+ else if (auto typePack = as<IRTypePack>(primalType))
+ {
+ witness = buildTupleWitness(builder, typePack, kind);
+ }
+ else if (auto tupleType = as<IRTupleType>(primalType))
+ {
+ witness = buildTupleWitness(builder, tupleType, kind);
+ }
+ else if (auto lookup = as<IRLookupWitnessMethod>(primalType))
+ {
+ // For types that are lookups from a table, we can simply lookup the witness from the same table
+ if (lookup->getRequirementKey() == sharedContext->differentialAssocTypeStructKey)
{
- witness = getTupleWitness(builder, typePack);
+ witness = builder->emitLookupInterfaceMethodInst(
+ lookup->getWitnessTable()->getDataType(),
+ lookup->getWitnessTable(),
+ sharedContext->differentialAssocTypeWitnessStructKey);
}
- else if (auto tupleType = as<IRTupleType>(primalType))
+
+ if (lookup->getRequirementKey() == sharedContext->differentialAssocRefTypeStructKey)
{
- witness = getTupleWitness(builder, tupleType);
+ witness = builder->emitLookupInterfaceMethodInst(
+ lookup->getWitnessTable()->getDataType(),
+ lookup->getWitnessTable(),
+ sharedContext->differentialAssocRefTypeWitnessStructKey);
}
}
- return witness;
+
+ // If we created a witness, register it.
+ if (witness)
+ {
+ addTypeToDictionary((IRType*)primalType, witness);
+ return witness;
+ }
+
+ // Failed. Type is either non-differentiable, or unhandled.
+ return nullptr;
}
IRType* DifferentiableTypeConformanceContext::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness)
@@ -868,77 +1065,97 @@ IRType* DifferentiableTypeConformanceContext::getOrCreateDiffPairType(IRBuilder*
witness);
}
-IRInst* DifferentiableTypeConformanceContext::getOrCreateDifferentiablePairWitness(IRBuilder* builder, IRDifferentialPairTypeBase* pairType)
+IRInst* DifferentiableTypeConformanceContext::buildDifferentiablePairWitness(
+ IRBuilder* builder,
+ IRDifferentialPairTypeBase* pairType,
+ DiffConformanceKind target)
{
- // Differentiate the pair type to get it's differential (which is itself a pair)
- auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)pairType);
-
- auto addMethod = builder->createFunc();
- auto zeroMethod = builder->createFunc();
-
- auto table = builder->createWitnessTable(this->sharedContext->differentiableInterfaceType, (IRType*)pairType);
-
- // And place it in the synthesized witness table.
- builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffDiffPairType);
- builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table);
- builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
- builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
-
- bool isUserCodeType = as<IRDifferentialPairUserCodeType>(pairType) ? true : false;
-
- // Fill in differential method implementations.
- auto elementType = as<IRDifferentialPairTypeBase>(pairType)->getValueType();
- auto innerWitness = as<IRDifferentialPairTypeBase>(pairType)->getWitness();
-
- {
- // Add method.
- IRBuilder b = *builder;
- b.setInsertInto(addMethod);
- b.addBackwardDifferentiableDecoration(addMethod);
- IRType* paramTypes[2] = { diffDiffPairType, diffDiffPairType };
- addMethod->setFullType(b.getFuncType(2, paramTypes, diffDiffPairType));
- b.emitBlock();
- auto p0 = b.emitParam(diffDiffPairType);
- auto p1 = b.emitParam(diffDiffPairType);
-
- // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type.
- auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey);
- IRInst* argsPrimal[2] = {
- isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p0) : b.emitDifferentialPairGetPrimal(p0),
- isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p1) : b.emitDifferentialPairGetPrimal(p1) };
- auto primalPart = b.emitCallInst(elementType, innerAdd, 2, argsPrimal);
- IRInst* argsDiff[2] = {
- isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p0) : b.emitDifferentialPairGetDifferential(elementType, p0),
- isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p1) : b.emitDifferentialPairGetDifferential(elementType, p1)};
- auto diffPart = b.emitCallInst(elementType, innerAdd, 2, argsDiff);
- auto retVal =
- isUserCodeType
- ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, primalPart, diffPart)
- : b.emitMakeDifferentialPair(diffDiffPairType, primalPart, diffPart);
- b.emitReturn(retVal);
- }
- {
- // Zero method.
- IRBuilder b = *builder;
- b.setInsertInto(zeroMethod);
- zeroMethod->setFullType(b.getFuncType(0, nullptr, diffDiffPairType));
- b.emitBlock();
- auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey);
- auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr);
- auto retVal =
- isUserCodeType
- ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, zeroVal, zeroVal)
- : b.emitMakeDifferentialPair(diffDiffPairType, zeroVal, zeroVal);
- b.emitReturn(retVal);
+ IRWitnessTable* table = nullptr;
+ if (target == DiffConformanceKind::Value)
+ {
+ // Differentiate the pair type to get it's differential (which is itself a pair)
+ auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)pairType);
+
+ auto addMethod = builder->createFunc();
+ auto zeroMethod = builder->createFunc();
+
+ table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)pairType);
+
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffDiffPairType);
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table);
+ builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
+ builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+
+ bool isUserCodeType = as<IRDifferentialPairUserCodeType>(pairType) ? true : false;
+
+ // Fill in differential method implementations.
+ auto elementType = as<IRDifferentialPairTypeBase>(pairType)->getValueType();
+ auto innerWitness = as<IRDifferentialPairTypeBase>(pairType)->getWitness();
+
+ {
+ // Add method.
+ IRBuilder b = *builder;
+ b.setInsertInto(addMethod);
+ b.addBackwardDifferentiableDecoration(addMethod);
+ IRType* paramTypes[2] = { diffDiffPairType, diffDiffPairType };
+ addMethod->setFullType(b.getFuncType(2, paramTypes, diffDiffPairType));
+ b.emitBlock();
+ auto p0 = b.emitParam(diffDiffPairType);
+ auto p1 = b.emitParam(diffDiffPairType);
+
+ // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type.
+ auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey, sharedContext->addMethodType);
+ IRInst* argsPrimal[2] = {
+ isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p0) : b.emitDifferentialPairGetPrimal(p0),
+ isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p1) : b.emitDifferentialPairGetPrimal(p1) };
+ auto primalPart = b.emitCallInst(elementType, innerAdd, 2, argsPrimal);
+ IRInst* argsDiff[2] = {
+ isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p0) : b.emitDifferentialPairGetDifferential(elementType, p0),
+ isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p1) : b.emitDifferentialPairGetDifferential(elementType, p1)};
+ auto diffPart = b.emitCallInst(elementType, innerAdd, 2, argsDiff);
+ auto retVal =
+ isUserCodeType
+ ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, primalPart, diffPart)
+ : b.emitMakeDifferentialPair(diffDiffPairType, primalPart, diffPart);
+ b.emitReturn(retVal);
+ }
+ {
+ // Zero method.
+ IRBuilder b = *builder;
+ b.setInsertInto(zeroMethod);
+ zeroMethod->setFullType(b.getFuncType(0, nullptr, diffDiffPairType));
+ b.emitBlock();
+ auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType);
+ auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr);
+ auto retVal =
+ isUserCodeType
+ ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, zeroVal, zeroVal)
+ : b.emitMakeDifferentialPair(diffDiffPairType, zeroVal, zeroVal);
+ b.emitReturn(retVal);
+ }
+ }
+ else if (target == DiffConformanceKind::Ptr)
+ {
+ // Differentiate the pair type to get it's differential (which is itself a pair)
+ auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)pairType);
+
+ table = builder->createWitnessTable(
+ sharedContext->differentiablePtrInterfaceType,
+ (IRType*)pairType);
+
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeStructKey, diffDiffPairType);
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeWitnessStructKey, table);
}
-
- // Record this in the context for future lookups
- differentiableWitnessDictionary[(IRType*)pairType] = table;
return table;
}
-IRInst* DifferentiableTypeConformanceContext::getArrayWitness(IRBuilder* builder, IRArrayType* arrayType)
+IRInst* DifferentiableTypeConformanceContext::buildArrayWitness(
+ IRBuilder* builder,
+ IRArrayType* arrayType,
+ DiffConformanceKind target)
{
// Differentiate the pair type to get it's differential (which is itself a pair)
auto diffArrayType = (IRType*)differentiateType(builder, (IRType*)arrayType);
@@ -946,70 +1163,89 @@ IRInst* DifferentiableTypeConformanceContext::getArrayWitness(IRBuilder* builder
if (!diffArrayType)
return nullptr;
- auto innerWitness = tryGetDifferentiableWitness(builder, as<IRArrayTypeBase>(arrayType)->getElementType());
+ IRWitnessTable* table = nullptr;
+ if (target == DiffConformanceKind::Value)
+ {
+ SLANG_ASSERT(isDifferentiableValueType((IRType*)arrayType));
+ auto innerWitness = tryGetDifferentiableWitness(builder, as<IRArrayTypeBase>(arrayType)->getElementType(), DiffConformanceKind::Value);
- auto addMethod = builder->createFunc();
- auto zeroMethod = builder->createFunc();
+ auto addMethod = builder->createFunc();
+ auto zeroMethod = builder->createFunc();
- auto table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)arrayType);
+ table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)arrayType);
- // And place it in the synthesized witness table.
- builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffArrayType);
- builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table);
- builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
- builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffArrayType);
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table);
+ builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
+ builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
- auto elementType = as<IRArrayTypeBase>(diffArrayType)->getElementType();
+ auto elementType = as<IRArrayTypeBase>(diffArrayType)->getElementType();
- // Fill in differential method implementations.
+ // Fill in differential method implementations.
+ {
+ // Add method.
+ IRBuilder b = *builder;
+ b.setInsertInto(addMethod);
+ b.addBackwardDifferentiableDecoration(addMethod);
+ IRType* paramTypes[2] = { diffArrayType, diffArrayType };
+ addMethod->setFullType(b.getFuncType(2, paramTypes, diffArrayType));
+ b.emitBlock();
+ auto p0 = b.emitParam(diffArrayType);
+ auto p1 = b.emitParam(diffArrayType);
+
+ // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type.
+ auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey, sharedContext->addMethodType);
+ auto resultVar = b.emitVar(diffArrayType);
+ IRBlock* loopBodyBlock = nullptr;
+ IRBlock* loopBreakBlock = nullptr;
+ auto loopCounter = emitLoopBlocks(&b, b.getIntValue(b.getIntType(), 0), as<IRArrayTypeBase>(diffArrayType)->getElementCount(), loopBodyBlock, loopBreakBlock);
+ b.setInsertBefore(loopBodyBlock->getTerminator());
+
+ IRInst* args[2] = {
+ b.emitElementExtract(p0, loopCounter),
+ b.emitElementExtract(p1, loopCounter) };
+ auto elementResult = b.emitCallInst(elementType, innerAdd, 2, args);
+ auto addr = b.emitElementAddress(resultVar, loopCounter);
+ b.emitStore(addr, elementResult);
+ b.setInsertInto(loopBreakBlock);
+ b.emitReturn(b.emitLoad(resultVar));
+ }
+ {
+ // Zero method.
+ IRBuilder b = *builder;
+ b.setInsertInto(zeroMethod);
+ zeroMethod->setFullType(b.getFuncType(0, nullptr, diffArrayType));
+ b.emitBlock();
+
+ auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType);
+ auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr);
+ auto retVal = b.emitMakeArrayFromElement(diffArrayType, zeroVal);
+ b.emitReturn(retVal);
+ }
+ }
+ else if (target == DiffConformanceKind::Ptr)
{
- // Add method.
- IRBuilder b = *builder;
- b.setInsertInto(addMethod);
- b.addBackwardDifferentiableDecoration(addMethod);
- IRType* paramTypes[2] = { diffArrayType, diffArrayType };
- addMethod->setFullType(b.getFuncType(2, paramTypes, diffArrayType));
- b.emitBlock();
- auto p0 = b.emitParam(diffArrayType);
- auto p1 = b.emitParam(diffArrayType);
+ SLANG_ASSERT(isDifferentiablePtrType((IRType*)arrayType));
- // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type.
- auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey);
- auto resultVar = b.emitVar(diffArrayType);
- IRBlock* loopBodyBlock = nullptr;
- IRBlock* loopBreakBlock = nullptr;
- auto loopCounter = emitLoopBlocks(&b, b.getIntValue(b.getIntType(), 0), as<IRArrayTypeBase>(diffArrayType)->getElementCount(), loopBodyBlock, loopBreakBlock);
- b.setInsertBefore(loopBodyBlock->getTerminator());
+ table = builder->createWitnessTable(sharedContext->differentiablePtrInterfaceType, (IRType*)arrayType);
- IRInst* args[2] = {
- b.emitElementExtract(p0, loopCounter),
- b.emitElementExtract(p1, loopCounter) };
- auto elementResult = b.emitCallInst(elementType, innerAdd, 2, args);
- auto addr = b.emitElementAddress(resultVar, loopCounter);
- b.emitStore(addr, elementResult);
- b.setInsertInto(loopBreakBlock);
- b.emitReturn(b.emitLoad(resultVar));
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeStructKey, diffArrayType);
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeWitnessStructKey, table);
}
+ else
{
- // Zero method.
- IRBuilder b = *builder;
- b.setInsertInto(zeroMethod);
- zeroMethod->setFullType(b.getFuncType(0, nullptr, diffArrayType));
- b.emitBlock();
-
- auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey);
- auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr);
- auto retVal = b.emitMakeArrayFromElement(diffArrayType, zeroVal);
- b.emitReturn(retVal);
+ SLANG_UNEXPECTED("Invalid conformance kind for synthesis");
}
- // Record this in the context for future lookups
- differentiableWitnessDictionary[(IRType*)arrayType] = table;
-
return table;
}
-IRInst* DifferentiableTypeConformanceContext::getTupleWitness(IRBuilder* builder, IRInst* inTupleType)
+IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
+ IRBuilder* builder,
+ IRInst* inTupleType,
+ DiffConformanceKind target)
{
// Differentiate the pair type to get it's differential (which is itself a pair)
auto diffTupleType = (IRType*)differentiateType(builder, (IRType*)inTupleType);
@@ -1017,100 +1253,116 @@ IRInst* DifferentiableTypeConformanceContext::getTupleWitness(IRBuilder* builder
if (!diffTupleType)
return nullptr;
- auto addMethod = builder->createFunc();
- auto zeroMethod = builder->createFunc();
-
- auto table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)inTupleType);
-
- // And place it in the synthesized witness table.
- builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffTupleType);
- builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table);
- builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
- builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
-
- // Fill in differential method implementations.
- {
- // Add method.
- IRBuilder b = *builder;
- b.setInsertInto(addMethod);
- b.addBackwardDifferentiableDecoration(addMethod);
- IRType* paramTypes[2] = { diffTupleType, diffTupleType };
- addMethod->setFullType(b.getFuncType(2, paramTypes, diffTupleType));
- b.emitBlock();
- auto p0 = b.emitParam(diffTupleType);
- auto p1 = b.emitParam(diffTupleType);
- List<IRInst*> results;
- for (UInt i = 0; i < inTupleType->getOperandCount(); i++)
- {
- auto elementType = inTupleType->getOperand(i);
- auto diffElementType = (IRType*)diffTupleType->getOperand(i);
- auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType);
- IRInst* elementResult = nullptr;
- if (!innerWitness)
+ IRWitnessTable* table = nullptr;
+ if (target == DiffConformanceKind::Value)
+ {
+ SLANG_ASSERT(isDifferentiableValueType((IRType*)inTupleType));
+
+ auto addMethod = builder->createFunc();
+ auto zeroMethod = builder->createFunc();
+
+ table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)inTupleType);
+
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffTupleType);
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table);
+ builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
+ builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+
+ // Fill in differential method implementations.
+ {
+ // Add method.
+ IRBuilder b = *builder;
+ b.setInsertInto(addMethod);
+ b.addBackwardDifferentiableDecoration(addMethod);
+ IRType* paramTypes[2] = { diffTupleType, diffTupleType };
+ addMethod->setFullType(b.getFuncType(2, paramTypes, diffTupleType));
+ b.emitBlock();
+ auto p0 = b.emitParam(diffTupleType);
+ auto p1 = b.emitParam(diffTupleType);
+ List<IRInst*> results;
+ for (UInt i = 0; i < inTupleType->getOperandCount(); i++)
{
- elementResult = b.getVoidValue();
+ auto elementType = inTupleType->getOperand(i);
+ auto diffElementType = (IRType*)diffTupleType->getOperand(i);
+ auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType, DiffConformanceKind::Value);
+ IRInst* elementResult = nullptr;
+ if (!innerWitness)
+ {
+ elementResult = b.getVoidValue();
+ }
+ else
+ {
+ auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey, sharedContext->addMethodType);
+ auto iVal = b.getIntValue(b.getIntType(), i);
+ IRInst* args[2] = {
+ b.emitGetTupleElement(diffElementType, p0, iVal),
+ b.emitGetTupleElement(diffElementType, p1, iVal) };
+ elementResult = b.emitCallInst(diffElementType, innerAdd, 2, args);
+ }
+ results.add(elementResult);
}
+ IRInst* resultVal = nullptr;
+ if (diffTupleType->getOp() == kIROp_TupleType)
+ resultVal = b.emitMakeTuple(diffTupleType, results);
else
- {
- auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey);
- auto iVal = b.getIntValue(b.getIntType(), i);
- IRInst* args[2] = {
- b.emitGetTupleElement(diffElementType, p0, iVal),
- b.emitGetTupleElement(diffElementType, p1, iVal) };
- elementResult = b.emitCallInst(diffElementType, innerAdd, 2, args);
- }
- results.add(elementResult);
+ resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer());
+ b.emitReturn(resultVal);
}
- IRInst* resultVal = nullptr;
- if (diffTupleType->getOp() == kIROp_TupleType)
- resultVal = b.emitMakeTuple(diffTupleType, results);
- else
- resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer());
- b.emitReturn(resultVal);
- }
- {
- // Zero method.
- IRBuilder b = *builder;
- b.setInsertInto(addMethod);
- b.addBackwardDifferentiableDecoration(addMethod);
- addMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType));
- b.emitBlock();
- List<IRInst*> results;
- for (UInt i = 0; i < inTupleType->getOperandCount(); i++)
- {
- auto elementType = inTupleType->getOperand(i);
- auto diffElementType = (IRType*)diffTupleType->getOperand(i);
- auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType);
- IRInst* elementResult = nullptr;
- if (!innerWitness)
+ {
+ // Zero method.
+ IRBuilder b = *builder;
+ b.setInsertInto(addMethod);
+ b.addBackwardDifferentiableDecoration(addMethod);
+ addMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType));
+ b.emitBlock();
+ List<IRInst*> results;
+ for (UInt i = 0; i < inTupleType->getOperandCount(); i++)
{
- elementResult = b.getVoidValue();
+ auto elementType = inTupleType->getOperand(i);
+ auto diffElementType = (IRType*)diffTupleType->getOperand(i);
+ auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType, DiffConformanceKind::Value);
+ IRInst* elementResult = nullptr;
+ if (!innerWitness)
+ {
+ elementResult = b.getVoidValue();
+ }
+ else
+ {
+ auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType);
+ elementResult = b.emitCallInst(diffElementType, innerZero, 0, nullptr);
+ }
+ results.add(elementResult);
}
+ IRInst* resultVal = nullptr;
+ if (diffTupleType->getOp() == kIROp_TupleType)
+ resultVal = b.emitMakeTuple(diffTupleType, results);
else
- {
- auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey);
- elementResult = b.emitCallInst(diffElementType, innerZero, 0, nullptr);
- }
- results.add(elementResult);
+ resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer());
+ b.emitReturn(resultVal);
}
- IRInst* resultVal = nullptr;
- if (diffTupleType->getOp() == kIROp_TupleType)
- resultVal = b.emitMakeTuple(diffTupleType, results);
- else
- resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer());
- b.emitReturn(resultVal);
}
+ else if (target == DiffConformanceKind::Ptr)
+ {
+ SLANG_ASSERT(isDifferentiablePtrType((IRType*)inTupleType));
- // Record this in the context for future lookups
- differentiableWitnessDictionary[(IRType*)inTupleType] = table;
+ table = builder->createWitnessTable(sharedContext->differentiablePtrInterfaceType, (IRType*)inTupleType);
+
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeStructKey, diffTupleType);
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeWitnessStructKey, table);
+ }
return table;
}
-IRInst* DifferentiableTypeConformanceContext::getExtractExistensialTypeWitness(
+IRInst* DifferentiableTypeConformanceContext::buildExtractExistensialTypeWitness(
IRBuilder* builder,
- IRExtractExistentialType* extractExistentialType)
+ IRExtractExistentialType* extractExistentialType,
+ DiffConformanceKind target)
{
+ SLANG_UNUSED(target); // logic is the same for both value and ptr
+
// Check that the type's base is differentiable
if (differentiateType(builder, extractExistentialType->getOperand(0)->getDataType()))
{
@@ -1310,12 +1562,13 @@ bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst*
if (context.isDifferentiableType((IRType*)typeInst))
return true;
+
// Look for equivalent types.
- for (auto type : context.differentiableWitnessDictionary)
+ for (auto type : context.differentiableTypeWitnessDictionary)
{
if (isTypeEqual(type.key, (IRType*)typeInst))
{
- context.differentiableWitnessDictionary[(IRType*)typeInst] = type.value;
+ context.differentiableTypeWitnessDictionary[(IRType*)typeInst] = type.value;
return true;
}
}
@@ -1672,7 +1925,7 @@ struct AutoDiffPass : public InstPassBase
IRBuilder keyBuilder = builder;
keyBuilder.setInsertBefore(maybeFindOuterGeneric(originalType));
auto diffKey = keyBuilder.createStructKey();
- auto diffFieldType = _lookupWitness(&keyBuilder, diffFieldWitness, autodiffContext->differentialAssocTypeStructKey);
+ auto diffFieldType = _lookupWitness(&keyBuilder, diffFieldWitness, autodiffContext->differentialAssocTypeStructKey, builder.getTypeKind());
info.field = builder.createStructField(diffType, diffKey, (IRType*)diffFieldType);
info.witness = diffFieldWitness;
builder.addDecoration(field->getKey(), kIROp_DerivativeMemberDecoration, diffKey);
@@ -1695,7 +1948,11 @@ struct AutoDiffPass : public InstPassBase
List<IRInst*> fieldVals;
for (auto info : diffFields)
{
- auto innerZeroMethod = _lookupWitness(&builder, info.witness, autodiffContext->zeroMethodStructKey);
+ auto innerZeroMethod = _lookupWitness(
+ &builder,
+ info.witness,
+ autodiffContext->zeroMethodStructKey,
+ autodiffContext->zeroMethodType);
IRInst* val = builder.emitCallInst(info.field->getFieldType(), innerZeroMethod, 0, nullptr);
fieldVals.add(val);
}
@@ -1719,7 +1976,11 @@ struct AutoDiffPass : public InstPassBase
List<IRInst*> fieldVals;
for (auto info : diffFields)
{
- auto innerAddMethod = _lookupWitness(&builder, info.witness, autodiffContext->addMethodStructKey);
+ auto innerAddMethod = _lookupWitness(
+ &builder,
+ info.witness,
+ autodiffContext->addMethodStructKey,
+ autodiffContext->addMethodType);
IRInst* args[2] = {
builder.emitFieldExtract(info.field->getFieldType(), param1, info.field->getKey()),
builder.emitFieldExtract(info.field->getFieldType(), param2, info.field->getKey()),
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index 812471fe3..ad2486aad 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -57,6 +57,14 @@ struct DiffTranscriberSet
AutoDiffTranscriberBase* backwardTranscriber = nullptr;
};
+
+enum class DiffConformanceKind
+{
+ Any = 0, // Perform actions for any conformance (infer from context)
+ Ptr = 1, // Perform actions for IDifferentiablePtrType
+ Value = 2 // Perform actions for IDifferentiable
+};
+
struct AutoDiffSharedContext
{
TargetProgram* targetProgram = nullptr;
@@ -78,6 +86,7 @@ struct AutoDiffSharedContext
// The struct key for the witness that `Differential` associated type conforms to
// `IDifferential`.
IRStructKey* differentialAssocTypeWitnessStructKey = nullptr;
+ IRWitnessTableType* differentialAssocTypeWitnessTableType = nullptr;
// The struct key for the 'zero()' associated type
@@ -85,12 +94,14 @@ struct AutoDiffSharedContext
// implementation of zero() for a given type.
//
IRStructKey* zeroMethodStructKey = nullptr;
+ IRFuncType* zeroMethodType = nullptr;
// The struct key for the 'add()' associated type
// defined inside IDifferential. We use this to lookup the
// implementation of add() for a given type.
//
IRStructKey* addMethodStructKey = nullptr;
+ IRFuncType* addMethodType = nullptr;
IRStructKey* mulMethodStructKey = nullptr;
@@ -104,12 +115,27 @@ struct AutoDiffSharedContext
//
IRInst* nullDifferentialWitness = nullptr;
+
+ // A reference to the builtin IDifferentiablePtrType interface type.
+ IRInterfaceType* differentiablePtrInterfaceType = nullptr;
+
+ // The struct key for the 'Differential' associated type
+ // defined inside IDifferentialPtrType. We use this to lookup the differential
+ // type in the conformance table associated with the concrete type.
+ //
+ IRStructKey* differentialAssocRefTypeStructKey = nullptr;
+
+ // The struct key for the witness that `Differential` associated type conforms to
+ // `IDifferentialPtrType`.
+ IRStructKey* differentialAssocRefTypeWitnessStructKey = nullptr;
+ IRWitnessTableType* differentialAssocRefTypeWitnessTableType = nullptr;
// Modules that don't use differentiable types
// won't have the IDifferentiable interface type available.
// Set to false to indicate that we are uninitialized.
//
bool isInterfaceAvailable = false;
+ bool isPtrInterfaceAvailable = false;
List<FuncBodyTranscriptionTask> followUpFunctionsToTranscribe;
@@ -127,38 +153,70 @@ private:
IRStructKey* findDifferentialTypeStructKey()
{
- return getIDifferentiableStructKeyAtIndex(0);
+ return cast<IRStructKey>(
+ getInterfaceEntryAtIndex(differentiableInterfaceType, 0)->getRequirementKey());
}
IRStructKey* findDifferentialTypeWitnessStructKey()
{
- return getIDifferentiableStructKeyAtIndex(1);
+ return cast<IRStructKey>(
+ getInterfaceEntryAtIndex(differentiableInterfaceType, 1)->getRequirementKey());
+ }
+
+ IRWitnessTableType* findDifferentialTypeWitnessTableType()
+ {
+ return cast<IRWitnessTableType>(
+ getInterfaceEntryAtIndex(differentiableInterfaceType, 1)->getRequirementVal());
}
IRStructKey* findZeroMethodStructKey()
{
- return getIDifferentiableStructKeyAtIndex(2);
+ return cast<IRStructKey>(
+ getInterfaceEntryAtIndex(differentiableInterfaceType, 2)->getRequirementKey());
}
IRStructKey* findAddMethodStructKey()
{
- return getIDifferentiableStructKeyAtIndex(3);
+ return cast<IRStructKey>(
+ getInterfaceEntryAtIndex(differentiableInterfaceType, 3)->getRequirementKey());
}
IRStructKey* findMulMethodStructKey()
{
- return getIDifferentiableStructKeyAtIndex(4);
+ return cast<IRStructKey>(
+ getInterfaceEntryAtIndex(differentiableInterfaceType, 4)->getRequirementKey());
+ }
+
+
+ IRStructKey* findDifferentialPtrTypeStructKey()
+ {
+ return cast<IRStructKey>(
+ getInterfaceEntryAtIndex(differentiablePtrInterfaceType, 0)->getRequirementKey());
}
- IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index);
+ IRStructKey* findDifferentialPtrTypeWitnessStructKey()
+ {
+ return cast<IRStructKey>(
+ getInterfaceEntryAtIndex(differentiablePtrInterfaceType, 1)->getRequirementKey());
+ }
+
+ IRWitnessTableType* findDifferentialPtrTypeWitnessTableType()
+ {
+ return cast<IRWitnessTableType>(
+ getInterfaceEntryAtIndex(differentiablePtrInterfaceType, 1)->getRequirementVal());
+ }
+
+ //IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index);
+ IRInterfaceRequirementEntry* getInterfaceEntryAtIndex(IRInterfaceType* interface, UInt index);
};
+
struct DifferentiableTypeConformanceContext
{
AutoDiffSharedContext* sharedContext;
IRGlobalValueWithCode* parentFunc = nullptr;
- OrderedDictionary<IRType*, IRInst*> differentiableWitnessDictionary;
+ OrderedDictionary<IRType*, IRInst*> differentiableTypeWitnessDictionary;
IRFunc* existentialDAddFunc = nullptr;
@@ -167,7 +225,7 @@ struct DifferentiableTypeConformanceContext
{
// Populate dictionary with null differential type.
if (sharedContext->nullDifferentialStructType)
- differentiableWitnessDictionary.add(
+ differentiableTypeWitnessDictionary.add(
sharedContext->nullDifferentialStructType,
sharedContext->nullDifferentialWitness);
}
@@ -179,21 +237,13 @@ struct DifferentiableTypeConformanceContext
// Lookup a witness table for the concreteType. One should exist if concreteType
// inherits (successfully) from IDifferentiable.
//
- IRInst* lookUpConformanceForType(IRInst* type);
+ IRInst* lookUpConformanceForType(IRInst* type, DiffConformanceKind kind);
- IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key);
+ IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key, IRType* resultType = nullptr);
IRType* differentiateType(IRBuilder* builder, IRInst* primalType);
- IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType);
-
- IRInst* getOrCreateDifferentiablePairWitness(IRBuilder* builder, IRDifferentialPairTypeBase* pairType);
-
- IRInst* getArrayWitness(IRBuilder* builder, IRArrayType* pairType);
-
- IRInst* getTupleWitness(IRBuilder* builder, IRInst* tupleType);
-
- IRInst* getExtractExistensialTypeWitness(IRBuilder* builder, IRExtractExistentialType* extractExistentialType);
+ IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType, DiffConformanceKind kind);
IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness);
@@ -207,17 +257,21 @@ struct DifferentiableTypeConformanceContext
IRInst* getDiffAddMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type);
+ void addTypeToDictionary(IRType* type, IRInst* witness);
+
+ IRInterfaceType* getConformanceTypeFromWitness(IRInst* witness);
+
IRInst* tryExtractConformanceFromInterfaceType(
IRBuilder* builder,
IRInterfaceType* interfaceType,
IRWitnessTable* witnessTable);
- List<IRInterfaceRequirementEntry*> findDifferentiableInterfaceLookupPath(
- IRInterfaceType* idiffType,
+ List<IRInterfaceRequirementEntry*> findInterfaceLookupPath(
+ IRInterfaceType* supType,
IRInterfaceType* type);
// Lookup and return the 'Differential' type declared in the concrete type
- // in order to conform to the IDifferentiable interface.
+ // in order to conform to the IDifferentiable/IDifferentiablePtrType interfaces
// Note that inside a generic block, this will be a witness table lookup instruction
// that gets resolved during the specialization pass.
//
@@ -227,8 +281,10 @@ struct DifferentiableTypeConformanceContext
{
case kIROp_InterfaceType:
{
- if (isDifferentiableType(origType))
+ if (isDifferentiableValueType(origType))
return this->sharedContext->differentiableInterfaceType;
+ else if (isDifferentiablePtrType(origType))
+ return this->sharedContext->differentiablePtrInterfaceType;
else
return nullptr;
}
@@ -254,13 +310,30 @@ struct DifferentiableTypeConformanceContext
auto diffWitness = getDiffTypeWitnessFromPairType(builder, diffPairType);
return builder->getDifferentialPairUserCodeType((IRType*)diffType, diffWitness);
}
+ case kIROp_DifferentialPtrPairType:
+ {
+ auto diffPairType = as<IRDifferentialPairTypeBase>(origType);
+ auto diffType = getDiffTypeFromPairType(builder, diffPairType);
+ auto diffWitness = getDiffTypeWitnessFromPairType(builder, diffPairType);
+ return builder->getDifferentialPtrPairType((IRType*)diffType, diffWitness);
+ }
default:
- return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey);
+ if (isDifferentiableValueType(origType))
+ return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey, builder->getTypeKind());
+ else if (isDifferentiablePtrType(origType))
+ return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocRefTypeStructKey, builder->getTypeKind());
+ else
+ return nullptr;
}
}
bool isDifferentiableType(IRType* origType)
{
+ return isDifferentiableValueType(origType) || isDifferentiablePtrType(origType);
+ }
+
+ bool isDifferentiableValueType(IRType* origType)
+ {
for (; origType;)
{
switch (origType->getOp())
@@ -279,7 +352,27 @@ struct DifferentiableTypeConformanceContext
origType = (IRType*)origType->getOperand(0);
continue;
default:
- return lookUpConformanceForType(origType) != nullptr;
+ return lookUpConformanceForType(origType, DiffConformanceKind::Value) != nullptr;
+ }
+ }
+ return false;
+ }
+
+ bool isDifferentiablePtrType(IRType* origType)
+ {
+ for (; origType;)
+ {
+ switch (origType->getOp())
+ {
+ case kIROp_VectorType:
+ case kIROp_ArrayType:
+ case kIROp_PtrType:
+ case kIROp_OutType:
+ case kIROp_InOutType:
+ origType = (IRType*)origType->getOperand(0);
+ continue;
+ default:
+ return lookUpConformanceForType(origType, DiffConformanceKind::Ptr) != nullptr;
}
}
return false;
@@ -287,13 +380,13 @@ struct DifferentiableTypeConformanceContext
IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType)
{
- auto result = lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey);
+ auto result = lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType);
return result;
}
IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType)
{
- auto result = lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey);
+ auto result = lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey, sharedContext->addMethodType);
return result;
}
@@ -307,8 +400,28 @@ struct DifferentiableTypeConformanceContext
IRFunc* getOrCreateExistentialDAddMethod();
+ IRInst* buildDifferentiablePairWitness(
+ IRBuilder* builder,
+ IRDifferentialPairTypeBase* pairType,
+ DiffConformanceKind target);
+
+ IRInst* buildArrayWitness(
+ IRBuilder* builder,
+ IRArrayType* pairType,
+ DiffConformanceKind target);
+
+ IRInst* buildTupleWitness(
+ IRBuilder* builder,
+ IRInst* tupleType,
+ DiffConformanceKind target);
+
+ IRInst* buildExtractExistensialTypeWitness(
+ IRBuilder* builder,
+ IRExtractExistentialType* extractExistentialType,
+ DiffConformanceKind target);
};
+
struct DifferentialPairTypeBuilder
{
DifferentialPairTypeBuilder() = default;
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 8b4886a2c..cae47fffd 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -625,7 +625,7 @@ public:
}
}
- if (!sharedContext.isInterfaceAvailable)
+ if (!sharedContext.isInterfaceAvailable && !sharedContext.isPtrInterfaceAvailable)
return;
for (auto inst : module->getGlobalInsts())
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 301a9c789..0d689660e 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -61,7 +61,8 @@ INST(Nop, nop, 0, 0)
INST(DifferentialPairType, DiffPair, 1, HOISTABLE)
INST(DifferentialPairUserCodeType, DiffPairUserCode, 1, HOISTABLE)
- INST_RANGE(DifferentialPairTypeBase, DifferentialPairType, DifferentialPairUserCodeType)
+ INST(DifferentialPtrPairType, DiffRefPair, 1, HOISTABLE)
+ INST_RANGE(DifferentialPairTypeBase, DifferentialPairType, DifferentialPtrPairType)
INST(BackwardDiffIntermediateContextType, BwdDiffIntermediateCtxType, 1, HOISTABLE)
@@ -325,15 +326,18 @@ INST(DefaultConstruct, defaultConstruct, 0, 0)
INST(MakeDifferentialPair, MakeDiffPair, 2, 0)
INST(MakeDifferentialPairUserCode, MakeDiffPairUserCode, 2, 0)
-INST_RANGE(MakeDifferentialPairBase, MakeDifferentialPair, MakeDifferentialPairUserCode)
+INST(MakeDifferentialPtrPair, MakeDiffRefPair, 2, 0)
+INST_RANGE(MakeDifferentialPairBase, MakeDifferentialPair, MakeDifferentialPtrPair)
INST(DifferentialPairGetDifferential, GetDifferential, 1, 0)
INST(DifferentialPairGetDifferentialUserCode, GetDifferentialUserCode, 1, 0)
-INST_RANGE(DifferentialPairGetDifferentialBase, DifferentialPairGetDifferential, DifferentialPairGetDifferentialUserCode)
+INST(DifferentialPtrPairGetDifferential, GetDifferentialPtr, 1, 0)
+INST_RANGE(DifferentialPairGetDifferentialBase, DifferentialPairGetDifferential, DifferentialPtrPairGetDifferential)
INST(DifferentialPairGetPrimal, GetPrimal, 1, 0)
INST(DifferentialPairGetPrimalUserCode, GetPrimalUserCode, 1, 0)
-INST_RANGE(DifferentialPairGetPrimalBase, DifferentialPairGetPrimal, DifferentialPairGetPrimalUserCode)
+INST(DifferentialPtrPairGetPrimal, GetPrimalRef, 1, 0)
+INST_RANGE(DifferentialPairGetPrimalBase, DifferentialPairGetPrimal, DifferentialPtrPairGetPrimal)
INST(Specialize, specialize, 2, HOISTABLE)
INST(LookupWitness, lookupWitness, 2, HOISTABLE)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 37f242e55..f31a56673 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2959,6 +2959,10 @@ struct IRMakeDifferentialPairUserCode : IRMakeDifferentialPairBase
{
IR_LEAF_ISA(MakeDifferentialPairUserCode)
};
+struct IRMakeDifferentialPtrPair : IRMakeDifferentialPairBase
+{
+ IR_LEAF_ISA(MakeDifferentialPtrPair)
+};
struct IRDifferentialPairGetDifferentialBase : IRInst
{
@@ -2973,6 +2977,10 @@ struct IRDifferentialPairGetDifferentialUserCode : IRDifferentialPairGetDifferen
{
IR_LEAF_ISA(DifferentialPairGetDifferentialUserCode)
};
+struct IRDifferentialPtrPairGetDifferential : IRDifferentialPairGetDifferentialBase
+{
+ IR_LEAF_ISA(DifferentialPtrPairGetDifferential)
+};
struct IRDifferentialPairGetPrimalBase : IRInst
{
@@ -2987,6 +2995,10 @@ struct IRDifferentialPairGetPrimalUserCode : IRDifferentialPairGetPrimalBase
{
IR_LEAF_ISA(DifferentialPairGetPrimalUserCode)
};
+struct IRDifferentialPtrPairGetPrimal : IRDifferentialPairGetPrimalBase
+{
+ IR_LEAF_ISA(DifferentialPtrPairGetPrimal)
+};
struct IRDetachDerivative : IRInst
{
@@ -3657,6 +3669,10 @@ public:
IRDifferentialPairType* getDifferentialPairType(
IRType* valueType,
IRInst* witnessTable);
+
+ IRDifferentialPtrPairType* getDifferentialPtrPairType(
+ IRType* valueType,
+ IRInst* witnessTable);
IRDifferentialPairUserCodeType* getDifferentialPairUserCodeType(
IRType* valueType,
@@ -3797,6 +3813,8 @@ public:
IRInst* emitGetTorchCudaStream();
IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential);
+ IRInst* emitMakeDifferentialValuePair(IRType* type, IRInst* primal, IRInst* differential);
+ IRInst* emitMakeDifferentialPtrPair(IRType* type, IRInst* primal, IRInst* differential);
IRInst* emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential);
IRInst* addDifferentiableTypeDictionaryDecoration(IRInst* target);
@@ -3979,9 +3997,19 @@ public:
IRInst* emitGetOptionalValue(IRInst* optValue);
IRInst* emitMakeOptionalValue(IRInst* optType, IRInst* value);
IRInst* emitMakeOptionalNone(IRInst* optType, IRInst* defaultValue);
+
IRInst* emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair);
+ IRInst* emitDifferentialValuePairGetDifferential(IRType* diffType, IRInst* diffPair);
+ IRInst* emitDifferentialPtrPairGetDifferential(IRType* diffType, IRInst* diffPair);
+
IRInst* emitDifferentialPairGetPrimal(IRInst* diffPair);
+ IRInst* emitDifferentialValuePairGetPrimal(IRInst* diffPair);
+ IRInst* emitDifferentialPtrPairGetPrimal(IRInst* diffPair);
+
IRInst* emitDifferentialPairGetPrimal(IRType* primalType, IRInst* diffPair);
+ IRInst* emitDifferentialValuePairGetPrimal(IRType* primalType, IRInst* diffPair);
+ IRInst* emitDifferentialPtrPairGetPrimal(IRType* primalType, IRInst* diffPair);
+
IRInst* emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair);
IRInst* emitDifferentialPairGetPrimalUserCode(IRInst* diffPair);
IRInst* emitMakeVector(
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 6c7691d13..b89929f55 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3022,6 +3022,17 @@ namespace Slang
operands);
}
+ IRDifferentialPtrPairType* IRBuilder::getDifferentialPtrPairType(
+ IRType* valueType,
+ IRInst* witnessTable)
+ {
+ IRInst* operands[] = { valueType, witnessTable };
+ return (IRDifferentialPtrPairType*)getType(
+ kIROp_DifferentialPtrPairType,
+ sizeof(operands) / sizeof(operands[0]),
+ operands);
+ }
+
IRDifferentialPairUserCodeType* IRBuilder::getDifferentialPairUserCodeType(
IRType* valueType,
IRInst* witnessTable)
@@ -3503,7 +3514,7 @@ namespace Slang
return inst;
}
- IRInst* IRBuilder::emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential)
+ IRInst* IRBuilder::emitMakeDifferentialValuePair(IRType* type, IRInst* primal, IRInst* differential)
{
SLANG_RELEASE_ASSERT(as<IRDifferentialPairType>(type));
SLANG_RELEASE_ASSERT(as<IRDifferentialPairType>(type)->getValueType() != nullptr);
@@ -3516,6 +3527,98 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitMakeDifferentialPtrPair(IRType* type, IRInst* primal, IRInst* differential)
+ {
+ SLANG_RELEASE_ASSERT(as<IRDifferentialPtrPairType>(type));
+ SLANG_RELEASE_ASSERT(as<IRDifferentialPtrPairType>(type)->getValueType() != nullptr);
+
+ IRInst* args[] = {primal, differential};
+ auto inst = createInstWithTrailingArgs<IRMakeDifferentialPtrPair>(
+ this, kIROp_MakeDifferentialPtrPair, type, 2, args);
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitMakeDifferentialPair(IRType* pairType, IRInst* primalVal, IRInst* diffVal)
+ {
+ if (as<IRDifferentialPairType>(pairType))
+ {
+ return emitMakeDifferentialValuePair(pairType, primalVal, diffVal);
+ }
+ else if (as<IRDifferentialPtrPairType>(pairType))
+ {
+ // Quick optimization:
+ // If primalVal and diffVal are extracted from the same pointer-pair,
+ // we can just use the pointer-pair directly.
+ //
+ if (auto primalPtrVal = as<IRDifferentialPtrPairGetPrimal>(primalVal))
+ {
+ if (auto diffPtrVal = as<IRDifferentialPtrPairGetDifferential>(diffVal))
+ {
+ if (primalPtrVal->getBase() == diffPtrVal->getBase())
+ return primalPtrVal->getBase();
+ }
+ }
+ return emitMakeDifferentialPtrPair(pairType, primalVal, diffVal);
+ }
+ else
+ {
+ SLANG_ASSERT(!"unreachable");
+ return nullptr;
+ }
+ }
+
+ IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* pairVal)
+ {
+ if (as<IRDifferentialPairType>(pairVal->getDataType()))
+ {
+ return emitDifferentialValuePairGetDifferential(diffType, pairVal);
+ }
+ else if (as<IRDifferentialPtrPairType>(pairVal->getDataType()))
+ {
+ return emitDifferentialPtrPairGetDifferential(diffType, pairVal);
+ }
+ else
+ {
+ SLANG_ASSERT(!"unreachable");
+ return nullptr;
+ }
+ }
+
+ IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* pairVal)
+ {
+ if (as<IRDifferentialPairType>(pairVal->getDataType()))
+ {
+ return emitDifferentialValuePairGetPrimal(pairVal);
+ }
+ else if (as<IRDifferentialPtrPairType>(pairVal->getDataType()))
+ {
+ return emitDifferentialPtrPairGetPrimal(pairVal);
+ }
+ else
+ {
+ SLANG_ASSERT(!"unreachable");
+ return nullptr;
+ }
+ }
+
+ IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRType* primalType, IRInst* pairVal)
+ {
+ if (as<IRDifferentialPairType>(pairVal->getDataType()))
+ {
+ return emitDifferentialValuePairGetPrimal(primalType, pairVal);
+ }
+ else if (as<IRDifferentialPtrPairType>(pairVal->getDataType()))
+ {
+ return emitDifferentialPtrPairGetPrimal(primalType, pairVal);
+ }
+ else
+ {
+ SLANG_ASSERT(!"unreachable");
+ return nullptr;
+ }
+ }
+
IRInst* IRBuilder::emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential)
{
SLANG_RELEASE_ASSERT(as<IRDifferentialPairTypeBase>(type));
@@ -4222,7 +4325,7 @@ namespace Slang
return emitIntrinsicInst(type, kIROp_MakeVector, argCount, args);
}
- IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair)
+ IRInst* IRBuilder::emitDifferentialValuePairGetDifferential(IRType* diffType, IRInst* diffPair)
{
SLANG_ASSERT(as<IRDifferentialPairTypeBase>(diffPair->getDataType()));
return emitIntrinsicInst(
@@ -4232,7 +4335,18 @@ namespace Slang
&diffPair);
}
- IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair)
+
+ IRInst* IRBuilder::emitDifferentialPtrPairGetDifferential(IRType* diffType, IRInst* diffPair)
+ {
+ SLANG_ASSERT(as<IRDifferentialPtrPairType>(diffPair->getDataType()));
+ return emitIntrinsicInst(
+ diffType,
+ kIROp_DifferentialPtrPairGetDifferential,
+ 1,
+ &diffPair);
+ }
+
+ IRInst* IRBuilder::emitDifferentialValuePairGetPrimal(IRInst* diffPair)
{
auto valueType = cast<IRDifferentialPairTypeBase>(diffPair->getDataType())->getValueType();
return emitIntrinsicInst(
@@ -4242,7 +4356,7 @@ namespace Slang
&diffPair);
}
- IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRType* primalType, IRInst* diffPair)
+ IRInst* IRBuilder::emitDifferentialValuePairGetPrimal(IRType* primalType, IRInst* diffPair)
{
return emitIntrinsicInst(
primalType,
@@ -4251,6 +4365,25 @@ namespace Slang
&diffPair);
}
+ IRInst* IRBuilder::emitDifferentialPtrPairGetPrimal(IRInst* diffPair)
+ {
+ auto valueType = cast<IRDifferentialPairTypeBase>(diffPair->getDataType())->getValueType();
+ return emitIntrinsicInst(
+ valueType,
+ kIROp_DifferentialPtrPairGetPrimal,
+ 1,
+ &diffPair);
+ }
+
+ IRInst* IRBuilder::emitDifferentialPtrPairGetPrimal(IRType* primalType, IRInst* diffPair)
+ {
+ return emitIntrinsicInst(
+ primalType,
+ kIROp_DifferentialPtrPairGetPrimal,
+ 1,
+ &diffPair);
+ }
+
IRInst* IRBuilder::emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair)
{
SLANG_ASSERT(as<IRDifferentialPairTypeBase>(diffPair->getDataType()));
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 375107d1d..14dde200f 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1660,6 +1660,11 @@ struct IRDifferentialPairType : IRDifferentialPairTypeBase
IR_LEAF_ISA(DifferentialPairType)
};
+struct IRDifferentialPtrPairType : IRDifferentialPairTypeBase
+{
+ IR_LEAF_ISA(DifferentialPtrPairType)
+};
+
struct IRDifferentialPairUserCodeType : IRDifferentialPairTypeBase
{
IR_LEAF_ISA(DifferentialPairUserCodeType)