summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-08-05 13:19:20 -0400
committerGitHub <noreply@github.com>2022-08-05 13:19:20 -0400
commit2db8c15c04f2aade49636e42f0adee636afb3b73 (patch)
tree774758a9f854ddf655f6c46765a3ef8ca1950857 /source/slang/slang-check-expr.cpp
parent12a846e8facf090aaeb68fcabf55867f5eaed747 (diff)
Added a new differential type system and various improvements (#2343)
* Merge slang-ir-diff-jvp.cpp * Added support and tests for other float vector types * Added swizzle test and code to handle it (tests failing currently) * Fixed one test, the other is still pending * Fixed instruction cloning logic to avoid modifying original function * Fixed an issue with custom 'pow_jvp' and added support for vector contructor * Minor update to comments * Fixed support for division * Fixed an issue with uninitialized diagnostic sink * Moved derivative processing to after mandatory inlining. Skip instructions that don't have side-effects and aren't used by anything. * WIP: Handling unconditional control flow and multi-block functions * Support for unconditional multi-block functions * Added a dead code elimination step to the derivative pass * Changed name of 'hasNoSideEffects()' * Refactored variable names * Added initial IR defs for new type system * Added necessary logic for semantic checking * Overhauled type system to use builtin pair types and conform to the IDifferentiable interface * Automatically replace IRDifferentiablePairType to a custom IRStructType * Added generics handling by expanding the conformance context functionality and allowing for type parameters * Minor fix: early return in processPairTypes() * Minor fixes to differentiable resolution on generic types * Added new instructions for differential pairs. Basic tests work now. Looking into generic types. * Adjusted most tests to the new type system. OutType and InOutType are still not properly working. * Updated __jvp to produce both primal and differential output * Moved autodiff related declarations to diff.meta.slang * Refactored variable names * Added initial IR defs for new type system * Added necessary logic for semantic checking * Overhauled type system to use builtin pair types and conform to the IDifferentiable interface * Automatically replace IRDifferentiablePairType to a custom IRStructType * Added generics handling by expanding the conformance context functionality and allowing for type parameters * Minor fix: early return in processPairTypes() * Minor fixes to differentiable resolution on generic types * Added new instructions for differential pairs. Basic tests work now. Looking into generic types. * Adjusted most tests to the new type system. OutType and InOutType are still not properly working. * Updated __jvp to produce both primal and differential output * Moved autodiff related declarations to diff.meta.slang * Removed external changes * Cleanup the transcription logic: each case returns a pair of insts for the primal and differential computation.
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp66
1 files changed, 27 insertions, 39 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index b7f99c4e7..a787af211 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1526,48 +1526,42 @@ namespace Slang
return expr;
}
- // This function proceses primal params (i.e params of the inner function that is being
- // differentiated) that need to be carried over to the function signature for the JVP
- // function. (eg. out types can be discarded)
- //
- Type* primalToInputType(ASTBuilder*, Type* primalType)
- {
- if (auto primalOutType = as<OutType>(primalType))
- return nullptr;
- else if (auto primalInOutType = as<InOutType>(primalType))
- return primalInOutType->getValueType();
-
- return primalType;
- }
- Type* primalToJVPParamType(ASTBuilder* builder, Type* primalType)
+ Type* SemanticsVisitor::_toDifferentialParamType(ASTBuilder* builder, Type* primalType)
{
- // Only float and vector<float> types can be differentiated for now.
-
- if (primalType->equals(builder->getFloatType()))
- return primalType;
- else if (auto primalVectorType = as<VectorExpressionType>(primalType))
- {
- if (auto jvpElementType = primalToJVPParamType(builder, primalVectorType->elementType))
- return builder->getVectorType(jvpElementType, primalVectorType->elementCount);
- }
- else if (auto primalOutType = as<OutType>(primalType))
+ // Check for type modifiers like 'out' and 'inout'. We need to differentiate the
+ // nested type.
+ //
+ if (auto primalOutType = as<OutType>(primalType))
{
- return builder->getOutType(primalToJVPParamType(builder, primalOutType->getValueType()));
+ return builder->getOutType(_toDifferentialParamType(builder, primalOutType->getValueType()));
}
else if (auto primalInOutType = as<InOutType>(primalType))
{
- return builder->getInOutType(primalToJVPParamType(builder, primalInOutType->getValueType()));
+ return builder->getInOutType(_toDifferentialParamType(builder, primalInOutType->getValueType()));
}
- return nullptr;
+
+ // Get a reference to the builtin 'IDifferentiable' interface
+ auto differentiableInterface = builder->getDifferentiableInterface();
+
+ // Check if the provided type inherits from IDifferentiable.
+ // If not, return the original type.
+ if (auto conformanceWitness = as<Witness>(tryGetInterfaceConformanceWitness(primalType, differentiableInterface)))
+ return builder->getDifferentialPairType(primalType, conformanceWitness);
+ else
+ return primalType;
+
}
- Type* primalToJVPReturnType(ASTBuilder* builder, Type* primalType)
+ Type* SemanticsVisitor::_toJVPReturnType(ASTBuilder* builder, Type* primalType)
{
- if(auto jvpType = primalToJVPParamType(builder, primalType))
- return jvpType;
+ if (auto conformanceWitness =
+ as<Witness>(tryGetInterfaceConformanceWitness(
+ primalType,
+ builder->getDifferentiableInterface())))
+ return builder->getDifferentialPairType(primalType, conformanceWitness);
else
- return builder->getVoidType();
+ return primalType;
}
Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr)
@@ -1588,7 +1582,7 @@ namespace Slang
// The JVP return type is float if primal return type is float
// void otherwise.
//
- jvpType->resultType = primalToJVPReturnType(astBuilder, primalType->getResultType());
+ jvpType->resultType = _toJVPReturnType(astBuilder, primalType->getResultType());
// No support for differentiating function that throw errors, for now.
SLANG_ASSERT(primalType->errorType->equals(astBuilder->getBottomType()));
@@ -1596,13 +1590,7 @@ namespace Slang
for (UInt i = 0; i < primalType->getParamCount(); i++)
{
- if(auto primalInputType = primalToInputType(astBuilder, primalType->getParamType(i)))
- jvpType->paramTypes.add(primalInputType);
- }
-
- for (UInt i = 0; i < primalType->getParamCount(); i++)
- {
- if(auto jvpParamType = primalToJVPParamType(astBuilder, primalType->getParamType(i)))
+ if(auto jvpParamType = _toDifferentialParamType(astBuilder, primalType->getParamType(i)))
jvpType->paramTypes.add(jvpParamType);
}