diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-07-11 23:18:06 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-07-11 23:18:06 -0400 |
| commit | b513d0deef521318ad943d820dd37029075a33c4 (patch) | |
| tree | cc6dc625ae381e0461724c5b137e1a034b03e636 | |
| parent | 9261c7a23ddf061fe9f5bfc3376f09f3c0513bff (diff) | |
Added support for differentiating calls to basic functions, as well as arithmetic on the float3 type (#2313)
* Added support for differentiating calls to basic functions, as well as arithmetic on the float3 type
* Added test expected result
| -rw-r--r-- | source/compiler-core/slang-diagnostic-sink.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 37 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 104 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-generic-call.cpp | 2 | ||||
| -rw-r--r-- | tests/autodiff/local-redecl-custom-jvp.slang (renamed from tests/autodiff/redecl-custom-jvp.slang) | 0 | ||||
| -rw-r--r-- | tests/autodiff/local-redecl-custom-jvp.slang.expected.txt (renamed from tests/autodiff/redecl-custom-jvp.slang.expected.txt) | 0 | ||||
| -rw-r--r-- | tests/autodiff/nested-jvp.slang | 54 | ||||
| -rw-r--r-- | tests/autodiff/nested-jvp.slang.expected.txt | 5 | ||||
| -rw-r--r-- | tests/autodiff/test-intrinsics.slang | 5 | ||||
| -rw-r--r-- | tests/autodiff/vector-arithmetic-jvp.slang | 30 | ||||
| -rw-r--r-- | tests/autodiff/vector-arithmetic-jvp.slang.expected.txt | 5 |
11 files changed, 231 insertions, 15 deletions
diff --git a/source/compiler-core/slang-diagnostic-sink.cpp b/source/compiler-core/slang-diagnostic-sink.cpp index 0110b16d7..34a3c4968 100644 --- a/source/compiler-core/slang-diagnostic-sink.cpp +++ b/source/compiler-core/slang-diagnostic-sink.cpp @@ -642,13 +642,13 @@ void DiagnosticSink::diagnoseRaw( // Did the client supply a callback for us to use? if(writer) { - // If so, pass the error string along to them + // If so, pass the error string along to them. writer->write(message.begin(), message.getLength()); } else { // If the user doesn't have a callback, then just - // collect our diagnostic messages into a buffer + // collect our diagnostic messages into a buffer. outputBuffer.append(message); } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index df58b11ed..a3fec4802 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1509,6 +1509,29 @@ namespace Slang return expr; } + Type* primalToJVPParamType(ASTBuilder* builder, Type* primalType) + { + // Only float and float3 types can be differentiated for now. + + if(primalType->equals(builder->getFloatType())) + return primalType; + else if(auto primalVectorType = as<VectorExpressionType>(primalType)) + { + // TODO(sai): There's probably a more elegant way to check if a type is a float3? + if (getIntVal(primalVectorType->elementCount) == 3 && primalVectorType->elementType->equals(builder->getFloatType())) + return primalVectorType; + } + return nullptr; + } + + Type* primalToJVPReturnType(ASTBuilder* builder, Type* primalType) + { + if(auto jvpType = primalToJVPParamType(builder, primalType)) + return jvpType; + else + return builder->getVoidType(); + } + Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr) { // Check/Resolve inner function declaration. @@ -1524,18 +1547,10 @@ namespace Slang FuncType* jvpType = astBuilder->create<FuncType>(); - // Only float types can be differentiated for now. - // The JVP return type is float if primal return type is float // void otherwise. // - if (primalType->resultType->equals(astBuilder->getFloatType())) - jvpType->resultType = astBuilder->getFloatType(); - else - { - //TODO(yong): issue proper diagnostic here. - jvpType->resultType = astBuilder->getVoidType(); - } + jvpType->resultType = primalToJVPReturnType(astBuilder, primalType->getResultType()); // No support for differentiating function that throw errors, for now. SLANG_ASSERT(primalType->errorType->equals(astBuilder->getBottomType())); @@ -1548,8 +1563,8 @@ namespace Slang for (UInt i = 0; i < primalType->getParamCount(); i++) { - if(primalType->getParamType(i)->equals(astBuilder->getFloatType())) - jvpType->paramTypes.add(astBuilder->getFloatType()); + if(auto jvpParamType = primalToJVPParamType(astBuilder, primalType->getParamType(i))) + jvpType->paramTypes.add(jvpParamType); } expr->type = jvpType; diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 00210daaa..5b77d483d 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -38,6 +38,16 @@ struct JVPTranscriber return instMapD[instP]; } + IRInst* getDifferentialInst(IRInst* instP, IRInst* defaultInst) + { + return (hasDifferentialInst(instP)) ? instMapD[instP] : defaultInst; + } + + bool hasDifferentialInst(IRInst* instP) + { + return instMapD.ContainsKey(instP); + } + IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) { List<IRType*> parameterTypesD; @@ -74,7 +84,8 @@ struct JVPTranscriber case kIROp_FloatType: case kIROp_DoubleType: return builder->getType(typeP->getOp()); - + case kIROp_VectorType: + return as<IRVectorType>(typeP); default: return nullptr; } @@ -252,6 +263,94 @@ struct JVPTranscriber return nullptr; } + // Differentiating a call instruction here is primarily about generating + // an appropriate call list based on whichever parameters have differentials + // in the current transcription context. + // Note(sai): Currently we don't look at modifiers (in, out, const etc..) in the function + // type, and so only support 'plain' parameters. We need to validte this somewhere to + // avoid weird behaviour + // + IRInst* differentiateCall(IRBuilder* builder, IRCall* callP) + { + if (auto calleeP = as<IRFunc>(callP->getCallee())) + { + + // Build the differential callee + IRInst* calleeD = builder->emitJVPDifferentiateInst( + differentiateFunctionType(builder, as<IRFuncType>(calleeP->getFullType())), + calleeP); + + List<IRInst*> args; + // Go over the parameter list and all primal arguments. + for (UIndex ii = 0; ii < callP->getArgCount(); ii++) + { + args.add(callP->getArg(ii)); + } + + { + IRParam* param = calleeP->getFirstParam(); + // Go over the parameter list again and arguments for types that need differentials. + for (UIndex ii = 0; ii < callP->getArgCount(); ii++) + { + // Look the parameter up in the callee's signature. If it requires a derivative, proceed. + // Otherwise, continue. + // + if (differentiateType(builder, param->getDataType())) + { + // If the corresponding argument does not have a differential, create and place a + // 0 argument. + // + auto argP = callP->getArg(ii); + if (auto argD = getDifferentialInst(argP, nullptr)) + args.add(argD); + else + args.add(getZeroOfType(builder, argP->getDataType())); + } + + param = param->getNextParam(); + } + } + + return builder->emitCallInst(differentiateType(builder, callP->getFullType()), + calleeD, + args); + } + else + { + // Note that this can only happen if the callee is a result + // of a higher-order operation. For now, we assume that we cannot + // differentiate such calls safely. + // TODO(sai): Should probably get checked in the front-end. + // + getSink()->diagnose(callP->sourceLoc, + Diagnostics::internalCompilerError, + "attempting to differentiate unresolved callee"); + } + return nullptr; + } + + // In differential computation, the 'default' differential value is always zero. + // This is a consequence of differential computing being inherently linear. As a + // result, it's useful to have a method to generate zero literals of any (arithmetic) type. + // + IRInst* getZeroOfType(IRBuilder* builder, IRType* type) + { + switch (type->getOp()) + { + case kIROp_FloatType: + case kIROp_HalfType: + case kIROp_DoubleType: + return builder->getFloatValue(type, 0.0); + case kIROp_IntType: + return builder->getIntValue(type, 0); + default: + getSink()->diagnose(type->sourceLoc, + Diagnostics::internalCompilerError, + "could not generate zero value for given type"); + return nullptr; + } + } + // Logic for whether a primal instruction needs to be replicated // in the differential function. For puerly functional blocks with // no side-effects, it's safe to replicate everything except the @@ -307,6 +406,9 @@ struct JVPTranscriber case kIROp_Construct: return differentiateConstruct(builder, instP); + + case kIROp_Call: + return differentiateCall(builder, as<IRCall>(instP)); default: getSink()->diagnose(instP->sourceLoc, diff --git a/source/slang/slang-ir-lower-generic-call.cpp b/source/slang/slang-ir-lower-generic-call.cpp index 7dbe11f52..2a97bc28a 100644 --- a/source/slang/slang-ir-lower-generic-call.cpp +++ b/source/slang/slang-ir-lower-generic-call.cpp @@ -39,7 +39,7 @@ namespace Slang if (auto ptrType = as<IRPtrTypeBase>(paramType)) { paramValType = ptrType->getValueType(); - } + } auto argType = arg->getDataType(); if (auto argPtrType = as<IRPtrTypeBase>(argType)) { diff --git a/tests/autodiff/redecl-custom-jvp.slang b/tests/autodiff/local-redecl-custom-jvp.slang index 2bc7cd582..2bc7cd582 100644 --- a/tests/autodiff/redecl-custom-jvp.slang +++ b/tests/autodiff/local-redecl-custom-jvp.slang diff --git a/tests/autodiff/redecl-custom-jvp.slang.expected.txt b/tests/autodiff/local-redecl-custom-jvp.slang.expected.txt index 965f2cb48..965f2cb48 100644 --- a/tests/autodiff/redecl-custom-jvp.slang.expected.txt +++ b/tests/autodiff/local-redecl-custom-jvp.slang.expected.txt diff --git a/tests/autodiff/nested-jvp.slang b/tests/autodiff/nested-jvp.slang new file mode 100644 index 000000000..c37641c91 --- /dev/null +++ b/tests/autodiff/nested-jvp.slang @@ -0,0 +1,54 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[__custom_jvp(pow_jvp)] +float pow_(float x, float n) +{ + return pow<float>(x, n); +} + + +[__custom_jvp(max_jvp)] +float max_(float x, float y) +{ + return max<float>(x, y); +} + + +float pow_jvp(float x, float n, float dx, float dn) +{ + return dx * n * pow(x, n-1) + dn * pow(x, n) * log(x); +} + + +float max_jvp(float x, float y, float dx, float dy) +{ + return (x > y) ? dx : dy; +} + + +/* Fresnel Schlick example */ +__differentiate_jvp float3 fresnel(float3 f0, float3 f90, float cosTheta) +{ + return f0 + (f90 - f0) * pow_(max_(1 - cosTheta, 0.0), 5); +} + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + float3 f0 = float3(0.2, 0.2, 0.2); + float3 f90 = float3(0.7, 0.7, 0.7); + float cosTheta = 0.5; + + float3 d_f0 = float3(0.1, 0.1, 0.1); + float3 d_f90 = float3(0.9, 0.9, 0.9); + float d_cosTheta = 1.0; + + outputBuffer[0] = __jvp(fresnel)(f0, f90, cosTheta, d_f0, d_f90, d_cosTheta).y; // Expect: -0.031250 + } +} diff --git a/tests/autodiff/nested-jvp.slang.expected.txt b/tests/autodiff/nested-jvp.slang.expected.txt new file mode 100644 index 000000000..107153351 --- /dev/null +++ b/tests/autodiff/nested-jvp.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +-0.031250 +0.0 +0.0 +0.0
\ No newline at end of file diff --git a/tests/autodiff/test-intrinsics.slang b/tests/autodiff/test-intrinsics.slang index 3fa53e85a..189004543 100644 --- a/tests/autodiff/test-intrinsics.slang +++ b/tests/autodiff/test-intrinsics.slang @@ -4,3 +4,8 @@ float pow_(float x, float n) { return pow(x, n); } + +float max_(float x, float y) +{ + return max(x, y); +}
\ No newline at end of file diff --git a/tests/autodiff/vector-arithmetic-jvp.slang b/tests/autodiff/vector-arithmetic-jvp.slang new file mode 100644 index 000000000..5852b49a7 --- /dev/null +++ b/tests/autodiff/vector-arithmetic-jvp.slang @@ -0,0 +1,30 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +__differentiate_jvp float3 f(float3 x) +{ + return x; +} + +__differentiate_jvp float3 g(float3 x, float3 y) +{ + float3 a = x + y; + float3 b = x - y; + return a * b + 2 * x * y; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + float3 a = float3(2.0, 2.0, 2.0); + float3 b = float3(1.5, 1.5, 1.5); + float3 da = float3(1.0, 1.0, 1.0); + + outputBuffer[0] = __jvp(f)(a, da).z; // Expect: 1 + outputBuffer[1] = __jvp(g)(a, b, da, float3(2.0, 1.0, 0.0)).y; // Expect: 8 + } +} diff --git a/tests/autodiff/vector-arithmetic-jvp.slang.expected.txt b/tests/autodiff/vector-arithmetic-jvp.slang.expected.txt new file mode 100644 index 000000000..adbb9a448 --- /dev/null +++ b/tests/autodiff/vector-arithmetic-jvp.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +1.0 +8.0 +0.0 +0.0
\ No newline at end of file |
