diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-07-13 15:55:30 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-07-13 15:55:30 -0400 |
| commit | 4af61e2296a49876c2d9e7cf192ae825302a83de (patch) | |
| tree | d067b944b9794fe5061bbf51e8ef6a39d5fcefbf | |
| parent | 564f0d84a9c5276c05e8288955a7685f96278d1b (diff) | |
Added support for differentiating out and inout parameters. (#2323)
* Added out/inout tests
* Added support for out and inout parameters. Still untested
* Fixed and tested support for out and inout types
* Removed some comments
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 29 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 106 | ||||
| -rw-r--r-- | tests/autodiff/inout-parameters-jvp.slang | 30 | ||||
| -rw-r--r-- | tests/autodiff/inout-parameters-jvp.slang.expected.txt | 5 | ||||
| -rw-r--r-- | tests/autodiff/out-parameters-jvp.slang | 28 | ||||
| -rw-r--r-- | tests/autodiff/out-parameters-jvp.slang.expected.txt | 5 |
6 files changed, 183 insertions, 20 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 03da084d3..67e8bf650 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1509,18 +1509,40 @@ namespace Slang return expr; } + // This function proceses primal params (i.e params of the inner function that is being + // differentiated) that need to be carried over to the function signature for the JVP + // function. (eg. out types can be discarded) + // + Type* primalToInputType(ASTBuilder*, Type* primalType) + { + if (auto primalOutType = as<OutType>(primalType)) + return nullptr; + else if (auto primalInOutType = as<InOutType>(primalType)) + return primalInOutType->getValueType(); + + return primalType; + } + Type* primalToJVPParamType(ASTBuilder* builder, Type* primalType) { // Only float and float3 types can be differentiated for now. - if(primalType->equals(builder->getFloatType())) + if (primalType->equals(builder->getFloatType())) return primalType; - else if(auto primalVectorType = as<VectorExpressionType>(primalType)) + else if (auto primalVectorType = as<VectorExpressionType>(primalType)) { // TODO(sai): There's probably a more elegant way to check if a type is a float3? if (getIntVal(primalVectorType->elementCount) == 3 && primalVectorType->elementType->equals(builder->getFloatType())) return primalVectorType; } + else if (auto primalOutType = as<OutType>(primalType)) + { + return builder->getOutType(primalToJVPParamType(builder, primalOutType->getValueType())); + } + else if (auto primalInOutType = as<InOutType>(primalType)) + { + return builder->getInOutType(primalToJVPParamType(builder, primalInOutType->getValueType())); + } return nullptr; } @@ -1558,7 +1580,8 @@ namespace Slang for (UInt i = 0; i < primalType->getParamCount(); i++) { - jvpType->paramTypes.add(primalType->getParamType(i)); + if(auto primalInputType = primalToInputType(astBuilder, primalType->getParamType(i))) + jvpType->paramTypes.add(primalInputType); } for (UInt i = 0; i < primalType->getParamCount(); i++) diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 5b77d483d..f5afccd0c 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -56,7 +56,9 @@ struct JVPTranscriber // Add all primal parameters to the list. for (UIndex i = 0; i < funcType->getParamCount(); i++) { - parameterTypesD.add(funcType->getParamType(i)); + // TODO(sai): Move this check to a separate function. + if (!as<IROutType>(funcType->getParamType(i))) + parameterTypesD.add(funcType->getParamType(i)); } // Add differential versions for the types we support. @@ -85,7 +87,12 @@ struct JVPTranscriber case kIROp_DoubleType: return builder->getType(typeP->getOp()); case kIROp_VectorType: + // TODO(sai): Call differentiateType() on typeP. return as<IRVectorType>(typeP); + case kIROp_OutType: + return builder->getOutType(differentiateType(builder, as<IROutType>(typeP)->getValueType())); + case kIROp_InOutType: + return builder->getInOutType(differentiateType(builder, as<IRInOutType>(typeP)->getValueType())); default: return nullptr; } @@ -102,21 +109,45 @@ struct JVPTranscriber return nullptr; } + IRInst* emitInputParam(IRBuilder* builder, IRParam* paramP) + { + // Convert primal 'inout' types into pure input types, because a + // JVP transformed function must never have primal side-effects. + // + if (auto inoutTypeP = as<IRInOutType>(paramP->getDataType())) + { + auto newParamP = builder->emitParam(inoutTypeP->getValueType()); + cloneEnv.mapOldValToNew.Add(paramP, newParamP); + + return newParamP; + } + else if (as<IROutType>(paramP->getDataType())) + { + getSink()->diagnose(paramP->sourceLoc, + Diagnostics::unexpected, + "encountered unexpected output parameter"); + return nullptr; + } + else + return as<IRParam>(cloneInst(&cloneEnv, builder, paramP)); + } + List<IRParam*> transcribeParams(IRBuilder* builder, IRInstList<IRParam> paramListP) { // Clone (and emit) all the primal parameters. List<IRParam*> newParamListP; for (auto paramP : paramListP) { - newParamListP.add(as<IRParam>(cloneInst(&cloneEnv, builder, paramP))); + if(requiresPrimalClone(builder, paramP)) + newParamListP.add(as<IRParam>(emitInputParam(builder, paramP))); } // Now emit differentials. List<IRParam*> newParamListD; - for (auto paramP : newParamListP) + for (auto paramP : paramListP) { IRParam* paramD = as<IRParam>(differentiateParam(builder, paramP)); - mapDifferentialInst(paramP, paramD); + mapDifferentialInst(findCloneForOperand(&cloneEnv, paramP), paramD); newParamListD.add(paramD); } @@ -187,15 +218,16 @@ struct JVPTranscriber IRInst* differentiateLoad(IRBuilder* builder, IRLoad* loadP) { - if (auto varP = as<IRVar>(loadP->getPtr())) + auto ptrP = loadP->getPtr(); + if (as<IRVar>(ptrP) || as<IRParam>(ptrP)) { // If the loaded parameter has a differential version, // emit a load instruction for the differential parameter. // Otherwise, emit nothing since there's nothing to load. // - if (auto varD = as<IRVar>(getDifferentialInst(varP))) + if (auto ptrD = getDifferentialInst(ptrP, nullptr)) { - IRLoad* loadD = as<IRLoad>(builder->emitLoad(varD)); + IRLoad* loadD = as<IRLoad>(builder->emitLoad(ptrD)); SLANG_ASSERT(loadD); return loadD; } @@ -212,14 +244,14 @@ struct JVPTranscriber { IRInst* storeLocation = storeP->getPtr(); IRInst* storeVal = storeP->getVal(); - if (auto destParam = as<IRVar>(storeLocation)) + if (as<IRVar>(storeLocation) || as<IRParam>(storeLocation)) { // If the stored value has a differential version, // emit a store instruction for the differential parameter. // Otherwise, emit nothing since there's nothing to load. // IRInst* storeValD = getDifferentialInst(storeVal); - IRVar* storeLocationD = as<IRVar>(getDifferentialInst(destParam)); + IRInst* storeLocationD = getDifferentialInst(storeLocation); if (storeValD && storeLocationD) { IRStore* storeD = as<IRStore>( @@ -239,13 +271,18 @@ struct JVPTranscriber IRInst* differentiateReturn(IRBuilder* builder, IRReturn* returnP) { IRInst* returnVal = findCloneForOperand(&cloneEnv, returnP->getVal()); - if (auto returnValD = getDifferentialInst(returnVal)) + if (auto returnValD = getDifferentialInst(returnVal, nullptr)) { IRReturn* returnD = as<IRReturn>(builder->emitReturn(returnValD)); SLANG_ASSERT(returnD); return returnD; } - return nullptr; + else + { + // If the differential return value is not available, emit a + // void return. + return builder->emitReturn(); + } } // Since int/float literals are sometimes nested inside an IRConstructor @@ -352,16 +389,38 @@ struct JVPTranscriber } // Logic for whether a primal instruction needs to be replicated - // in the differential function. For puerly functional blocks with - // no side-effects, it's safe to replicate everything except the - // return instruction. - // + // in the differential function. We detect and avoid replicating + // side-effect instructions. + // bool requiresPrimalClone(IRBuilder*, IRInst* instP) { if (as<IRReturn>(instP)) return false; - else - return true; + else if (auto paramP = as<IRParam>(instP)) + { + // Out-type parameters are discarded from the parameter list, + // since pure JVP functions to not write to primal outputs. + // + if (as<IROutType>(paramP->getDataType())) + return false; + } + else if (auto storeP = as<IRStore>(instP)) + { + IRInst* storeLocation = storeP->getPtr(); + + // Writing to a parameter is a side-effect that should be avoided. + if(as<IRParam>(storeLocation)) + return false; + + // If attempting to store to a location without a clone, + // then this instruction likely has side-effects external to the + // current function. + // + if(!lookUp(&cloneEnv, storeLocation)) + return false; + } + + return true; } IRInst* transcribe(IRBuilder* builder, IRInst* oldInstP) @@ -374,6 +433,19 @@ struct JVPTranscriber // if (requiresPrimalClone(builder, oldInstP)) instP = cloneInst(&cloneEnv, builder, oldInstP); + else + { + // We replace the operands of the old instruction with their clones, + // if available. + // + for(UInt ii = 0; ii < oldInstP->getOperandCount(); ++ii) + { + auto oldOperand = oldInstP->getOperand(ii); + auto newOperand = findCloneForOperand(&cloneEnv, oldOperand); + + instP->getOperands()[ii].init(instP, newOperand); + } + } SLANG_ASSERT(instP); IRInst* instD = differentiateInst(builder, instP); diff --git a/tests/autodiff/inout-parameters-jvp.slang b/tests/autodiff/inout-parameters-jvp.slang new file mode 100644 index 000000000..989e56c02 --- /dev/null +++ b/tests/autodiff/inout-parameters-jvp.slang @@ -0,0 +1,30 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +__differentiate_jvp void h(float x, float y, inout float z) +{ + float m = x + y; + float n = x - y; + z = z + m * n + 2 * x * y; +} + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + float x = 2.0; + float y = 3.5; + float z = 1.0; + float dx = 1.0; + float dy = 0.5; + float dz = 2.5; + + __jvp(h)(x, y, z, dx, dy, dz); + + outputBuffer[0] = dz; // Expect: 12.0 + outputBuffer[1] = z; // Expect: 1.0 + +}
\ No newline at end of file diff --git a/tests/autodiff/inout-parameters-jvp.slang.expected.txt b/tests/autodiff/inout-parameters-jvp.slang.expected.txt new file mode 100644 index 000000000..d8a590c0e --- /dev/null +++ b/tests/autodiff/inout-parameters-jvp.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +12.0 +1.0 +0.0 +0.0
\ No newline at end of file diff --git a/tests/autodiff/out-parameters-jvp.slang b/tests/autodiff/out-parameters-jvp.slang new file mode 100644 index 000000000..58c6cfeb0 --- /dev/null +++ b/tests/autodiff/out-parameters-jvp.slang @@ -0,0 +1,28 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +__differentiate_jvp void h(float x, float y, out float result) +{ + float m = x + y; + float n = x - y; + result = m * n + 2 * x * y; +} + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + float x = 2.0; + float y = 3.5; + float dx = 1.0; + float dy = 0.5; + + float dresult = 0.0f; + __jvp(h)(x, y, dx, dy, dresult); + + outputBuffer[0] = dresult; // Expect: 9.5 + +}
\ No newline at end of file diff --git a/tests/autodiff/out-parameters-jvp.slang.expected.txt b/tests/autodiff/out-parameters-jvp.slang.expected.txt new file mode 100644 index 000000000..555935fc4 --- /dev/null +++ b/tests/autodiff/out-parameters-jvp.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +9.5 +0.0 +0.0 +0.0
\ No newline at end of file |
