Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions src/smtml/expr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ let rec ty (hte : t) : Ty.t =
| Cvtop (ty, _, _)
| Naryop (ty, _, _) ->
ty
| Extract (_, h, l) -> Ty_bitv ((h - l) * 8)
| Extract (_, h, l) -> Ty_bitv (h - l + 1)
| Concat (e1, e2) -> (
match (ty e1, ty e2) with
| Ty_bitv n1, Ty_bitv n2 -> Ty_bitv (n1 + n2)
Expand Down Expand Up @@ -582,21 +582,19 @@ let[@inline] raw_extract (hte : t) ~(high : int) ~(low : int) : t =

let extract (hte : t) ~(high : int) ~(low : int) : t =
match (view hte, high, low) with
| Val (Bitv bv), high, low ->
let high = (high * 8) - 1 in
let low = low * 8 in
value (Bitv (Bitvector.extract bv ~high ~low))
| Val (Bitv bv), high, low -> value (Bitv (Bitvector.extract bv ~high ~low))
| ( Cvtop
( _
, (Zero_extend 24 | Sign_extend 24)
, ({ node = Symbol { ty = Ty_bitv 8; _ }; _ } as sym) )
, 1
, 7
, 0 ) ->
sym
| Concat (_, e), h, l when Ty.size (ty e) = h - l -> e
| Concat (e, _), 8, 4 when Ty.size (ty e) = 4 -> e
| Concat (_, e), h, l when l = 0 && Ty.bitsize (ty e) = h - l + 1 -> e
| Concat (e, _), 63, 32 when Ty.bitsize (ty e) = 32 -> e
| _ ->
if high - low = Ty.size (ty hte) then hte else raw_extract hte ~high ~low
if high - low + 1 = Ty.bitsize (ty hte) then hte
else raw_extract hte ~high ~low

let raw_concat (msb : t) (lsb : t) : t = make (Concat (msb, lsb)) [@@inline]

Expand All @@ -606,8 +604,8 @@ let rec concat (msb : t) (lsb : t) : t =
| Val (Bitv a), Val (Bitv b) -> value (Bitv (Bitvector.concat a b))
| Val (Bitv _), Concat (({ node = Val (Bitv _); _ } as b), se) ->
raw_concat (concat msb b) se
| Extract (s1, h, m1), Extract (s2, m2, l) when equal s1 s2 && m1 = m2 ->
if h - l = Ty.size (ty s1) then s1 else raw_extract s1 ~high:h ~low:l
| Extract (s1, h, m1), Extract (s2, m2, l) when equal s1 s2 && m1 = m2 + 1 ->
if h - l + 1 = Ty.bitsize (ty s1) then s1 else raw_extract s1 ~high:h ~low:l
| Extract (_, _, _), Concat (({ node = Extract (_, _, _); _ } as e2), e3) ->
raw_concat (concat msb e2) e3
| _ -> raw_concat msb lsb
Expand Down
4 changes: 2 additions & 2 deletions src/smtml/mappings.ml
Original file line number Diff line number Diff line change
Expand Up @@ -748,9 +748,9 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct
(* This is needed so arguments don't end up out of order in the operator *)
let es = List.rev es in
(ctx, naryop ty op es)
| Extract (e, h, l) ->
| Extract (e, high, low) ->
let ctx, e = encode_expr ctx e in
(ctx, M.Bitv.extract e ~high:((h * 8) - 1) ~low:(l * 8))
(ctx, M.Bitv.extract e ~high ~low)
| Concat (e1, e2) ->
let ctx, e1 = encode_expr ctx e1 in
let ctx, e2 = encode_expr ctx e2 in
Expand Down
8 changes: 2 additions & 6 deletions src/smtml/smtlib.ml
Original file line number Diff line number Diff line change
Expand Up @@ -288,14 +288,10 @@ module Term = struct
begin match (basename, indices, args) with
| "extract", [ h; l ], [ a ] ->
let high =
match int_of_string_opt h with
| None -> assert false
| Some h -> (h + 1) / 8
match int_of_string_opt h with None -> assert false | Some h -> h
in
let low =
match int_of_string_opt l with
| None -> assert false
| Some l -> l / 8
match int_of_string_opt l with None -> assert false | Some l -> l
in
Expr.raw_extract a ~high ~low
| "zero_extend", [ bits ], [ a ] ->
Expand Down
9 changes: 6 additions & 3 deletions src/smtml/ty.ml
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,17 @@ let of_string = function
end
else Fmt.error_msg "can not parse type %s" s

let size (ty : t) : int =
let bitsize (ty : t) : int =
match ty with
| Ty_bitv n | Ty_fp n -> n / 8
| Ty_int | Ty_bool -> 4
| Ty_bool -> 1
| Ty_int -> 32
| Ty_bitv n | Ty_fp n -> n
| Ty_real | Ty_str | Ty_list | Ty_app | Ty_unit | Ty_none | Ty_regexp
| Ty_roundingMode ->
assert false

let size ty = bitsize ty / 8

module Unop = struct
type t =
| Neg
Expand Down
5 changes: 4 additions & 1 deletion src/smtml/ty.mli
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,12 @@ val of_string : string -> (t, [> `Msg of string ]) Result.t

(** {1 Type Size} *)

(** [size t] returns the size (in bits) of the type [t], if applicable. *)
(** [size t] returns the size (in bytes) of the type [t], if applicable. *)
val size : t -> int

(** [bitsize t] returns the size (in bits) of the type [t], if applicable. *)
val bitsize : t -> int

(** {1 Unary Operations} *)

module Unop : sig
Expand Down
40 changes: 20 additions & 20 deletions src/smtml/typed.ml
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,10 @@ module Bitv32 = struct
let[@inline] of_int16_u x = Expr.cvtop ty (Zero_extend 16) x

let[@inline] to_bytes x =
[ extract x ~high:1 ~low:0
; extract x ~high:2 ~low:1
; extract x ~high:3 ~low:2
; extract x ~high:4 ~low:3
[ extract x ~high:7 ~low:0
; extract x ~high:15 ~low:8
; extract x ~high:23 ~low:16
; extract x ~high:31 ~low:24
]

let[@inline] trunc_f32_s_exn x = Expr.cvtop ty TruncSF32 x
Expand Down Expand Up @@ -317,7 +317,7 @@ module Bitv32 = struct
let[@inline] wrap_i64 x = Expr.cvtop ty WrapI64 x

let[@inline] extend_s n x =
Expr.cvtop ty (Sign_extend (32 - n)) (Expr.extract x ~high:(n / 8) ~low:0)
Expr.cvtop ty (Sign_extend (32 - n)) (Expr.extract x ~high:(n - 1) ~low:0)
end

module Bitv64 = struct
Expand All @@ -334,14 +334,14 @@ module Bitv64 = struct
let[@inline] to_int32 x = Expr.cvtop Bitv32.ty WrapI64 x

let[@inline] to_bytes x =
[ extract x ~high:1 ~low:0
; extract x ~high:2 ~low:1
; extract x ~high:3 ~low:2
; extract x ~high:4 ~low:3
; extract x ~high:5 ~low:4
; extract x ~high:6 ~low:5
; extract x ~high:7 ~low:6
; extract x ~high:8 ~low:7
[ extract x ~high:7 ~low:0
; extract x ~high:15 ~low:8
; extract x ~high:23 ~low:16
; extract x ~high:31 ~low:24
; extract x ~high:39 ~low:32
; extract x ~high:47 ~low:40
; extract x ~high:55 ~low:48
; extract x ~high:63 ~low:56
]

let[@inline] trunc_f32_s_exn x = Expr.cvtop ty TruncSF32 x
Expand Down Expand Up @@ -387,7 +387,7 @@ module Bitv64 = struct
let[@inline] reinterpret_f64 x = Expr.cvtop ty Reinterpret_float x

let[@inline] extend_s n x =
Expr.cvtop ty (Sign_extend (64 - n)) (Expr.extract x ~high:(n / 8) ~low:0)
Expr.cvtop ty (Sign_extend (64 - n)) (Expr.extract x ~high:(n - 1) ~low:0)

let[@inline] extend_i32_s x = Expr.cvtop ty (Sign_extend 32) x

Expand All @@ -404,10 +404,10 @@ module Bitv128 = struct
let of_i32x4 a b c d = Bitv64.concat (Bitv32.concat a b) (Bitv32.concat c d)

let to_i32x4 v =
let a = extract v ~low:12 ~high:16 in
let b = extract v ~low:8 ~high:12 in
let c = extract v ~low:4 ~high:8 in
let d = extract v ~low:0 ~high:4 in
let a = extract v ~low:96 ~high:127 in
let b = extract v ~low:64 ~high:95 in
let c = extract v ~low:32 ~high:63 in
let d = extract v ~low:0 ~high:31 in
(a, b, c, d)

let of_int64x2 a b =
Expand All @@ -418,8 +418,8 @@ module Bitv128 = struct
let of_i64x2 a b = Bitv64.concat a b

let to_i64x2 v =
let a = extract v ~low:8 ~high:16 in
let b = extract v ~low:0 ~high:8 in
let a = extract v ~low:64 ~high:127 in
let b = extract v ~low:0 ~high:63 in
(a, b)
end

Expand Down
10 changes: 5 additions & 5 deletions src/smtml/typed.mli
Original file line number Diff line number Diff line change
Expand Up @@ -497,11 +497,11 @@ module Bitv : sig
Example: [concat (v8 0xAA) (v8 0xBB)] results in [0xAABB] (16-bit). *)
val concat : 'a expr -> 'b expr -> 'c expr

(** [extract t ~high ~low] extracts the bytes from index [high] down to
[low] (inclusive).
(** [extract t ~high ~low] extracts the bits from index [high] down to[low]
(inclusive).

Example: [extract (i32 0xAABBCCDD) ~high:2 ~low:1] results in [0xCC]
(1-byte). *)
Example:[extract (i32 0xAABBCCDD) ~high:15 ~low:8] results in [0xCC]
(8-bit). *)
val extract : t -> high:int -> low:int -> 'a expr

(** [zero_extend n t] extends [t] to a width of [width(t) + n] by padding
Expand Down Expand Up @@ -625,7 +625,7 @@ module Bitv64 : sig
(** [to_int32 t] extracts the lower 32 bits of [t]. *)
val to_int32 : t -> bitv32 expr

(** [to_bytes t] splits the 32-bit vector into 4 bytes (little-endian). *)
(** [to_bytes t] splits the 64-bit vector into 8 bytes (little-endian). *)
val to_bytes : t -> bitv8 expr list

(** Truncate float to signed integer (raises exception on overflow/NaN). *)
Expand Down
124 changes: 124 additions & 0 deletions test/integration/test_solver.ml
Original file line number Diff line number Diff line change
Expand Up @@ -431,4 +431,128 @@ module Make (M : Mappings_intf.S_with_fresh) = struct
assert_sat ~f:"test_uninterpreted_function"
(Solver.check solver []) )
]

let test_extract_bit_level solver_module =
let open Infix in
let module Solver = (val solver_module : Solver_intf.S) in
let create_solver () =
Solver.create ~params:(Params.default ()) ~logic:QF_BVFP ()
in

(* Test 1: Basic extraction - extract bits 3-0 from 0xAF should give 0xF *)
let solver = create_solver () in
let x = int8 0xAF in
let extracted = Expr.raw_extract x ~high:3 ~low:0 in
(* Result type should be 4 bits *)
assert_equal (Expr.ty extracted) (Ty.Ty_bitv 4);
Solver.add solver
[ Expr.raw_relop (Ty_bitv 4) Eq extracted
(Expr.value (Bitv (Bitvector.make (Z.of_int 0xF) 4)))
];
assert_sat ~f:"test_extract_low_bits" (Solver.check solver []);

(* Test 2: Basic extraction 2 - extract bits 7-4 from 0xAF should give 0xA *)
let solver = create_solver () in
let extracted_high = Expr.raw_extract x ~high:7 ~low:4 in
assert_equal (Expr.ty extracted_high) (Ty.Ty_bitv 4);
Solver.add solver
[ Expr.raw_relop (Ty_bitv 4) Eq extracted_high
(Expr.value (Bitv (Bitvector.make (Z.of_int 0xA) 4)))
];
assert_sat ~f:"test_extract_high_bits" (Solver.check solver []);

(* Test 3: Non-byte-aligned extraction - bits 5-2 from 0xAB (10101011) are 1010 = 0xA *)
let solver = create_solver () in
let y = int8 0xAB in
let extracted_mid = Expr.raw_extract y ~high:5 ~low:2 in
assert_equal (Expr.ty extracted_mid) (Ty.Ty_bitv 4);
Solver.add solver
[ Expr.raw_relop (Ty_bitv 4) Eq extracted_mid
(Expr.value (Bitv (Bitvector.make (Z.of_int 0xA) 4)))
];
assert_sat ~f:"test_extract_non_aligned" (Solver.check solver []);

(* Test 4: Single bit extraction - bit 0 from 0xF should be 1 *)
let solver = create_solver () in
let z = int32 0xFl in
let single_bit = Expr.raw_extract z ~high:0 ~low:0 in
assert_equal (Expr.ty single_bit) (Ty.Ty_bitv 1);
Solver.add solver
[ Expr.raw_relop (Ty_bitv 1) Eq single_bit
(Expr.value (Bitv (Bitvector.make Z.one 1)))
];
assert_sat ~f:"test_extract_single_bit" (Solver.check solver []);

(* Test 5: Full 32-bit extraction *)
let solver = create_solver () in
let w = int32 0xDEADBEEFl in
let full_extract = Expr.raw_extract w ~high:31 ~low:0 in
assert_equal (Expr.ty full_extract) (Ty.Ty_bitv 32);
Solver.add solver
[ Expr.raw_relop (Ty_bitv 32) Eq full_extract (int32 0xDEADBEEFl) ];
assert_sat ~f:"test_extract_full_width" (Solver.check solver []);

(* Test 6: Symbolic extraction with solver verification *)
let solver = create_solver () in
let sym_x = symbol "bv_x" (Ty_bitv 32) in
Solver.add solver [ Expr.relop (Ty_bitv 32) Eq sym_x (int32 0x12345678l) ];
let sym_extracted = Expr.extract sym_x ~high:15 ~low:8 in
(* Bits 15-8 of 0x12345678 should be 0x56 *)
Solver.add solver [ Expr.relop (Ty_bitv 8) Eq sym_extracted (int8 0x56) ];
assert_sat ~f:"test_extract_symbolic" (Solver.check solver [])

let test_extract =
"test_extract"
>::: [ "test_extract_bit_level" >:: with_solver test_extract_bit_level ]

let test_bitv32_to_bytes solver_module =
let open Typed in
let module Solver = (val solver_module : Solver_intf.S) in
let solver = Solver.create ~params:(Params.default ()) ~logic:QF_BVFP () in

let bv_val = Bitvector.make (Z.of_int32 0xDEADBEEFl) 32 in
let bv = Bitv32.v bv_val in

match Bitv32.to_bytes bv with
| [ b0; b1; b2; b3 ] ->
let v8 i = Bitv8.v (Bitvector.make (Z.of_int i) 8) in
Solver.add solver
[ (Bool.eq b0 (v8 0xEF) :> Expr.t)
; (Bool.eq b1 (v8 0xBE) :> Expr.t)
; (Bool.eq b2 (v8 0xAD) :> Expr.t)
; (Bool.eq b3 (v8 0xDE) :> Expr.t)
];
assert_sat ~f:"test_bitv32_to_bytes" (Solver.check solver [])
| _ -> OUnit2.assert_failure "Bitv32.to_bytes should return exactly 4 bytes"

let test_bitv64_to_bytes solver_module =
let open Typed in
let module Solver = (val solver_module : Solver_intf.S) in
let solver = Solver.create ~params:(Params.default ()) ~logic:QF_BVFP () in

let bv_val = Bitvector.make (Z.of_int64 0x0123456789ABCDEFL) 64 in
let bv = Bitv64.v bv_val in

(* Match para extrair os 8 bytes da lista *)
match Bitv64.to_bytes bv with
| [ b0; b1; b2; b3; b4; b5; b6; b7 ] ->
let v8 i = Bitv8.v (Bitvector.make (Z.of_int i) 8) in
Solver.add solver
[ (Bool.eq b0 (v8 0xEF) :> Expr.t)
; (Bool.eq b1 (v8 0xCD) :> Expr.t)
; (Bool.eq b2 (v8 0xAB) :> Expr.t)
; (Bool.eq b3 (v8 0x89) :> Expr.t)
; (Bool.eq b4 (v8 0x67) :> Expr.t)
; (Bool.eq b5 (v8 0x45) :> Expr.t)
; (Bool.eq b6 (v8 0x23) :> Expr.t)
; (Bool.eq b7 (v8 0x01) :> Expr.t)
];
assert_sat ~f:"test_bitv64_to_bytes" (Solver.check solver [])
| _ -> OUnit2.assert_failure "Bitv64.to_bytes should return exactly 8 bytes"

let test_typed_api_consistency =
"test_typed_api_consistency"
>::: [ "test_bitv32_to_bytes" >:: with_solver test_bitv32_to_bytes
; "test_bitv64_to_bytes" >:: with_solver test_bitv64_to_bytes
]
end
2 changes: 2 additions & 0 deletions test/integration/test_solver_altergo.ml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ let test_suite =
; Alt_ergo.test_cached
; Alt_ergo.test_lia
; Alt_ergo.test_bv
; Alt_ergo.test_extract
; Alt_ergo.test_typed_api_consistency
]

let () = run_test_tt_main test_suite
2 changes: 2 additions & 0 deletions test/integration/test_solver_bitwuzla.ml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ let test_suite =
; Bitwuzla.test_params
; Bitwuzla.test_bv
; Bitwuzla.test_fp
; Bitwuzla.test_extract
; Bitwuzla.test_typed_api_consistency
]

let () = run_test_tt_main test_suite
2 changes: 2 additions & 0 deletions test/integration/test_solver_colibri2.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ let test_suite =
; C2.test_bv
; C2.test_fp
; C2.test_lia
; C2.test_extract
; C2.test_typed_api_consistency
]

let () = run_test_tt_main test_suite
2 changes: 2 additions & 0 deletions test/integration/test_solver_cvc5.ml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ let test_suite =
; Cvc5_solv.test_lia
; Cvc5_solv.test_bv
; Cvc5_solv.test_regexp
; Cvc5_solv.test_extract
; Cvc5_solv.test_typed_api_consistency
]

let () = run_test_tt_main test_suite
2 changes: 2 additions & 0 deletions test/integration/test_solver_z3.ml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ let test_suite =
; Z3_solv.test_regexp
; Z3_solv.test_uninterpreted
; Z3_bindings.test_adt
; Z3_solv.test_extract
; Z3_solv.test_typed_api_consistency
]

let () = run_test_tt_main test_suite
Loading