Skip to content
This repository was archived by the owner on Feb 21, 2026. It is now read-only.

Commit 466b19e

Browse files
committed
[CIR] Support mixed scalar/vector init-list vector construction
Extend vector init-list emission to accept a mix of scalar and vector-valued initializers by flattening sub-vectors elementwise. This enables nested vector construction patterns such as (int3)(int2, scalar) and aligns CIR behavior with Clang vector semantics.
1 parent 5237bd4 commit 466b19e

2 files changed

Lines changed: 79 additions & 7 deletions

File tree

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2239,21 +2239,52 @@ mlir::Value ScalarExprEmitter::VisitInitListExpr(InitListExpr *E) {
22392239
assert(!cir::MissingFeatures::scalableVectors() &&
22402240
"NYI: scalable vector init");
22412241
assert(!cir::MissingFeatures::vectorConstants() && "NYI: vector constants");
2242+
22422243
auto VectorType =
2243-
mlir::dyn_cast<cir::VectorType>(CGF.convertType(E->getType()));
2244+
mlir::cast<cir::VectorType>(CGF.convertType(E->getType()));
2245+
auto ElemTy = VectorType.getElementType();
2246+
22442247
SmallVector<mlir::Value, 16> Elements;
2248+
22452249
for (Expr *init : E->inits()) {
2246-
Elements.push_back(Visit(init));
2250+
mlir::Value V = Visit(init);
2251+
mlir::Type VTy = V.getType();
2252+
2253+
// Scalar element: one lane
2254+
if (VTy == ElemTy) {
2255+
Elements.push_back(V);
2256+
continue;
2257+
}
2258+
2259+
// Subvector: flatten into scalar lanes
2260+
if (auto SubVecTy = mlir::dyn_cast<cir::VectorType>(VTy)) {
2261+
assert(SubVecTy.getElementType() == ElemTy &&
2262+
"vector element type mismatch in init");
2263+
2264+
for (unsigned i = 0; i < SubVecTy.getSize(); ++i) {
2265+
auto Idx = CGF.getBuilder().getUInt32(i, CGF.getLoc(E->getExprLoc()));
2266+
Elements.push_back(cir::VecExtractOp::create(
2267+
CGF.getBuilder(), CGF.getLoc(E->getExprLoc()), ElemTy, V, Idx));
2268+
}
2269+
continue;
2270+
}
2271+
2272+
llvm_unreachable("invalid vector initializer element");
22472273
}
2248-
// Zero-initialize any remaining values.
2249-
if (NumInitElements < VectorType.getSize()) {
2274+
2275+
// Zero-initialize remaining lanes
2276+
if (Elements.size() < VectorType.getSize()) {
22502277
mlir::Value ZeroValue = cir::ConstantOp::create(
22512278
CGF.getBuilder(), CGF.getLoc(E->getSourceRange()),
2252-
CGF.getBuilder().getZeroInitAttr(VectorType.getElementType()));
2253-
for (uint64_t i = NumInitElements; i < VectorType.getSize(); ++i) {
2279+
CGF.getBuilder().getZeroInitAttr(ElemTy));
2280+
2281+
while (Elements.size() < VectorType.getSize())
22542282
Elements.push_back(ZeroValue);
2255-
}
22562283
}
2284+
2285+
// Truncate excess lanes if any
2286+
Elements.resize(VectorType.getSize());
2287+
22572288
return cir::VecCreateOp::create(CGF.getBuilder(),
22582289
CGF.getLoc(E->getSourceRange()), VectorType,
22592290
Elements);
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// RUN: %clang -cc1 -triple spirv64-unknown-unknown -cl-std=CL2.0 -finclude-default-header -O0 -emit-cir -fclangir -o - %s | FileCheck %s --check-prefix=CIR
2+
// RUN: %clang -cc1 -triple spirv64-unknown-unknown -cl-std=CL2.0 -finclude-default-header -O2 -emit-llvm -fclangir -o - %s | FileCheck %s --check-prefix=LLVM
3+
// RUN: %clang -cc1 -triple spirv64-unknown-unknown -cl-std=CL2.0 -finclude-default-header -O2 -emit-llvm -o - %s | FileCheck %s --check-prefix=OG-LLVM
4+
5+
int test_scalar(int val, char n) {
6+
return val >> (n & 0x1f);
7+
}
8+
9+
int2 test_vec2(int2 val, char2 n) {
10+
return (int2)(test_scalar(val.x, n.x), test_scalar(val.y, n.y));
11+
}
12+
13+
int3 test_vec3(int3 val, char3 n) {
14+
return (int3)(test_vec2(val.xy, n.xy), test_scalar(val.z, n.z));
15+
}
16+
17+
// CIR-LABEL: cir.func no_inline optnone @test_vec3
18+
// CIR: %[[IDX0:.*]] = cir.const #cir.int<0> : !u32i
19+
// CIR: %[[E0:.*]] = cir.vec.extract %{{.*}}[%[[IDX0]] : !u32i] : !cir.vector<!s32i x 2>
20+
// CIR: %[[IDX1:.*]] = cir.const #cir.int<1> : !u32i
21+
// CIR: %[[E1:.*]] = cir.vec.extract %{{.*}}[%[[IDX1]] : !u32i] : !cir.vector<!s32i x 2>
22+
// CIR: %[[V3:.*]] = cir.load {{.*}} : !cir.ptr<!cir.vector<!s32i x 3>{{.*}}>, !cir.vector<!s32i x 3>
23+
// CIR: %[[IDX2:.*]] = cir.const #cir.int<2> : !s64i
24+
// CIR: %[[VAL2:.*]] = cir.vec.extract %[[V3]]
25+
// CIR: %[[N3:.*]] = cir.load {{.*}} : !cir.ptr<!cir.vector<!s8i x 3>{{.*}}>, !cir.vector<!s8i x 3>
26+
// CIR: %[[NIDX2:.*]] = cir.const #cir.int<2> : !s64i
27+
// CIR: %[[NVAL2:.*]] = cir.vec.extract %[[N3]]
28+
// CIR: %[[SCALAR:.*]] = cir.call @test_scalar(%[[VAL2]], %[[NVAL2]])
29+
// CIR: cir.vec.create(%[[E0]], %[[E1]], %[[SCALAR]] : !s32i, !s32i, !s32i) : !cir.vector<!s32i x 3>
30+
31+
// LLVM-LABEL: define spir_func <3 x i32> @test_vec3
32+
// LLVM: %[[V0:.*]] = insertelement <3 x i32> poison, i32 %{{.*}}, i64 0
33+
// LLVM: %[[V1:.*]] = insertelement <3 x i32> %[[V0]], i32 %{{.*}}, i64 1
34+
// LLVM: %[[V2:.*]] = insertelement <3 x i32> %[[V1]], i32 %{{.*}}, i64 2
35+
// LLVM: ret <3 x i32> %[[V2]]
36+
37+
// OG-LLVM-LABEL: define spir_func <3 x i32> @test_vec3
38+
// OG-LLVM: %[[V0:.*]] = insertelement <3 x i32> poison, i32 %{{.*}}, i64 0
39+
// OG-LLVM: %[[V1:.*]] = insertelement <3 x i32> %[[V0]], i32 %{{.*}}, i64 1
40+
// OG-LLVM: %[[V2:.*]] = insertelement <3 x i32> %[[V1]], i32 %{{.*}}, i64 2
41+
// OG-LLVM: ret <3 x i32> %[[V2]]

0 commit comments

Comments
 (0)