-
Notifications
You must be signed in to change notification settings - Fork 9
/
expr.ml
267 lines (244 loc) · 7.74 KB
/
expr.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
(* ========================================================================== *)
(* FPTaylor: A Tool for Rigorous Estimation of Round-off Errors *)
(* *)
(* Author: Alexey Solovyev, University of Utah *)
(* *)
(* This file is distributed under the terms of the MIT license *)
(* ========================================================================== *)
(* -------------------------------------------------------------------------- *)
(* Symbolic expressions *)
(* -------------------------------------------------------------------------- *)
(* Operations *)
type u_op_type =
| Op_neg
| Op_abs
| Op_inv
| Op_sqrt
| Op_sin
| Op_cos
| Op_tan
| Op_asin
| Op_acos
| Op_atan
| Op_exp
| Op_log
| Op_sinh
| Op_cosh
| Op_tanh
| Op_asinh
| Op_acosh
| Op_atanh
| Op_floor_power2
type bin_op_type =
| Op_max
| Op_min
| Op_add
| Op_sub
| Op_mul
| Op_div
| Op_nat_pow
| Op_sub2
| Op_abs_err
type gen_op_type =
| Op_fma
| Op_ulp
(* Expression *)
type expr =
| Const of Const.t
| Var of string
| Rounding of Rounding.rnd_info * expr
| U_op of u_op_type * expr
| Bin_op of bin_op_type * expr * expr
| Gen_op of gen_op_type * expr list
type formula =
| Le of expr * expr
| Lt of expr * expr
| Eq of expr * expr
type constraints = {
var_interval : string -> Interval.interval;
var_rat_bounds : string -> Num.num * Num.num;
var_uncertainty : string -> Const.t;
constraints : formula list;
}
let mk_const c = Const c and
mk_var v = Var v and
mk_rounding rnd a = Rounding (rnd, a) and
mk_neg a = U_op (Op_neg, a) and
mk_abs a = U_op (Op_abs, a) and
mk_sqrt a = U_op (Op_sqrt, a) and
mk_inv a = U_op (Op_inv, a) and
mk_sin a = U_op (Op_sin, a) and
mk_cos a = U_op (Op_cos, a) and
mk_tan a = U_op (Op_tan, a) and
mk_asin a = U_op (Op_asin, a) and
mk_acos a = U_op (Op_acos, a) and
mk_atan a = U_op (Op_atan, a) and
mk_exp a = U_op (Op_exp, a) and
mk_log a = U_op (Op_log, a) and
mk_sinh a = U_op (Op_sinh, a) and
mk_cosh a = U_op (Op_cosh, a) and
mk_tanh a = U_op (Op_tanh, a) and
mk_asinh a = U_op (Op_asinh, a) and
mk_acosh a = U_op (Op_acosh, a) and
mk_atanh a = U_op (Op_atanh, a) and
mk_max a b = Bin_op (Op_max, a, b) and
mk_min a b = Bin_op (Op_min, a, b) and
mk_add a b = Bin_op (Op_add, a, b) and
mk_sub a b = Bin_op (Op_sub, a, b) and
mk_mul a b = Bin_op (Op_mul, a, b) and
mk_div a b = Bin_op (Op_div, a, b) and
mk_nat_pow a b = Bin_op (Op_nat_pow, a, b) and
mk_fma a b c = Gen_op (Op_fma, [a; b; c]) and
mk_sub2 a b = Bin_op (Op_sub2, a, b) and
mk_abs_err t x = Bin_op (Op_abs_err, t, x) and
mk_floor_power2 a = U_op (Op_floor_power2, a)
let mk_ulp (prec, e_min) x =
let p = mk_const (Const.of_int prec) in
let e = mk_const (Const.of_int e_min) in
Gen_op (Op_ulp, [p; e; x])
let mk_int_const i = mk_const (Const.of_int i)
let mk_num_const n = mk_const (Const.of_num n)
let mk_float_const f = mk_const (Const.of_float f)
let mk_interval_const v = mk_const (Const.of_interval v)
let mk_floor_sub2 a b = mk_floor_power2 (mk_sub2 a b)
let const_0 = mk_int_const 0 and
const_1 = mk_int_const 1 and
const_2 = mk_int_const 2 and
const_3 = mk_int_const 3 and
const_4 = mk_int_const 4 and
const_5 = mk_int_const 5
let u_op_name = function
| Op_neg -> "neg"
| Op_abs -> "abs"
| Op_inv -> "inv"
| Op_sqrt -> "sqrt"
| Op_sin -> "sin"
| Op_cos -> "cos"
| Op_tan -> "tan"
| Op_asin -> "asin"
| Op_acos -> "acos"
| Op_atan -> "atan"
| Op_exp -> "exp"
| Op_log -> "log"
| Op_sinh -> "sinh"
| Op_cosh -> "cosh"
| Op_tanh -> "tanh"
| Op_asinh -> "asinh"
| Op_acosh -> "acosh"
| Op_atanh -> "atanh"
| Op_floor_power2 -> "floor_power2"
let bin_op_name = function
| Op_max -> "max"
| Op_min -> "min"
| Op_add -> "+"
| Op_sub -> "-"
| Op_mul -> "*"
| Op_div -> "/"
| Op_nat_pow -> "^"
| Op_sub2 -> "sub2"
| Op_abs_err -> "abs_err"
let gen_op_name = function
| Op_fma -> "fma"
| Op_ulp -> "ulp"
let rec eq_expr e1 e2 =
match (e1, e2) with
| _ when e1 == e2 -> true
| (Const c1, Const c2) -> Const.eq_c c1 c2
| (Var v1, Var v2) -> v1 = v2
| (Rounding (r1, a1), Rounding (r2, a2)) when r1 = r2 ->
eq_expr a1 a2
| (U_op (t1, a1), U_op (t2, a2)) when t1 = t2 ->
eq_expr a1 a2
| (Bin_op (t1, a1, b1), Bin_op (t2, a2, b2)) when t1 = t2 ->
eq_expr a1 a2 && eq_expr b1 b2
| (Gen_op (t1, as1), Gen_op (t2, as2)) when t1 = t2 ->
List.for_all2 (fun a1 a2 -> eq_expr a1 a2) as1 as2
| _ -> false
let rec hash_expr = function
| Const (Rat n) -> Hashtbl.hash (Num.string_of_num n)
| Const (Interval v) -> (Hashtbl.hash v.low lxor Hashtbl.hash v.high) + 12434327
| Var v -> Hashtbl.hash v
| Rounding (r, a) -> 541 * hash_expr a + 1012324
| U_op (op, a) -> (Hashtbl.hash (u_op_name op) lxor hash_expr a) + 1013435
| Bin_op (op, a, b) -> (Hashtbl.hash (bin_op_name op) lxor hash_expr a lxor hash_expr b) + 101343561
| Gen_op (op, a) -> Hashtbl.hash (gen_op_name op) lxor (List.fold_left (fun h x -> h lxor hash_expr x) 0 a)
module ExprHashtbl = Hashtbl.Make (
struct
type t = expr
let equal = eq_expr
let hash = hash_expr
end)
let rec vars_in_expr e =
match e with
| Var v -> [v]
| Rounding (_, a1) ->
vars_in_expr a1
| U_op (_, a1) ->
vars_in_expr a1
| Bin_op (_, a1, a2) ->
Lib.union (vars_in_expr a1) (vars_in_expr a2)
| Gen_op (_, args) ->
let vs = List.map vars_in_expr args in
List.fold_left Lib.union [] vs
| _ -> []
let is_ref_var = function
| Var v when Lib.starts_with v "ref~" -> true
| _ -> false
let mk_ref_var i = Var ("ref~" ^ string_of_int i)
let index_of_ref_var = function
| Var v when Lib.starts_with v "ref~" -> int_of_string (Lib.slice v 4)
| _ -> failwith "ref_var_index: not a reference"
(* Finds common subexpressions and returns a list of expressions
with references *)
let expr_ref_list_of_expr ex =
let hc = ExprHashtbl.create 128 in
let hi = ExprHashtbl.create 128 in
let get_count ex = try ExprHashtbl.find hc ex with Not_found -> 0 in
let incr_count ex = ExprHashtbl.replace hc ex (1 + get_count ex) in
let get_index ex = try ExprHashtbl.find hi ex with Not_found -> -1 in
let set_index ex i = ExprHashtbl.add hi ex i in
let rec count ex =
incr_count ex;
match ex with
| Rounding (_, arg) -> count arg
| U_op (_, arg) -> count arg
| Bin_op (_, arg1, arg2) -> count arg1; count arg2
| Gen_op (_, args) -> List.iter count args
| _ -> ()
in
let rec find_common acc ex =
match ex with
| Const _ | Var _ -> acc, ex
| _ ->
let acc, ex' =
match ex with
| Rounding (rnd, arg) ->
let acc, arg = find_common acc arg in
acc, Rounding (rnd, arg)
| U_op (op, arg) ->
let acc, arg = find_common acc arg in
acc, U_op (op, arg)
| Bin_op (op, arg1, arg2) ->
let acc, arg1 = find_common acc arg1 in
let acc, arg2 = find_common acc arg2 in
acc, Bin_op (op, arg1, arg2)
| Gen_op (op, args) ->
let acc, args = List.fold_left (fun (acc, args) arg ->
let acc, arg = find_common acc arg in acc, arg :: args)
(acc, []) args in
acc, Gen_op (op, List.rev args)
| Const _ | Var _ -> failwith "Impossible" in
let i = get_index ex in
if i >= 0 then
acc, mk_ref_var i
else if get_count ex < 2 then
acc, ex'
else begin
let i = List.length acc in
set_index ex i;
ex' :: acc, mk_ref_var i
end
in
count ex;
let acc, ex = find_common [] ex in
List.rev (ex :: acc)