-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrsm.zig
More file actions
78 lines (66 loc) · 2.46 KB
/
trsm.zig
File metadata and controls
78 lines (66 loc) · 2.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
/// cuBLAS TRSM Example: Triangular Solve
///
/// Solves op(A) * X = α·B where A is triangular.
/// Real use case: solving linear systems after LU factorization.
///
/// Reference: CUDALibrarySamples/cuBLAS/Level-3/trsm
const std = @import("std");
const cuda = @import("zcuda");
pub fn main() !void {
std.debug.print("=== cuBLAS TRSM Example ===\n\n", .{});
const ctx = try cuda.driver.CudaContext.new(0);
defer ctx.deinit();
const stream = ctx.defaultStream();
const blas = try cuda.cublas.CublasContext.init(ctx);
defer blas.deinit();
// Solve A * X = B where A is lower triangular 3×3, B is 3×1
const m: i32 = 3;
const n: i32 = 1;
// A (lower triangular):
// | 2 0 0 |
// | 3 4 0 |
// | 1 5 6 |
// Column-major
const A_data = [_]f32{ 2, 3, 1, 0, 4, 5, 0, 0, 6 };
// B = | 4 |
// | 23 |
// | 58 |
// Solution should be X = | 2 |
// | 4.25 |
// | 5.625 |
var B_data = [_]f32{ 4, 23, 58 };
std.debug.print("A (lower triangular):\n", .{});
for (0..@intCast(m)) |r| {
std.debug.print(" [", .{});
for (0..@intCast(m)) |c| {
std.debug.print(" {d:3.0}", .{A_data[c * @as(usize, @intCast(m)) + r]});
}
std.debug.print(" ]\n", .{});
}
std.debug.print("B = [ ", .{});
for (&B_data) |v| std.debug.print("{d:.0} ", .{v});
std.debug.print("]\n\n", .{});
const d_A = try stream.cloneHtoD(f32, &A_data);
defer d_A.deinit();
const d_B = try stream.cloneHtoD(f32, &B_data);
defer d_B.deinit();
// Solve: A * X = 1.0 * B (result stored in B)
try blas.strsm(.left, .lower, .no_transpose, .non_unit, m, n, 1.0, d_A, m, d_B, m);
var X: [3]f32 = undefined;
try stream.memcpyDtoH(f32, &X, d_B);
std.debug.print("X (solution of A·X = B):\n [ ", .{});
for (&X) |v| std.debug.print("{d:.4} ", .{v});
std.debug.print("]\n", .{});
// Verify: A * X should equal original B
const orig_B = [_]f32{ 4, 23, 58 };
std.debug.print("\nVerification A·X:\n", .{});
for (0..@intCast(m)) |r| {
var sum: f32 = 0.0;
for (0..r + 1) |c| {
sum += A_data[c * @as(usize, @intCast(m)) + r] * X[c];
}
std.debug.print(" Row {}: {d:.4} (expected {d:.0})\n", .{ r, sum, orig_B[r] });
if (@abs(sum - orig_B[r]) > 1e-3) return error.ValidationFailed;
}
std.debug.print("\n✓ cuBLAS TRSM verified\n", .{});
}