diff options
| author | Yong He <yonghe@outlook.com> | 2022-09-05 00:38:45 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-09-05 00:38:45 -0700 |
| commit | ea0845285b0307d153a91d6f0a5010fc2d7219ed (patch) | |
| tree | bf2d8f7258b2681deddf3391c551c5ff2b1a7918 | |
| parent | 2a869c105dcc23ede8f5e6e16b08261f45aa5aad (diff) | |
Multi parameter `__subscript` (#2392)
* Multi parameter `__subscript`
* Fix.
* Fix bugs.
* Fix.
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/slang-ast-expr.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-ast-iterator.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 36 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-any-value-marshalling.cpp | 15 | ||||
| -rw-r--r-- | source/slang/slang-ir-generics-lowering-context.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-generics-lowering-context.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-resources.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-witness-table-wrapper.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-language-server-ast-lookup.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 33 | ||||
| -rw-r--r-- | tests/diagnostics/interfaces/anyvalue-size-validation.slang.expected | 1 | ||||
| -rw-r--r-- | tests/language-feature/operators/subscript-multi-dimension.slang | 31 | ||||
| -rw-r--r-- | tests/language-feature/operators/subscript-multi-dimension.slang.expected.txt | 5 |
15 files changed, 114 insertions, 52 deletions
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index 828ca035f..70390255f 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -212,9 +212,12 @@ class PostfixExpr: public OperatorExpr class IndexExpr: public Expr { SLANG_AST_CLASS(IndexExpr) + Expr* baseExpression; + List<Expr*> indexExprs; - Expr* baseExpression = nullptr; - Expr* indexExpression = nullptr; + // The source location of `(`, `)`, and `,` that marks the start/end of the application op and + // each argument expr. This info is used by language server. + List<SourceLoc> argumentDelimeterLocs; }; class MemberExpr: public DeclRefExpr diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index 233ce9a17..4d37b68e7 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -67,7 +67,8 @@ struct ASTIterator { iterator->maybeDispatchCallback(subscriptExpr); dispatchIfNotNull(subscriptExpr->baseExpression); - dispatchIfNotNull(subscriptExpr->indexExpression); + for (auto arg : subscriptExpr->indexExprs) + dispatchIfNotNull(arg); } void visitParenExpr(ParenExpr* expr) diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index c7bfdd3a6..11b001c2c 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1303,7 +1303,18 @@ namespace Slang Type* elementType) { auto baseExpr = subscriptExpr->baseExpression; - auto indexExpr = subscriptExpr->indexExpression; + if (subscriptExpr->indexExprs.getCount() < 1) + { + getSink()->diagnose(subscriptExpr, Diagnostics::notEnoughArguments, subscriptExpr->indexExprs.getCount(), 1); + return CreateErrorExpr(subscriptExpr); + } + else if (subscriptExpr->indexExprs.getCount() > 1) + { + getSink()->diagnose(subscriptExpr, Diagnostics::tooManyArguments, subscriptExpr->indexExprs.getCount(), 1); + return CreateErrorExpr(subscriptExpr); + } + + auto indexExpr = subscriptExpr->indexExprs[0]; if (!indexExpr->type->equals(m_astBuilder->getIntType()) && !indexExpr->type->equals(m_astBuilder->getUIntType())) @@ -1325,20 +1336,18 @@ namespace Slang auto baseExpr = subscriptExpr->baseExpression; baseExpr = CheckExpr(baseExpr); - Expr* indexExpr = subscriptExpr->indexExpression; - if (indexExpr) + for (auto& arg : subscriptExpr->indexExprs) { - indexExpr = CheckTerm(indexExpr); + arg = CheckTerm(arg); } - subscriptExpr->baseExpression = baseExpr; - subscriptExpr->indexExpression = indexExpr; - // If anything went wrong in the base expression, // then just move along... if (IsErrorExpr(baseExpr)) return CreateErrorExpr(subscriptExpr); + subscriptExpr->baseExpression = baseExpr; + // Otherwise, we need to look at the type of the base expression, // to figure out how subscripting should work. auto baseType = baseExpr->type.Ptr(); @@ -1348,9 +1357,13 @@ namespace Slang // which should be interpreted as resolving to an array type. IntVal* elementCount = nullptr; - if (indexExpr) + if (subscriptExpr->indexExprs.getCount() == 1) { - elementCount = CheckIntegerConstantExpression(indexExpr, IntegerConstantExpressionCoercionType::AnyInteger, nullptr); + elementCount = CheckIntegerConstantExpression(subscriptExpr->indexExprs[0], IntegerConstantExpressionCoercionType::AnyInteger, nullptr); + } + else if (subscriptExpr->indexExprs.getCount() != 0) + { + getSink()->diagnose(subscriptExpr, Diagnostics::multiDimensionalArrayNotSupported); } auto elementType = CoerceToUsableType(TypeExp(baseExpr, baseTypeType->type)); @@ -1420,9 +1433,8 @@ namespace Slang InvokeExpr* subscriptCallExpr = m_astBuilder->create<InvokeExpr>(); subscriptCallExpr->loc = subscriptExpr->loc; subscriptCallExpr->functionExpr = subscriptFuncExpr; - - // TODO(tfoley): This path can support multiple arguments easily - subscriptCallExpr->arguments.add(subscriptExpr->indexExpression); + subscriptCallExpr->arguments.addRange(subscriptExpr->indexExprs); + subscriptCallExpr->argumentDelimeterLocs.addRange(subscriptExpr->argumentDelimeterLocs); return CheckInvokeExprWithCheckedOperands(subscriptCallExpr); } diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index c3062ea4f..d0c0b8954 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -385,7 +385,7 @@ DIAGNOSTIC(30832, Error, invalidTypeForInheritance, "type '$0' cannot be used fo DIAGNOSTIC(30850, Error, invalidExtensionOnType, "type '$0' cannot be extended. `extension` can only be used to extend a nominal type.") // 309xx: subscripts - +DIAGNOSTIC(30900, Error, multiDimensionalArrayNotSupported, "multi-dimensional array is not supported.") // 310xx: properties // 311xx: accessors @@ -477,7 +477,7 @@ DIAGNOSTIC(38025, Error, mismatchSpecializationArguments, "expected $0 specializ DIAGNOSTIC(38026, Error, globalTypeArgumentDoesNotConformToInterface, "type argument `$1` for global generic parameter `$0` does not conform to interface `$2`.") DIAGNOSTIC(38027, Error, mismatchExistentialSlotArgCount, "expected $0 existential slot arguments ($1 provided)") -DIAGNOSTIC(38029, Error,typeArgumentDoesNotConformToInterface, "type argument '$0' does not conform to the required interface '$1'") +DIAGNOSTIC(38029, Error, typeArgumentDoesNotConformToInterface, "type argument '$0' does not conform to the required interface '$1'") DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself") DIAGNOSTIC(39999, Error, errorInImportedModule, "import of module '$0' failed because of a compilation error") @@ -549,7 +549,7 @@ DIAGNOSTIC(41000, Warning, unreachableCode, "unreachable code detected") DIAGNOSTIC(41010, Warning, missingReturn, "control flow may reach end of non-'void' function") DIAGNOSTIC(41011, Error, typeDoesNotFitAnyValueSize, "type '$0' does not fit in the size required by its conforming interface.") - +DIAGNOSTIC(41012, Note, typeAndLimit, "sizeof($0) is $1, limit is $2") DIAGNOSTIC(41012, Error, typeCannotBePackedIntoAnyValue, "type '$0' contains fields that cannot be packed into an AnyValue.") // diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp index 39292e2b1..ea1a6cf32 100644 --- a/source/slang/slang-ir-any-value-marshalling.cpp +++ b/source/slang/slang-ir-any-value-marshalling.cpp @@ -86,18 +86,7 @@ namespace Slang virtual void marshalBasicType(IRBuilder* builder, IRType* dataType, IRInst* concreteTypedVar) = 0; // Defines what to do with resource handle elements. virtual void marshalResourceHandle(IRBuilder* builder, IRType* dataType, IRInst* concreteTypedVar) = 0; - // Validates that the type fits in the given AnyValueSize. - // After calling emitMarshallingCode, `fieldOffset` will be increased to the required `AnyValue` size. - // If this is larger than the provided AnyValue size, report a dianogstic. We might want to front load - // this in a separate IR validation pass in the future, but this is the easiest way to report the - // diagnostic now. - void validateAnyTypeSize(DiagnosticSink* sink, IRType* concreteType) - { - if (fieldOffset > static_cast<uint32_t>(anyValInfo->fieldKeys.getCount())) - { - sink->diagnose(concreteType->sourceLoc, Diagnostics::typeDoesNotFitAnyValueSize, concreteType); - } - } + void ensureOffsetAt4ByteBoundary() { if (intraFieldOffset) @@ -396,8 +385,6 @@ namespace Slang context.anyValueVar = resultVar; emitMarshallingCode(&builder, &context, concreteTypedVar); - context.validateAnyTypeSize(sharedContext->sink, type); - auto load = builder.emitLoad(resultVar); builder.emitReturn(load); return func; diff --git a/source/slang/slang-ir-generics-lowering-context.cpp b/source/slang/slang-ir-generics-lowering-context.cpp index 0ed7d75d7..d0e1fabaf 100644 --- a/source/slang/slang-ir-generics-lowering-context.cpp +++ b/source/slang/slang-ir-generics-lowering-context.cpp @@ -374,12 +374,15 @@ namespace Slang } - bool SharedGenericsLoweringContext::doesTypeFitInAnyValue(IRType* concreteType, IRInterfaceType* interfaceType) + bool SharedGenericsLoweringContext::doesTypeFitInAnyValue(IRType* concreteType, IRInterfaceType* interfaceType, IRIntegerValue* outTypeSize, IRIntegerValue* outLimit) { auto anyValueSize = getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc); + if (outLimit) *outLimit = anyValueSize; IRSizeAndAlignment sizeAndAlignment; Result result = getNaturalSizeAndAlignment(targetReq, concreteType, &sizeAndAlignment); + if (outTypeSize) *outTypeSize = sizeAndAlignment.size; + if(SLANG_FAILED(result) || (sizeAndAlignment.size > anyValueSize)) { // The value does not fit, either because it is too large, diff --git a/source/slang/slang-ir-generics-lowering-context.h b/source/slang/slang-ir-generics-lowering-context.h index 85ba2443d..fbfb42559 100644 --- a/source/slang/slang-ir-generics-lowering-context.h +++ b/source/slang/slang-ir-generics-lowering-context.h @@ -99,7 +99,7 @@ namespace Slang } /// Does the given `concreteType` fit within the any-value size deterined by `interfaceType`? - bool doesTypeFitInAnyValue(IRType* concreteType, IRInterfaceType* interfaceType); + bool doesTypeFitInAnyValue(IRType* concreteType, IRInterfaceType* interfaceType, IRIntegerValue* outTypeSize = nullptr, IRIntegerValue* outLimit = nullptr); }; bool isPolymorphicType(IRInst* typeInst); diff --git a/source/slang/slang-ir-specialize-resources.cpp b/source/slang/slang-ir-specialize-resources.cpp index ad6baea67..734315911 100644 --- a/source/slang/slang-ir-specialize-resources.cpp +++ b/source/slang/slang-ir-specialize-resources.cpp @@ -1002,6 +1002,9 @@ struct ResourceOutputSpecializationPass if(oldParamInfo.flavor == OutputInfo::Flavor::None) continue; + if (oldParamInfo.flavor == OutputInfo::Flavor::Undefined) + continue; + // For any paraemter that was specialized, we will use // the computed information on the parameter to materialize // a value for the output in the context of the caller. diff --git a/source/slang/slang-ir-witness-table-wrapper.cpp b/source/slang/slang-ir-witness-table-wrapper.cpp index 1d84eee19..81527b89f 100644 --- a/source/slang/slang-ir-witness-table-wrapper.cpp +++ b/source/slang/slang-ir-witness-table-wrapper.cpp @@ -186,8 +186,13 @@ namespace Slang // we can't consider this case a hard error. // auto concreteType = witnessTable->getConcreteType(); - if(!sharedContext->doesTypeFitInAnyValue(concreteType, interfaceType)) + IRIntegerValue typeSize, sizeLimit; + if (!sharedContext->doesTypeFitInAnyValue(concreteType, interfaceType, &typeSize, &sizeLimit)) + { + sharedContext->sink->diagnose(concreteType, Diagnostics::typeDoesNotFitAnyValueSize, concreteType); + sharedContext->sink->diagnose(concreteType, Diagnostics::typeAndLimit, concreteType, typeSize, sizeLimit); return; + } for (auto child : witnessTable->getChildren()) { diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index 353c98fa4..9a42f86f3 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -112,8 +112,9 @@ public: bool visitIncompleteExpr(IncompleteExpr*) { return false; } bool visitIndexExpr(IndexExpr* subscriptExpr) { - if (dispatchIfNotNull(subscriptExpr->indexExpression)) - return true; + for (auto arg : subscriptExpr->indexExprs) + if (dispatchIfNotNull(arg)) + return true; return dispatchIfNotNull(subscriptExpr->baseExpression); } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 31472e878..95d2c1cd7 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3073,7 +3073,10 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> { auto type = lowerType(context, expr->type); auto baseVal = lowerSubExpr(expr->baseExpression); - auto indexVal = getSimpleVal(context, lowerRValueExpr(context, expr->indexExpression)); + + SLANG_RELEASE_ASSERT(expr->indexExprs.getCount() == 1); + + auto indexVal = getSimpleVal(context, lowerRValueExpr(context, expr->indexExprs[0])); return subscriptValue(type, baseVal, indexVal); } @@ -6714,7 +6717,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // Allocate an IRInterfaceType with the `operandCount` operands. IRInterfaceType* irInterface = subBuilder->createInterfaceType(operandCount, nullptr); - + // Add `irInterface` to decl mapping now to prevent cyclic lowering. setValue(subContext, decl, LoweredValInfo::simple(irInterface)); diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 594ca4cc3..77a2319d2 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -1476,7 +1476,8 @@ namespace Slang void visitIndexExpr(IndexExpr * expr) { expr->baseExpression->accept(this, nullptr); - expr->indexExpression->accept(this, nullptr); + for (auto arg : expr->indexExprs) + arg->accept(this, nullptr); } void visitMemberExpr(MemberExpr * expr) { @@ -1824,7 +1825,8 @@ namespace Slang auto arrayTypeExpr = astBuilder->create<IndexExpr>(); arrayTypeExpr->loc = arrayDeclarator->openBracketLoc; arrayTypeExpr->baseExpression = ioInfo->typeSpec; - arrayTypeExpr->indexExpression = arrayDeclarator->elementCountExpr; + if (arrayDeclarator->elementCountExpr) + arrayTypeExpr->indexExprs.add(arrayDeclarator->elementCountExpr); ioInfo->typeSpec = arrayTypeExpr; declarator = arrayDeclarator->inner; @@ -2045,7 +2047,7 @@ namespace Slang parser->ReadToken(TokenType::LBracket); if (!parser->LookAheadToken(TokenType::RBracket)) { - arrType->indexExpression = parser->ParseExpression(); + arrType->indexExprs.add(parser->ParseExpression()); } parser->ReadToken(TokenType::RBracket); typeExpr = arrType; @@ -5779,18 +5781,23 @@ namespace Slang IndexExpr* indexExpr = parser->astBuilder->create<IndexExpr>(); indexExpr->baseExpression = expr; parser->FillPosition(indexExpr); - parser->ReadToken(TokenType::LBracket); - // TODO: eventually we may want to support multiple arguments inside the `[]` - if (!parser->LookAheadToken(TokenType::RBracket)) - { - indexExpr->indexExpression = parser->ParseExpression(); - } - else + auto lBracket = parser->ReadToken(TokenType::LBracket); + indexExpr->argumentDelimeterLocs.add(lBracket.loc); + while (!parser->tokenReader.isAtEnd()) { - indexExpr->indexExpression = parser->astBuilder->create<IncompleteExpr>(); + if (!parser->LookAheadToken(TokenType::RBracket)) + indexExpr->indexExprs.add(parser->ParseArgExpr()); + else + { + break; + } + if (!parser->LookAheadToken(TokenType::Comma)) + break; + auto comma = parser->ReadToken(TokenType::Comma); + indexExpr->argumentDelimeterLocs.add(comma.loc); } - parser->ReadToken(TokenType::RBracket); - + auto rBracket = parser->ReadToken(TokenType::RBracket); + indexExpr->argumentDelimeterLocs.add(rBracket.loc); expr = indexExpr; } break; diff --git a/tests/diagnostics/interfaces/anyvalue-size-validation.slang.expected b/tests/diagnostics/interfaces/anyvalue-size-validation.slang.expected index c3cb9e9c1..f2000909b 100644 --- a/tests/diagnostics/interfaces/anyvalue-size-validation.slang.expected +++ b/tests/diagnostics/interfaces/anyvalue-size-validation.slang.expected @@ -3,6 +3,7 @@ standard error = { tests/diagnostics/interfaces/anyvalue-size-validation.slang(11): error 41011: type 'S' does not fit in the size required by its conforming interface. struct S : IInterface ^ +tests/diagnostics/interfaces/anyvalue-size-validation.slang(11): note 41012: sizeof(S) is 12, limit is 8 } standard output = { } diff --git a/tests/language-feature/operators/subscript-multi-dimension.slang b/tests/language-feature/operators/subscript-multi-dimension.slang new file mode 100644 index 000000000..c15e390dc --- /dev/null +++ b/tests/language-feature/operators/subscript-multi-dimension.slang @@ -0,0 +1,31 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj -output-using-type + +struct S +{ + int4 data; + __subscript(int x, int y) -> int + { + [__unsafeForceInlineEarly] get { return data[y * 2 + x]; } + [__unsafeForceInlineEarly] set { data[y * 2 + x] = newValue;} + } +} + +int test() +{ + S s = {}; + s[1, 0] = 1; + s[1, 1] = 2; + let b = s[1, 0] + s[1, 1]; + return b; +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(4, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + int tid = dispatchThreadID.x; + outputBuffer[tid] = test(); +} diff --git a/tests/language-feature/operators/subscript-multi-dimension.slang.expected.txt b/tests/language-feature/operators/subscript-multi-dimension.slang.expected.txt new file mode 100644 index 000000000..2640fdeb6 --- /dev/null +++ b/tests/language-feature/operators/subscript-multi-dimension.slang.expected.txt @@ -0,0 +1,5 @@ +type: int32_t +3 +3 +3 +3 |
