Skip to content

Commit fc799fc

Browse files
committed
src: gpu: intel: include: make philox overloadable
1 parent d115200 commit fc799fc

7 files changed

Lines changed: 31 additions & 23 deletions

File tree

src/gpu/intel/eltwise/ref.cl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ __kernel void ref_eltwise_fwd(__global SRC_DATA_T *src,
9999
uint dropout_threshold = get_dropout_threshold(dropout_p);
100100
float dropout_inv_q = (dropout_p != 1.f) ? 1.f / (1.f - dropout_p) : 0.f;
101101
#if USE_OFFSET
102-
uint res = philox_4x32_u64_w_offset(
102+
uint res = philox_4x32_w_offset(
103103
(ulong)data_off, (ulong)dropout_seed, (ulong)dropout_offset);
104104
#else
105-
uint res = philox_4x32_u64((ulong)data_off, (ulong)dropout_seed);
105+
uint res = philox_4x32((ulong)data_off, (ulong)dropout_seed);
106106
#endif
107107
uchar dropout = res > dropout_threshold;
108108
tmp_s = (dropout) ? tmp_s * dropout_inv_q : 0;

src/gpu/intel/gemm/with_post_ops.cl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,10 @@ __kernel void gemm_post_ops(__global SRC_DATA_T *src,
123123
uint dropout_threshold = get_dropout_threshold(dropout_p);
124124
float dropout_inv_q = (dropout_p != 1.f) ? 1.f / (1.f - dropout_p) : 0.f;
125125
#if DROPOUT_USE_OFFSET
126-
uint res = philox_4x32_u64_w_offset(
126+
uint res = philox_4x32_w_offset(
127127
(ulong)data_idx, (ulong)dropout_seed, (ulong)dropout_offset);
128128
#else
129-
uint res = philox_4x32_u64((ulong)data_idx, (ulong)dropout_seed);
129+
uint res = philox_4x32((ulong)data_idx, (ulong)dropout_seed);
130130
#endif
131131
uchar dropout = res > dropout_threshold;
132132
accumulator = (dropout) ? accumulator * dropout_inv_q : 0;

src/gpu/intel/include/philox.h

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#define DT_UNDEF 1
2121
#include "gpu/intel/include/types.h"
2222

23-
uint4 philox_4x32_s64_vec4(ulong idx, ulong seed, ulong offset) {
23+
uint4 philox_4x32_vec4_w_offset(ulong idx, ulong seed, ulong offset) {
2424
#define PHILOX_4UINT_ROUND(mul, ctr, key) \
2525
as_uint4(convert_ulong2(ctr.s02) * mul).s3210 \
2626
^ (uint4)(ctr.s1 ^ key.s0, 0, ctr.s3 ^ key.s1, 0)
@@ -53,40 +53,45 @@ uint4 philox_4x32_s64_vec4(ulong idx, ulong seed, ulong offset) {
5353
return ctr;
5454
}
5555

56-
uint philox_4x32_u64_w_offset(ulong idx, ulong seed, ulong offset) {
57-
return philox_4x32_s64_vec4(idx, seed, offset)[idx & 3L];
56+
uint philox_4x32_w_offset(ulong idx, ulong seed, ulong offset) {
57+
return philox_4x32_vec4_w_offset(idx, seed, offset)[idx & 3L];
5858
}
5959

60-
uint philox_4x32_u64(ulong idx, ulong seed) {
60+
uint __attribute__((overloadable)) philox_4x32(ulong idx, ulong seed) {
6161
// Note: this is for compatibility with impls that don't support s64 rand
6262
ulong x = idx & ~3L;
6363
ulong idx_64 = ((x + 3) << 32) + (x + 2);
6464
ulong offset_64 = ((x + 1) << 32) + x;
6565
ulong seed_64 = (seed << 32) + seed;
66-
return philox_4x32_s64(idx_64, seed_64, offset_64);
66+
return philox_4x32_w_offset(idx_64, seed_64, offset_64);
6767
}
6868

69-
uint philox_4x32(uint idx, uint seed) {
70-
// Note: preserve old signature for compatibility
71-
return philox_4x32_u64((ulong)idx, (ulong)seed);
69+
uint __attribute__((overloadable)) philox_4x32(long idx, long seed) {
70+
// Convert long to ulong and call the existing function
71+
return philox_4x32((ulong)idx, (ulong)seed);
72+
}
73+
74+
uint __attribute__((overloadable)) philox_4x32(int idx, int seed) {
75+
// Convert int to ulong and call the existing overloadable function
76+
return philox_4x32((ulong)idx, (ulong)seed);
7277
}
7378

7479
uint4 philox_4x32_vec4(uint idx, uint seed) {
7580
ulong x = idx & ~3L;
7681
ulong idx_64 = ((x + 3) << 32) + (x + 2);
7782
ulong offset_64 = ((x + 1) << 32) + x;
7883
ulong seed_64 = ((ulong)(seed) << 32) + seed;
79-
return philox_4x32_s64_vec4(idx_64, seed_64, offset_64);
84+
return philox_4x32_vec4_w_offset(idx_64, seed_64, offset_64);
8085
}
8186

8287
ushort philox_8x16(long idx, uint seed) {
8388
ulong idx_ = (ulong)idx;
84-
return as_ushort2(philox_4x32_u64(idx_ >> 1, (ulong)seed))[idx_ & 1];
89+
return as_ushort2(philox_4x32(idx_ >> 1, (ulong)seed))[idx_ & 1];
8590
}
8691

8792
uchar philox_16x8(long idx, uint seed) {
8893
ulong idx_ = (ulong)idx;
89-
return as_uchar4(philox_4x32_u64(idx_ >> 2, (ulong)seed))[idx_ & 3];
94+
return as_uchar4(philox_4x32(idx_ >> 2, (ulong)seed))[idx_ & 3];
9095
}
9196

9297
#if WITH_SROUND
@@ -120,4 +125,7 @@ uint get_dropout_threshold(float p) {
120125
+ !!(mantissa & ((1u << exponent) - 1u));
121126
}
122127
#endif
128+
129+
130+
123131
#endif

src/gpu/intel/matmul/ref.cl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,10 +289,10 @@ __kernel void ref_matmul(__global SRC_DATA_T *A, __global WEI_DATA_T *B,
289289

290290
#if WITH_DROPOUT
291291
#if WITH_SEED_S64 && USE_OFFSET
292-
uint res = philox_4x32_u64_w_offset(
292+
uint res = philox_4x32_w_offset(
293293
(ulong)dst_off, (ulong)dropout_seed, (ulong)dropout_offset);
294294
#else
295-
uint res = philox_4x32_u64((ulong)dst_off, (ulong)dropout_seed);
295+
uint res = philox_4x32((ulong)dst_off, (ulong)dropout_seed);
296296
#endif
297297
uchar dropout = res > dropout_threshold;
298298
po_acc = (dropout) ? po_acc * dropout_inv_q : 0;

src/gpu/intel/sdpa/micro.cl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ inline void apply_dropout_s_tile(
5656
ulong _goff = batch_head_base + (ulong)offset_c * (ulong)k_stride \
5757
+ (ulong)offset_r; \
5858
uint _philox = use_dropout_offset \
59-
? philox_4x32_u64_w_offset( \
59+
? philox_4x32_w_offset( \
6060
(ulong)_goff, (ulong)seed, (ulong)offset) \
61-
: philox_4x32_u64((ulong)_goff, (ulong)seed); \
61+
: philox_4x32((ulong)_goff, (ulong)seed); \
6262
(offset_r < max_r && offset_c < max_c) && (_philox > threshold); \
6363
})
6464

src/gpu/intel/sdpa/micro_bwd.cl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ typedef ugemm_ktq_c_type ktq_tile_type; // D*Br tile
5252
ulong _goff = batch_head_base + (ulong)offset_c * (ulong)k_stride \
5353
+ (ulong)offset_r; \
5454
uint _philox = use_dropout_offset \
55-
? philox_4x32_u64_w_offset( \
55+
? philox_4x32_w_offset( \
5656
(ulong)_goff, (ulong)seed, (ulong)offset) \
57-
: philox_4x32_u64((ulong)_goff, (ulong)seed); \
57+
: philox_4x32((ulong)_goff, (ulong)seed); \
5858
(offset_r < max_r && offset_c < max_c) && (_philox > threshold); \
5959
})
6060

src/gpu/intel/softmax/simple.cl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,10 @@ simple_softmax_fwd_generic(__global SRC_DATA_T *src, __global DATA_T *dst,
190190
float dropout_inv_q
191191
= (dropout_p != 1.f) ? 1.f / (1.f - dropout_p) : 0.f;
192192
#if USE_OFFSET
193-
uint res = philox_4x32_u64_w_offset(
193+
uint res = philox_4x32_w_offset(
194194
(ulong)data_off, (ulong)dropout_seed, (ulong)dropout_offset);
195195
#else
196-
uint res = philox_4x32_u64((ulong)data_off, (ulong)dropout_seed);
196+
uint res = philox_4x32((ulong)data_off, (ulong)dropout_seed);
197197
#endif
198198
uchar dropout = res > dropout_threshold;
199199
tmp = (dropout) ? tmp * dropout_inv_q : 0;

0 commit comments

Comments
 (0)