summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-06-23 16:02:05 -0400
committerGitHub <noreply@github.com>2022-06-23 16:02:05 -0400
commit6cf3d496005c5635b273d9ce6c110f14541a9492 (patch)
tree91d6480cec20ca88e85fb9b4d437e4661fba229c /source
parent4aa6344f772d31c1f7b0676cbaf315104c4b30a2 (diff)
Added basic syntax to mark and request function derivatives, as well as the framework for passes to process them. (#2297)
* Added a decorator to mark functions for forward-mode differentiation * Fill out support for calls to non-decl values The existing compiler logic has a few places (semantic checking plus AST-to-IR lowering) where it assumes that function calls (`InvokeExpr`) are only ever made to expressions that resolve to a specific `Decl` (`DeclRefExpr`). This assumption allows semantic checking and lowering code to inspect things like the parameter list of an actual declaration, rather than just the type signature of the callee, and that infrastructure is used to support various features (e.g., default argument values on parameters). The AST and IR representations themselves have no matching requirement, and the places where the more general case of call expressions would need to be supported were relatively clear in the code. This change attempts to add suitable logic into each of those places. Note that this change does *not* surface any valid way to form input code that would cause these new code paths to be executed, so it is entirely possible that there are bugs in the logic as written here. The primary goal of this change is simply to get a sketch of the correct code checked in so that we have something to build on once we have language features that will require this support. * fixup: warnings-as-errors * Added parser logic for '__jvp(<fn-name>)' operator * Fixed issue with missing overload candidate item and added basic parsing test for the __jvp syntax * Added a blank JVP Auto-diff pass and a pass that replaces 'JVPDerivativeOf' calls with the differentiated function * Added a couple comments * Added parameter handling for the JVP pass Co-authored-by: Theresa Foley <tfoley@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-expr.h9
-rw-r--r--source/slang/slang-ast-support-types.h9
-rw-r--r--source/slang/slang-ast-type.cpp21
-rw-r--r--source/slang/slang-ast-type.h12
-rw-r--r--source/slang/slang-check-expr.cpp20
-rw-r--r--source/slang/slang-check-impl.h6
-rw-r--r--source/slang/slang-check-overload.cpp95
-rw-r--r--source/slang/slang-ir-diff-call.cpp90
-rw-r--r--source/slang/slang-ir-diff-call.h17
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp180
-rw-r--r--source/slang/slang-ir-diff-jvp.h17
-rw-r--r--source/slang/slang-ir-inst-defs.h8
-rw-r--r--source/slang/slang-ir-insts.h38
-rw-r--r--source/slang/slang-ir.cpp11
-rw-r--r--source/slang/slang-lower-to-ir.cpp204
-rw-r--r--source/slang/slang-parser.cpp20
16 files changed, 669 insertions, 88 deletions
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index 647eb37a4..8f407321e 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -369,6 +369,15 @@ class ExtractExistentialValueExpr: public Expr
DeclRef<VarDeclBase> declRef;
};
+ /// An expression of the form `__jvp(fn)` to access the
+ /// forward-mode derivative version of the function `fn`
+ ///
+class JVPDerivativeOfExpr: public Expr
+{
+ SLANG_AST_CLASS(JVPDerivativeOfExpr)
+ Expr* baseFn;
+};
+
/// A type expression of the form `__TaggedUnion(A, ...)`.
///
/// An expression of this form will resolve to a `TaggedUnionType`
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index c1e6a0132..e8ab51fbd 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -1477,6 +1477,15 @@ namespace Slang
List<ExtensionDecl*> candidateExtensions;
};
+ /// Represents the "direction" that a parameter is being passed (e.g., `in` or `out`
+ enum ParameterDirection
+ {
+ kParameterDirection_In, ///< Copy in
+ kParameterDirection_Out, ///< Copy out
+ kParameterDirection_InOut, ///< Copy in, copy out
+ kParameterDirection_Ref, ///< By-reference
+ };
+
} // namespace Slang
#endif
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index 5ee0c5c70..43fe751ee 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -557,6 +557,27 @@ HashCode NamedExpressionType::_getHashCodeOverride()
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ParameterDirection FuncType::getParamDirection(Index index)
+{
+ auto paramType = getParamType(index);
+ if (as<RefType>(paramType))
+ {
+ return kParameterDirection_Ref;
+ }
+ else if (as<InOutType>(paramType))
+ {
+ return kParameterDirection_InOut;
+ }
+ else if (as<OutType>(paramType))
+ {
+ return kParameterDirection_Out;
+ }
+ else
+ {
+ return kParameterDirection_In;
+ }
+}
+
void FuncType::_toTextOverride(StringBuilder& out)
{
out << toSlice("(");
diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h
index dbca6b18b..b82c6b182 100644
--- a/source/slang/slang-ast-type.h
+++ b/source/slang/slang-ast-type.h
@@ -525,10 +525,16 @@ class PtrType : public PtrTypeBase
SLANG_AST_CLASS(PtrType)
};
+/// A pointer-like type used to represent a parameter "direction"
+class ParamDirectionType : public PtrTypeBase
+{
+ SLANG_AST_CLASS(ParamDirectionType)
+};
+
// A type that represents the behind-the-scenes
// logical pointer that is passed for an `out`
// or `in out` parameter
-class OutTypeBase : public PtrTypeBase
+class OutTypeBase : public ParamDirectionType
{
SLANG_AST_CLASS(OutTypeBase)
};
@@ -546,7 +552,7 @@ class InOutType : public OutTypeBase
};
// The type for an `ref` parameter, e.g., `ref T`
-class RefType : public PtrTypeBase
+class RefType : public ParamDirectionType
{
SLANG_AST_CLASS(RefType)
};
@@ -595,6 +601,8 @@ class FuncType : public Type
Type* getResultType() { return resultType; }
Type* getErrorType() { return errorType; }
+ ParameterDirection getParamDirection(Index index);
+
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
Type* _createCanonicalTypeOverride();
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 3b308c46a..ff469428b 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1509,6 +1509,26 @@ namespace Slang
return expr;
}
+ Expr* SemanticsExprVisitor::visitJVPDerivativeOfExpr(JVPDerivativeOfExpr* expr)
+ {
+ // Check/Resolve inner function declaration.
+ expr->baseFn = CheckTerm(expr->baseFn);
+
+ if(auto funcType = as<FuncType>(expr->baseFn->type))
+ {
+ // Resolve JVP type here.
+ // Temporarily resolving to the same type as the original function.
+ expr->type = expr->baseFn->type;
+ }
+ else
+ {
+ // Error
+ UNREACHABLE_RETURN(nullptr);
+ }
+
+ return expr;
+ }
+
Expr* SemanticsExprVisitor::visitTypeCastExpr(TypeCastExpr * expr)
{
// Check the term we are applying first
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 5ef853b62..3be5ba68b 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -160,6 +160,7 @@ namespace Slang
Func,
Generic,
UnspecializedGeneric,
+ Expr,
};
Flavor flavor;
@@ -178,6 +179,9 @@ namespace Slang
// Reference to the declaration being applied
LookupResultItem item;
+ // Type of function being applied (for cases where `item` is not used)
+ FuncType* funcType = nullptr;
+
// The type of the result expression if this candidate is selected
Type* resultType = nullptr;
@@ -1728,6 +1732,8 @@ namespace Slang
Expr* visitAndTypeExpr(AndTypeExpr* expr);
Expr* visitModifiedTypeExpr(ModifiedTypeExpr* expr);
+ Expr* visitJVPDerivativeOfExpr(JVPDerivativeOfExpr* expr);
+
/// Perform semantic checking on a `modifier` that is being applied to the given `type`
Val* checkTypeModifier(Modifier* modifier, Type* type);
};
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index f4a1de3d5..bd27c7df2 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -76,6 +76,14 @@ namespace Slang
paramCounts = CountParameters(candidate.item.declRef.as<GenericDecl>());
break;
+ case OverloadCandidate::Flavor::Expr:
+ {
+ auto paramCount = candidate.funcType->getParamCount();
+ paramCounts.allowed = paramCount;
+ paramCounts.required = paramCount;
+ }
+ break;
+
default:
SLANG_UNEXPECTED("unknown flavor of overload candidate");
break;
@@ -312,11 +320,34 @@ namespace Slang
{
Index argCount = context.getArgCount();
- List<DeclRef<ParamDecl>> params;
+ List<Type*> paramTypes;
+// List<DeclRef<ParamDecl>> params;
switch (candidate.flavor)
{
case OverloadCandidate::Flavor::Func:
- params = getParameters(candidate.item.declRef.as<CallableDecl>()).toArray();
+ for (auto param : getParameters(candidate.item.declRef.as<CallableDecl>()))
+ {
+ auto paramType = getType(m_astBuilder, param);
+ paramTypes.add(paramType);
+ }
+ break;
+
+ case OverloadCandidate::Flavor::Expr:
+ {
+ auto funcType = candidate.funcType;
+ Count paramCount = funcType->getParamCount();
+ for (Index i = 0; i < paramCount; ++i)
+ {
+ auto paramType = funcType->getParamType(i);
+
+ if(auto paramDirectionType = as<ParamDirectionType>(paramType))
+ {
+ paramType = paramDirectionType->getValueType();
+ }
+
+ paramTypes.add(paramType);
+ }
+ }
break;
case OverloadCandidate::Flavor::Generic:
@@ -329,13 +360,13 @@ namespace Slang
// Note(tfoley): We might have fewer arguments than parameters in the
// case where one or more parameters had defaults.
- SLANG_RELEASE_ASSERT(argCount <= params.getCount());
+ SLANG_RELEASE_ASSERT(argCount <= paramTypes.getCount());
for (Index ii = 0; ii < argCount; ++ii)
{
auto& arg = context.getArg(ii);
auto argType = context.getArgType(ii);
- auto param = params[ii];
+ auto paramType = paramTypes[ii];
if (context.mode == OverloadResolveContext::Mode::JustTrying)
{
@@ -343,10 +374,10 @@ namespace Slang
if( context.disallowNestedConversions )
{
// We need an exact match in this case.
- if(!getType(m_astBuilder, param)->equals(argType))
+ if(!paramType->equals(argType))
return false;
}
- else if (!canCoerce(getType(m_astBuilder, param), argType, arg, &cost))
+ else if (!canCoerce(paramType, argType, arg, &cost))
{
return false;
}
@@ -354,7 +385,7 @@ namespace Slang
}
else
{
- arg = coerce(getType(m_astBuilder, param), arg);
+ arg = coerce(paramType, arg);
}
}
return true;
@@ -558,11 +589,24 @@ namespace Slang
{
auto originalAppExpr = as<AppExprBase>(context.originalExpr);
- auto baseExpr = ConstructLookupResultExpr(
- candidate.item,
- context.baseExpr,
- context.funcLoc,
- originalAppExpr ? originalAppExpr->functionExpr : nullptr);
+
+
+ Expr* baseExpr;
+ switch(candidate.flavor)
+ {
+ case OverloadCandidate::Flavor::Func:
+ case OverloadCandidate::Flavor::Generic:
+ baseExpr = ConstructLookupResultExpr(
+ candidate.item,
+ context.baseExpr,
+ context.funcLoc,
+ originalAppExpr ? originalAppExpr->functionExpr : nullptr);
+ break;
+ case OverloadCandidate::Flavor::Expr:
+ default:
+ baseExpr = nullptr;
+ break;
+ }
switch(candidate.flavor)
{
@@ -598,6 +642,25 @@ namespace Slang
break;
+ case OverloadCandidate::Flavor::Expr:
+ {
+ AppExprBase* callExpr = as<InvokeExpr>(context.originalExpr);
+ if (!callExpr)
+ {
+ callExpr = m_astBuilder->create<InvokeExpr>();
+ callExpr->loc = context.loc;
+ for (Index aa = 0; aa < context.argCount; ++aa)
+ callExpr->arguments.add(context.getArg(aa));
+ }
+
+ callExpr->originalFunctionExpr = callExpr->functionExpr;
+ callExpr->type = QualType(candidate.resultType);
+
+ return callExpr;
+
+ }
+ break;
+
case OverloadCandidate::Flavor::Generic:
return createGenericDeclRef(
baseExpr,
@@ -996,8 +1059,12 @@ namespace Slang
FuncType* funcType,
OverloadResolveContext& context)
{
- SLANG_UNUSED(funcType);
- getSink()->diagnose(context.loc, Diagnostics::unimplemented, "call on expression of function type");
+ OverloadCandidate candidate;
+ candidate.flavor = OverloadCandidate::Flavor::Expr;
+ candidate.funcType = funcType;
+ candidate.resultType = funcType->getResultType();
+
+ AddOverloadCandidate(context, candidate);
}
void SemanticsVisitor::AddCtorOverloadCandidate(
diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp
new file mode 100644
index 000000000..76ffe3c8b
--- /dev/null
+++ b/source/slang/slang-ir-diff-call.cpp
@@ -0,0 +1,90 @@
+// slang-ir-diff-call.cpp
+#include "slang-ir-diff-call.h"
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+
+namespace Slang
+{
+
+struct DerivativeCallProcessContext
+{
+ // This type passes over the module and replaces
+ // derivative calls with the processed derivative
+ // function.
+ //
+ IRModule* module;
+
+ bool processModule()
+ {
+ // Run through all the global-level instructions,
+ // looking for callable blocks.
+ for (auto inst : module->getGlobalInsts())
+ {
+ // If the instr is a callable, get all the basic blocks
+ if (auto callable = as<IRGlobalValueWithCode>(inst))
+ {
+ // Iterate over each block in the callable
+ for (auto block : callable->getBlocks())
+ {
+ // Iterate over each child instruction.
+ auto child = block->getFirstInst();
+ if (!child) continue;
+
+ do
+ {
+ auto nextChild = child->getNextInst();
+ // Look for IRJVPDerivativeOf
+ if (auto derivOf = as<IRJVPDerivativeOf>(child))
+ {
+ processDerivativeOf(derivOf);
+ }
+ child = nextChild;
+ }
+ while (child);
+ }
+ }
+ }
+ return true;
+ }
+
+ // Perform forward-mode automatic differentiation on
+ // the intstructions.
+ void processDerivativeOf(IRJVPDerivativeOf* derivOfInst)
+ {
+ IRFunc* jvpFunc = nullptr;
+
+ // Resolve the derivative function.
+ //
+ // Check for the 'JVPDerivativeReference' decorator on the
+ // base function.
+ if (auto jvpRefDecorator = derivOfInst->base.get()->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ {
+ jvpFunc = jvpRefDecorator->getJVPFunc();
+ }
+
+ // Substitute all uses of the 'derivativeOf' operation
+ // with the resolved derivative function.
+ while (auto use = derivOfInst->firstUse)
+ {
+ use->set(jvpFunc);
+ }
+
+ // Remove the 'derivativeOf'
+ derivOfInst->removeAndDeallocate();
+ }
+};
+
+// Set up context and call main process method.
+//
+bool processDerivativeCalls(
+ IRModule* module,
+ IRDerivativeCallProcessOptions const&)
+{
+ DerivativeCallProcessContext context;
+ context.module = module;
+
+ return context.processModule();
+}
+
+}
diff --git a/source/slang/slang-ir-diff-call.h b/source/slang/slang-ir-diff-call.h
new file mode 100644
index 000000000..d3b7d75a2
--- /dev/null
+++ b/source/slang/slang-ir-diff-call.h
@@ -0,0 +1,17 @@
+// slang-ir-diff-call.h
+#pragma once
+
+namespace Slang
+{
+ struct IRModule;
+
+ struct IRDerivativeCallProcessOptions
+ {
+ // Nothing for now..
+ };
+
+ bool processDerivativeCalls(
+ IRModule* module,
+ IRDerivativeCallProcessOptions const& options = IRDerivativeCallProcessOptions());
+
+} \ No newline at end of file
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
new file mode 100644
index 000000000..431c8e5b2
--- /dev/null
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -0,0 +1,180 @@
+// slang-ir-diff-jvp.cpp
+#include "slang-ir-diff-jvp.h"
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+
+namespace Slang
+{
+
+struct JVPDerivativeContext
+{
+ // This type passes over the module and generates
+ // forward-mode derivative versions of functions
+ // that are explicitly marked for it.
+ //
+ IRModule* module;
+
+ // Shared builder state for our derivative passes.
+ SharedIRBuilder sharedBuilderStorage;
+
+ bool processModule()
+ {
+ // We start by initializing our shared IR building state,
+ // since we will re-use that state for any code we
+ // generate along the way.
+ //
+ SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
+ sharedBuilder->init(module);
+
+ // Run through all the global-level instructions,
+ // looking for callables.
+ // Note: We're only processing global callables (IRGlobalValueWithCode)
+ // for now.
+ IRBuilder builderStorage(sharedBuilderStorage);
+ IRBuilder* builder = &builderStorage;
+ for (auto inst : module->getGlobalInsts())
+ {
+ // If the instr is a callable, get all the basic blocks
+ if (auto callable = as<IRGlobalValueWithCode>(inst))
+ {
+ if (isFunctionMarkedForJVP(callable))
+ {
+ SLANG_ASSERT(as<IRFunc>(callable));
+ IRFunc* jvpFunction = emitJVPFunction(&builderStorage, as<IRFunc>(callable));
+ builder->addJVPDerivativeReferenceDecoration(callable, jvpFunction);
+ }
+ }
+ }
+ return true;
+ }
+
+ // Checks decorators to see if the function should
+ // be differentiated (kIROp_JVPDerivativeMarkerDecoration)
+ //
+ bool isFunctionMarkedForJVP(IRGlobalValueWithCode* callable)
+ {
+ for(auto decoration = callable->getFirstDecoration();
+ decoration;
+ decoration = decoration->getNextDecoration())
+ {
+ if (decoration->getOp() == kIROp_JVPDerivativeMarkerDecoration)
+ {
+ return true;
+ }
+ // TODO: Need to remove this decoration or check for
+ // JVPDerivativeReferenceDecoration to avoid re-generating code.
+ }
+ return false;
+ }
+
+ List<IRParam*> emitFuncParameters(IRBuilder* builder, IRFuncType* dataType)
+ {
+ List<IRParam*> params;
+ for(UIndex i = 0; i < dataType->getParamCount(); i++)
+ {
+ params.add(
+ builder->emitParam(dataType->getParamType(i)));
+ }
+ return params;
+ }
+
+ // Perform forward-mode automatic differentiation on
+ // the intstructions.
+ IRFunc* emitJVPFunction(IRBuilder* builder,
+ IRFunc* primalFn)
+ {
+ // Note (sai): Is this safe? Should we use setInsertInto?
+ builder->setInsertBefore(primalFn->getNextInst());
+
+ auto jvpFn = builder->createFunc();
+ IRType* jvpFuncType = primalTypeToJVPType(primalFn->getFullType());
+ jvpFn->setFullType(jvpFuncType);
+ if (auto jvpName = getJVPFuncName(builder, primalFn))
+ builder->addNameHintDecoration(jvpFn, jvpName);
+
+ builder->setInsertInto(jvpFn);
+
+ // Start with _extremely_ basic functions
+ SLANG_ASSERT(primalFn->getFirstBlock() == primalFn->getLastBlock());
+
+ for (auto block = primalFn->getFirstBlock(); block; block = block->getNextBlock())
+ {
+ IRBlock* newJVPBlock = nullptr;
+ if (block == primalFn->getFirstBlock())
+ {
+ newJVPBlock = builder->emitBlock();
+ emitFuncParameters(builder, as<IRFuncType>(jvpFuncType));
+ }
+ newJVPBlock = emitJVPBlock(builder, primalFn->getFirstBlock(), newJVPBlock);
+ }
+
+ return jvpFn;
+ }
+
+ IRStringLit* getJVPFuncName(IRBuilder* builder,
+ IRFunc* func)
+ {
+ auto oldLoc = builder->getInsertLoc();
+ builder->setInsertBefore(func);
+
+ IRStringLit* name = nullptr;
+ if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>())
+ {
+ name = builder->getStringValue((String(linkageDecoration->getMangledName()) + "_jvp").getUnownedSlice());
+ }
+ else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>())
+ {
+ name = builder->getStringValue((String(namehintDecoration->getName()) + "_jvp").getUnownedSlice());
+ }
+
+ builder->setInsertLoc(oldLoc);
+
+ return name;
+ }
+
+
+ IRBlock* emitJVPBlock(IRBuilder* builder,
+ IRBlock* primalBlock,
+ IRBlock* jvpBlock = nullptr)
+ {
+ // Create if not already provided, and insert into new block.
+ if (!jvpBlock)
+ jvpBlock = builder->emitBlock();
+ else
+ builder->setInsertInto(jvpBlock);
+
+ // Temporarily, we're going to just emit a single return 0 instruction.
+ for(auto child = primalBlock->getFirstInst(); child; child = child->getNextInst())
+ {
+ if (auto returnOp = as<IRReturn>(child))
+ {
+ auto zeroVal = builder->getFloatValue(returnOp->getVal()->getDataType(), 0.0);
+ builder->emitReturn(zeroVal);
+ }
+ }
+
+ return jvpBlock;
+ }
+
+ IRType* primalTypeToJVPType(IRType* primalType)
+ {
+ // Temporarily, we're going to implement the identity transform.
+ // The return type is the same as the primal type.
+ return primalType;
+ }
+};
+
+// Set up context and call main process method.
+//
+bool processJVPDerivativeMarkers(
+ IRModule* module,
+ IRJVPDerivativePassOptions const&)
+{
+ JVPDerivativeContext context;
+ context.module = module;
+
+ return context.processModule();
+}
+
+}
diff --git a/source/slang/slang-ir-diff-jvp.h b/source/slang/slang-ir-diff-jvp.h
new file mode 100644
index 000000000..9bbcad4fb
--- /dev/null
+++ b/source/slang/slang-ir-diff-jvp.h
@@ -0,0 +1,17 @@
+// slang-ir-diff-jvp.h
+#pragma once
+
+namespace Slang
+{
+ struct IRModule;
+
+ struct IRJVPDerivativePassOptions
+ {
+ // Nothing for now..
+ };
+
+ bool processJVPDerivativeMarkers(
+ IRModule* module,
+ IRJVPDerivativePassOptions const& options = IRJVPDerivativePassOptions());
+
+}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 6304e65d2..793f1f78f 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -668,7 +668,11 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(SPIRVOpDecoration, spirvOpDecoration, 1, 0)
/// Decorated function is marked for the forward-mode differentiation pass.
- INST(JVPDerivativeDecoration, differentiateJvp, 0, 0)
+ INST(JVPDerivativeMarkerDecoration, differentiateJvp, 0, 0)
+
+ /// Used by the auto-diff pass to hold a reference to the
+ /// generated derivative function.
+ INST(JVPDerivativeReferenceDecoration, jvpFnReference, 1, 0)
/// Marks a struct type as being used as a structured buffer block.
/// Recognized by SPIRV-emit pass so we can emit a SPIRV `BufferBlock` decoration.
@@ -713,6 +717,8 @@ INST(ExtractTaggedUnionPayload, extractTaggedUnionPayload, 1, 0)
INST(BitCast, bitCast, 1, 0)
INST(Reinterpret, reinterpret, 1, 0)
+INST(JVPDerivativeOf, jvpDerivativeOf, 1, 0)
+
// Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer
INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 081c67d03..82d0d5a0e 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -515,6 +515,33 @@ struct IRSequentialIDDecoration : IRDecoration
IRIntegerValue getSequentialID() { return getSequentialIDOperand()->getValue(); }
};
+struct IRJVPDerivativeReferenceDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_JVPDerivativeReferenceDecoration
+ };
+ IR_LEAF_ISA(JVPDerivativeReferenceDecoration)
+
+ IRFunc* getJVPFunc() { return as<IRFunc>(getOperand(0)); }
+};
+
+
+// An instruction that replaces the function symbol
+// with it's derivative function.
+struct IRJVPDerivativeOf : IRInst
+{
+ enum
+ {
+ kOp = kIROp_JVPDerivativeOf
+ };
+ // The base function for the call.
+ IRUse base;
+ IRInst* getBaseFn() { return getOperand(0); }
+
+ IR_LEAF_ISA(JVPDerivativeOf)
+};
+
// An instruction that specializes another IR value
// (representing a generic) to a particular set of generic arguments
// (instructions representing types, witness tables, etc.)
@@ -2319,6 +2346,8 @@ public:
IRInst* emitExtractExistentialWitnessTable(
IRInst* existentialValue);
+ IRInst* emitJVPDerivativeOfInst(IRType* type, IRInst* baseFn);
+
IRInst* emitSpecializeInst(
IRType* type,
IRInst* genericVal,
@@ -2985,9 +3014,14 @@ public:
addDecoration(value, kIROp_ExternCppDecoration, getStringValue(mangledName));
}
- void addJVPDerivativeDecoration(IRInst* value, UnownedStringSlice const& mangledName)
+ void addJVPDerivativeMarkerDecoration(IRInst* value)
+ {
+ addDecoration(value, kIROp_JVPDerivativeMarkerDecoration);
+ }
+
+ void addJVPDerivativeReferenceDecoration(IRInst* value, IRInst* jvpFn)
{
- addDecoration(value, kIROp_JVPDerivativeDecoration, getStringValue(mangledName));
+ addDecoration(value, kIROp_JVPDerivativeReferenceDecoration, jvpFn);
}
void addDllImportDecoration(IRInst* value, UnownedStringSlice const& libraryName, UnownedStringSlice const& functionName)
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 42ff5823b..950061d4f 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3027,6 +3027,17 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitJVPDerivativeOfInst(IRType* type, IRInst* baseFn)
+ {
+ auto inst = createInst<IRJVPDerivativeOf>(
+ this,
+ kIROp_JVPDerivativeOf,
+ type,
+ baseFn);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitSpecializeInst(
IRType* type,
IRInst* genericVal,
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 791180890..d845342f0 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -7,6 +7,8 @@
#include "slang-ir.h"
#include "slang-ir-constexpr.h"
#include "slang-ir-dce.h"
+#include "slang-ir-diff-call.h"
+#include "slang-ir-diff-jvp.h"
#include "slang-ir-inline.h"
#include "slang-ir-insts.h"
#include "slang-ir-missing-return.h"
@@ -755,16 +757,6 @@ LoweredValInfo emitCallToDeclRef(
tryEnv);
}
- /// Represents the "direction" that a parameter is being passed (e.g., `in` or `out`
-enum ParameterDirection
-{
- kParameterDirection_In, ///< Copy in
- kParameterDirection_Out, ///< Copy out
- kParameterDirection_InOut, ///< Copy in, copy out
- kParameterDirection_Ref, ///< By-reference
-};
-
-
/// Emit a call to the given `accessorDeclRef`.
///
/// The `base` value represents the object on which the accessor is being invoked.
@@ -1151,7 +1143,7 @@ static void addLinkageDecoration(
}
if (decl->findModifier<JVPDerivativeModifier>())
{
- builder->addJVPDerivativeDecoration(inst, mangledName);
+ builder->addJVPDerivativeMarkerDecoration(inst);
}
if (as<InterfaceDecl>(decl->parentDecl) &&
decl->parentDecl->hasModifier<ComInterfaceAttribute>())
@@ -2947,6 +2939,21 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
return info;
}
+ // Emit IR to denote the forward-mode derivative
+ // of the inner func-expr. This will be resolved
+ // to a concrete function during the derivative
+ // pass.
+ LoweredValInfo visitJVPDerivativeOfExpr(JVPDerivativeOfExpr* expr)
+ {
+ auto baseVal = lowerSubExpr(expr->baseFn);
+ SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple);
+
+ return LoweredValInfo::simple(
+ getBuilder()->emitJVPDerivativeOfInst(
+ lowerType(context, expr->type),
+ baseVal.val));
+ }
+
LoweredValInfo visitOverloadedExpr(OverloadedExpr* /*expr*/)
{
SLANG_UNEXPECTED("overloaded expressions should not occur in checked AST");
@@ -3403,74 +3410,112 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// TODO: also need to handle this-type substitution here?
}
+ /// Create IR instructions for an argument at a call site, based on
+ /// AST-level expressions plus function signature information.
+ ///
+ /// The `funcType` parameter is always required, and specifies the types
+ /// of all the parameters. The `funcDeclRef` parameter is only required
+ /// if there are parameter positions for which the matching argument is
+ /// absent.
+ ///
void addDirectCallArgs(
InvokeExpr* expr,
- DeclRef<CallableDecl> funcDeclRef,
- List<IRInst*>* ioArgs,
+ Index argIndex,
+ IRType* paramType,
+ ParameterDirection paramDirection,
+ DeclRef<ParamDecl> paramDeclRef,
+ List<IRInst*>* ioArgs,
List<OutArgumentFixup>* ioFixups)
{
- UInt argCount = expr->arguments.getCount();
- UInt argCounter = 0;
- for (auto paramDeclRef : getMembersOfType<ParamDecl>(funcDeclRef))
+ Count argCount = expr->arguments.getCount();
+ if (argIndex < argCount)
{
- auto paramDecl = paramDeclRef.getDecl();
- IRType* paramType = lowerType(context, getType(getASTBuilder(), paramDeclRef));
- auto paramDirection = getParameterDirection(paramDecl);
+ auto argExpr = expr->arguments[argIndex];
+ addCallArgsForParam(context, paramType, paramDirection, argExpr, ioArgs, ioFixups);
+ }
+ else
+ {
+ // We have run out of arguments supplied at the call site,
+ // but there are still parameters remaining. This must mean
+ // that these parameters have default argument expressions
+ // associated with them.
+ //
+ // Currently we simply extract the initial-value expression
+ // from the parameter declaration and then lower it in
+ // the context of the caller.
+ //
+ // Note that the expression could involve subsitutions because
+ // in the general case it could depend on the generic parameters
+ // used the specialize the callee. For now we do not handle that
+ // case, and simply ignore generic arguments.
+ //
+ SubstExpr<Expr> argExpr = getInitExpr(getASTBuilder(), paramDeclRef);
+ SLANG_ASSERT(argExpr);
- UInt argIndex = argCounter++;
- if(argIndex < argCount)
- {
- auto argExpr = expr->arguments[argIndex];
- addCallArgsForParam(context, paramType, paramDirection, argExpr, ioArgs, ioFixups);
- }
- else
- {
- // We have run out of arguments supplied at the call site,
- // but there are still parameters remaining. This must mean
- // that these parameters have default argument expressions
- // associated with them.
- //
- // Currently we simply extract the initial-value expression
- // from the parameter declaration and then lower it in
- // the context of the caller.
- //
- // Note that the expression could involve subsitutions because
- // in the general case it could depend on the generic parameters
- // used the specialize the callee. For now we do not handle that
- // case, and simply ignore generic arguments.
- //
- SubstExpr<Expr> argExpr = getInitExpr(getASTBuilder(), paramDeclRef);
- SLANG_ASSERT(argExpr);
+ IRGenEnv subEnvStorage;
+ IRGenEnv* subEnv = &subEnvStorage;
+ subEnv->outer = context->env;
- IRGenEnv subEnvStorage;
- IRGenEnv* subEnv = &subEnvStorage;
- subEnv->outer = context->env;
+ IRGenContext subContextStorage = *context;
+ IRGenContext* subContext = &subContextStorage;
+ subContext->env = subEnv;
- IRGenContext subContextStorage = *context;
- IRGenContext* subContext = &subContextStorage;
- subContext->env = subEnv;
+ _lowerSubstitutionEnv(subContext, argExpr.getSubsts());
- _lowerSubstitutionEnv(subContext, argExpr.getSubsts());
+ addCallArgsForParam(subContext, paramType, paramDirection, argExpr.getExpr(), ioArgs, ioFixups);
- addCallArgsForParam(subContext, paramType, paramDirection, argExpr.getExpr(), ioArgs, ioFixups);
+ // TODO: The approach we are taking here to default arguments
+ // is simplistic, and has consequences for the front-end as
+ // well as binary serialization of modules.
+ //
+ // We could consider some more refined approaches where, e.g.,
+ // functions with default arguments generate multiple IR-level
+ // functions, that compute and provide the default values.
+ //
+ // Alternatively, each parameter with defaults could be generated
+ // into its own callable function that provides the default value,
+ // so that calling modules can call into a pre-generated function.
+ //
+ // Each of these options involves trade-offs, and we need to
+ // make a conscious decision at some point.
- // TODO: The approach we are taking here to default arguments
- // is simplistic, and has consequences for the front-end as
- // well as binary serialization of modules.
- //
- // We could consider some more refined approaches where, e.g.,
- // functions with default arguments generate multiple IR-level
- // functions, that compute and provide the default values.
- //
- // Alternatively, each parameter with defaults could be generated
- // into its own callable function that provides the default value,
- // so that calling modules can call into a pre-generated function.
- //
- // Each of these options involves trade-offs, and we need to
- // make a conscious decision at some point.
+ // Assert that such an expression must have been present.
+ }
+ }
- // Assert that such an expression must have been present.
- }
+ void addDirectCallArgs(
+ InvokeExpr* expr,
+ FuncType* funcType,
+ List<IRInst*>* ioArgs,
+ List<OutArgumentFixup>* ioFixups)
+ {
+ Count argCount = expr->arguments.getCount();
+ SLANG_ASSERT(argCount == static_cast<Count>(funcType->getParamCount()));
+
+ for(Index i = 0; i < argCount; ++i)
+ {
+ IRType* paramType = lowerType(context, funcType->getParamType(i));
+ ParameterDirection paramDirection = funcType->getParamDirection(i);
+ addDirectCallArgs(expr, i, paramType, paramDirection, DeclRef<ParamDecl>(), ioArgs, ioFixups);
+ }
+ }
+
+
+ void addDirectCallArgs(
+ InvokeExpr* expr,
+ DeclRef<CallableDecl> funcDeclRef,
+ List<IRInst*>* ioArgs,
+ List<OutArgumentFixup>* ioFixups)
+ {
+ Count argCounter = 0;
+ for (auto paramDeclRef : getMembersOfType<ParamDecl>(funcDeclRef))
+ {
+ auto paramDecl = paramDeclRef.getDecl();
+ IRType* paramType = lowerType(context, getType(getASTBuilder(), paramDeclRef));
+ auto paramDirection = getParameterDirection(paramDecl);
+
+ Index argIndex = argCounter++;
+ addDirectCallArgs(expr, argIndex, paramType, paramDirection, paramDeclRef, ioArgs, ioFixups);
}
}
@@ -3636,7 +3681,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
auto funcExpr = expr->functionExpr;
ResolvedCallInfo resolvedInfo;
- if( tryResolveDeclRefForCall(funcExpr, &resolvedInfo) )
+ if (tryResolveDeclRefForCall(funcExpr, &resolvedInfo))
{
// In this case we know exactly what declaration we
// are going to call, and so we can resolve things
@@ -3690,7 +3735,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// First comes the `this` argument if we are calling
// a member function:
- if( baseExpr )
+ if (baseExpr)
{
// The base expression might be an "upcast" to a base interface, in
// which case we don't want to emit the result of the cast, but instead
@@ -3725,6 +3770,17 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
applyOutArgumentFixups(context, argFixups);
return result;
}
+ else if(auto funcType = as<FuncType>(expr->functionExpr->type))
+ {
+ auto funcVal = lowerRValueExpr(context, expr->functionExpr);
+ addDirectCallArgs(expr, funcType, &irArgs, &argFixups);
+
+ auto result = emitCallToVal(context, type, funcVal, irArgs.getCount(), irArgs.getBuffer(), tryEnv);
+
+ applyOutArgumentFixups(context, argFixups);
+ return result;
+ }
+
// TODO: In this case we should be emitting code for the callee as
// an ordinary expression, then emitting the arguments according
@@ -8417,6 +8473,16 @@ RefPtr<IRModule> generateIRForTranslationUnit(
#endif
validateIRModuleIfEnabled(compileRequest, module);
+
+ // Process higher-order-function calls before any optimization passes
+ // to allow the optimizations to affect the generated funcitons.
+ // 1. Process JVP derivative functions.
+ processJVPDerivativeMarkers(module);
+ // 2. Process VJP derivative functions.
+ // processVJPDerivativeMarkers(module); // Disabled currently. No impl yet.
+ // 3. Replace JVP & VJP calls.
+ processDerivativeCalls(module);
+
// We will perform certain "mandatory" optimization passes now.
// These passes serve two purposes:
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index f6d85ade9..ee34eac6f 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -2056,6 +2056,25 @@ namespace Slang
{
return parseTaggedUnionType(parser);
}
+ /// Parse an expression of the form __jvp(fn) where fn is an
+ /// identifier pointing to a function.
+ static Expr* parseJVPDerivativeOf(Parser* parser)
+ {
+ JVPDerivativeOfExpr* jvpExpr = parser->astBuilder->create<JVPDerivativeOfExpr>();
+
+ parser->ReadToken(TokenType::LParent);
+
+ jvpExpr->baseFn = parser->ParseExpression();
+
+ parser->ReadToken(TokenType::RParent);
+
+ return jvpExpr;
+ }
+
+ static NodeBase* parseJVPDerivativeOf(Parser* parser, void* /* unused */)
+ {
+ return parseJVPDerivativeOf(parser);
+ }
/// Parse a `This` type expression
static Expr* parseThisTypeExpr(Parser* parser)
@@ -6473,6 +6492,7 @@ namespace Slang
_makeParseExpr("nullptr", parseNullPtrExpr),
_makeParseExpr("try", parseTryExpr),
_makeParseExpr("__TaggedUnion", parseTaggedUnionType),
+ _makeParseExpr("__jvp", parseJVPDerivativeOf)
};
ConstArrayView<SyntaxParseInfo> getSyntaxParseInfos()