-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgeneric_tiling_solver.py
More file actions
132 lines (101 loc) · 3.69 KB
/
generic_tiling_solver.py
File metadata and controls
132 lines (101 loc) · 3.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
Constraint-optimization tiler using Google OR-Tools (CP-SAT).
This provides a minimal generic layer for algorithm-centric tiling models.
---------------------------------------------------------------------------
Dependencies
---------------------------------------------------------------------------
pip install ortools
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, List, Tuple, Dict, Callable, Any
import math
@dataclass
class MemoryBudget:
L3_bytes: Optional[int] = None
L2_bytes: Optional[int] = None
L1_bytes: Optional[int] = None
@dataclass(frozen=True)
class TilingProblem:
var_domains: Dict[str, List[int]]
build_model: Callable[[Any, Dict[str, Any], int, "MemoryBudget"], Dict[str, Any]]
objectives: List[str]
@dataclass
class TilingResult:
values: Dict[str, int]
objective_values: Dict[str, int]
# ---------------- Global configuration parameters ----------------
DEFAULT_L1_OVERHEAD = 0
DEFAULT_L2_OVERHEAD = 0
DEFAULT_L3_OVERHEAD = 0
DEFAULT_L1_UTIL = 0.92
DEFAULT_L2_UTIL = 0.92
DEFAULT_L3_UTIL = 0.92
DEFAULT_MAX_TIME_SEC = 3.0
DEFAULT_NUM_WORKERS = 8
# ---------------- Utilities ----------------
def divisors(x: int) -> List[int]:
"""All positive divisors of x, sorted ascending."""
if x <= 0:
return [1]
small, large = [], []
r = int(math.isqrt(x))
for d in range(1, r + 1):
if x % d == 0:
small.append(d)
if d * d != x:
large.append(x // d)
return small + large[::-1]
# ===============
# Generic solver
# ===============
def solve_tiling(
problem: TilingProblem,
budgets: MemoryBudget,
dtype_bytes: int = 4,
*,
max_time_sec: float = DEFAULT_MAX_TIME_SEC,
num_workers: int = DEFAULT_NUM_WORKERS,
) -> Optional[TilingResult]:
if dtype_bytes <= 0:
raise ValueError(f"dtype_bytes must be > 0, got {dtype_bytes}")
try:
from ortools.sat.python import cp_model
except Exception as e:
raise ImportError(
"Google OR-Tools (ortools) is required. Install it with: pip install ortools"
) from e
def _build() -> Tuple["cp_model.CpModel", Dict[str, Any]]:
model = cp_model.CpModel()
vars_map: Dict[str, Any] = {}
for name, dom in problem.var_domains.items():
vars_map[name] = model.NewIntVarFromDomain(cp_model.Domain.FromValues(list(dom)), name)
derived = problem.build_model(model, vars_map, int(dtype_bytes), budgets) or {}
all_vars = dict(vars_map)
all_vars.update(derived)
return model, all_vars
def _solve_max(obj_name: str, equalities: List[Tuple[str, int]] = None) -> Tuple[Optional[int], Optional[Dict[str, int]]]:
model, v = _build()
if equalities:
for k, val in equalities:
model.Add(v[k] == val)
model.Maximize(v[obj_name])
solver = cp_model.CpSolver()
solver.parameters.max_time_in_seconds = float(max_time_sec)
solver.parameters.num_search_workers = int(max(1, num_workers))
status = solver.Solve(model)
if status not in (cp_model.OPTIMAL, cp_model.FEASIBLE):
return None, None
sol = {k: int(solver.Value(var)) for k, var in v.items()}
return sol[obj_name], sol
fixed: List[Tuple[str, int]] = []
last_sol: Optional[Dict[str, int]] = None
obj_vals: Dict[str, int] = {}
for obj in problem.objectives:
best, sol = _solve_max(obj, fixed)
if sol is None:
return None
obj_vals[obj] = int(best)
fixed.append((obj, int(best)))
last_sol = sol
return TilingResult(values=last_sol or {}, objective_values=obj_vals)