-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgemm.zig
More file actions
101 lines (86 loc) · 3.22 KB
/
gemm.zig
File metadata and controls
101 lines (86 loc) · 3.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
/// cuBLAS GEMM Example: C = α·A·B + β·C
///
/// General matrix-matrix multiply, the most important BLAS operation.
/// Demonstrates SGEMM with matrix setup, computation, and verification.
///
/// Reference: CUDALibrarySamples/cuBLAS/Level-3/gemm
const std = @import("std");
const cuda = @import("zcuda");
pub fn main() !void {
const allocator = std.heap.page_allocator;
std.debug.print("=== cuBLAS GEMM Example ===\n\n", .{});
const ctx = try cuda.driver.CudaContext.new(0);
defer ctx.deinit();
const stream = ctx.defaultStream();
const blas = try cuda.cublas.CublasContext.init(ctx);
defer blas.deinit();
// C (m×n) = alpha * A (m×k) * B (k×n) + beta * C (m×n)
const m: i32 = 4;
const n: i32 = 3;
const k: i32 = 5;
// Column-major storage
// A: 4×5 matrix
var A: [20]f32 = undefined;
var B: [15]f32 = undefined;
var C: [12]f32 = undefined;
var rng = std.Random.DefaultPrng.init(42);
const random = rng.random();
for (&A) |*v| v.* = @as(f32, @floatFromInt(random.intRangeAtMost(i32, 0, 9)));
for (&B) |*v| v.* = @as(f32, @floatFromInt(random.intRangeAtMost(i32, 0, 9)));
@memset(&C, 0.0);
std.debug.print("A ({}×{}):\n", .{ m, k });
for (0..@intCast(m)) |r| {
std.debug.print(" [", .{});
for (0..@intCast(k)) |c| {
std.debug.print(" {d:3.0}", .{A[c * @as(usize, @intCast(m)) + r]});
}
std.debug.print(" ]\n", .{});
}
std.debug.print("B ({}×{}):\n", .{ k, n });
for (0..@intCast(k)) |r| {
std.debug.print(" [", .{});
for (0..@intCast(n)) |c| {
std.debug.print(" {d:3.0}", .{B[c * @as(usize, @intCast(k)) + r]});
}
std.debug.print(" ]\n", .{});
}
// Copy to device
const d_A = try stream.cloneHtoD(f32, &A);
defer d_A.deinit();
const d_B = try stream.cloneHtoD(f32, &B);
defer d_B.deinit();
const d_C = try stream.allocZeros(f32, allocator, @intCast(m * n));
defer d_C.deinit();
// SGEMM: C = 1.0 * A * B + 0.0 * C
try blas.sgemm(.no_transpose, .no_transpose, m, n, k, 1.0, d_A, m, d_B, k, 0.0, d_C, m);
// Copy back
try stream.memcpyDtoH(f32, &C, d_C);
std.debug.print("\nC = A·B ({}×{}):\n", .{ m, n });
for (0..@intCast(m)) |r| {
std.debug.print(" [", .{});
for (0..@intCast(n)) |c| {
std.debug.print(" {d:6.0}", .{C[c * @as(usize, @intCast(m)) + r]});
}
std.debug.print(" ]\n", .{});
}
// Verify against CPU computation
var max_error: f32 = 0.0;
for (0..@intCast(m)) |r| {
for (0..@intCast(n)) |c| {
var expected: f32 = 0.0;
for (0..@intCast(k)) |p| {
const a_val = A[p * @as(usize, @intCast(m)) + r];
const b_val = B[c * @as(usize, @intCast(k)) + p];
expected += a_val * b_val;
}
const actual = C[c * @as(usize, @intCast(m)) + r];
max_error = @max(max_error, @abs(expected - actual));
}
}
std.debug.print("\nMax error: {e}\n", .{max_error});
if (max_error > 1e-4) {
std.debug.print("✗ FAILED\n", .{});
return error.ValidationFailed;
}
std.debug.print("✓ cuBLAS GEMM verified\n", .{});
}