From 739c3a7b53dc6489065fcd5e9f0a04370c5f9c8f Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 19 Sep 2023 18:51:24 -0400 Subject: Added `[AutoPyBindCUDA]` for automatic kernel binding + `[PyExport]` for exporting type information (#3209) * Initial: add a DiffTensor impl * Auto-binding and diff tensor implementations now work * Refactored diff-tensor implementation + added py-export for struct types * Cleanup * Update slang-ir-pytorch-cpp-binding.cpp * Updated test names * Update autodiff-data-flow.slang.expected * Add more versions of load/store & default generic args for DiffTensorView. * Add diagnostic for default generic arg and more tests * Add more `[AutoPyBind]` tests --- source/slang/slang-check-type.cpp | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) (limited to 'source/slang/slang-check-type.cpp') diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index d5d3e5a5d..5967da8b0 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -273,7 +273,8 @@ namespace Slang auto genericDeclRef = genericDeclRefType->getDeclRef(); ensureDecl(genericDeclRef, DeclCheckState::CanSpecializeGeneric); - List args; + List args; + List witnessArgs; for (Decl* member : genericDeclRef.getDecl()->members) { if (auto typeParam = as(member)) @@ -290,7 +291,7 @@ namespace Slang // TODO: this is one place where syntax should get cloned! if (outProperType) - args.add(typeParam->initType.exp); + args.add(ExtractGenericArgVal(typeParam->initType.exp)); } else if (auto valParam = as(member)) { @@ -305,14 +306,42 @@ namespace Slang } // TODO: this is one place where syntax should get cloned! if (outProperType) - args.add(valParam->initExpr); + args.add(ExtractGenericArgVal(valParam->initExpr)); + } + else if (auto constraintParam = as(member)) + { + auto genericParam = as(constraintParam->sub.type)->getDeclRef(); + if (!genericParam) + return false; + auto genericTypeParamDecl = as(genericParam.getDecl()); + if (!genericTypeParamDecl) + return false; + auto defaultType = CheckProperType(genericTypeParamDecl->initType); + auto witness = tryGetSubtypeWitness(defaultType, CheckProperType(constraintParam->sup)); + if (!witness) + { + // diagnose + getSink()->diagnose( + genericTypeParamDecl->initType.exp, + Diagnostics::typeArgumentDoesNotConformToInterface, + defaultType, + constraintParam->sup); + + SLANG_UNEXPECTED("default type argument does not conform to interface"); + return false; + } + witnessArgs.add(witness); } else { // ignore non-parameter members } } - result = InstantiateGenericType(genericDeclRef, args); + // Combine args and witnessArgs + args.addRange(witnessArgs); + + result = DeclRefType::create(getASTBuilder(), + getASTBuilder()->getGenericAppDeclRef(genericDeclRef, args.getArrayView())); } // default case: we expect this to already be a proper type -- cgit v1.2.3