summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--build/visual-studio/slang/slang.vcxproj4
-rw-r--r--build/visual-studio/slang/slang.vcxproj.filters12
-rw-r--r--source/slang/slang-ast-expr.h9
-rw-r--r--source/slang/slang-ast-support-types.h9
-rw-r--r--source/slang/slang-ast-type.cpp21
-rw-r--r--source/slang/slang-ast-type.h12
-rw-r--r--source/slang/slang-check-expr.cpp20
-rw-r--r--source/slang/slang-check-impl.h6
-rw-r--r--source/slang/slang-check-overload.cpp95
-rw-r--r--source/slang/slang-ir-diff-call.cpp90
-rw-r--r--source/slang/slang-ir-diff-call.h17
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp180
-rw-r--r--source/slang/slang-ir-diff-jvp.h17
-rw-r--r--source/slang/slang-ir-inst-defs.h8
-rw-r--r--source/slang/slang-ir-insts.h38
-rw-r--r--source/slang/slang-ir.cpp11
-rw-r--r--source/slang/slang-lower-to-ir.cpp204
-rw-r--r--source/slang/slang-parser.cpp20
-rw-r--r--tests/ir/derivative-op-ir-test.slang21
-rw-r--r--tests/ir/derivative-op-ir-test.slang.expected.txt5
20 files changed, 711 insertions, 88 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj
index b958753d9..c20633c34 100644
--- a/build/visual-studio/slang/slang.vcxproj
+++ b/build/visual-studio/slang/slang.vcxproj
@@ -350,6 +350,8 @@ IF EXIST ..\..\..\external\slang-binaries\bin\windows-aarch64\slang-glslang.dll\
<ClInclude Include="..\..\..\source\slang\slang-ir-com-interface.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-constexpr.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-dce.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-diff-call.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-diff-jvp.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-dll-import.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-dominators.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-eliminate-phis.h" />
@@ -506,6 +508,8 @@ IF EXIST ..\..\..\external\slang-binaries\bin\windows-aarch64\slang-glslang.dll\
<ClCompile Include="..\..\..\source\slang\slang-ir-constexpr.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-dce.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-deduplicate.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-diff-call.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-diff-jvp.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-dll-import.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-dominators.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-eliminate-phis.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters
index 0667e6249..cb3e0f278 100644
--- a/build/visual-studio/slang/slang.vcxproj.filters
+++ b/build/visual-studio/slang/slang.vcxproj.filters
@@ -147,6 +147,12 @@
<ClInclude Include="..\..\..\source\slang\slang-ir-dce.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-diff-call.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-diff-jvp.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-dll-import.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -611,6 +617,12 @@
<ClCompile Include="..\..\..\source\slang\slang-ir-deduplicate.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-diff-call.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-diff-jvp.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-dll-import.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index 647eb37a4..8f407321e 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -369,6 +369,15 @@ class ExtractExistentialValueExpr: public Expr
DeclRef<VarDeclBase> declRef;
};
+ /// An expression of the form `__jvp(fn)` to access the
+ /// forward-mode derivative version of the function `fn`
+ ///
+class JVPDerivativeOfExpr: public Expr
+{
+ SLANG_AST_CLASS(JVPDerivativeOfExpr)
+ Expr* baseFn;
+};
+
/// A type expression of the form `__TaggedUnion(A, ...)`.
///
/// An expression of this form will resolve to a `TaggedUnionType`
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index c1e6a0132..e8ab51fbd 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -1477,6 +1477,15 @@ namespace Slang
List<ExtensionDecl*> candidateExtensions;
};
+ /// Represents the "direction" that a parameter is being passed (e.g., `in` or `out`
+ enum ParameterDirection
+ {
+ kParameterDirection_In, ///< Copy in
+ kParameterDirection_Out, ///< Copy out
+ kParameterDirection_InOut, ///< Copy in, copy out
+ kParameterDirection_Ref, ///< By-reference
+ };
+
} // namespace Slang
#endif
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index 5ee0c5c70..43fe751ee 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -557,6 +557,27 @@ HashCode NamedExpressionType::_getHashCodeOverride()
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ParameterDirection FuncType::getParamDirection(Index index)
+{
+ auto paramType = getParamType(index);
+ if (as<RefType>(paramType))
+ {
+ return kParameterDirection_Ref;
+ }
+ else if (as<InOutType>(paramType))
+ {
+ return kParameterDirection_InOut;
+ }
+ else if (as<OutType>(paramType))
+ {
+ return kParameterDirection_Out;
+ }
+ else
+ {
+ return kParameterDirection_In;
+ }
+}
+
void FuncType::_toTextOverride(StringBuilder& out)
{
out << toSlice("(");
diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h
index dbca6b18b..b82c6b182 100644
--- a/source/slang/slang-ast-type.h
+++ b/source/slang/slang-ast-type.h
@@ -525,10 +525,16 @@ class PtrType : public PtrTypeBase
SLANG_AST_CLASS(PtrType)
};
+/// A pointer-like type used to represent a parameter "direction"
+class ParamDirectionType : public PtrTypeBase
+{
+ SLANG_AST_CLASS(ParamDirectionType)
+};
+
// A type that represents the behind-the-scenes
// logical pointer that is passed for an `out`
// or `in out` parameter
-class OutTypeBase : public PtrTypeBase
+class OutTypeBase : public ParamDirectionType
{
SLANG_AST_CLASS(OutTypeBase)
};
@@ -546,7 +552,7 @@ class InOutType : public OutTypeBase
};
// The type for an `ref` parameter, e.g., `ref T`
-class RefType : public PtrTypeBase
+class RefType : public ParamDirectionType
{
SLANG_AST_CLASS(RefType)
};
@@ -595,6 +601,8 @@ class FuncType : public Type
Type* getResultType() { return resultType; }
Type* getErrorType() { return errorType; }
+ ParameterDirection getParamDirection(Index index);
+
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
Type* _createCanonicalTypeOverride();
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 3b308c46a..ff469428b 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1509,6 +1509,26 @@ namespace Slang
return expr;
}
+ Expr* SemanticsExprVisitor::visitJVPDerivativeOfExpr(JVPDerivativeOfExpr* expr)
+ {
+ // Check/Resolve inner function declaration.
+ expr->baseFn = CheckTerm(expr->baseFn);
+
+ if(auto funcType = as<FuncType>(expr->baseFn->type))
+ {
+ // Resolve JVP type here.
+ // Temporarily resolving to the same type as the original function.
+ expr->type = expr->baseFn->type;
+ }
+ else
+ {
+ // Error
+ UNREACHABLE_RETURN(nullptr);
+ }
+
+ return expr;
+ }
+
Expr* SemanticsExprVisitor::visitTypeCastExpr(TypeCastExpr * expr)
{
// Check the term we are applying first
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 5ef853b62..3be5ba68b 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -160,6 +160,7 @@ namespace Slang
Func,
Generic,
UnspecializedGeneric,
+ Expr,
};
Flavor flavor;
@@ -178,6 +179,9 @@ namespace Slang
// Reference to the declaration being applied
LookupResultItem item;
+ // Type of function being applied (for cases where `item` is not used)
+ FuncType* funcType = nullptr;
+
// The type of the result expression if this candidate is selected
Type* resultType = nullptr;
@@ -1728,6 +1732,8 @@ namespace Slang
Expr* visitAndTypeExpr(AndTypeExpr* expr);
Expr* visitModifiedTypeExpr(ModifiedTypeExpr* expr);
+ Expr* visitJVPDerivativeOfExpr(JVPDerivativeOfExpr* expr);
+
/// Perform semantic checking on a `modifier` that is being applied to the given `type`
Val* checkTypeModifier(Modifier* modifier, Type* type);
};
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index f4a1de3d5..bd27c7df2 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -76,6 +76,14 @@ namespace Slang
paramCounts = CountParameters(candidate.item.declRef.as<GenericDecl>());
break;
+ case OverloadCandidate::Flavor::Expr:
+ {
+ auto paramCount = candidate.funcType->getParamCount();
+ paramCounts.allowed = paramCount;
+ paramCounts.required = paramCount;
+ }
+ break;
+
default:
SLANG_UNEXPECTED("unknown flavor of overload candidate");
break;
@@ -312,11 +320,34 @@ namespace Slang
{
Index argCount = context.getArgCount();
- List<DeclRef<ParamDecl>> params;
+ List<Type*> paramTypes;
+// List<DeclRef<ParamDecl>> params;
switch (candidate.flavor)
{
case OverloadCandidate::Flavor::Func:
- params = getParameters(candidate.item.declRef.as<CallableDecl>()).toArray();
+ for (auto param : getParameters(candidate.item.declRef.as<CallableDecl>()))
+ {
+ auto paramType = getType(m_astBuilder, param);
+ paramTypes.add(paramType);
+ }
+ break;
+
+ case OverloadCandidate::Flavor::Expr:
+ {
+ auto funcType = candidate.funcType;
+ Count paramCount = funcType->getParamCount();
+ for (Index i = 0; i < paramCount; ++i)
+ {
+ auto paramType = funcType->getParamType(i);
+
+ if(auto paramDirectionType = as<ParamDirectionType>(paramType))
+ {
+ paramType = paramDirectionType->getValueType();
+ }
+
+ paramTypes.add(paramType);
+ }
+ }
break;
case OverloadCandidate::Flavor::Generic:
@@ -329,13 +360,13 @@ namespace Slang
// Note(tfoley): We might have fewer arguments than parameters in the
// case where one or more parameters had defaults.
- SLANG_RELEASE_ASSERT(argCount <= params.getCount());
+ SLANG_RELEASE_ASSERT(argCount <= paramTypes.getCount());
for (Index ii = 0; ii < argCount; ++ii)
{
auto& arg = context.getArg(ii);
auto argType = context.getArgType(ii);
- auto param = params[ii];
+ auto paramType = paramTypes[ii];
if (context.mode == OverloadResolveContext::Mode::JustTrying)
{
@@ -343,10 +374,10 @@ namespace Slang
if( context.disallowNestedConversions )
{
// We need an exact match in this case.
- if(!getType(m_astBuilder, param)->equals(argType))
+ if(!paramType->equals(argType))
return false;
}
- else if (!canCoerce(getType(m_astBuilder, param), argType, arg, &cost))
+ else if (!canCoerce(paramType, argType, arg, &cost))
{
return false;
}
@@ -354,7 +385,7 @@ namespace Slang
}
else
{
- arg = coerce(getType(m_astBuilder, param), arg);
+ arg = coerce(paramType, arg);
}
}
return true;
@@ -558,11 +589,24 @@ namespace Slang
{
auto originalAppExpr = as<AppExprBase>(context.originalExpr);
- auto baseExpr = ConstructLookupResultExpr(
- candidate.item,
- context.baseExpr,
- context.funcLoc,
- originalAppExpr ? originalAppExpr->functionExpr : nullptr);
+
+
+ Expr* baseExpr;
+ switch(candidate.flavor)
+ {
+ case OverloadCandidate::Flavor::Func:
+ case OverloadCandidate::Flavor::Generic:
+ baseExpr = ConstructLookupResultExpr(
+ candidate.item,
+ context.baseExpr,
+ context.funcLoc,
+ originalAppExpr ? originalAppExpr->functionExpr : nullptr);
+ break;
+ case OverloadCandidate::Flavor::Expr:
+ default:
+ baseExpr = nullptr;
+ break;
+ }
switch(candidate.flavor)
{
@@ -598,6 +642,25 @@ namespace Slang
break;
+ case OverloadCandidate::Flavor::Expr:
+ {
+ AppExprBase* callExpr = as<InvokeExpr>(context.originalExpr);
+ if (!callExpr)
+ {
+ callExpr = m_astBuilder->create<InvokeExpr>();
+ callExpr->loc = context.loc;
+ for (Index aa = 0; aa < context.argCount; ++aa)
+ callExpr->arguments.add(context.getArg(aa));
+ }
+
+ callExpr->originalFunctionExpr = callExpr->functionExpr;
+ callExpr->type = QualType(candidate.resultType);
+
+ return callExpr;
+
+ }
+ break;
+
case OverloadCandidate::Flavor::Generic:
return createGenericDeclRef(
baseExpr,
@@ -996,8 +1059,12 @@ namespace Slang
FuncType* funcType,
OverloadResolveContext& context)
{
- SLANG_UNUSED(funcType);
- getSink()->diagnose(context.loc, Diagnostics::unimplemented, "call on expression of function type");
+ OverloadCandidate candidate;
+ candidate.flavor = OverloadCandidate::Flavor::Expr;
+ candidate.funcType = funcType;
+ candidate.resultType = funcType->getResultType();
+
+ AddOverloadCandidate(context, candidate);
}
void SemanticsVisitor::AddCtorOverloadCandidate(
diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp
new file mode 100644
index 000000000..76ffe3c8b
--- /dev/null
+++ b/source/slang/slang-ir-diff-call.cpp
@@ -0,0 +1,90 @@
+// slang-ir-diff-call.cpp
+#include "slang-ir-diff-call.h"
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+
+namespace Slang
+{
+
+struct DerivativeCallProcessContext
+{
+ // This type passes over the module and replaces
+ // derivative calls with the processed derivative
+ // function.
+ //
+ IRModule* module;
+
+ bool processModule()
+ {
+ // Run through all the global-level instructions,
+ // looking for callable blocks.
+ for (auto inst : module->getGlobalInsts())
+ {
+ // If the instr is a callable, get all the basic blocks
+ if (auto callable = as<IRGlobalValueWithCode>(inst))
+ {
+ // Iterate over each block in the callable
+ for (auto block : callable->getBlocks())
+ {
+ // Iterate over each child instruction.
+ auto child = block->getFirstInst();
+ if (!child) continue;
+
+ do
+ {
+ auto nextChild = child->getNextInst();
+ // Look for IRJVPDerivativeOf
+ if (auto derivOf = as<IRJVPDerivativeOf>(child))
+ {
+ processDerivativeOf(derivOf);
+ }
+ child = nextChild;
+ }
+ while (child);
+ }
+ }
+ }
+ return true;
+ }
+
+ // Perform forward-mode automatic differentiation on
+ // the intstructions.
+ void processDerivativeOf(IRJVPDerivativeOf* derivOfInst)
+ {
+ IRFunc* jvpFunc = nullptr;
+
+ // Resolve the derivative function.
+ //
+ // Check for the 'JVPDerivativeReference' decorator on the
+ // base function.
+ if (auto jvpRefDecorator = derivOfInst->base.get()->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ {
+ jvpFunc = jvpRefDecorator->getJVPFunc();
+ }
+
+ // Substitute all uses of the 'derivativeOf' operation
+ // with the resolved derivative function.
+ while (auto use = derivOfInst->firstUse)
+ {
+ use->set(jvpFunc);
+ }
+
+ // Remove the 'derivativeOf'
+ derivOfInst->removeAndDeallocate();
+ }
+};
+
+// Set up context and call main process method.
+//
+bool processDerivativeCalls(
+ IRModule* module,
+ IRDerivativeCallProcessOptions const&)
+{
+ DerivativeCallProcessContext context;
+ context.module = module;
+
+ return context.processModule();
+}
+
+}
diff --git a/source/slang/slang-ir-diff-call.h b/source/slang/slang-ir-diff-call.h
new file mode 100644
index 000000000..d3b7d75a2
--- /dev/null
+++ b/source/slang/slang-ir-diff-call.h
@@ -0,0 +1,17 @@
+// slang-ir-diff-call.h
+#pragma once
+
+namespace Slang
+{
+ struct IRModule;
+
+ struct IRDerivativeCallProcessOptions
+ {
+ // Nothing for now..
+ };
+
+ bool processDerivativeCalls(
+ IRModule* module,
+ IRDerivativeCallProcessOptions const& options = IRDerivativeCallProcessOptions());
+
+} \ No newline at end of file
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
new file mode 100644
index 000000000..431c8e5b2
--- /dev/null
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -0,0 +1,180 @@
+// slang-ir-diff-jvp.cpp
+#include "slang-ir-diff-jvp.h"
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+
+namespace Slang
+{
+
+struct JVPDerivativeContext
+{
+ // This type passes over the module and generates
+ // forward-mode derivative versions of functions
+ // that are explicitly marked for it.
+ //
+ IRModule* module;
+
+ // Shared builder state for our derivative passes.
+ SharedIRBuilder sharedBuilderStorage;
+
+ bool processModule()
+ {
+ // We start by initializing our shared IR building state,
+ // since we will re-use that state for any code we
+ // generate along the way.
+ //
+ SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
+ sharedBuilder->init(module);
+
+ // Run through all the global-level instructions,
+ // looking for callables.
+ // Note: We're only processing global callables (IRGlobalValueWithCode)
+ // for now.
+ IRBuilder builderStorage(sharedBuilderStorage);
+ IRBuilder* builder = &builderStorage;
+ for (auto inst : module->getGlobalInsts())
+ {
+ // If the instr is a callable, get all the basic blocks
+ if (auto callable = as<IRGlobalValueWithCode>(inst))
+ {
+ if (isFunctionMarkedForJVP(callable))
+ {
+ SLANG_ASSERT(as<IRFunc>(callable));
+ IRFunc* jvpFunction = emitJVPFunction(&builderStorage, as<IRFunc>(callable));
+ builder->addJVPDerivativeReferenceDecoration(callable, jvpFunction);
+ }
+ }
+ }
+ return true;
+ }
+
+ // Checks decorators to see if the function should
+ // be differentiated (kIROp_JVPDerivativeMarkerDecoration)
+ //
+ bool isFunctionMarkedForJVP(IRGlobalValueWithCode* callable)
+ {
+ for(auto decoration = callable->getFirstDecoration();
+ decoration;
+ decoration = decoration->getNextDecoration())
+ {
+ if (decoration->getOp() == kIROp_JVPDerivativeMarkerDecoration)
+ {
+ return true;
+ }
+ // TODO: Need to remove this decoration or check for
+ // JVPDerivativeReferenceDecoration to avoid re-generating code.
+ }
+ return false;
+ }
+
+ List<IRParam*> emitFuncParameters(IRBuilder* builder, IRFuncType* dataType)
+ {
+ List<IRParam*> params;
+ for(UIndex i = 0; i < dataType->getParamCount(); i++)
+ {
+ params.add(
+ builder->emitParam(dataType->getParamType(i)));
+ }
+ return params;
+ }
+
+ // Perform forward-mode automatic differentiation on
+ // the intstructions.
+ IRFunc* emitJVPFunction(IRBuilder* builder,
+ IRFunc* primalFn)
+ {
+ // Note (sai): Is this safe? Should we use setInsertInto?
+ builder->setInsertBefore(primalFn->getNextInst());
+
+ auto jvpFn = builder->createFunc();
+ IRType* jvpFuncType = primalTypeToJVPType(primalFn->getFullType());
+ jvpFn->setFullType(jvpFuncType);
+ if (auto jvpName = getJVPFuncName(builder, primalFn))
+ builder->addNameHintDecoration(jvpFn, jvpName);
+
+ builder->setInsertInto(jvpFn);
+
+ // Start with _extremely_ basic functions
+ SLANG_ASSERT(primalFn->getFirstBlock() == primalFn->getLastBlock());
+
+ for (auto block = primalFn->getFirstBlock(); block; block = block->getNextBlock())
+ {
+ IRBlock* newJVPBlock = nullptr;
+ if (block == primalFn->getFirstBlock())
+ {
+ newJVPBlock = builder->emitBlock();
+ emitFuncParameters(builder, as<IRFuncType>(jvpFuncType));
+ }
+ newJVPBlock = emitJVPBlock(builder, primalFn->getFirstBlock(), newJVPBlock);
+ }
+
+ return jvpFn;
+ }
+
+ IRStringLit* getJVPFuncName(IRBuilder* builder,
+ IRFunc* func)
+ {
+ auto oldLoc = builder->getInsertLoc();
+ builder->setInsertBefore(func);
+
+ IRStringLit* name = nullptr;
+ if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>())
+ {
+ name = builder->getStringValue((String(linkageDecoration->getMangledName()) + "_jvp").getUnownedSlice());
+ }
+ else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>())
+ {
+ name = builder->getStringValue((String(namehintDecoration->getName()) + "_jvp").getUnownedSlice());
+ }
+
+ builder->setInsertLoc(oldLoc);
+
+ return name;
+ }
+
+
+ IRBlock* emitJVPBlock(IRBuilder* builder,
+ IRBlock* primalBlock,
+ IRBlock* jvpBlock = nullptr)
+ {
+ // Create if not already provided, and insert into new block.
+ if (!jvpBlock)
+ jvpBlock = builder->emitBlock();
+ else
+ builder->setInsertInto(jvpBlock);
+
+ // Temporarily, we're going to just emit a single return 0 instruction.
+ for(auto child = primalBlock->getFirstInst(); child; child = child->getNextInst())
+ {
+ if (auto returnOp = as<IRReturn>(child))
+ {
+ auto zeroVal = builder->getFloatValue(returnOp->getVal()->getDataType(), 0.0);
+ builder->emitReturn(zeroVal);
+ }
+ }
+
+ return jvpBlock;
+ }
+
+ IRType* primalTypeToJVPType(IRType* primalType)
+ {
+ // Temporarily, we're going to implement the identity transform.
+ // The return type is the same as the primal type.
+ return primalType;
+ }
+};
+
+// Set up context and call main process method.
+//
+bool processJVPDerivativeMarkers(
+ IRModule* module,
+ IRJVPDerivativePassOptions const&)
+{
+ JVPDerivativeContext context;
+ context.module = module;
+
+ return context.processModule();
+}
+
+}
diff --git a/source/slang/slang-ir-diff-jvp.h b/source/slang/slang-ir-diff-jvp.h
new file mode 100644
index 000000000..9bbcad4fb
--- /dev/null
+++ b/source/slang/slang-ir-diff-jvp.h
@@ -0,0 +1,17 @@
+// slang-ir-diff-jvp.h
+#pragma once
+
+namespace Slang
+{
+ struct IRModule;
+
+ struct IRJVPDerivativePassOptions
+ {
+ // Nothing for now..
+ };
+
+ bool processJVPDerivativeMarkers(
+ IRModule* module,
+ IRJVPDerivativePassOptions const& options = IRJVPDerivativePassOptions());
+
+}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 6304e65d2..793f1f78f 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -668,7 +668,11 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(SPIRVOpDecoration, spirvOpDecoration, 1, 0)
/// Decorated function is marked for the forward-mode differentiation pass.
- INST(JVPDerivativeDecoration, differentiateJvp, 0, 0)
+ INST(JVPDerivativeMarkerDecoration, differentiateJvp, 0, 0)
+
+ /// Used by the auto-diff pass to hold a reference to the
+ /// generated derivative function.
+ INST(JVPDerivativeReferenceDecoration, jvpFnReference, 1, 0)
/// Marks a struct type as being used as a structured buffer block.
/// Recognized by SPIRV-emit pass so we can emit a SPIRV `BufferBlock` decoration.
@@ -713,6 +717,8 @@ INST(ExtractTaggedUnionPayload, extractTaggedUnionPayload, 1, 0)
INST(BitCast, bitCast, 1, 0)
INST(Reinterpret, reinterpret, 1, 0)
+INST(JVPDerivativeOf, jvpDerivativeOf, 1, 0)
+
// Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer
INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 081c67d03..82d0d5a0e 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -515,6 +515,33 @@ struct IRSequentialIDDecoration : IRDecoration
IRIntegerValue getSequentialID() { return getSequentialIDOperand()->getValue(); }
};
+struct IRJVPDerivativeReferenceDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_JVPDerivativeReferenceDecoration
+ };
+ IR_LEAF_ISA(JVPDerivativeReferenceDecoration)
+
+ IRFunc* getJVPFunc() { return as<IRFunc>(getOperand(0)); }
+};
+
+
+// An instruction that replaces the function symbol
+// with it's derivative function.
+struct IRJVPDerivativeOf : IRInst
+{
+ enum
+ {
+ kOp = kIROp_JVPDerivativeOf
+ };
+ // The base function for the call.
+ IRUse base;
+ IRInst* getBaseFn() { return getOperand(0); }
+
+ IR_LEAF_ISA(JVPDerivativeOf)
+};
+
// An instruction that specializes another IR value
// (representing a generic) to a particular set of generic arguments
// (instructions representing types, witness tables, etc.)
@@ -2319,6 +2346,8 @@ public:
IRInst* emitExtractExistentialWitnessTable(
IRInst* existentialValue);
+ IRInst* emitJVPDerivativeOfInst(IRType* type, IRInst* baseFn);
+
IRInst* emitSpecializeInst(
IRType* type,
IRInst* genericVal,
@@ -2985,9 +3014,14 @@ public:
addDecoration(value, kIROp_ExternCppDecoration, getStringValue(mangledName));
}
- void addJVPDerivativeDecoration(IRInst* value, UnownedStringSlice const& mangledName)
+ void addJVPDerivativeMarkerDecoration(IRInst* value)
+ {
+ addDecoration(value, kIROp_JVPDerivativeMarkerDecoration);
+ }
+
+ void addJVPDerivativeReferenceDecoration(IRInst* value, IRInst* jvpFn)
{
- addDecoration(value, kIROp_JVPDerivativeDecoration, getStringValue(mangledName));
+ addDecoration(value, kIROp_JVPDerivativeReferenceDecoration, jvpFn);
}
void addDllImportDecoration(IRInst* value, UnownedStringSlice const& libraryName, UnownedStringSlice const& functionName)
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 42ff5823b..950061d4f 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3027,6 +3027,17 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitJVPDerivativeOfInst(IRType* type, IRInst* baseFn)
+ {
+ auto inst = createInst<IRJVPDerivativeOf>(
+ this,
+ kIROp_JVPDerivativeOf,
+ type,
+ baseFn);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitSpecializeInst(
IRType* type,
IRInst* genericVal,
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 791180890..d845342f0 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -7,6 +7,8 @@
#include "slang-ir.h"
#include "slang-ir-constexpr.h"
#include "slang-ir-dce.h"
+#include "slang-ir-diff-call.h"
+#include "slang-ir-diff-jvp.h"
#include "slang-ir-inline.h"
#include "slang-ir-insts.h"
#include "slang-ir-missing-return.h"
@@ -755,16 +757,6 @@ LoweredValInfo emitCallToDeclRef(
tryEnv);
}
- /// Represents the "direction" that a parameter is being passed (e.g., `in` or `out`
-enum ParameterDirection
-{
- kParameterDirection_In, ///< Copy in
- kParameterDirection_Out, ///< Copy out
- kParameterDirection_InOut, ///< Copy in, copy out
- kParameterDirection_Ref, ///< By-reference
-};
-
-
/// Emit a call to the given `accessorDeclRef`.
///
/// The `base` value represents the object on which the accessor is being invoked.
@@ -1151,7 +1143,7 @@ static void addLinkageDecoration(
}
if (decl->findModifier<JVPDerivativeModifier>())
{
- builder->addJVPDerivativeDecoration(inst, mangledName);
+ builder->addJVPDerivativeMarkerDecoration(inst);
}
if (as<InterfaceDecl>(decl->parentDecl) &&
decl->parentDecl->hasModifier<ComInterfaceAttribute>())
@@ -2947,6 +2939,21 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
return info;
}
+ // Emit IR to denote the forward-mode derivative
+ // of the inner func-expr. This will be resolved
+ // to a concrete function during the derivative
+ // pass.
+ LoweredValInfo visitJVPDerivativeOfExpr(JVPDerivativeOfExpr* expr)
+ {
+ auto baseVal = lowerSubExpr(expr->baseFn);
+ SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple);
+
+ return LoweredValInfo::simple(
+ getBuilder()->emitJVPDerivativeOfInst(
+ lowerType(context, expr->type),
+ baseVal.val));
+ }
+
LoweredValInfo visitOverloadedExpr(OverloadedExpr* /*expr*/)
{
SLANG_UNEXPECTED("overloaded expressions should not occur in checked AST");
@@ -3403,74 +3410,112 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// TODO: also need to handle this-type substitution here?
}
+ /// Create IR instructions for an argument at a call site, based on
+ /// AST-level expressions plus function signature information.
+ ///
+ /// The `funcType` parameter is always required, and specifies the types
+ /// of all the parameters. The `funcDeclRef` parameter is only required
+ /// if there are parameter positions for which the matching argument is
+ /// absent.
+ ///
void addDirectCallArgs(
InvokeExpr* expr,
- DeclRef<CallableDecl> funcDeclRef,
- List<IRInst*>* ioArgs,
+ Index argIndex,
+ IRType* paramType,
+ ParameterDirection paramDirection,
+ DeclRef<ParamDecl> paramDeclRef,
+ List<IRInst*>* ioArgs,
List<OutArgumentFixup>* ioFixups)
{
- UInt argCount = expr->arguments.getCount();
- UInt argCounter = 0;
- for (auto paramDeclRef : getMembersOfType<ParamDecl>(funcDeclRef))
+ Count argCount = expr->arguments.getCount();
+ if (argIndex < argCount)
{
- auto paramDecl = paramDeclRef.getDecl();
- IRType* paramType = lowerType(context, getType(getASTBuilder(), paramDeclRef));
- auto paramDirection = getParameterDirection(paramDecl);
+ auto argExpr = expr->arguments[argIndex];
+ addCallArgsForParam(context, paramType, paramDirection, argExpr, ioArgs, ioFixups);
+ }
+ else
+ {
+ // We have run out of arguments supplied at the call site,
+ // but there are still parameters remaining. This must mean
+ // that these parameters have default argument expressions
+ // associated with them.
+ //
+ // Currently we simply extract the initial-value expression
+ // from the parameter declaration and then lower it in
+ // the context of the caller.
+ //
+ // Note that the expression could involve subsitutions because
+ // in the general case it could depend on the generic parameters
+ // used the specialize the callee. For now we do not handle that
+ // case, and simply ignore generic arguments.
+ //
+ SubstExpr<Expr> argExpr = getInitExpr(getASTBuilder(), paramDeclRef);
+ SLANG_ASSERT(argExpr);
- UInt argIndex = argCounter++;
- if(argIndex < argCount)
- {
- auto argExpr = expr->arguments[argIndex];
- addCallArgsForParam(context, paramType, paramDirection, argExpr, ioArgs, ioFixups);
- }
- else
- {
- // We have run out of arguments supplied at the call site,
- // but there are still parameters remaining. This must mean
- // that these parameters have default argument expressions
- // associated with them.
- //
- // Currently we simply extract the initial-value expression
- // from the parameter declaration and then lower it in
- // the context of the caller.
- //
- // Note that the expression could involve subsitutions because
- // in the general case it could depend on the generic parameters
- // used the specialize the callee. For now we do not handle that
- // case, and simply ignore generic arguments.
- //
- SubstExpr<Expr> argExpr = getInitExpr(getASTBuilder(), paramDeclRef);
- SLANG_ASSERT(argExpr);
+ IRGenEnv subEnvStorage;
+ IRGenEnv* subEnv = &subEnvStorage;
+ subEnv->outer = context->env;
- IRGenEnv subEnvStorage;
- IRGenEnv* subEnv = &subEnvStorage;
- subEnv->outer = context->env;
+ IRGenContext subContextStorage = *context;
+ IRGenContext* subContext = &subContextStorage;
+ subContext->env = subEnv;
- IRGenContext subContextStorage = *context;
- IRGenContext* subContext = &subContextStorage;
- subContext->env = subEnv;
+ _lowerSubstitutionEnv(subContext, argExpr.getSubsts());
- _lowerSubstitutionEnv(subContext, argExpr.getSubsts());
+ addCallArgsForParam(subContext, paramType, paramDirection, argExpr.getExpr(), ioArgs, ioFixups);
- addCallArgsForParam(subContext, paramType, paramDirection, argExpr.getExpr(), ioArgs, ioFixups);
+ // TODO: The approach we are taking here to default arguments
+ // is simplistic, and has consequences for the front-end as
+ // well as binary serialization of modules.
+ //
+ // We could consider some more refined approaches where, e.g.,
+ // functions with default arguments generate multiple IR-level
+ // functions, that compute and provide the default values.
+ //
+ // Alternatively, each parameter with defaults could be generated
+ // into its own callable function that provides the default value,
+ // so that calling modules can call into a pre-generated function.
+ //
+ // Each of these options involves trade-offs, and we need to
+ // make a conscious decision at some point.
- // TODO: The approach we are taking here to default arguments
- // is simplistic, and has consequences for the front-end as
- // well as binary serialization of modules.
- //
- // We could consider some more refined approaches where, e.g.,
- // functions with default arguments generate multiple IR-level
- // functions, that compute and provide the default values.
- //
- // Alternatively, each parameter with defaults could be generated
- // into its own callable function that provides the default value,
- // so that calling modules can call into a pre-generated function.
- //
- // Each of these options involves trade-offs, and we need to
- // make a conscious decision at some point.
+ // Assert that such an expression must have been present.
+ }
+ }
- // Assert that such an expression must have been present.
- }
+ void addDirectCallArgs(
+ InvokeExpr* expr,
+ FuncType* funcType,
+ List<IRInst*>* ioArgs,
+ List<OutArgumentFixup>* ioFixups)
+ {
+ Count argCount = expr->arguments.getCount();
+ SLANG_ASSERT(argCount == static_cast<Count>(funcType->getParamCount()));
+
+ for(Index i = 0; i < argCount; ++i)
+ {
+ IRType* paramType = lowerType(context, funcType->getParamType(i));
+ ParameterDirection paramDirection = funcType->getParamDirection(i);
+ addDirectCallArgs(expr, i, paramType, paramDirection, DeclRef<ParamDecl>(), ioArgs, ioFixups);
+ }
+ }
+
+
+ void addDirectCallArgs(
+ InvokeExpr* expr,
+ DeclRef<CallableDecl> funcDeclRef,
+ List<IRInst*>* ioArgs,
+ List<OutArgumentFixup>* ioFixups)
+ {
+ Count argCounter = 0;
+ for (auto paramDeclRef : getMembersOfType<ParamDecl>(funcDeclRef))
+ {
+ auto paramDecl = paramDeclRef.getDecl();
+ IRType* paramType = lowerType(context, getType(getASTBuilder(), paramDeclRef));
+ auto paramDirection = getParameterDirection(paramDecl);
+
+ Index argIndex = argCounter++;
+ addDirectCallArgs(expr, argIndex, paramType, paramDirection, paramDeclRef, ioArgs, ioFixups);
}
}
@@ -3636,7 +3681,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
auto funcExpr = expr->functionExpr;
ResolvedCallInfo resolvedInfo;
- if( tryResolveDeclRefForCall(funcExpr, &resolvedInfo) )
+ if (tryResolveDeclRefForCall(funcExpr, &resolvedInfo))
{
// In this case we know exactly what declaration we
// are going to call, and so we can resolve things
@@ -3690,7 +3735,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// First comes the `this` argument if we are calling
// a member function:
- if( baseExpr )
+ if (baseExpr)
{
// The base expression might be an "upcast" to a base interface, in
// which case we don't want to emit the result of the cast, but instead
@@ -3725,6 +3770,17 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
applyOutArgumentFixups(context, argFixups);
return result;
}
+ else if(auto funcType = as<FuncType>(expr->functionExpr->type))
+ {
+ auto funcVal = lowerRValueExpr(context, expr->functionExpr);
+ addDirectCallArgs(expr, funcType, &irArgs, &argFixups);
+
+ auto result = emitCallToVal(context, type, funcVal, irArgs.getCount(), irArgs.getBuffer(), tryEnv);
+
+ applyOutArgumentFixups(context, argFixups);
+ return result;
+ }
+
// TODO: In this case we should be emitting code for the callee as
// an ordinary expression, then emitting the arguments according
@@ -8417,6 +8473,16 @@ RefPtr<IRModule> generateIRForTranslationUnit(
#endif
validateIRModuleIfEnabled(compileRequest, module);
+
+ // Process higher-order-function calls before any optimization passes
+ // to allow the optimizations to affect the generated funcitons.
+ // 1. Process JVP derivative functions.
+ processJVPDerivativeMarkers(module);
+ // 2. Process VJP derivative functions.
+ // processVJPDerivativeMarkers(module); // Disabled currently. No impl yet.
+ // 3. Replace JVP & VJP calls.
+ processDerivativeCalls(module);
+
// We will perform certain "mandatory" optimization passes now.
// These passes serve two purposes:
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index f6d85ade9..ee34eac6f 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -2056,6 +2056,25 @@ namespace Slang
{
return parseTaggedUnionType(parser);
}
+ /// Parse an expression of the form __jvp(fn) where fn is an
+ /// identifier pointing to a function.
+ static Expr* parseJVPDerivativeOf(Parser* parser)
+ {
+ JVPDerivativeOfExpr* jvpExpr = parser->astBuilder->create<JVPDerivativeOfExpr>();
+
+ parser->ReadToken(TokenType::LParent);
+
+ jvpExpr->baseFn = parser->ParseExpression();
+
+ parser->ReadToken(TokenType::RParent);
+
+ return jvpExpr;
+ }
+
+ static NodeBase* parseJVPDerivativeOf(Parser* parser, void* /* unused */)
+ {
+ return parseJVPDerivativeOf(parser);
+ }
/// Parse a `This` type expression
static Expr* parseThisTypeExpr(Parser* parser)
@@ -6473,6 +6492,7 @@ namespace Slang
_makeParseExpr("nullptr", parseNullPtrExpr),
_makeParseExpr("try", parseTryExpr),
_makeParseExpr("__TaggedUnion", parseTaggedUnionType),
+ _makeParseExpr("__jvp", parseJVPDerivativeOf)
};
ConstArrayView<SyntaxParseInfo> getSyntaxParseInfos()
diff --git a/tests/ir/derivative-op-ir-test.slang b/tests/ir/derivative-op-ir-test.slang
new file mode 100644
index 000000000..209446765
--- /dev/null
+++ b/tests/ir/derivative-op-ir-test.slang
@@ -0,0 +1,21 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+__differentiate_jvp float f(float x)
+{
+ return x;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ float a = 1.0;
+ float b = -2.0;
+ outputBuffer[0] = __jvp(f)(a);
+ outputBuffer[1] = __jvp(f)(b);
+ }
+}
diff --git a/tests/ir/derivative-op-ir-test.slang.expected.txt b/tests/ir/derivative-op-ir-test.slang.expected.txt
new file mode 100644
index 000000000..f095a0071
--- /dev/null
+++ b/tests/ir/derivative-op-ir-test.slang.expected.txt
@@ -0,0 +1,5 @@
+type: float
+0.0
+0.0
+0.0
+0.0 \ No newline at end of file