-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathaxpy.zig
More file actions
68 lines (55 loc) · 2.15 KB
/
axpy.zig
File metadata and controls
68 lines (55 loc) · 2.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
/// cuBLAS AXPY Example: y = α·x + y
///
/// Demonstrates the fundamental BLAS Level-1 operation.
/// Both single (SAXPY) and double precision (DAXPY) variants.
///
/// Reference: CUDALibrarySamples/cuBLAS/Level-1/axpy
const std = @import("std");
const cuda = @import("zcuda");
pub fn main() !void {
std.debug.print("=== cuBLAS AXPY Example ===\n\n", .{});
const ctx = try cuda.driver.CudaContext.new(0);
defer ctx.deinit();
std.debug.print("Device: {s}\n\n", .{ctx.name()});
const stream = ctx.defaultStream();
const blas = try cuda.cublas.CublasContext.init(ctx);
defer blas.deinit();
// --- SAXPY (single precision) ---
std.debug.print("─── SAXPY: y = 2.0 * x + y ───\n", .{});
const n: i32 = 8;
const alpha: f32 = 2.0;
const x_data = [_]f32{ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0 };
const y_data = [_]f32{ 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0 };
const d_x = try stream.cloneHtoD(f32, &x_data);
defer d_x.deinit();
const d_y = try stream.cloneHtoD(f32, &y_data);
defer d_y.deinit();
std.debug.print(" x = [ ", .{});
for (&x_data) |v| std.debug.print("{d:.0} ", .{v});
std.debug.print("]\n y = [ ", .{});
for (&y_data) |v| std.debug.print("{d:.0} ", .{v});
std.debug.print("]\n α = {d:.1}\n\n", .{alpha});
try blas.saxpy(n, alpha, d_x, d_y);
var h_result: [8]f32 = undefined;
try stream.memcpyDtoH(f32, &h_result, d_y);
std.debug.print(" Result y = [ ", .{});
for (&h_result) |v| std.debug.print("{d:.0} ", .{v});
std.debug.print("]\n", .{});
// Verify
std.debug.print(" Expected = [ ", .{});
for (&x_data, &y_data) |x, y| {
const expected = alpha * x + y;
std.debug.print("{d:.0} ", .{expected});
}
std.debug.print("]\n", .{});
// Check correctness
for (&x_data, &y_data, &h_result) |x, y, r| {
const expected = alpha * x + y;
if (@abs(r - expected) > 1e-5) {
std.debug.print(" ✗ FAILED\n", .{});
return error.ValidationFailed;
}
}
std.debug.print(" ✓ Verified\n", .{});
std.debug.print("\n✓ cuBLAS AXPY complete\n", .{});
}