summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang3
-rw-r--r--source/slang/slang-ast-modifier.h11
-rw-r--r--source/slang/slang-check-modifier.cpp11
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp184
-rw-r--r--source/slang/slang-ir-diff-jvp.h4
-rw-r--r--source/slang/slang-ir-link.cpp1
-rw-r--r--source/slang/slang-lower-to-ir.cpp10
7 files changed, 200 insertions, 24 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 721529726..e1f7503a8 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2212,3 +2212,6 @@ attribute_syntax [noinline] : NoInlineAttribute;
__attributeTarget(StructDecl)
attribute_syntax [payload] : PayloadAttribute;
+// Custom JVP Function reference
+__attributeTarget(FuncDecl)
+attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute;
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index d0f215cbe..62a32045b 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -955,6 +955,15 @@ class RequiresNVAPIAttribute : public Attribute
SLANG_AST_CLASS(RequiresNVAPIAttribute)
};
+ /// The `[__custom_jvp(function)]` attribute specifies a custom function that should
+ /// be used as the derivative for the decorated function.
+class CustomJVPAttribute : public Attribute
+{
+ SLANG_AST_CLASS(CustomJVPAttribute)
+
+ DeclRefExpr* funcDeclRef;
+};
+
/// Indicates that the modified declaration is one of the "magic" declarations
/// that NVAPI uses to communicate extended operations. When NVAPI is being included
/// via the prelude for downstream compilation, declarations with this modifier
@@ -1052,4 +1061,6 @@ class SNormModifier : public ResourceElementFormatModifier
SLANG_AST_CLASS(SNormModifier)
};
+
+
} // namespace Slang
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index c6a33930b..28164c126 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -595,6 +595,17 @@ namespace Slang
callablePayloadAttr->location = (int32_t)val->value;
}
+ else if (auto customJVPAttr = as<CustomJVPAttribute>(attr))
+ {
+ SLANG_ASSERT(attr->args.getCount() == 1);
+
+ // Ensure that the argument is a reference to a function definition or declaration.
+ auto funcExpr = as<DeclRefExpr>(CheckTerm(attr->args[0]));
+ if (!as<FuncType>(funcExpr->type))
+ return false;
+
+ customJVPAttr->funcDeclRef = funcExpr;
+ }
else
{
if(attr->args.getCount() == 0)
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 2a42a7b6e..00210daaa 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -13,11 +13,20 @@ struct JVPTranscriber
// Stores the mapping of arbitrary 'R-value' instructions to instructions that represent
// their differential values.
- Dictionary<IRInst*, IRInst*> instMapD;
+ Dictionary<IRInst*, IRInst*> instMapD;
// Cloning environment to hold mapping from old to new copies for the primal
// instructions.
- IRCloneEnv cloneEnv;
+ IRCloneEnv cloneEnv;
+
+ // Diagnostic sink for error messages.
+ DiagnosticSink* sink;
+
+ DiagnosticSink* getSink()
+ {
+ SLANG_ASSERT(sink);
+ return sink;
+ }
void mapDifferentialInst(IRInst* instP, IRInst* instD)
{
@@ -156,7 +165,9 @@ struct JVPTranscriber
rightP->getDataType(), rightP, rightP
));
default:
- SLANG_UNEXPECTED("Attempting to differentiate unsupported arithmetic");
+ getSink()->diagnose(arith->sourceLoc,
+ Diagnostics::unimplemented,
+ "this arithmetic instruction cannot be differentiated");
}
}
@@ -179,8 +190,11 @@ struct JVPTranscriber
}
return nullptr;
}
-
- SLANG_UNEXPECTED("Attempting to differentiate an unsupported load instruction");
+ else
+ getSink()->diagnose(loadP->sourceLoc,
+ Diagnostics::unimplemented,
+ "this load instruction cannot be differentiated");
+ return nullptr;
}
IRInst* differentiateStore(IRBuilder* builder, IRStore* storeP)
@@ -204,8 +218,11 @@ struct JVPTranscriber
}
return nullptr;
}
-
- SLANG_UNEXPECTED("Attempting to differentiate an unsupported store instruction");
+ else
+ getSink()->diagnose(storeP->sourceLoc,
+ Diagnostics::unimplemented,
+ "this store instruction cannot be differentiated");
+ return nullptr;
}
IRInst* differentiateReturn(IRBuilder* builder, IRReturn* returnP)
@@ -226,12 +243,13 @@ struct JVPTranscriber
//
IRInst* differentiateConstruct(IRBuilder*, IRInst* consP)
{
-
if (as<IRConstant>(consP->getOperand(0)) && consP->getOperandCount() == 1)
- {
return nullptr;
- }
- SLANG_UNEXPECTED("Attempting to differentiate unsupported constructor");
+ else
+ getSink()->diagnose(consP->sourceLoc,
+ Diagnostics::unimplemented,
+ "this construct instruction cannot be differentiated");
+ return nullptr;
}
// Logic for whether a primal instruction needs to be replicated
@@ -242,13 +260,9 @@ struct JVPTranscriber
bool requiresPrimalClone(IRBuilder*, IRInst* instP)
{
if (as<IRReturn>(instP))
- {
return false;
- }
else
- {
return true;
- }
}
IRInst* transcribe(IRBuilder* builder, IRInst* oldInstP)
@@ -295,11 +309,52 @@ struct JVPTranscriber
return differentiateConstruct(builder, instP);
default:
- SLANG_UNEXPECTED("Attempting to differentiate unrecognized instruction");
+ getSink()->diagnose(instP->sourceLoc,
+ Diagnostics::unimplemented,
+ "this instruction cannot be differentiated");
+ return nullptr;
}
}
};
+struct IRWorkQueue
+{
+ // Work list to hold the active set of insts whose children
+ // need to be looked at.
+ //
+ List<IRInst*> workList;
+ HashSet<IRInst*> workListSet;
+
+ void push(IRInst* inst)
+ {
+ if(!inst) return;
+ if(workListSet.Contains(inst)) return;
+
+ workList.add(inst);
+ workListSet.Add(inst);
+ }
+
+ IRInst* pop()
+ {
+ if (workList.getCount() != 0)
+ {
+ IRInst* topItem = workList.getFirst();
+ // TODO(Sai): Repeatedly calling removeAt() can be really slow.
+ // Consider a specialized data structure or using removeLast()
+ //
+ workList.removeAt(0);
+ workListSet.Remove(topItem);
+ return topItem;
+ }
+ return nullptr;
+ }
+
+ IRInst* peek()
+ {
+ return workList.getFirst();
+ }
+};
+
struct JVPDerivativeContext
{
// This type passes over the module and generates
@@ -315,6 +370,19 @@ struct JVPDerivativeContext
// processing instructions while maintaining state.
//
JVPTranscriber transcriberStorage;
+
+ // Diagnostic object from the compile request for
+ // error messages.
+ DiagnosticSink* sink;
+
+ // Work queue to hold a stream of instructions that need
+ // to be checked for references to derivative functions.
+ IRWorkQueue workQueueStorage;
+
+ DiagnosticSink* getSink()
+ {
+ return sink;
+ }
bool processModule()
{
@@ -324,14 +392,81 @@ struct JVPDerivativeContext
//
SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
sharedBuilder->init(module);
-
- // Run through all the global-level instructions,
- // looking for callables.
- // Note: We're only processing global callables (IRGlobalValueWithCode)
- // for now.
- //
+
IRBuilder builderStorage(sharedBuilderStorage);
IRBuilder* builder = &builderStorage;
+
+ // processMarkedGlobalFunctions(builder);
+ return processReferencedFunctions(builder);
+ }
+
+ IRInst* lookupJVPReference(IRInst* primalFunction)
+ {
+ if(auto jvpDefinition = primalFunction->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ return jvpDefinition->getJVPFunc();
+
+ return nullptr;
+ }
+
+ // Recursively process instructions looking for JVP calls (kIROp_JVPDifferentiate),
+ // then check that the referenced function is marked correctly for differentiation.
+ //
+ bool processReferencedFunctions(IRBuilder* builder)
+ {
+ IRWorkQueue* workQueue = &(workQueueStorage);
+
+ // Put the top-level inst into the queue.
+ workQueue->push(module->getModuleInst());
+
+ // Keep processing items until the queue is complete.
+ while (IRInst* workItem = workQueue->pop())
+ {
+ for(auto child = workItem->getFirstChild(); child; child = child->getNextInst())
+ {
+ // Either the child instruction has more children (func/block etc..)
+ // and we add it to the work list for further processing, or
+ // it's an ordinary inst in which case we check if it's a JVPDifferentiate
+ // instruction.
+ //
+ if (child->getFirstChild() != nullptr)
+ workQueue->push(child);
+
+ if (auto jvpDiffInst = as<IRJVPDifferentiate>(child))
+ {
+ auto baseFunction = jvpDiffInst->getBaseFn();
+ // If the JVP Reference already exists, no need to
+ // differentiate again.
+ //
+ if(lookupJVPReference(baseFunction)) continue;
+
+ if (isFunctionMarkedForJVP(as<IRGlobalValueWithCode>(baseFunction)))
+ {
+ IRFunc* jvpFunction = emitJVPFunction(builder, as<IRFunc>(baseFunction));
+ builder->addJVPDerivativeReferenceDecoration(baseFunction, jvpFunction);
+ workQueue->push(jvpFunction);
+ }
+ else
+ {
+ // TODO(Sai): This would probably be better with a more specific
+ // error code.
+ getSink()->diagnose(jvpDiffInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "Cannot differentiate functions not marked for differentiation");
+ }
+ }
+ }
+ }
+
+ return true;
+ }
+
+ // Run through all the global-level instructions,
+ // looking for callables.
+ // Note: We're only processing global callables (IRGlobalValueWithCode)
+ // for now.
+ //
+ bool processMarkedGlobalFunctions(IRBuilder* builder)
+ {
for (auto inst : module->getGlobalInsts())
{
// If the instr is a callable, get all the basic blocks
@@ -340,7 +475,8 @@ struct JVPDerivativeContext
if (isFunctionMarkedForJVP(callable))
{
SLANG_ASSERT(as<IRFunc>(callable));
- IRFunc* jvpFunction = emitJVPFunction(&builderStorage, as<IRFunc>(callable));
+
+ IRFunc* jvpFunction = emitJVPFunction(builder, as<IRFunc>(callable));
builder->addJVPDerivativeReferenceDecoration(callable, jvpFunction);
unmarkForJVP(callable);
@@ -485,10 +621,12 @@ struct JVPDerivativeContext
//
bool processJVPDerivativeMarkers(
IRModule* module,
+ DiagnosticSink* sink,
IRJVPDerivativePassOptions const&)
{
JVPDerivativeContext context;
context.module = module;
+ context.sink = sink;
return context.processModule();
}
diff --git a/source/slang/slang-ir-diff-jvp.h b/source/slang/slang-ir-diff-jvp.h
index 9bbcad4fb..8ae6e949a 100644
--- a/source/slang/slang-ir-diff-jvp.h
+++ b/source/slang/slang-ir-diff-jvp.h
@@ -1,6 +1,9 @@
// slang-ir-diff-jvp.h
#pragma once
+#include "slang-ir.h"
+#include "slang-compiler.h"
+
namespace Slang
{
struct IRModule;
@@ -12,6 +15,7 @@ namespace Slang
bool processJVPDerivativeMarkers(
IRModule* module,
+ DiagnosticSink* sink,
IRJVPDerivativePassOptions const& options = IRJVPDerivativePassOptions());
}
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 7984c5037..14724046a 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -446,6 +446,7 @@ static void cloneExtraDecorations(
case kIROp_LayoutDecoration:
case kIROp_PublicDecoration:
case kIROp_SequentialIDDecoration:
+ case kIROp_JVPDerivativeReferenceDecoration:
if(!clonedInst->findDecorationImpl(decoration->getOp()))
{
cloneInst(context, builder, decoration);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index b7c5155b5..e65bcc8e3 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -7832,6 +7832,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getBuilder()->addDecoration(irFunc, kIROp_UnsafeForceInlineEarlyDecoration);
}
+ if (auto attr = decl->findModifier<CustomJVPAttribute>())
+ {
+ auto loweredVal = lowerLValueExpr(this->context, attr->funcDeclRef);
+ SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple);
+ IRFunc* jvpFunc = as<IRFunc>(loweredVal.val);
+ getBuilder()->addDecoration(irFunc, kIROp_JVPDerivativeReferenceDecoration, jvpFunc);
+ }
+
// For convenience, ensure that any additional global
// values that were emitted while outputting the function
// body appear before the function itself in the list
@@ -8477,7 +8485,7 @@ RefPtr<IRModule> generateIRForTranslationUnit(
// Process higher-order-function calls before any optimization passes
// to allow the optimizations to affect the generated funcitons.
// 1. Process JVP derivative functions.
- processJVPDerivativeMarkers(module);
+ processJVPDerivativeMarkers(module, compileRequest->getSink());
// 2. Process VJP derivative functions.
// processVJPDerivativeMarkers(module); // Disabled currently. No impl yet.
// 3. Replace JVP & VJP calls.