summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-type.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-09-19 18:51:24 -0400
committerGitHub <noreply@github.com>2023-09-19 18:51:24 -0400
commit739c3a7b53dc6489065fcd5e9f0a04370c5f9c8f (patch)
tree593c86cbc184476479c66554cc6784b454bdec66 /source/slang/slang-check-type.cpp
parent359fdc9d556b4c493c588c5b8f93df85933634f8 (diff)
Added `[AutoPyBindCUDA]` for automatic kernel binding + `[PyExport]` for exporting type information (#3209)
* Initial: add a DiffTensor impl * Auto-binding and diff tensor implementations now work * Refactored diff-tensor implementation + added py-export for struct types * Cleanup * Update slang-ir-pytorch-cpp-binding.cpp * Updated test names * Update autodiff-data-flow.slang.expected * Add more versions of load/store & default generic args for DiffTensorView. * Add diagnostic for default generic arg and more tests * Add more `[AutoPyBind]` tests
Diffstat (limited to 'source/slang/slang-check-type.cpp')
-rw-r--r--source/slang/slang-check-type.cpp37
1 files changed, 33 insertions, 4 deletions
diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp
index d5d3e5a5d..5967da8b0 100644
--- a/source/slang/slang-check-type.cpp
+++ b/source/slang/slang-check-type.cpp
@@ -273,7 +273,8 @@ namespace Slang
auto genericDeclRef = genericDeclRefType->getDeclRef();
ensureDecl(genericDeclRef, DeclCheckState::CanSpecializeGeneric);
- List<Expr*> args;
+ List<Val*> args;
+ List<Val*> witnessArgs;
for (Decl* member : genericDeclRef.getDecl()->members)
{
if (auto typeParam = as<GenericTypeParamDecl>(member))
@@ -290,7 +291,7 @@ namespace Slang
// TODO: this is one place where syntax should get cloned!
if (outProperType)
- args.add(typeParam->initType.exp);
+ args.add(ExtractGenericArgVal(typeParam->initType.exp));
}
else if (auto valParam = as<GenericValueParamDecl>(member))
{
@@ -305,14 +306,42 @@ namespace Slang
}
// TODO: this is one place where syntax should get cloned!
if (outProperType)
- args.add(valParam->initExpr);
+ args.add(ExtractGenericArgVal(valParam->initExpr));
+ }
+ else if (auto constraintParam = as<GenericTypeConstraintDecl>(member))
+ {
+ auto genericParam = as<DeclRefType>(constraintParam->sub.type)->getDeclRef();
+ if (!genericParam)
+ return false;
+ auto genericTypeParamDecl = as<GenericTypeParamDecl>(genericParam.getDecl());
+ if (!genericTypeParamDecl)
+ return false;
+ auto defaultType = CheckProperType(genericTypeParamDecl->initType);
+ auto witness = tryGetSubtypeWitness(defaultType, CheckProperType(constraintParam->sup));
+ if (!witness)
+ {
+ // diagnose
+ getSink()->diagnose(
+ genericTypeParamDecl->initType.exp,
+ Diagnostics::typeArgumentDoesNotConformToInterface,
+ defaultType,
+ constraintParam->sup);
+
+ SLANG_UNEXPECTED("default type argument does not conform to interface");
+ return false;
+ }
+ witnessArgs.add(witness);
}
else
{
// ignore non-parameter members
}
}
- result = InstantiateGenericType(genericDeclRef, args);
+ // Combine args and witnessArgs
+ args.addRange(witnessArgs);
+
+ result = DeclRefType::create(getASTBuilder(),
+ getASTBuilder()->getGenericAppDeclRef(genericDeclRef, args.getArrayView()));
}
// default case: we expect this to already be a proper type