From 97cb4851eed7a43f10196971b08d3d311386ce9f Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 23 Nov 2022 09:39:08 -0800 Subject: Autodiff through simple dynamic dispatch. (#2527) * Autodiff through simple dynamic dispatch. * Revert changes. * Fix. Co-authored-by: Yong He --- source/slang/slang-check-decl.cpp | 108 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) (limited to 'source/slang/slang-check-decl.cpp') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 5a1218abe..36a1061c9 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -107,6 +107,9 @@ namespace Slang void visitAccessorDecl(AccessorDecl* decl); void visitSetterDecl(SetterDecl* decl); + + void cloneModifiers(Decl* dest, Decl* src); + void setFuncTypeIntoRequirementDecl(CallableDecl* decl, FuncType* funcType); }; struct SemanticsDeclRedeclarationVisitor @@ -1866,6 +1869,32 @@ namespace Slang return false; } + bool hasBackwardDerivative = false; + bool hasForwardDerivative = false; + if (requiredMemberDeclRef.getDecl()->hasModifier()) + { + if (!satisfyingMemberDeclRef.getDecl()->hasModifier() + && !satisfyingMemberDeclRef.getDecl()->hasModifier()) + { + // A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` requirement and vice versa. + return false; + } + hasBackwardDerivative = true; + hasForwardDerivative = true; + } + else if (requiredMemberDeclRef.getDecl()->hasModifier()) + { + if (!satisfyingMemberDeclRef.getDecl()->hasModifier() + && !satisfyingMemberDeclRef.getDecl()->hasModifier() + && !satisfyingMemberDeclRef.getDecl()->hasModifier() + && !satisfyingMemberDeclRef.getDecl()->hasModifier()) + { + // A non-`ForwardDifferentiable` method can't satisfy a `ForwardDifferentiable` requirement and vice versa. + return false; + } + hasForwardDerivative = true; + } + // A signature matches the required one if it has the right number of parameters, // and those parameters have the right types, and also the result/return type // is the required one. @@ -1896,6 +1925,24 @@ namespace Slang witnessTable->add( requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingMemberDeclRef)); + + if (hasForwardDerivative) + { + auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType(); + SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty()); + ForwardDifferentiateVal* val = m_astBuilder->create(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(reqDecl.getFirst(), RequirementWitness(val)); + } + + if (hasBackwardDerivative) + { + auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType(); + SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty()); + BackwardDifferentiateVal* val = m_astBuilder->create(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(reqDecl.getFirst(), RequirementWitness(val)); + } return true; } @@ -5515,6 +5562,43 @@ namespace Slang } } + void SemanticsDeclHeaderVisitor::cloneModifiers(Decl* dest, Decl* src) + { + dest->modifiers = src->modifiers; + } + void SemanticsDeclHeaderVisitor::setFuncTypeIntoRequirementDecl(CallableDecl* decl, FuncType* funcType) + { + if (!funcType) + return; + decl->returnType.type = funcType->getResultType(); + decl->errorType.type = funcType->getErrorType(); + for (UInt i = 0; i < funcType->getParamCount(); i++) + { + auto paramType = funcType->getParamType(i); + if (auto dirType = as(paramType)) + paramType = dirType->getValueType(); + auto param = m_astBuilder->create(); + param->type.type = paramType; + auto paramDir = funcType->getParamDirection(i); + switch (paramDir) + { + case ParameterDirection::kParameterDirection_InOut: + addModifier(param, m_astBuilder->create()); + break; + case ParameterDirection::kParameterDirection_Out: + addModifier(param, m_astBuilder->create()); + break; + case ParameterDirection::kParameterDirection_Ref: + addModifier(param, m_astBuilder->create()); + break; + default: + break; + } + decl->members.add(param); + param->parentDecl = decl; + } + } + void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) { for(auto paramDecl : decl->getParameters()) @@ -5532,6 +5616,30 @@ namespace Slang errorType = TypeExp(m_astBuilder->getBottomType()); } decl->errorType = errorType; + + if (isInterfaceRequirement(decl)) + { + if (decl->hasModifier()) + { + auto reqDecl = m_astBuilder->create(); + cloneModifiers(reqDecl, decl); + auto declRef = DeclRef(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); + auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef)); + setFuncTypeIntoRequirementDecl(reqDecl, as(diffFuncType)); + decl->members.add(reqDecl); + reqDecl->parentDecl = decl; + } + if (decl->hasModifier()) + { + auto reqDecl = m_astBuilder->create(); + cloneModifiers(reqDecl, decl); + auto declRef = DeclRef(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); + auto diffFuncType = getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef)); + setFuncTypeIntoRequirementDecl(reqDecl, as(diffFuncType)); + decl->members.add(reqDecl); + reqDecl->parentDecl = decl; + } + } } void SemanticsDeclHeaderVisitor::visitFuncDecl(FuncDecl* funcDecl) -- cgit v1.2.3