1 | type cond = |
---|
2 | | Ceq of int (** index is equal to *) |
---|
3 | | Cgeq of int (** index is greater or equal to *) |
---|
4 | | Cmod of int*int (** index modulo equal to *) |
---|
5 | | Cgeqmod of int*int*int (** index greater than and modulo equal to *) |
---|
6 | |
---|
7 | module CondSet = Set.Make(struct |
---|
8 | type t = cond |
---|
9 | let compare = compare |
---|
10 | end) |
---|
11 | |
---|
12 | open CostLabel |
---|
13 | |
---|
14 | |
---|
15 | let cond_of_sexpr = function |
---|
16 | | Sexpr(0, b) -> Ceq(b) |
---|
17 | | Sexpr(1, b) -> Cgeq(b) |
---|
18 | | Sexpr(a, b) when b < a -> Cmod(a, b) |
---|
19 | | Sexpr(a, b) -> Cgeqmod(b, a, b mod a) |
---|
20 | |
---|
21 | type cost_expr = |
---|
22 | | Exact of int |
---|
23 | | Ternary of index * CondSet.t * cost_expr * cost_expr |
---|
24 | |
---|
25 | (* compute from the set [s] a the 3-uple [Some (h, s_h, s_rest)] where *) |
---|
26 | (* [h] is the head of first elements of [s], and [s] is the union of *) |
---|
27 | (* [{h :: tl | tl in s_h}] and [s_rest]. Gives [None] if either [s] is empty *) |
---|
28 | (* or it starts with an empty list. ([s] should contain lists with same length *) |
---|
29 | let heads_tails_of s = |
---|
30 | if IndexingSet.is_empty s then None else |
---|
31 | match IndexingSet.min_elt s with |
---|
32 | | [] -> None |
---|
33 | | head :: _ -> |
---|
34 | let filter x = (List.hd x = head) in |
---|
35 | let (s_head, s_rest) = IndexingSet.partition filter s in |
---|
36 | let add_tail l = IndexingSet.add (List.tl l) in |
---|
37 | let s_head = IndexingSet.fold add_tail s_head IndexingSet.empty in |
---|
38 | Some (head, s_head, s_rest) |
---|
39 | |
---|
40 | let rec cost_mapping_ind atom ind (m : int Map.t) (s : IndexingSet.t) = |
---|
41 | let lbl = {name = atom; i = ind} in |
---|
42 | match heads_tails_of s with |
---|
43 | | None when Map.mem lbl m-> Exact (Map.find lbl m) |
---|
44 | | None -> Exact 0 |
---|
45 | | Some(h, s_head, s_rest) -> |
---|
46 | let i = List.length ind in |
---|
47 | let condition = cond_of_sexpr h in |
---|
48 | let if_true = cost_mapping_ind atom (h :: ind) m s_head in |
---|
49 | let if_false = cost_mapping_ind atom ind m s_rest in |
---|
50 | Ternary(i, CondSet.singleton condition, if_true, if_false) |
---|
51 | |
---|
52 | let indexing_sets_from_cost_mapping m = |
---|
53 | let f k _ sets = |
---|
54 | let s = |
---|
55 | try |
---|
56 | IndexingSet.add k.i (Atom.Map.find k.name sets) |
---|
57 | with |
---|
58 | | Not_found -> IndexingSet.singleton k.i in |
---|
59 | Atom.Map.add k.name s sets in |
---|
60 | Map.fold f m Atom.Map.empty |
---|
61 | |
---|
62 | (* extended_gcd a b = (x, y, gcd(a, b)) where x*a + y*b = gcd(a, b) *) |
---|
63 | let rec extended_gcd a = function |
---|
64 | | 0 -> (1, 0, a) |
---|
65 | | b -> |
---|
66 | let (x, y, r) = extended_gcd b (a mod b) in |
---|
67 | (y, x - (a / b) * y, r) |
---|
68 | |
---|
69 | let opt_bind (t : 'a option) (f : 'a -> 'b option) : 'b option = |
---|
70 | match t with |
---|
71 | | None -> None |
---|
72 | | Some x -> f x |
---|
73 | |
---|
74 | (* in the following a set of conditions is considered as sequents, i.e. *) |
---|
75 | (* {c1, ..., ck} is considered as c1 || ... || ck. The empty set is for false,*) |
---|
76 | (* while true is given by Cgeq 0. We will call sets general conditions *) |
---|
77 | |
---|
78 | let print_cond = function |
---|
79 | | Ceq i -> Printf.sprintf "== %d" i |
---|
80 | | Cgeq i -> Printf.sprintf ">= %d" i |
---|
81 | | Cmod (a, b) -> Printf.sprintf "%% %d == %d" a b |
---|
82 | | Cgeqmod (i, a, b) -> Printf.sprintf ">= %d & %% %d == %d" i a b |
---|
83 | |
---|
84 | let print_gen_cond gc = |
---|
85 | let f c = Printf.sprintf "%s || %s" (print_cond c) in |
---|
86 | CondSet.fold f gc "F" |
---|
87 | |
---|
88 | (* cond_and_single c1 c2 gives c1 && c2 as a generalized condition. Recursion*) |
---|
89 | (* is only to re-use match cases *) |
---|
90 | let cond_and_single c1 c2 = |
---|
91 | let rec cond_and_single' c1 c2 = match c1, c2 with |
---|
92 | | Ceq h as c, Ceq k when h = k -> Some c |
---|
93 | | (Ceq h as c), Cgeq k | Cgeq k, (Ceq h as c) when h >= k -> |
---|
94 | Some c |
---|
95 | | (Ceq h as c), Cmod(a, b) | Cmod (a, b), (Ceq h as c) when h mod a = b -> |
---|
96 | Some c |
---|
97 | | (Ceq h as c), Cgeqmod(k, a, b) | Cgeqmod (k, a, b), (Ceq h as c) |
---|
98 | when h mod a = b && h >= k -> Some c |
---|
99 | | Ceq _, _ | _, Ceq _ -> None |
---|
100 | | Cgeq h, Cgeq k -> Some (Cgeq (max h k)) |
---|
101 | | Cgeq h, Cmod (a,b) | Cmod (a,b), Cgeq h -> |
---|
102 | if h <= b then Some (Cmod (a, b)) else |
---|
103 | Some (Cgeqmod(h - h mod a + a + b, a, b)) |
---|
104 | | Cgeq h, Cgeqmod (k, a, b) | Cgeqmod(k, a, b), Cgeq h -> |
---|
105 | cond_and_single' (Cgeq (max h k)) (Cmod(a, b)) |
---|
106 | (* special case of Chinese remainder theorem *) |
---|
107 | | Cmod (a, b), Cmod(c, d) -> |
---|
108 | let (x, y, gcd) = extended_gcd a c in |
---|
109 | if b mod gcd <> d mod gcd then None else |
---|
110 | let a_gcd = a / gcd in |
---|
111 | let lcm = a_gcd * c in |
---|
112 | let res = (b + a_gcd * x * (d - b)) mod lcm in |
---|
113 | Some (Cmod(lcm, res)) |
---|
114 | | Cmod (a, b), Cgeqmod(k, c, d) | Cgeqmod(k, c, d), Cmod(a, b) -> |
---|
115 | opt_bind (cond_and_single' (Cmod(a, b)) (Cmod(c,d))) |
---|
116 | (fun x -> cond_and_single' (Cgeq k) x) |
---|
117 | | Cgeqmod (h, a, b), Cgeqmod(k, c, d) -> |
---|
118 | opt_bind (cond_and_single' (Cmod(a, b)) (Cmod(c,d))) |
---|
119 | (fun x -> cond_and_single' (Cgeq (max h k)) x) in |
---|
120 | match cond_and_single' c1 c2 with |
---|
121 | | None -> CondSet.empty |
---|
122 | | Some x -> CondSet.singleton x |
---|
123 | |
---|
124 | (* this generalizes to general conditions for first argument, based on*) |
---|
125 | (* (c1 || ... || ck) && c = (c1 && c || ... || ck && c) *) |
---|
126 | let cond_and s1 c2 = |
---|
127 | let add_and c1 = CondSet.union (cond_and_single c1 c2) in |
---|
128 | CondSet.fold add_and s1 CondSet.empty |
---|
129 | |
---|
130 | (* this creates the set { f 0, ..., f (n-1) } *) |
---|
131 | let rec init_set f n = |
---|
132 | if n <= 0 then CondSet.empty else |
---|
133 | CondSet.add (f (n-1)) (init_set f (n-1)) |
---|
134 | |
---|
135 | (* cond_and_not_single c1 c2 is equivalent to c1 && !c2 as a generalized *) |
---|
136 | (* condition *) |
---|
137 | let rec cond_and_not_single c1 c2 = |
---|
138 | match c1, c2 with |
---|
139 | | Ceq h, Ceq k when h = k -> CondSet.empty |
---|
140 | | Ceq h, Cgeq k when h >= k -> CondSet.empty |
---|
141 | | Ceq h, Cmod (a, b) when h mod a = b -> CondSet.empty |
---|
142 | | Ceq h, Cgeqmod (k, a, b) when h >= k && h mod a = b -> CondSet.empty |
---|
143 | | Ceq _, _ -> CondSet.singleton c1 |
---|
144 | | Cgeq h, Ceq k when k < h -> CondSet.singleton c1 |
---|
145 | | Cgeq h, Ceq k -> |
---|
146 | (* Ceq h, Ceq (h+1), ... , Ceq (k-1), Cgeq (k+1) *) |
---|
147 | let s' = init_set (fun x -> Ceq(h + x)) (k - h) in |
---|
148 | CondSet.add (Cgeq(k+1)) s' |
---|
149 | | Cgeq h, Cgeq k -> |
---|
150 | (* if k < h init_set will correctly give an empty set, otherwise *) |
---|
151 | (* {Ceq h, ... , Ceq (k - 1)} *) |
---|
152 | init_set (fun x -> Ceq(h + x)) (k - h) |
---|
153 | | Cmod (a, b), Ceq k when k mod a <> b -> CondSet.singleton c1 |
---|
154 | | Cmod (a, b), Ceq k -> |
---|
155 | let s' = init_set (fun x -> Ceq(a * x + b)) (k / a) in |
---|
156 | CondSet.add (Cgeqmod(k + a, a, b)) s' |
---|
157 | | Cmod (a, b), Cgeq k -> |
---|
158 | init_set (fun x -> Ceq(a * x + b)) (k / a) |
---|
159 | | Cgeqmod(h, a, b), Ceq k when k < h || k mod a <> b -> |
---|
160 | CondSet.singleton c1 |
---|
161 | | Cgeqmod(h, a, b), Ceq k -> |
---|
162 | let h' = h - h mod a + b in |
---|
163 | let s' = init_set (fun x -> Ceq(a * x + h')) (k / a - h / a) in |
---|
164 | CondSet.add (Cgeqmod(k + a, a, b)) s' |
---|
165 | | Cgeqmod(h, a, b), Cgeq k -> |
---|
166 | let h' = h - h mod a + b in |
---|
167 | init_set (fun x -> Ceq(a * x + h')) (k / a - h / a) |
---|
168 | (* when we do not use cleverer ways, we just use cond_and_single *) |
---|
169 | | c1, (Cmod (a, b) as c2) -> |
---|
170 | let s' = CondSet.remove c2 (init_set (fun x -> Cmod(a, x)) a) in |
---|
171 | let f x = CondSet.union (cond_and_single c1 x) in |
---|
172 | CondSet.fold f s' CondSet.empty |
---|
173 | | c1, Cgeqmod (k, a, b) -> |
---|
174 | let c2 = Cmod (a, b) in |
---|
175 | let s' = CondSet.remove c2 (init_set (fun x -> Cmod(a, x)) a) in |
---|
176 | let f x = CondSet.union (cond_and_single c1 x) in |
---|
177 | let s'' = cond_and_not_single c1 (Cgeq k) in |
---|
178 | CondSet.fold f s' s'' |
---|
179 | |
---|
180 | (* generalization *) |
---|
181 | let cond_and_not s1 c2 = |
---|
182 | let add_and_not x = CondSet.union (cond_and_not_single x c2) in |
---|
183 | CondSet.fold add_and_not s1 CondSet.empty |
---|
184 | |
---|
185 | (* cond_implied_by_single c2 c1 gives true iff c1 => c2, i.e. if the set *) |
---|
186 | (* of indexes denoted by c1 is contained in the one denoted by c2. *) |
---|
187 | let cond_implied_by_single c2 c1 = match c1, c2 with |
---|
188 | | c1, c2 when c1 = c2 -> true (* shortcut *) |
---|
189 | | Ceq h, Ceq k -> h = k |
---|
190 | | Ceq h, Cgeq k |
---|
191 | | Cgeq h, Cgeq k -> h >= k |
---|
192 | | Ceq h, Cmod (a, b) -> h mod a = b |
---|
193 | | Ceq h, Cgeqmod (k, a, b) -> h >= k && h mod a = b |
---|
194 | | Cmod (a, b), Cmod(c, d) |
---|
195 | | Cgeqmod(_, a, b), Cmod(c, d) -> a mod c = 0 && b mod c = d mod c |
---|
196 | | Cmod (a, b), Cgeqmod(k, c, d) -> |
---|
197 | let k' = k - k mod c + d in |
---|
198 | b >= k' && a mod c = 0 && b mod c = d mod c |
---|
199 | | Cgeqmod (h, a, b), Cgeqmod(k, c, d) -> |
---|
200 | let h' = h - h mod a + b in |
---|
201 | let k' = k - k mod c + d in |
---|
202 | h' >= k' && a mod c = 0 && b mod c = d mod c |
---|
203 | | _ -> false |
---|
204 | |
---|
205 | (* cond_implies s1 c2 iff s1 => c2. Based on fact that (c1 || ... || ck) => c *) |
---|
206 | (* iff c1 => c && ... && ck => c *) |
---|
207 | let cond_implies s1 c2 = CondSet.for_all (cond_implied_by_single c2) s1 |
---|
208 | |
---|
209 | (* cond_neg_implied_by_single c2 c1 iff c1 => !c2 iff !(c1 && c2), which is *) |
---|
210 | (* symmetric. *) |
---|
211 | let cond_neg_implied_by_single c2 c1 = match c1, c2 with |
---|
212 | | Ceq h, Ceq k -> h <> k |
---|
213 | | Ceq h, Cgeq k | Cgeq k, Ceq h -> h < k |
---|
214 | | Ceq h, Cmod (a, b) | Cmod (a, b), Ceq h -> h mod a <> b |
---|
215 | | Ceq h, Cgeqmod (k, a, b) | Cgeqmod(k, a, b), Ceq h -> h < k || h mod a <> b |
---|
216 | | Cmod (a, b), Cmod (c, d) |
---|
217 | | Cmod (a, b), Cgeqmod(_, c, d) |
---|
218 | | Cgeqmod(_, a, b), Cmod (c, d) |
---|
219 | | Cgeqmod(_, a, b), Cgeqmod(_, c, d) -> |
---|
220 | let (_, _, gcd) = extended_gcd a c in |
---|
221 | b mod gcd <> d mod gcd |
---|
222 | | _ -> false |
---|
223 | |
---|
224 | (* cond_implies_neg s1 c2 iff s1 => !c2 *) |
---|
225 | let cond_implies_neg s1 c2 = CondSet.for_all (cond_neg_implied_by_single c2) s1 |
---|
226 | |
---|
227 | (* cond_simpl s turns Cgeqmod conditions into Cmod ones, if s implies the geq *) |
---|
228 | (* part. *) |
---|
229 | let cond_simpl s = function |
---|
230 | | Cgeqmod(k, a, b) when cond_implies s (Cgeq k) -> Cmod(a, b) |
---|
231 | | c -> c |
---|
232 | |
---|
233 | (** Simplify the cost expression, removing useless conditions in it *) |
---|
234 | let remove_useless_branches = |
---|
235 | (* conds represents the info known while descending the branches *) |
---|
236 | let rec simplify' conds = function |
---|
237 | | Exact k -> Exact k |
---|
238 | | Ternary (i, gen_cond, if_true, if_false) -> |
---|
239 | assert (CondSet.cardinal gen_cond = 1); |
---|
240 | let cond = CondSet.choose gen_cond in |
---|
241 | (* if it is the first time a condition on i is encountered, we ensure *) |
---|
242 | (* that conds holds a default value that will be the "true" of these *) |
---|
243 | (* conditions (i.e. { Cgeq 0 }) *) |
---|
244 | ExtArray.ensure conds i; |
---|
245 | let conds_i = ExtArray.get conds i in |
---|
246 | (* if conds => cond, then we can erase the if_false branch *) |
---|
247 | if cond_implies conds_i cond then (simplify' conds if_true) else |
---|
248 | (* if conds => !cond, then we can erase if_true *) |
---|
249 | if cond_implies_neg conds_i cond then (simplify' conds if_false) else |
---|
250 | begin |
---|
251 | let cond' = cond_simpl conds_i cond in |
---|
252 | (* simplify the if_true branch knowing that cond *) |
---|
253 | ExtArray.set conds i (cond_and conds_i cond); |
---|
254 | let if_true' = simplify' conds if_true in |
---|
255 | (* simplify the if_false branch knowing that !cond *) |
---|
256 | ExtArray.set conds i (cond_and_not conds_i cond); |
---|
257 | let if_false' = simplify' conds if_false in |
---|
258 | Ternary (i, CondSet.singleton cond', if_true', if_false') |
---|
259 | end in |
---|
260 | let conds = ExtArray.make ~buff:4 0 (CondSet.singleton (Cgeq 0)) in |
---|
261 | simplify' conds |
---|
262 | |
---|
263 | let gen_or = CondSet.union |
---|
264 | |
---|
265 | let gen_and s1 s2 = |
---|
266 | let f c s = CondSet.union s (cond_and s1 c) in |
---|
267 | CondSet.fold f s2 CondSet.empty |
---|
268 | |
---|
269 | let gen_and_not s1 s2 = |
---|
270 | let f c s = cond_and_not s c in |
---|
271 | CondSet.fold f s2 s1 |
---|
272 | |
---|
273 | let gen_not s = gen_and_not (CondSet.singleton (Cgeq 0)) s |
---|
274 | |
---|
275 | let rec remove_useless_branchings = function |
---|
276 | | Exact k -> Exact k |
---|
277 | | Ternary (i, c1, left, right) -> |
---|
278 | let left = remove_useless_branchings left in |
---|
279 | let right = remove_useless_branchings right in |
---|
280 | match left, right with |
---|
281 | | _, _ when left = right -> left |
---|
282 | | Ternary(j, c2, lleft, lright), _ |
---|
283 | when i = j && lleft = right -> |
---|
284 | let c = gen_or (gen_not c1) c2 in |
---|
285 | Ternary(i, c, lleft, lright) |
---|
286 | | Ternary(j, c2, lleft, lright), _ |
---|
287 | when i = j && lright = right -> |
---|
288 | let c = gen_and c1 c2 in |
---|
289 | Ternary(i, c, lleft, lright) |
---|
290 | | _, Ternary(j, c2, rleft, rright) |
---|
291 | when i = j && left = rleft -> |
---|
292 | let c = gen_or c1 c2 in |
---|
293 | Ternary(i, c, rleft, rright) |
---|
294 | | _, Ternary(j, c2, rleft, rright) |
---|
295 | when i = j && left = rright -> |
---|
296 | let c = gen_and_not c2 c1 in |
---|
297 | Ternary(i, c, rleft, rright) |
---|
298 | | _ -> Ternary (i, c1, left, right) |
---|
299 | |
---|
300 | let rec chose_smaller_cond = function |
---|
301 | | Exact k -> Exact k |
---|
302 | | Ternary (i, c, left, right) -> |
---|
303 | let left = chose_smaller_cond left in |
---|
304 | let right = chose_smaller_cond right in |
---|
305 | let c_not = gen_not c in |
---|
306 | if CondSet.cardinal c > CondSet.cardinal c_not then |
---|
307 | Ternary (i, c_not, right, left) |
---|
308 | else |
---|
309 | Ternary (i, c, left, right) |
---|
310 | |
---|
311 | let cost_expr_mapping_of_cost_mapping m = |
---|
312 | let sets = indexing_sets_from_cost_mapping m in |
---|
313 | let f at s = |
---|
314 | let e = cost_mapping_ind at [] m s in |
---|
315 | let e = remove_useless_branches e in |
---|
316 | let e = remove_useless_branchings e in |
---|
317 | let e = chose_smaller_cond e in |
---|
318 | Atom.Map.add at e in |
---|
319 | Atom.Map.fold f sets Atom.Map.empty |
---|
320 | |
---|