summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-04-26 17:37:04 -0700
committerGitHub <noreply@github.com>2023-04-26 17:37:04 -0700
commitfc54adee1f7f0ba18591fc84ce5d51ac23afa954 (patch)
tree4727ed6109ac50e95c49aadcebc0fb8b95495739 /source
parent61eb17b0b556ccc06f65f921bb0a4ea2784c4e20 (diff)
Autodiff support for dynamically dispatched generic method. (#2846)
* Autodiff support for dynamically dispatched generic method. * Fix. * Support dynamically dispatched generic type. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-decl.h3
-rw-r--r--source/slang/slang-check-decl.cpp8
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h13
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h15
-rw-r--r--source/slang/slang-ir-autodiff.cpp12
-rw-r--r--source/slang/slang-ir-autodiff.h4
-rw-r--r--source/slang/slang-ir-lower-witness-lookup.cpp1
-rw-r--r--source/slang/slang-ir.cpp3
-rw-r--r--source/slang/slang-lower-to-ir.cpp14
12 files changed, 51 insertions, 28 deletions
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index ccbac0286..e75660c7b 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -526,6 +526,9 @@ class AttributeDecl : public ContainerDecl
class DerivativeRequirementDecl : public FunctionDeclBase
{
SLANG_AST_CLASS(DerivativeRequirementDecl)
+
+ // The original requirement decl.
+ Decl* originalRequirementDecl = nullptr;
};
// A reference to a synthesized decl representing a differentiable function requirement, this decl will
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 0901d2026..b3470e882 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -1427,6 +1427,7 @@ namespace Slang
varDecl->initExpr = CompleteOverloadCandidate(overloadContext, *overloadContext.bestCandidate);
}
}
+ maybeRegisterDifferentiableType(getASTBuilder(), varDecl->getType());
}
// Fill in default substitutions for the 'subtype' part of a type constraint decl
@@ -4738,7 +4739,6 @@ namespace Slang
void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)
{
auto newContext = withParentFunc(decl);
-
if (newContext.getParentDifferentiableAttribute())
{
// Register additional types outside the function body first.
@@ -5638,11 +5638,8 @@ namespace Slang
bool isDiffFunc = false;
if (decl->hasModifier<ForwardDifferentiableAttribute>() || decl->hasModifier<BackwardDifferentiableAttribute>())
{
- if (GetOuterGeneric(decl))
- {
- getSink()->diagnose(decl, Diagnostics::differentiableGenericInterfaceMethodNotSupported);
- }
auto reqDecl = m_astBuilder->create<ForwardDerivativeRequirementDecl>();
+ reqDecl->originalRequirementDecl = decl;
cloneModifiers(reqDecl, decl);
auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl));
auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef));
@@ -5664,6 +5661,7 @@ namespace Slang
auto diffFuncType = as<FuncType>(getBackwardDiffFuncType(originalFuncType));
{
auto reqDecl = m_astBuilder->create<BackwardDerivativeRequirementDecl>();
+ reqDecl->originalRequirementDecl = decl;
cloneModifiers(reqDecl, decl);
setFuncTypeIntoRequirementDecl(reqDecl, diffFuncType);
interfaceDecl->members.add(reqDecl);
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index ec8131824..cb441ade8 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -359,8 +359,6 @@ DIAGNOSTIC(31146, Error, declAlreadyHasAttribute, "'$0' already has attribute '[
DIAGNOSTIC(31147, Error, cannotResolveOriginalFunctionForDerivative, "cannot resolve the original function for the the custom derivative.")
DIAGNOSTIC(31148, Error, cannotResolveDerivativeFunction, "cannot resolve the custom derivative function")
-DIAGNOSTIC(31149, Error, differentiableGenericInterfaceMethodNotSupported, "`[ForwardDifferentiable] and [BackwardDifferentiable] are not supported on generic interface requirements.")
-
DIAGNOSTIC(31200, Warning, deprecatedUsage, "$0 has been deprecated: $1")
// Enums
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index e0b916090..819c6bc57 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -955,7 +955,7 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpec
builder->getTypeKind(), diffBaseSpecialize->getBase(), args.getCount(), args.getBuffer());
return InstPair(primalSpecialize, diffSpecialize);
}
- else if (_isDifferentiableFunc(genericInnerVal))
+ else if (_isDifferentiableFunc(genericInnerVal) || as<IRFuncType>(genericInnerVal))
{
List<IRInst*> args;
for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 2994a8c31..e5735b831 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -1273,7 +1273,7 @@ namespace Slang
return InstPair(primalSpecialize, diffSpecialize);
}
- else if (isBackwardDifferentiableFunc(genericInnerVal))
+ else if (isBackwardDifferentiableFunc(genericInnerVal) || as<IRFuncType>(genericInnerVal))
{
List<IRInst*> args;
for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 8a734446d..910c23708 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -655,13 +655,6 @@ struct DiffTransposePass
subBuilder.addBackwardDerivativePrimalReturnDecoration(branch, retVal);
}
- // TODO: Should move this to before all the transposition, but a lot of the
- // transposition logic seems to access the parent of blocks to find the func.
- // Replace those uses.
- //
- for (auto block : workList)
- block->removeFromParent();
-
// At this point, the only block left without terminator insts
// should be the last one. Add a void return to complete it.
//
@@ -1101,7 +1094,7 @@ struct DiffTransposePass
};
List<DiffValWriteBack> writebacks;
- auto baseFnType = as<IRFuncType>(baseFn->getDataType());
+ auto baseFnType = as<IRFuncType>(getResolvedInstForDecorations(baseFn->getDataType()));
SLANG_RELEASE_ASSERT(baseFnType);
SLANG_RELEASE_ASSERT(fwdCall->getArgCount() == baseFnType->getParamCount());
@@ -1151,8 +1144,8 @@ struct DiffTransposePass
auto pairType = as<IRDifferentialPairType>(arg->getDataType());
auto var = builder->emitVar(arg->getDataType());
- auto diffType = (IRType*)diffTypeContext.getDifferentialForType(builder, pairType->getValueType());
- auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, pairType->getValueType());
+ auto diffType = (IRType*)diffTypeContext.getDiffTypeFromPairType(builder, pairType);
+ auto zeroMethod = diffTypeContext.getDiffZeroMethodFromPairType(builder, pairType);
SLANG_ASSERT(zeroMethod);
auto diffZero = builder->emitCallInst(
diffType,
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 34f0f6c9b..63b46f779 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -210,8 +210,8 @@ struct DiffUnzipPass
auto baseFn = _getOriginalFunc(mixedCall);
SLANG_RELEASE_ASSERT(baseFn);
- auto primalFuncType = autodiffContext->transcriberSet.primalTranscriber->differentiateFunctionType(
- primalBuilder, baseFn, as<IRFuncType>(baseFn->getDataType()));
+ auto primalFuncType = autodiffContext->transcriberSet.primalTranscriber->transcribe(
+ primalBuilder, baseFn->getDataType());
IRInst* intermediateType = nullptr;
@@ -251,12 +251,12 @@ struct DiffUnzipPass
intermediateVar = primalBuilder->emitVar((IRType*)intermediateType);
primalBuilder->markInstAsPrimal(intermediateVar);
}
-
+
IRInst* primalFn = nullptr;
if (intermediateVar)
{
primalBuilder->addBackwardDerivativePrimalContextDecoration(intermediateVar, intermediateVar);
- primalFn = primalBuilder->emitBackwardDifferentiatePrimalInst(primalFuncType, baseFn);
+ primalFn = primalBuilder->emitBackwardDifferentiatePrimalInst((IRType*)primalFuncType, baseFn);
}
else
{
@@ -298,7 +298,10 @@ struct DiffUnzipPass
primalBuilder->addBackwardDerivativePrimalContextDecoration(primalVal, intermediateVar);
primalBuilder->markInstAsPrimal(primalVal);
- SLANG_RELEASE_ASSERT(mixedCall->getArgCount() <= primalFuncType->getParamCount());
+ auto resolvedPrimalFuncType = as<IRFuncType>(getResolvedInstForDecorations(primalFuncType));
+ SLANG_RELEASE_ASSERT(resolvedPrimalFuncType);
+
+ SLANG_RELEASE_ASSERT(mixedCall->getArgCount() <= resolvedPrimalFuncType->getParamCount());
List<IRInst*> diffArgs;
for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++)
@@ -316,7 +319,7 @@ struct DiffUnzipPass
// If arg is a mixed differential (pair), it should have already been split.
SLANG_ASSERT(primalArg);
SLANG_ASSERT(diffArg);
- auto primalParamType = primalFuncType->getParamType(ii);
+ auto primalParamType = resolvedPrimalFuncType->getParamType(ii);
if (auto outType = as<IROutType>(primalParamType))
{
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 4188d2ec8..4e33a01ab 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -458,6 +458,18 @@ IRInst* DifferentiableTypeConformanceContext::getDiffTypeWitnessFromPairType(IRB
return _getDiffTypeWitnessFromPairType(sharedContext, builder, type);
}
+IRInst* DifferentiableTypeConformanceContext::getDiffZeroMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type)
+{
+ auto witnessTable = type->getWitness();
+ return _lookupWitness(builder, witnessTable, sharedContext->zeroMethodStructKey);
+}
+
+IRInst* DifferentiableTypeConformanceContext::getDiffAddMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type)
+{
+ auto witnessTable = type->getWitness();
+ return _lookupWitness(builder, witnessTable, sharedContext->addMethodStructKey);
+}
+
void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary()
{
for (auto globalInst : sharedContext->moduleInst->getChildren())
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index 52cf346b3..91b45c5be 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -177,6 +177,10 @@ struct DifferentiableTypeConformanceContext
IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type);
+ IRInst* getDiffZeroMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type);
+
+ IRInst* getDiffAddMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type);
+
// Lookup and return the 'Differential' type declared in the concrete type
// in order to conform to the IDifferentiable interface.
// Note that inside a generic block, this will be a witness table lookup instruction
diff --git a/source/slang/slang-ir-lower-witness-lookup.cpp b/source/slang/slang-ir-lower-witness-lookup.cpp
index c1ee204b0..0e46987c7 100644
--- a/source/slang/slang-ir-lower-witness-lookup.cpp
+++ b/source/slang/slang-ir-lower-witness-lookup.cpp
@@ -350,6 +350,7 @@ struct WitnessLookupLoweringContext
{
if (auto specialize = as<IRSpecialize>(use->getUser()))
{
+ builder.setInsertBefore(use->getUser());
List<IRInst*> args;
for (UInt i = 0; i < specialize->getArgCount(); i++)
args.add(specialize->getArg(i));
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index e74a57424..eefcb9eea 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -7067,6 +7067,8 @@ namespace Slang
// and then destroy it (it had better have no uses!)
void IRInst::removeAndDeallocate()
{
+ removeAndDeallocateAllDecorationsAndChildren();
+
if (auto module = getModule())
{
if (getIROpInfo(getOp()).isHoistable())
@@ -7080,7 +7082,6 @@ namespace Slang
module->getDeduplicationContext()->getInstReplacementMap().remove(this);
}
removeArguments();
- removeAndDeallocateAllDecorationsAndChildren();
removeFromParent();
// Run destructor to be sure...
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index d644d01c7..c8a41c7c7 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -7429,7 +7429,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
else
{
- if (auto callableDecl = as<CallableDecl>(requirementDecl))
+ CallableDecl* callableDecl = nullptr;
+ if (auto genDecl = as<GenericDecl>(requirementDecl))
+ callableDecl = as<CallableDecl>(genDecl->inner);
+ else
+ callableDecl = as<CallableDecl>(requirementDecl);
+ if (callableDecl)
{
// Differentiable functions has additional requirements for the derivatives.
for (auto diffDecl : callableDecl->getMembersOfType<DerivativeRequirementReferenceDecl>())
@@ -8369,7 +8374,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
LoweredValInfo lowerFuncDeclInContext(IRGenContext* subContext, IRBuilder* subBuilder, FunctionDeclBase* decl, bool emitBody = true)
{
- auto outerGeneric = emitOuterGenerics(subContext, decl, decl);
+ IRGeneric* outerGeneric = nullptr;
+
+ if (auto derivativeRequirement = as<DerivativeRequirementDecl>(decl))
+ outerGeneric = emitOuterGenerics(subContext, derivativeRequirement->originalRequirementDecl, derivativeRequirement->originalRequirementDecl);
+ else
+ outerGeneric = emitOuterGenerics(subContext, decl, decl);
// need to create an IR function here