summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-08-17 14:45:13 -0400
committerGitHub <noreply@github.com>2023-08-17 14:45:13 -0400
commit945409c4c6871c18aad24086c594cc66b5913733 (patch)
tree41eed63f115971d82875e23acbec77d78be4cf3a /source/slang/slang-check-decl.cpp
parent216fc18661fd6e05053b4cc864396e6017e85b04 (diff)
Initial support for differentiating existential types (#3111)
* Merge * WIP: Complete auto-diff logic for existential types * Revert "Add compiler option for generating representative hash" This reverts commit 13b09ef4621e73844c96d64d9c111a8ed0d45aae. * More fixes for fwd-mode AD on existential types * Add anyValueSize inference pass * Fix checking of `Differential.Differential==Differential` * In-progress: infer any-value-size for existential types * Existentials now work in forward-mode * Overhaul handling of existential AD types. Fwd-mode works, reverse-mode requires front-end changes * Reverse-mode now works on existentials * Cleanup * Remove diff rules for create existential object for now * Revert treat-as-differentiable changes * Fixes * More fixes * Cleanup * more cleanup * signed/unsigned * Revert "Cleanup" This reverts commit e4f7d71f07bb207736f90708961eeecd09a1b652. * Cleanup (again) * Remove public/export/keep-alive on null differential after AD pass * Minor fix * Update dictionary accessors * Keep export decoration * More fixes + Support for `kIROp_PackAnyValue` * Merge upstream * Update expected-failure.txt
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp22
1 files changed, 20 insertions, 2 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index cb3db9e39..f25821dac 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -1656,8 +1656,7 @@ namespace Slang
RequirementWitness witnessValue;
auto requirementDecl = m_astBuilder->getSharedASTBuilder()->findBuiltinRequirementDecl(BuiltinRequirementKind::DifferentialType);
if (!inheritanceDecl->witnessTable->getRequirementDictionary().tryGetValue(requirementDecl, witnessValue))
- return;
-
+ return;
// A type used as differential type must have itself as its own differential type.
if (witnessValue.getFlavor() != RequirementWitness::Flavor::val)
return;
@@ -5781,6 +5780,16 @@ namespace Slang
interfaceDecl->members.add(reqDecl);
reqDecl->parentDecl = interfaceDecl;
+ if (!decl->hasModifier<NoDiffThisAttribute>())
+ {
+ // Build decl-ref-type from interface.
+ auto interfaceType = DeclRefType::create(getASTBuilder(), makeDeclRef(interfaceDecl));
+
+ // If the interface is differentiable, make the this type a pair.
+ if (tryGetDifferentialType(getASTBuilder(), interfaceType))
+ reqDecl->diffThisType = getDifferentialPairType(interfaceType);
+ }
+
auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
reqRef->referencedDecl = reqDecl;
reqRef->parentDecl = decl;
@@ -5800,6 +5809,15 @@ namespace Slang
setFuncTypeIntoRequirementDecl(reqDecl, diffFuncType);
interfaceDecl->members.add(reqDecl);
reqDecl->parentDecl = interfaceDecl;
+ if (!decl->hasModifier<NoDiffThisAttribute>())
+ {
+ // Build decl-ref-type from interface.
+ auto interfaceType = DeclRefType::create(getASTBuilder(), makeDeclRef(interfaceDecl));
+
+ // If the interface is differentiable, make the this type a pair.
+ if (tryGetDifferentialType(getASTBuilder(), interfaceType))
+ reqDecl->diffThisType = getDifferentialPairType(interfaceType);
+ }
auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
reqRef->referencedDecl = reqDecl;