Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 137 additions & 28 deletions arrow-ord/src/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -949,40 +949,55 @@ pub fn lexsort_to_indices(
));
};

let mut value_indices = (0..row_count).collect::<Vec<usize>>();
let mut len = value_indices.len();
let len = limit.unwrap_or(row_count).min(row_count);

if let Some(limit) = limit {
len = limit.min(len);
if len == 0 {
return Ok(UInt32Array::from(Vec::<u32>::new()));
}

// Instantiate specialized versions of comparisons for small numbers
// of columns as it helps the compiler generate better code.
match columns.len() {
2 => {
sort_fixed_column::<2>(columns, &mut value_indices, len)?;
}
3 => {
sort_fixed_column::<3>(columns, &mut value_indices, len)?;
}
4 => {
sort_fixed_column::<4>(columns, &mut value_indices, len)?;
}
5 => {
sort_fixed_column::<5>(columns, &mut value_indices, len)?;
// The heap path avoids allocating and partially sorting all row indices
// when the requested limit is a small fraction of the input. For larger
// limits, the existing partial-sort path is preferred because heap
// maintenance costs grow with the requested limit.
let value_indices = if limit.is_some() && len <= row_count / 10 {
match columns.len() {
2 => lexsort_topk_fixed::<2>(columns, row_count, len)?,
3 => lexsort_topk_fixed::<3>(columns, row_count, len)?,
4 => lexsort_topk_fixed::<4>(columns, row_count, len)?,
5 => lexsort_topk_fixed::<5>(columns, row_count, len)?,
_ => {
let lexicographical_comparator = LexicographicalComparator::try_new(columns)?;
lexsort_topk(row_count, len, |a, b| {
lexicographical_comparator.compare(a, b)
})
}
}
_ => {
let lexicographical_comparator = LexicographicalComparator::try_new(columns)?;
// uint32 can be sorted unstably
sort_unstable_by(&mut value_indices, len, |a, b| {
lexicographical_comparator.compare(*a, *b)
});
} else {
let mut value_indices = (0..row_count).collect::<Vec<usize>>();

// Instantiate specialized versions of comparisons for small numbers
// of columns as it helps the compiler generate better code.
match columns.len() {
2 => sort_fixed_column::<2>(columns, &mut value_indices, len)?,
3 => sort_fixed_column::<3>(columns, &mut value_indices, len)?,
4 => sort_fixed_column::<4>(columns, &mut value_indices, len)?,
5 => sort_fixed_column::<5>(columns, &mut value_indices, len)?,
_ => {
let lexicographical_comparator = LexicographicalComparator::try_new(columns)?;
sort_unstable_by(&mut value_indices, len, |a, b| {
lexicographical_comparator.compare(*a, *b)
});
}
}
}

value_indices.truncate(len);
value_indices
};

Ok(UInt32Array::from(
value_indices[..len]
.iter()
.map(|i| *i as u32)
value_indices
.into_iter()
.map(|i| i as u32)
.collect::<Vec<_>>(),
))
}
Expand All @@ -1000,6 +1015,100 @@ fn sort_fixed_column<const N: usize>(
Ok(())
}

// Uses the fixed-column comparator for the bounded heap path.
fn lexsort_topk_fixed<const N: usize>(
columns: &[SortColumn],
row_count: usize,
limit: usize,
) -> Result<Vec<usize>, ArrowError> {
let lexicographical_comparator = FixedLexicographicalComparator::<N>::try_new(columns)?;
Ok(lexsort_topk(row_count, limit, |a, b| {
lexicographical_comparator.compare(a, b)
}))
}

// Keeps the smallest `limit` indices in a bounded max-heap.
// The root is the largest retained index according to `compare`.
fn lexsort_topk(
row_count: usize,
limit: usize,
mut compare: impl FnMut(usize, usize) -> Ordering,
) -> Vec<usize> {
if limit >= row_count {
let mut value_indices = (0..row_count).collect::<Vec<_>>();
value_indices.sort_unstable_by(|a, b| compare(*a, *b).then_with(|| a.cmp(b)));
return value_indices;
}

// Break ties by row index so equal sort keys are selected consistently
// during heap updates and final sorting.
let mut compare_total = |a: usize, b: usize| compare(a, b).then_with(|| a.cmp(&b));
let mut heap = Vec::with_capacity(limit);

for idx in 0..row_count {
if heap.len() < limit {
heap.push(idx);
let pos = heap.len() - 1;
sift_up_worst_heap(&mut heap, pos, &mut compare_total);
} else if compare_total(idx, heap[0]) == Ordering::Less {
heap[0] = idx;
sift_down_worst_heap(&mut heap, 0, &mut compare_total);
}
}

// The heap contains the selected indices, but not in sorted order.
heap.sort_unstable_by(|a, b| compare_total(*a, *b));
heap
}

// Moves a newly inserted index toward the root while it is larger than its parent.
fn sift_up_worst_heap(
heap: &mut [usize],
mut pos: usize,
compare: &mut impl FnMut(usize, usize) -> Ordering,
) {
while pos > 0 {
let parent = (pos - 1) / 2;

if compare(heap[parent], heap[pos]) != Ordering::Less {
break;
}

heap.swap(parent, pos);
pos = parent;
}
}

// Moves the root down until both children are no larger than it.
// The larger child is selected at each step so the worst retained row remains
// at heap[0].
fn sift_down_worst_heap(
heap: &mut [usize],
mut pos: usize,
compare: &mut impl FnMut(usize, usize) -> Ordering,
) {
loop {
let left = pos * 2 + 1;
if left >= heap.len() {
break;
}

let right = left + 1;
let mut worst = left;

if right < heap.len() && compare(heap[left], heap[right]) == Ordering::Less {
worst = right;
}

if compare(heap[pos], heap[worst]) != Ordering::Less {
break;
}

heap.swap(pos, worst);
pos = worst;
}
}

/// It's unstable_sort, may not preserve the order of equal elements
pub fn partial_sort<T, F>(v: &mut [T], limit: usize, mut is_less: F)
where
Expand Down
Loading