diff options
| -rw-r--r-- | docs/user-guide/06-interfaces-generics.md | 17 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 42 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-peephole.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 52 | ||||
| -rw-r--r-- | tests/language-feature/types/as-is-generic.slang | 33 |
8 files changed, 121 insertions, 41 deletions
diff --git a/docs/user-guide/06-interfaces-generics.md b/docs/user-guide/06-interfaces-generics.md index 59a0352fa..2436dab07 100644 --- a/docs/user-guide/06-interfaces-generics.md +++ b/docs/user-guide/06-interfaces-generics.md @@ -671,6 +671,23 @@ void main() // "success" ``` +In addition to casting from an interface type to a concrete type, `as` and `is` operator can be used on generic types as well to cast a generic type into a concrete type. For example: +```csharp +T compute<T>(T a1, T a2) +{ + if (a1 is float) + { + return reinterpret<T>((a1 as float).value + (a2 as float).value); + } + else if (T is int) + { + return reinterpret<T>((a1 as int).value - (a2 as int).value); + } + return T(); +} +// compute(1.0f, 2.0f) == 3.0f +// compute(3, 1) == 2 +``` Extensions to Interfaces ----------------------------- diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 89a7373ee..33a1fa680 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -3183,8 +3183,12 @@ namespace Slang expr->type = m_astBuilder->getBoolType(); expr->value = originalVal; + auto valueType = expr->value->type.type; + if (auto typeType = as<TypeType>(valueType)) + valueType = typeType->getType(); + // If value is a subtype of `type`, then this expr is always true. - if(isSubtype(expr->value->type.type, expr->typeExpr.type)) + if(isSubtype(valueType, expr->typeExpr.type)) { // Instead of returning a BoolLiteralExpr, we use a field to indicate this scenario, // so that the language server can still see the original syntax tree. @@ -3195,10 +3199,11 @@ namespace Slang return expr; } - // Otherwise, we need to ensure the target type is a subtype of value->type. + // Otherwise, if the target type is a subtype of value->type, we need to grab the + // subtype witness for runtime checks. expr->value = maybeOpenExistential(originalVal); - expr->witnessArg = tryGetSubtypeWitness(expr->typeExpr.type, originalVal->type.type); + expr->witnessArg = tryGetSubtypeWitness(expr->typeExpr.type, valueType); if (expr->witnessArg) { // For now we can only support the scenario where `expr->value` is an interface type. @@ -3208,15 +3213,6 @@ namespace Slang } return expr; } - - if (!as<ErrorType>(expr->typeExpr.type) && !as<ErrorType>(expr->value->type.type)) - { - // The type is not in the same hierarchy, so we evaluate to false. - expr->constantVal = m_astBuilder->create<BoolLiteralExpr>(); - expr->constantVal->type = m_astBuilder->getBoolType(); - expr->constantVal->value = false; - expr->constantVal->loc = expr->loc; - } return expr; } @@ -3241,27 +3237,21 @@ namespace Slang return makeOptional; } - // For now we can only support the scenario where `expr->value` is an interface type. - if (!isInterfaceType(expr->value->type)) - { - getSink()->diagnose(expr, Diagnostics::isOperatorValueMustBeInterfaceType); - } - - expr->typeExpr = typeExpr.exp; + // If target type is an interface type, we will obtain the witness here for + // runtime casting. expr->witnessArg = tryGetSubtypeWitness(typeExpr.type, expr->value->type.type); if (expr->witnessArg) { + // For now we can only support the scenario where `expr->value` is an interface type. + if (!isInterfaceType(expr->value->type.type)) + { + getSink()->diagnose(expr, Diagnostics::isOperatorValueMustBeInterfaceType); + } expr->value = maybeOpenExistential(expr->value); return expr; } - if (!as<ErrorType>(typeExpr.type) && !as<ErrorType>(expr->value->type.type)) - { - getSink()->diagnose(expr, Diagnostics::typeNotInTheSameHierarchy, expr->value->type.type, typeExpr.type); - } - - expr->type = m_astBuilder->getErrorType(); - + expr->typeExpr = typeExpr.exp; return expr; } diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 037e441d4..2150be5ad 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -282,7 +282,6 @@ DIAGNOSTIC(30012, Error, noOverloadFoundForBinOperatorOnTypes, "no overload foun DIAGNOSTIC(30013, Error, subscriptNonArray, "no subscript operation found for type '$0'") DIAGNOSTIC(30014, Error, subscriptIndexNonInteger, "index expression must evaluate to int.") DIAGNOSTIC(30015, Error, undefinedIdentifier2, "undefined identifier '$0'.") -DIAGNOSTIC(30018, Error, typeNotInTheSameHierarchy, "invalid use of 'as' operator: expression evaluates to '$0', which is not in the same type hierarchy as target type '$1'.") DIAGNOSTIC(30019, Error, typeMismatch, "expected an expression of type '$0', got '$1'") DIAGNOSTIC(30021, Error, noApplicationFunction, "$0: no overload takes arguments ($1)") DIAGNOSTIC(30022, Error, invalidTypeCast, "invalid type cast between \"$0\" and \"$1\".") diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 763d96840..2217bc143 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -478,6 +478,8 @@ Result linkAndOptimizeIR( } lowerReinterpret(targetProgram, irModule, sink); + if (sink->getErrorCount() != 0) + return SLANG_FAIL; validateIRModuleIfEnabled(codeGenContext, irModule); diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index d7c3c6bda..d4369da7a 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -958,8 +958,14 @@ struct PeepholeContext : InstPassBase } case kIROp_TypeEquals: { - auto left = inst->getOperand(0)->getDataType(); - auto right = inst->getOperand(1)->getDataType(); + auto getTypeFromOperand = [](IRInst* operand) -> IRType* + { + if (as<IRTypeType>(operand->getFullType()) || !operand->getFullType()) + return (IRType*)operand; + return operand->getFullType(); + }; + auto left = getTypeFromOperand(inst->getOperand(0)); + auto right = getTypeFromOperand(inst->getOperand(1)); if (isConcreteType(left) && isConcreteType(right)) { IRBuilder builder(module); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index ff9c6f6f8..fd5ea0fc7 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -31,6 +31,11 @@ namespace Slang { if (!irObject) return; + if (as<IRType>(irObject)) + { + getTypeNameHint(sb, irObject); + return; + } if (auto nameHint = irObject->findDecoration<IRNameHintDecoration>()) { sb << nameHint->getName(); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 409cd65ee..638a9c577 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -4314,9 +4314,7 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> getBuilder()->emitMakeStruct(irType, args.getCount(), args.getBuffer())); } } - - SLANG_UNEXPECTED("unexpected type when creating default value"); - UNREACHABLE_RETURN(LoweredValInfo()); + return LoweredValInfo::simple(getBuilder()->emitDefaultConstruct(irType)); } LoweredValInfo getDefaultVal(DeclRef<VarDeclBase> decl) @@ -4718,20 +4716,32 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> LoweredValInfo visitAsTypeExpr(AsTypeExpr* expr) { auto value = lowerLValueExpr(context, expr->value); - auto existentialInfo = value.getExtractedExistentialValInfo(); + ExtractedExistentialValInfo* existentialInfo = nullptr; auto optType = lowerType(context, expr->type); SLANG_RELEASE_ASSERT(optType->getOp() == kIROp_OptionalType); auto targetType = optType->getOperand(0); - auto witness = lowerSimpleVal(context, expr->witnessArg); auto builder = getBuilder(); auto var = builder->emitVar(optType); - auto isType = builder->emitIsType(existentialInfo->extractedVal, existentialInfo->witnessTable, targetType, witness); + IRInst* isType = nullptr; + if (expr->witnessArg) + { + auto witness = lowerSimpleVal(context, expr->witnessArg); + existentialInfo = value.getExtractedExistentialValInfo(); + isType = builder->emitIsType(existentialInfo->extractedVal, existentialInfo->witnessTable, targetType, witness); + } + else + { + SLANG_ASSERT(value.val); + auto leftType = lowerType(context, expr->value->type); + IRInst* args[] = { leftType, targetType }; + isType = builder->emitIntrinsicInst(builder->getBoolType(), kIROp_TypeEquals, 2, args); + } IRBlock* trueBlock; IRBlock* falseBlock; IRBlock* afterBlock; builder->emitIfElseWithBlocks(isType, trueBlock, falseBlock, afterBlock); builder->setInsertInto(trueBlock); - auto irVal = builder->emitReinterpret(targetType, existentialInfo->extractedVal); + auto irVal = builder->emitReinterpret(targetType, existentialInfo ? existentialInfo->extractedVal : getSimpleVal(context, value)); auto optionalVal = builder->emitMakeOptionalValue(optType, irVal); builder->emitStore(var, optionalVal); builder->emitBranch(afterBlock); @@ -4751,11 +4761,29 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> { return LoweredValInfo::simple(getBuilder()->getBoolValue(expr->constantVal->value)); } - auto value = lowerLValueExpr(context, expr->value); - auto type = lowerType(context, expr->type); - auto witness = lowerSimpleVal(context, expr->witnessArg); - auto existentialInfo = value.getExtractedExistentialValInfo(); - auto irVal = getBuilder()->emitIsType(existentialInfo->extractedVal, existentialInfo->witnessTable, type, witness); + // If expr is a witness, then this is a run-time type check from for an existential type. + if (expr->witnessArg) + { + auto value = lowerLValueExpr(context, expr->value); + auto type = lowerType(context, expr->typeExpr.type); + auto witness = lowerSimpleVal(context, expr->witnessArg); + auto existentialInfo = value.getExtractedExistentialValInfo(); + auto irVal = getBuilder()->emitIsType(existentialInfo->extractedVal, existentialInfo->witnessTable, type, witness); + return LoweredValInfo::simple(irVal); + } + // For all other cases, we map to a simple type equality check in the IR. + IRType* leftType = nullptr; + if (auto typeType = as<TypeType>(expr->value->type)) + { + leftType = lowerType(context, typeType->getType()); + } + else + { + leftType = lowerType(context, expr->value->type); + } + auto rightType = lowerType(context, expr->typeExpr.type); + IRInst* args[] = { leftType, rightType }; + auto irVal = getBuilder()->emitIntrinsicInst(getBuilder()->getBoolType(), kIROp_TypeEquals, 2, args); return LoweredValInfo::simple(irVal); } diff --git a/tests/language-feature/types/as-is-generic.slang b/tests/language-feature/types/as-is-generic.slang new file mode 100644 index 000000000..6790eda37 --- /dev/null +++ b/tests/language-feature/types/as-is-generic.slang @@ -0,0 +1,33 @@ +// optional.slang + +// Test that `is` and `as` operator works on generic types. + +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -vk -compute -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -cpu -compute -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +T compute<T>(T a1, T a2) +{ + if (a1 is float) + { + return reinterpret<T>((a1 as float).value + (a2 as float).value); + } + else if (T is int) + { + return reinterpret<T>((a1 as int).value - (a2 as int).value); + } + return T(); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + // CHECK: 3.0 + outputBuffer[0] = compute(1.0f, 2.0f); + + // CHECK: 1.0 + outputBuffer[1] = compute(2, 1); +} |
