diff options
| -rw-r--r-- | build/visual-studio/slang/slang.vcxproj | 4 | ||||
| -rw-r--r-- | build/visual-studio/slang/slang.vcxproj.filters | 12 | ||||
| -rw-r--r-- | source/slang/slang-ast-expr.h | 9 | ||||
| -rw-r--r-- | source/slang/slang-ast-support-types.h | 9 | ||||
| -rw-r--r-- | source/slang/slang-ast-type.cpp | 21 | ||||
| -rw-r--r-- | source/slang/slang-ast-type.h | 12 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 20 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 95 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-call.cpp | 90 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-call.h | 17 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 180 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.h | 17 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 38 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 204 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 20 | ||||
| -rw-r--r-- | tests/ir/derivative-op-ir-test.slang | 21 | ||||
| -rw-r--r-- | tests/ir/derivative-op-ir-test.slang.expected.txt | 5 |
20 files changed, 711 insertions, 88 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj index b958753d9..c20633c34 100644 --- a/build/visual-studio/slang/slang.vcxproj +++ b/build/visual-studio/slang/slang.vcxproj @@ -350,6 +350,8 @@ IF EXIST ..\..\..\external\slang-binaries\bin\windows-aarch64\slang-glslang.dll\ <ClInclude Include="..\..\..\source\slang\slang-ir-com-interface.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-constexpr.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-dce.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-diff-call.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-diff-jvp.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-dll-import.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-dominators.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-eliminate-phis.h" />
@@ -506,6 +508,8 @@ IF EXIST ..\..\..\external\slang-binaries\bin\windows-aarch64\slang-glslang.dll\ <ClCompile Include="..\..\..\source\slang\slang-ir-constexpr.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-dce.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-deduplicate.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-diff-call.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-diff-jvp.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-dll-import.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-dominators.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-eliminate-phis.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters index 0667e6249..cb3e0f278 100644 --- a/build/visual-studio/slang/slang.vcxproj.filters +++ b/build/visual-studio/slang/slang.vcxproj.filters @@ -147,6 +147,12 @@ <ClInclude Include="..\..\..\source\slang\slang-ir-dce.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-diff-call.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-diff-jvp.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-dll-import.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -611,6 +617,12 @@ <ClCompile Include="..\..\..\source\slang\slang-ir-deduplicate.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-diff-call.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-diff-jvp.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-dll-import.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index 647eb37a4..8f407321e 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -369,6 +369,15 @@ class ExtractExistentialValueExpr: public Expr DeclRef<VarDeclBase> declRef; }; + /// An expression of the form `__jvp(fn)` to access the + /// forward-mode derivative version of the function `fn` + /// +class JVPDerivativeOfExpr: public Expr +{ + SLANG_AST_CLASS(JVPDerivativeOfExpr) + Expr* baseFn; +}; + /// A type expression of the form `__TaggedUnion(A, ...)`. /// /// An expression of this form will resolve to a `TaggedUnionType` diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index c1e6a0132..e8ab51fbd 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -1477,6 +1477,15 @@ namespace Slang List<ExtensionDecl*> candidateExtensions; }; + /// Represents the "direction" that a parameter is being passed (e.g., `in` or `out` + enum ParameterDirection + { + kParameterDirection_In, ///< Copy in + kParameterDirection_Out, ///< Copy out + kParameterDirection_InOut, ///< Copy in, copy out + kParameterDirection_Ref, ///< By-reference + }; + } // namespace Slang #endif diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 5ee0c5c70..43fe751ee 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -557,6 +557,27 @@ HashCode NamedExpressionType::_getHashCodeOverride() // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +ParameterDirection FuncType::getParamDirection(Index index) +{ + auto paramType = getParamType(index); + if (as<RefType>(paramType)) + { + return kParameterDirection_Ref; + } + else if (as<InOutType>(paramType)) + { + return kParameterDirection_InOut; + } + else if (as<OutType>(paramType)) + { + return kParameterDirection_Out; + } + else + { + return kParameterDirection_In; + } +} + void FuncType::_toTextOverride(StringBuilder& out) { out << toSlice("("); diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index dbca6b18b..b82c6b182 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -525,10 +525,16 @@ class PtrType : public PtrTypeBase SLANG_AST_CLASS(PtrType) }; +/// A pointer-like type used to represent a parameter "direction" +class ParamDirectionType : public PtrTypeBase +{ + SLANG_AST_CLASS(ParamDirectionType) +}; + // A type that represents the behind-the-scenes // logical pointer that is passed for an `out` // or `in out` parameter -class OutTypeBase : public PtrTypeBase +class OutTypeBase : public ParamDirectionType { SLANG_AST_CLASS(OutTypeBase) }; @@ -546,7 +552,7 @@ class InOutType : public OutTypeBase }; // The type for an `ref` parameter, e.g., `ref T` -class RefType : public PtrTypeBase +class RefType : public ParamDirectionType { SLANG_AST_CLASS(RefType) }; @@ -595,6 +601,8 @@ class FuncType : public Type Type* getResultType() { return resultType; } Type* getErrorType() { return errorType; } + ParameterDirection getParamDirection(Index index); + // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 3b308c46a..ff469428b 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1509,6 +1509,26 @@ namespace Slang return expr; } + Expr* SemanticsExprVisitor::visitJVPDerivativeOfExpr(JVPDerivativeOfExpr* expr) + { + // Check/Resolve inner function declaration. + expr->baseFn = CheckTerm(expr->baseFn); + + if(auto funcType = as<FuncType>(expr->baseFn->type)) + { + // Resolve JVP type here. + // Temporarily resolving to the same type as the original function. + expr->type = expr->baseFn->type; + } + else + { + // Error + UNREACHABLE_RETURN(nullptr); + } + + return expr; + } + Expr* SemanticsExprVisitor::visitTypeCastExpr(TypeCastExpr * expr) { // Check the term we are applying first diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 5ef853b62..3be5ba68b 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -160,6 +160,7 @@ namespace Slang Func, Generic, UnspecializedGeneric, + Expr, }; Flavor flavor; @@ -178,6 +179,9 @@ namespace Slang // Reference to the declaration being applied LookupResultItem item; + // Type of function being applied (for cases where `item` is not used) + FuncType* funcType = nullptr; + // The type of the result expression if this candidate is selected Type* resultType = nullptr; @@ -1728,6 +1732,8 @@ namespace Slang Expr* visitAndTypeExpr(AndTypeExpr* expr); Expr* visitModifiedTypeExpr(ModifiedTypeExpr* expr); + Expr* visitJVPDerivativeOfExpr(JVPDerivativeOfExpr* expr); + /// Perform semantic checking on a `modifier` that is being applied to the given `type` Val* checkTypeModifier(Modifier* modifier, Type* type); }; diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index f4a1de3d5..bd27c7df2 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -76,6 +76,14 @@ namespace Slang paramCounts = CountParameters(candidate.item.declRef.as<GenericDecl>()); break; + case OverloadCandidate::Flavor::Expr: + { + auto paramCount = candidate.funcType->getParamCount(); + paramCounts.allowed = paramCount; + paramCounts.required = paramCount; + } + break; + default: SLANG_UNEXPECTED("unknown flavor of overload candidate"); break; @@ -312,11 +320,34 @@ namespace Slang { Index argCount = context.getArgCount(); - List<DeclRef<ParamDecl>> params; + List<Type*> paramTypes; +// List<DeclRef<ParamDecl>> params; switch (candidate.flavor) { case OverloadCandidate::Flavor::Func: - params = getParameters(candidate.item.declRef.as<CallableDecl>()).toArray(); + for (auto param : getParameters(candidate.item.declRef.as<CallableDecl>())) + { + auto paramType = getType(m_astBuilder, param); + paramTypes.add(paramType); + } + break; + + case OverloadCandidate::Flavor::Expr: + { + auto funcType = candidate.funcType; + Count paramCount = funcType->getParamCount(); + for (Index i = 0; i < paramCount; ++i) + { + auto paramType = funcType->getParamType(i); + + if(auto paramDirectionType = as<ParamDirectionType>(paramType)) + { + paramType = paramDirectionType->getValueType(); + } + + paramTypes.add(paramType); + } + } break; case OverloadCandidate::Flavor::Generic: @@ -329,13 +360,13 @@ namespace Slang // Note(tfoley): We might have fewer arguments than parameters in the // case where one or more parameters had defaults. - SLANG_RELEASE_ASSERT(argCount <= params.getCount()); + SLANG_RELEASE_ASSERT(argCount <= paramTypes.getCount()); for (Index ii = 0; ii < argCount; ++ii) { auto& arg = context.getArg(ii); auto argType = context.getArgType(ii); - auto param = params[ii]; + auto paramType = paramTypes[ii]; if (context.mode == OverloadResolveContext::Mode::JustTrying) { @@ -343,10 +374,10 @@ namespace Slang if( context.disallowNestedConversions ) { // We need an exact match in this case. - if(!getType(m_astBuilder, param)->equals(argType)) + if(!paramType->equals(argType)) return false; } - else if (!canCoerce(getType(m_astBuilder, param), argType, arg, &cost)) + else if (!canCoerce(paramType, argType, arg, &cost)) { return false; } @@ -354,7 +385,7 @@ namespace Slang } else { - arg = coerce(getType(m_astBuilder, param), arg); + arg = coerce(paramType, arg); } } return true; @@ -558,11 +589,24 @@ namespace Slang { auto originalAppExpr = as<AppExprBase>(context.originalExpr); - auto baseExpr = ConstructLookupResultExpr( - candidate.item, - context.baseExpr, - context.funcLoc, - originalAppExpr ? originalAppExpr->functionExpr : nullptr); + + + Expr* baseExpr; + switch(candidate.flavor) + { + case OverloadCandidate::Flavor::Func: + case OverloadCandidate::Flavor::Generic: + baseExpr = ConstructLookupResultExpr( + candidate.item, + context.baseExpr, + context.funcLoc, + originalAppExpr ? originalAppExpr->functionExpr : nullptr); + break; + case OverloadCandidate::Flavor::Expr: + default: + baseExpr = nullptr; + break; + } switch(candidate.flavor) { @@ -598,6 +642,25 @@ namespace Slang break; + case OverloadCandidate::Flavor::Expr: + { + AppExprBase* callExpr = as<InvokeExpr>(context.originalExpr); + if (!callExpr) + { + callExpr = m_astBuilder->create<InvokeExpr>(); + callExpr->loc = context.loc; + for (Index aa = 0; aa < context.argCount; ++aa) + callExpr->arguments.add(context.getArg(aa)); + } + + callExpr->originalFunctionExpr = callExpr->functionExpr; + callExpr->type = QualType(candidate.resultType); + + return callExpr; + + } + break; + case OverloadCandidate::Flavor::Generic: return createGenericDeclRef( baseExpr, @@ -996,8 +1059,12 @@ namespace Slang FuncType* funcType, OverloadResolveContext& context) { - SLANG_UNUSED(funcType); - getSink()->diagnose(context.loc, Diagnostics::unimplemented, "call on expression of function type"); + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Expr; + candidate.funcType = funcType; + candidate.resultType = funcType->getResultType(); + + AddOverloadCandidate(context, candidate); } void SemanticsVisitor::AddCtorOverloadCandidate( diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp new file mode 100644 index 000000000..76ffe3c8b --- /dev/null +++ b/source/slang/slang-ir-diff-call.cpp @@ -0,0 +1,90 @@ +// slang-ir-diff-call.cpp +#include "slang-ir-diff-call.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + +struct DerivativeCallProcessContext +{ + // This type passes over the module and replaces + // derivative calls with the processed derivative + // function. + // + IRModule* module; + + bool processModule() + { + // Run through all the global-level instructions, + // looking for callable blocks. + for (auto inst : module->getGlobalInsts()) + { + // If the instr is a callable, get all the basic blocks + if (auto callable = as<IRGlobalValueWithCode>(inst)) + { + // Iterate over each block in the callable + for (auto block : callable->getBlocks()) + { + // Iterate over each child instruction. + auto child = block->getFirstInst(); + if (!child) continue; + + do + { + auto nextChild = child->getNextInst(); + // Look for IRJVPDerivativeOf + if (auto derivOf = as<IRJVPDerivativeOf>(child)) + { + processDerivativeOf(derivOf); + } + child = nextChild; + } + while (child); + } + } + } + return true; + } + + // Perform forward-mode automatic differentiation on + // the intstructions. + void processDerivativeOf(IRJVPDerivativeOf* derivOfInst) + { + IRFunc* jvpFunc = nullptr; + + // Resolve the derivative function. + // + // Check for the 'JVPDerivativeReference' decorator on the + // base function. + if (auto jvpRefDecorator = derivOfInst->base.get()->findDecoration<IRJVPDerivativeReferenceDecoration>()) + { + jvpFunc = jvpRefDecorator->getJVPFunc(); + } + + // Substitute all uses of the 'derivativeOf' operation + // with the resolved derivative function. + while (auto use = derivOfInst->firstUse) + { + use->set(jvpFunc); + } + + // Remove the 'derivativeOf' + derivOfInst->removeAndDeallocate(); + } +}; + +// Set up context and call main process method. +// +bool processDerivativeCalls( + IRModule* module, + IRDerivativeCallProcessOptions const&) +{ + DerivativeCallProcessContext context; + context.module = module; + + return context.processModule(); +} + +} diff --git a/source/slang/slang-ir-diff-call.h b/source/slang/slang-ir-diff-call.h new file mode 100644 index 000000000..d3b7d75a2 --- /dev/null +++ b/source/slang/slang-ir-diff-call.h @@ -0,0 +1,17 @@ +// slang-ir-diff-call.h +#pragma once + +namespace Slang +{ + struct IRModule; + + struct IRDerivativeCallProcessOptions + { + // Nothing for now.. + }; + + bool processDerivativeCalls( + IRModule* module, + IRDerivativeCallProcessOptions const& options = IRDerivativeCallProcessOptions()); + +}
\ No newline at end of file diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp new file mode 100644 index 000000000..431c8e5b2 --- /dev/null +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -0,0 +1,180 @@ +// slang-ir-diff-jvp.cpp +#include "slang-ir-diff-jvp.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + +struct JVPDerivativeContext +{ + // This type passes over the module and generates + // forward-mode derivative versions of functions + // that are explicitly marked for it. + // + IRModule* module; + + // Shared builder state for our derivative passes. + SharedIRBuilder sharedBuilderStorage; + + bool processModule() + { + // We start by initializing our shared IR building state, + // since we will re-use that state for any code we + // generate along the way. + // + SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; + sharedBuilder->init(module); + + // Run through all the global-level instructions, + // looking for callables. + // Note: We're only processing global callables (IRGlobalValueWithCode) + // for now. + IRBuilder builderStorage(sharedBuilderStorage); + IRBuilder* builder = &builderStorage; + for (auto inst : module->getGlobalInsts()) + { + // If the instr is a callable, get all the basic blocks + if (auto callable = as<IRGlobalValueWithCode>(inst)) + { + if (isFunctionMarkedForJVP(callable)) + { + SLANG_ASSERT(as<IRFunc>(callable)); + IRFunc* jvpFunction = emitJVPFunction(&builderStorage, as<IRFunc>(callable)); + builder->addJVPDerivativeReferenceDecoration(callable, jvpFunction); + } + } + } + return true; + } + + // Checks decorators to see if the function should + // be differentiated (kIROp_JVPDerivativeMarkerDecoration) + // + bool isFunctionMarkedForJVP(IRGlobalValueWithCode* callable) + { + for(auto decoration = callable->getFirstDecoration(); + decoration; + decoration = decoration->getNextDecoration()) + { + if (decoration->getOp() == kIROp_JVPDerivativeMarkerDecoration) + { + return true; + } + // TODO: Need to remove this decoration or check for + // JVPDerivativeReferenceDecoration to avoid re-generating code. + } + return false; + } + + List<IRParam*> emitFuncParameters(IRBuilder* builder, IRFuncType* dataType) + { + List<IRParam*> params; + for(UIndex i = 0; i < dataType->getParamCount(); i++) + { + params.add( + builder->emitParam(dataType->getParamType(i))); + } + return params; + } + + // Perform forward-mode automatic differentiation on + // the intstructions. + IRFunc* emitJVPFunction(IRBuilder* builder, + IRFunc* primalFn) + { + // Note (sai): Is this safe? Should we use setInsertInto? + builder->setInsertBefore(primalFn->getNextInst()); + + auto jvpFn = builder->createFunc(); + IRType* jvpFuncType = primalTypeToJVPType(primalFn->getFullType()); + jvpFn->setFullType(jvpFuncType); + if (auto jvpName = getJVPFuncName(builder, primalFn)) + builder->addNameHintDecoration(jvpFn, jvpName); + + builder->setInsertInto(jvpFn); + + // Start with _extremely_ basic functions + SLANG_ASSERT(primalFn->getFirstBlock() == primalFn->getLastBlock()); + + for (auto block = primalFn->getFirstBlock(); block; block = block->getNextBlock()) + { + IRBlock* newJVPBlock = nullptr; + if (block == primalFn->getFirstBlock()) + { + newJVPBlock = builder->emitBlock(); + emitFuncParameters(builder, as<IRFuncType>(jvpFuncType)); + } + newJVPBlock = emitJVPBlock(builder, primalFn->getFirstBlock(), newJVPBlock); + } + + return jvpFn; + } + + IRStringLit* getJVPFuncName(IRBuilder* builder, + IRFunc* func) + { + auto oldLoc = builder->getInsertLoc(); + builder->setInsertBefore(func); + + IRStringLit* name = nullptr; + if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>()) + { + name = builder->getStringValue((String(linkageDecoration->getMangledName()) + "_jvp").getUnownedSlice()); + } + else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>()) + { + name = builder->getStringValue((String(namehintDecoration->getName()) + "_jvp").getUnownedSlice()); + } + + builder->setInsertLoc(oldLoc); + + return name; + } + + + IRBlock* emitJVPBlock(IRBuilder* builder, + IRBlock* primalBlock, + IRBlock* jvpBlock = nullptr) + { + // Create if not already provided, and insert into new block. + if (!jvpBlock) + jvpBlock = builder->emitBlock(); + else + builder->setInsertInto(jvpBlock); + + // Temporarily, we're going to just emit a single return 0 instruction. + for(auto child = primalBlock->getFirstInst(); child; child = child->getNextInst()) + { + if (auto returnOp = as<IRReturn>(child)) + { + auto zeroVal = builder->getFloatValue(returnOp->getVal()->getDataType(), 0.0); + builder->emitReturn(zeroVal); + } + } + + return jvpBlock; + } + + IRType* primalTypeToJVPType(IRType* primalType) + { + // Temporarily, we're going to implement the identity transform. + // The return type is the same as the primal type. + return primalType; + } +}; + +// Set up context and call main process method. +// +bool processJVPDerivativeMarkers( + IRModule* module, + IRJVPDerivativePassOptions const&) +{ + JVPDerivativeContext context; + context.module = module; + + return context.processModule(); +} + +} diff --git a/source/slang/slang-ir-diff-jvp.h b/source/slang/slang-ir-diff-jvp.h new file mode 100644 index 000000000..9bbcad4fb --- /dev/null +++ b/source/slang/slang-ir-diff-jvp.h @@ -0,0 +1,17 @@ +// slang-ir-diff-jvp.h +#pragma once + +namespace Slang +{ + struct IRModule; + + struct IRJVPDerivativePassOptions + { + // Nothing for now.. + }; + + bool processJVPDerivativeMarkers( + IRModule* module, + IRJVPDerivativePassOptions const& options = IRJVPDerivativePassOptions()); + +} diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 6304e65d2..793f1f78f 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -668,7 +668,11 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(SPIRVOpDecoration, spirvOpDecoration, 1, 0) /// Decorated function is marked for the forward-mode differentiation pass. - INST(JVPDerivativeDecoration, differentiateJvp, 0, 0) + INST(JVPDerivativeMarkerDecoration, differentiateJvp, 0, 0) + + /// Used by the auto-diff pass to hold a reference to the + /// generated derivative function. + INST(JVPDerivativeReferenceDecoration, jvpFnReference, 1, 0) /// Marks a struct type as being used as a structured buffer block. /// Recognized by SPIRV-emit pass so we can emit a SPIRV `BufferBlock` decoration. @@ -713,6 +717,8 @@ INST(ExtractTaggedUnionPayload, extractTaggedUnionPayload, 1, 0) INST(BitCast, bitCast, 1, 0) INST(Reinterpret, reinterpret, 1, 0) +INST(JVPDerivativeOf, jvpDerivativeOf, 1, 0) + // Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 081c67d03..82d0d5a0e 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -515,6 +515,33 @@ struct IRSequentialIDDecoration : IRDecoration IRIntegerValue getSequentialID() { return getSequentialIDOperand()->getValue(); } }; +struct IRJVPDerivativeReferenceDecoration : IRDecoration +{ + enum + { + kOp = kIROp_JVPDerivativeReferenceDecoration + }; + IR_LEAF_ISA(JVPDerivativeReferenceDecoration) + + IRFunc* getJVPFunc() { return as<IRFunc>(getOperand(0)); } +}; + + +// An instruction that replaces the function symbol +// with it's derivative function. +struct IRJVPDerivativeOf : IRInst +{ + enum + { + kOp = kIROp_JVPDerivativeOf + }; + // The base function for the call. + IRUse base; + IRInst* getBaseFn() { return getOperand(0); } + + IR_LEAF_ISA(JVPDerivativeOf) +}; + // An instruction that specializes another IR value // (representing a generic) to a particular set of generic arguments // (instructions representing types, witness tables, etc.) @@ -2319,6 +2346,8 @@ public: IRInst* emitExtractExistentialWitnessTable( IRInst* existentialValue); + IRInst* emitJVPDerivativeOfInst(IRType* type, IRInst* baseFn); + IRInst* emitSpecializeInst( IRType* type, IRInst* genericVal, @@ -2985,9 +3014,14 @@ public: addDecoration(value, kIROp_ExternCppDecoration, getStringValue(mangledName)); } - void addJVPDerivativeDecoration(IRInst* value, UnownedStringSlice const& mangledName) + void addJVPDerivativeMarkerDecoration(IRInst* value) + { + addDecoration(value, kIROp_JVPDerivativeMarkerDecoration); + } + + void addJVPDerivativeReferenceDecoration(IRInst* value, IRInst* jvpFn) { - addDecoration(value, kIROp_JVPDerivativeDecoration, getStringValue(mangledName)); + addDecoration(value, kIROp_JVPDerivativeReferenceDecoration, jvpFn); } void addDllImportDecoration(IRInst* value, UnownedStringSlice const& libraryName, UnownedStringSlice const& functionName) diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 42ff5823b..950061d4f 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3027,6 +3027,17 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitJVPDerivativeOfInst(IRType* type, IRInst* baseFn) + { + auto inst = createInst<IRJVPDerivativeOf>( + this, + kIROp_JVPDerivativeOf, + type, + baseFn); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitSpecializeInst( IRType* type, IRInst* genericVal, diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 791180890..d845342f0 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7,6 +7,8 @@ #include "slang-ir.h" #include "slang-ir-constexpr.h" #include "slang-ir-dce.h" +#include "slang-ir-diff-call.h" +#include "slang-ir-diff-jvp.h" #include "slang-ir-inline.h" #include "slang-ir-insts.h" #include "slang-ir-missing-return.h" @@ -755,16 +757,6 @@ LoweredValInfo emitCallToDeclRef( tryEnv); } - /// Represents the "direction" that a parameter is being passed (e.g., `in` or `out` -enum ParameterDirection -{ - kParameterDirection_In, ///< Copy in - kParameterDirection_Out, ///< Copy out - kParameterDirection_InOut, ///< Copy in, copy out - kParameterDirection_Ref, ///< By-reference -}; - - /// Emit a call to the given `accessorDeclRef`. /// /// The `base` value represents the object on which the accessor is being invoked. @@ -1151,7 +1143,7 @@ static void addLinkageDecoration( } if (decl->findModifier<JVPDerivativeModifier>()) { - builder->addJVPDerivativeDecoration(inst, mangledName); + builder->addJVPDerivativeMarkerDecoration(inst); } if (as<InterfaceDecl>(decl->parentDecl) && decl->parentDecl->hasModifier<ComInterfaceAttribute>()) @@ -2947,6 +2939,21 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> return info; } + // Emit IR to denote the forward-mode derivative + // of the inner func-expr. This will be resolved + // to a concrete function during the derivative + // pass. + LoweredValInfo visitJVPDerivativeOfExpr(JVPDerivativeOfExpr* expr) + { + auto baseVal = lowerSubExpr(expr->baseFn); + SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); + + return LoweredValInfo::simple( + getBuilder()->emitJVPDerivativeOfInst( + lowerType(context, expr->type), + baseVal.val)); + } + LoweredValInfo visitOverloadedExpr(OverloadedExpr* /*expr*/) { SLANG_UNEXPECTED("overloaded expressions should not occur in checked AST"); @@ -3403,74 +3410,112 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // TODO: also need to handle this-type substitution here? } + /// Create IR instructions for an argument at a call site, based on + /// AST-level expressions plus function signature information. + /// + /// The `funcType` parameter is always required, and specifies the types + /// of all the parameters. The `funcDeclRef` parameter is only required + /// if there are parameter positions for which the matching argument is + /// absent. + /// void addDirectCallArgs( InvokeExpr* expr, - DeclRef<CallableDecl> funcDeclRef, - List<IRInst*>* ioArgs, + Index argIndex, + IRType* paramType, + ParameterDirection paramDirection, + DeclRef<ParamDecl> paramDeclRef, + List<IRInst*>* ioArgs, List<OutArgumentFixup>* ioFixups) { - UInt argCount = expr->arguments.getCount(); - UInt argCounter = 0; - for (auto paramDeclRef : getMembersOfType<ParamDecl>(funcDeclRef)) + Count argCount = expr->arguments.getCount(); + if (argIndex < argCount) { - auto paramDecl = paramDeclRef.getDecl(); - IRType* paramType = lowerType(context, getType(getASTBuilder(), paramDeclRef)); - auto paramDirection = getParameterDirection(paramDecl); + auto argExpr = expr->arguments[argIndex]; + addCallArgsForParam(context, paramType, paramDirection, argExpr, ioArgs, ioFixups); + } + else + { + // We have run out of arguments supplied at the call site, + // but there are still parameters remaining. This must mean + // that these parameters have default argument expressions + // associated with them. + // + // Currently we simply extract the initial-value expression + // from the parameter declaration and then lower it in + // the context of the caller. + // + // Note that the expression could involve subsitutions because + // in the general case it could depend on the generic parameters + // used the specialize the callee. For now we do not handle that + // case, and simply ignore generic arguments. + // + SubstExpr<Expr> argExpr = getInitExpr(getASTBuilder(), paramDeclRef); + SLANG_ASSERT(argExpr); - UInt argIndex = argCounter++; - if(argIndex < argCount) - { - auto argExpr = expr->arguments[argIndex]; - addCallArgsForParam(context, paramType, paramDirection, argExpr, ioArgs, ioFixups); - } - else - { - // We have run out of arguments supplied at the call site, - // but there are still parameters remaining. This must mean - // that these parameters have default argument expressions - // associated with them. - // - // Currently we simply extract the initial-value expression - // from the parameter declaration and then lower it in - // the context of the caller. - // - // Note that the expression could involve subsitutions because - // in the general case it could depend on the generic parameters - // used the specialize the callee. For now we do not handle that - // case, and simply ignore generic arguments. - // - SubstExpr<Expr> argExpr = getInitExpr(getASTBuilder(), paramDeclRef); - SLANG_ASSERT(argExpr); + IRGenEnv subEnvStorage; + IRGenEnv* subEnv = &subEnvStorage; + subEnv->outer = context->env; - IRGenEnv subEnvStorage; - IRGenEnv* subEnv = &subEnvStorage; - subEnv->outer = context->env; + IRGenContext subContextStorage = *context; + IRGenContext* subContext = &subContextStorage; + subContext->env = subEnv; - IRGenContext subContextStorage = *context; - IRGenContext* subContext = &subContextStorage; - subContext->env = subEnv; + _lowerSubstitutionEnv(subContext, argExpr.getSubsts()); - _lowerSubstitutionEnv(subContext, argExpr.getSubsts()); + addCallArgsForParam(subContext, paramType, paramDirection, argExpr.getExpr(), ioArgs, ioFixups); - addCallArgsForParam(subContext, paramType, paramDirection, argExpr.getExpr(), ioArgs, ioFixups); + // TODO: The approach we are taking here to default arguments + // is simplistic, and has consequences for the front-end as + // well as binary serialization of modules. + // + // We could consider some more refined approaches where, e.g., + // functions with default arguments generate multiple IR-level + // functions, that compute and provide the default values. + // + // Alternatively, each parameter with defaults could be generated + // into its own callable function that provides the default value, + // so that calling modules can call into a pre-generated function. + // + // Each of these options involves trade-offs, and we need to + // make a conscious decision at some point. - // TODO: The approach we are taking here to default arguments - // is simplistic, and has consequences for the front-end as - // well as binary serialization of modules. - // - // We could consider some more refined approaches where, e.g., - // functions with default arguments generate multiple IR-level - // functions, that compute and provide the default values. - // - // Alternatively, each parameter with defaults could be generated - // into its own callable function that provides the default value, - // so that calling modules can call into a pre-generated function. - // - // Each of these options involves trade-offs, and we need to - // make a conscious decision at some point. + // Assert that such an expression must have been present. + } + } - // Assert that such an expression must have been present. - } + void addDirectCallArgs( + InvokeExpr* expr, + FuncType* funcType, + List<IRInst*>* ioArgs, + List<OutArgumentFixup>* ioFixups) + { + Count argCount = expr->arguments.getCount(); + SLANG_ASSERT(argCount == static_cast<Count>(funcType->getParamCount())); + + for(Index i = 0; i < argCount; ++i) + { + IRType* paramType = lowerType(context, funcType->getParamType(i)); + ParameterDirection paramDirection = funcType->getParamDirection(i); + addDirectCallArgs(expr, i, paramType, paramDirection, DeclRef<ParamDecl>(), ioArgs, ioFixups); + } + } + + + void addDirectCallArgs( + InvokeExpr* expr, + DeclRef<CallableDecl> funcDeclRef, + List<IRInst*>* ioArgs, + List<OutArgumentFixup>* ioFixups) + { + Count argCounter = 0; + for (auto paramDeclRef : getMembersOfType<ParamDecl>(funcDeclRef)) + { + auto paramDecl = paramDeclRef.getDecl(); + IRType* paramType = lowerType(context, getType(getASTBuilder(), paramDeclRef)); + auto paramDirection = getParameterDirection(paramDecl); + + Index argIndex = argCounter++; + addDirectCallArgs(expr, argIndex, paramType, paramDirection, paramDeclRef, ioArgs, ioFixups); } } @@ -3636,7 +3681,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> auto funcExpr = expr->functionExpr; ResolvedCallInfo resolvedInfo; - if( tryResolveDeclRefForCall(funcExpr, &resolvedInfo) ) + if (tryResolveDeclRefForCall(funcExpr, &resolvedInfo)) { // In this case we know exactly what declaration we // are going to call, and so we can resolve things @@ -3690,7 +3735,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // First comes the `this` argument if we are calling // a member function: - if( baseExpr ) + if (baseExpr) { // The base expression might be an "upcast" to a base interface, in // which case we don't want to emit the result of the cast, but instead @@ -3725,6 +3770,17 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> applyOutArgumentFixups(context, argFixups); return result; } + else if(auto funcType = as<FuncType>(expr->functionExpr->type)) + { + auto funcVal = lowerRValueExpr(context, expr->functionExpr); + addDirectCallArgs(expr, funcType, &irArgs, &argFixups); + + auto result = emitCallToVal(context, type, funcVal, irArgs.getCount(), irArgs.getBuffer(), tryEnv); + + applyOutArgumentFixups(context, argFixups); + return result; + } + // TODO: In this case we should be emitting code for the callee as // an ordinary expression, then emitting the arguments according @@ -8417,6 +8473,16 @@ RefPtr<IRModule> generateIRForTranslationUnit( #endif validateIRModuleIfEnabled(compileRequest, module); + + // Process higher-order-function calls before any optimization passes + // to allow the optimizations to affect the generated funcitons. + // 1. Process JVP derivative functions. + processJVPDerivativeMarkers(module); + // 2. Process VJP derivative functions. + // processVJPDerivativeMarkers(module); // Disabled currently. No impl yet. + // 3. Replace JVP & VJP calls. + processDerivativeCalls(module); + // We will perform certain "mandatory" optimization passes now. // These passes serve two purposes: diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index f6d85ade9..ee34eac6f 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -2056,6 +2056,25 @@ namespace Slang { return parseTaggedUnionType(parser); } + /// Parse an expression of the form __jvp(fn) where fn is an + /// identifier pointing to a function. + static Expr* parseJVPDerivativeOf(Parser* parser) + { + JVPDerivativeOfExpr* jvpExpr = parser->astBuilder->create<JVPDerivativeOfExpr>(); + + parser->ReadToken(TokenType::LParent); + + jvpExpr->baseFn = parser->ParseExpression(); + + parser->ReadToken(TokenType::RParent); + + return jvpExpr; + } + + static NodeBase* parseJVPDerivativeOf(Parser* parser, void* /* unused */) + { + return parseJVPDerivativeOf(parser); + } /// Parse a `This` type expression static Expr* parseThisTypeExpr(Parser* parser) @@ -6473,6 +6492,7 @@ namespace Slang _makeParseExpr("nullptr", parseNullPtrExpr), _makeParseExpr("try", parseTryExpr), _makeParseExpr("__TaggedUnion", parseTaggedUnionType), + _makeParseExpr("__jvp", parseJVPDerivativeOf) }; ConstArrayView<SyntaxParseInfo> getSyntaxParseInfos() diff --git a/tests/ir/derivative-op-ir-test.slang b/tests/ir/derivative-op-ir-test.slang new file mode 100644 index 000000000..209446765 --- /dev/null +++ b/tests/ir/derivative-op-ir-test.slang @@ -0,0 +1,21 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +__differentiate_jvp float f(float x) +{ + return x; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + float a = 1.0; + float b = -2.0; + outputBuffer[0] = __jvp(f)(a); + outputBuffer[1] = __jvp(f)(b); + } +} diff --git a/tests/ir/derivative-op-ir-test.slang.expected.txt b/tests/ir/derivative-op-ir-test.slang.expected.txt new file mode 100644 index 000000000..f095a0071 --- /dev/null +++ b/tests/ir/derivative-op-ir-test.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +0.0 +0.0 +0.0 +0.0
\ No newline at end of file |
