diff --git a/DatapathVerification/BitHeap/Chain.lean b/DatapathVerification/BitHeap/Chain.lean index 88a24e0..a7a1fd1 100644 --- a/DatapathVerification/BitHeap/Chain.lean +++ b/DatapathVerification/BitHeap/Chain.lean @@ -55,6 +55,49 @@ theorem applyChain_correct (steps : List Adder) (h : BitHeap) simp [applyAdder, fullAdder_correct, hwf] grind +/-- Check if a single step of the chain is applicable -/ +def isApplicable (step: Adder) (h : BitHeap) : Bool := + match step with + | .halfAdder col i j => + let column := h.get col + column.contains i && column.contains j && i != j + | .fullAdder col i j k => + let column := h.get col + column.contains i && column.contains j && column.contains k && i != j && i != k && j != k + +/-- Apply a chain of adders if they are applicable, otherwise return none -/ +def applyChainSafe (steps : List Adder) (h : BitHeap) : Option BitHeap := + match steps with + | [] => some h + | s :: rest => + if isApplicable s h then + applyChainSafe rest (applyAdder s h) + else + none + +/-- If a chain of adders is applicable (it does not return none), then it preserves the heap's value -/ +theorem applyChainSafe_correct (steps : List Adder) (h h' : BitHeap) (env : BitEnv) + (heq : applyChainSafe steps h = some h') : + h'.eval env = h.eval env := by + induction steps generalizing h with + | nil => + simp [applyChainSafe] at heq + rw [heq] + | cons s rest ih => + simp [applyChainSafe] at heq + obtain ⟨hleft, hright⟩ := heq + have ih_applied := ih (applyAdder s h) hright + rw [ih_applied] + simp_all [applyAdder, isApplicable] + cases s with + simp at hleft + | halfAdder => + obtain ⟨⟨i_in_col, j_in_col⟩, not_eq⟩ := hleft + rw [halfAdder_correct _ _ _ _ i_in_col j_in_col not_eq] + | fullAdder => + obtain ⟨⟨⟨⟨⟨hi, hj⟩, hk⟩, hij⟩, hik⟩, hjk⟩ := hleft + exact fullAdder_correct _ _ _ _ _ hi hj hk hij hik hjk env + end Chain end BitHeap diff --git a/DatapathVerification/BitHeap/Examples.lean b/DatapathVerification/BitHeap/Examples.lean index c6f0721..f6c996e 100644 --- a/DatapathVerification/BitHeap/Examples.lean +++ b/DatapathVerification/BitHeap/Examples.lean @@ -51,6 +51,35 @@ info: 6 #guard_msgs in #eval (applyChain compressionChain fourBitsInCol1).eval (show BitEnv from fun n => n = 1 || n = 2 || n = 3) +---------------------------- +-- Examples of incorrect chains -- + +def badChain : List Adder := + [.halfAdder 1 (Circuit.bit 1) (Circuit.const true)] + +-- The result 8 does not make sense here since we compressed a bit was not a part of the bitheap. +/-- +info: 8 +-/ +#guard_msgs in +#eval (applyChain badChain fourBitsInCol1).eval + (show BitEnv from fun n => n = 1 || n = 2 || n = 3) + +-- Returns none since the half adder is not applicable (constant bit is not in the heap). +/-- +info: none +-/ +#guard_msgs in +#eval applyChainSafe badChain fourBitsInCol1 + +-- Returns the correct value. +/-- +info: 6 +-/ +#guard_msgs in +#eval (applyChain compressionChain fourBitsInCol1).eval + (show BitEnv from fun n => n = 1 || n = 2 || n = 3) + end Examples end BitHeap