Commit ab81c75a by xuelu

rebuild sat bsd

parents
bin
build
CXX := g++
# -Iinclude 是关键,让头文件直接可见
CXXFLAGS := -O3 -std=c++17 -fopenmp -Wall -Wextra -Iinclude
# 依赖库
LDLIBS := -lz3 -lcryptominisat5 -lgmp -lcadiback -lcadical
# 目录
SRC_DIR := src
BUILD_DIR := build
BIN_DIR := bin
TARGET := $(BIN_DIR)/bsd_run
# 自动探测 Conda 环境
ifneq ($(CONDA_PREFIX),)
CXXFLAGS += -I$(CONDA_PREFIX)/include
LDFLAGS += -L$(CONDA_PREFIX)/lib
endif
# 自动扫描源文件
SRCS := $(wildcard $(SRC_DIR)/*.cpp)
OBJS := $(patsubst $(SRC_DIR)/%.cpp, $(BUILD_DIR)/%.o, $(SRCS))
DEPS := $(OBJS:.o=.d)
# ==========================================
# Rules
# ==========================================
all: $(TARGET)
$(TARGET): $(OBJS)
@mkdir -p $(BIN_DIR)
@echo "Linking $@"
@$(CXX) $(CXXFLAGS) $(LDFLAGS) $^ $(LDLIBS) -o $@
$(BUILD_DIR)/%.o: $(SRC_DIR)/%.cpp
@mkdir -p $(BUILD_DIR)
@echo "Compiling $<"
@$(CXX) $(CXXFLAGS) -MMD -MP -c $< -o $@
-include $(DEPS)
# 一键运行 (自动设置库路径)
run: $(TARGET)
@echo "--- Running BSD ---"
@export LD_LIBRARY_PATH=$(CONDA_PREFIX)/lib:$(LD_LIBRARY_PATH) && ./$(TARGET)
clean:
@rm -rf $(BUILD_DIR) $(BIN_DIR)
.PHONY: all clean run
\ No newline at end of file
# moduleBSD
**moduleBSD demo:** 用于测试模块式拓展序对BSD生成的影响。
## 目前支持/进度:
1. 支持简单的自定义拓展序,每层可以是任意比特的特定op组合或单比特,如:`1 and 2``1 and 2 and 3``1 xor 2`等,具体支持见`Expansion.h`。默认为默认黑盒函数的完备与门组合。
2. 支持自定义黑盒函数,但输出必须是单比特,默认为乘法中间位,即:`c[15:0] = a[7:0] * b[7:0], output = c[7]`
3. 目前仅用sat判断合并,可能纯单比特效果不如random。
4. 没有加入bsd转verilog逻辑。
## 可能的优化:
1. 加入反向边。
2. random和sat组合验证。
## 命令:
1. 运行:`make run`
2. 清理:`make clean`
#pragma once
#include <cstdint>
#include <vector>
#include <string>
#include <chrono>
#include <unordered_set>
#include <unordered_map>
#include <omp.h>
#include <thread>
#include "Utils.h"
#include "Expansion.h"
#include "SatSolver.h"
struct BDDNode {
int lson, rson;
BDDNode(int l, int r) : lson(l), rson(r) {}
std::vector<int> parents;
std::vector<int> values;
void add_parents(int parent, int value) {
parents.push_back(parent);
values.push_back(value);
}
};
class BSD {
const bool use_predict = true;
uint64_t num_test = 10000;
uint64_t num_compare = 1000;
uint64_t num_check_leaf = 10000;
Utils utils;
ExpansionManager split_bit_record;
// sat求解相关,用于check leaf
std::vector<std::string> vars;
std::vector<std::vector<std::string>> bdd_vars;
// sat求解相关,用于merge node
std::vector<std::string> vars_A;
std::vector<std::vector<std::string>> bdd_vars_A;
std::vector<std::string> vars_B;
std::vector<std::vector<std::string>> bdd_vars_B;
// 树结构
std::vector<std::vector<BDDNode>> bdd;
// --- 辅助功能 ---
// 递归添加sat limit
void add_pre_limit_of_node(
SatSolver& solver, int depth, int nodeId,
const std::vector<std::vector<std::string>> &bdd_vars,
const std::vector<uint64_t> &gate_to_idx
);
// 求给定input的后缀值
std::string bdd_infer_split(const std::vector<int>& input_bits, ExpansionManager& lower_split_bit);
// 求给定input的bsd推理值
int bdd_infer(std::vector<int> input_bits);
// 求给定input的bsd推测节点
int bdd_infer_node(std::vector<int> input_bits, int depth);
// 测试当前树的正确率
double test_result();
// --- 核心功能 ---
// step1: predict
void predict(int layer_width, int depth);
// step2: check leaf
void set_leaf(int depth, int nodeId, uint64_t res);
void set_none(int depth, int nodeId);
void check_is_leaf(int layer_width, int depth);
// step3: merge node
bool module_fit(int node_a, int node_b, int depth);
int merge_node(int depth);
void build_tree_recursive(int layer_width, int depth);
void build_tree();
public:
BSD(int input_width, bool use_predict=true);
BSD(int input_width, BlackBoxFunc func, ExpansionManager split_bit_record, bool use_predict=true);
void work();
};
\ No newline at end of file
#pragma once
#include <cstdint>
#include <vector>
#include <string>
// 全局枚举,用起来方便
enum OpType { OP_BIT, OP_AND, OP_OR, OP_XOR };
struct ExpNode {
OpType type;
std::vector<uint64_t> inputs; // 存放依赖节点的索引
int infer(const std::vector<int>& values);
};
struct ExpansionManager {
std::vector<ExpNode> nodes;
// 声明函数,实现放 cpp
uint64_t add_bit(const uint64_t var);
uint64_t add_gate(OpType type, const std::vector<uint64_t>& inputs);
uint64_t size() const { return nodes.size(); }
ExpansionManager slice(size_t start_index, size_t end_index) const;
// 获取某一层的详细信息(用于调试)
void print_layer(uint64_t index) const;
int infer_layer(uint64_t index, const std::vector<int>& values);
OpType get_layer_op(uint64_t index);
std::vector<uint64_t> trans_layer(uint64_t index, const std::vector<uint64_t>& vars_to_idx);
};
\ No newline at end of file
#pragma once
#include <cstdint>
#include <vector>
#include <cryptominisat5/cryptominisat.h>
#include "Expansion.h"
class SatSolver {
CMSat::SATSolver solver;
uint64_t nums_var = 0;
public:
// 构造函数
SatSolver(uint64_t num_threads = 1);
// 基础功能
uint64_t create_vars(uint64_t count = 1);
uint64_t size();
void add_limits(const std::vector<CMSat::Lit>& limits);
void add_pos_var(uint64_t var);
void add_neg_var(uint64_t var);
// 添加bsd树上限制辅助函数
void add_c_equal_a_and_b(uint64_t a, uint64_t b, uint64_t c);
void add_c_equal_a_and_not_b(uint64_t a, uint64_t b, uint64_t c);
void add_b_equal_or_of_a(const std::vector<uint64_t>& a, uint64_t b);
void add_b_equal_and_of_a(const std::vector<uint64_t>& a, uint64_t b);
void add_b_equal_xor_of_a(const std::vector<uint64_t>& a, uint64_t b);
void add_a_equal_b(uint64_t a, uint64_t b);
// 添加拓展序限制辅助函数
void add_b_equal_op_of_a(const std::vector<uint64_t>& a, uint64_t b, OpType op);
// 求解
std::vector<int> work(const std::vector<uint64_t> &vars);
};
\ No newline at end of file
#pragma once
#include <vector>
#include <iostream>
#include <algorithm>
#include <random>
#include <functional>
#include <cstdint>
#include <cassert>
// 定义黑盒函数签名
using BlackBoxFunc = std::function<int(const std::vector<int>&)>;
struct Utils {
int input_width;
std::mt19937 rng;
BlackBoxFunc logic;
// 模式 A: 【默认模式】 (乘法器)
Utils(int width, uint32_t seed = 0);
// 模式 B: 【自定义模式】 (完全接管)
Utils(int width, BlackBoxFunc custom_logic, uint32_t seed = 0);
// --- 核心功能 ---
std::vector<int> get_random_input();
int query(const std::vector<int>& input);
// --- 静态工具箱 (Helper Utils) ---
static std::vector<int> dec2bin(uint64_t val, int width);
static uint64_t bin2dec(const std::vector<int>& bin);
// 基础运算模拟
static std::vector<int> op_multiply(const std::vector<int>& a, const std::vector<int>& b);
static std::vector<int> op_add(const std::vector<int>& a, const std::vector<int>& b);
static int op_parity(const std::vector<int>& input); // 奇偶校验
// 打印工具
template <typename T>
static void print(const std::vector<T>& list) {
if (list.empty()) { std::cout << "[]\n"; return; }
std::cout << "[ ";
for (const auto& x : list) std::cout << x << " ";
std::cout << "]\n";
}
// vector去重
template <typename T>
static std::vector<T> unique(const std::vector<T>& vec) {
std::vector<T> res = vec;
std::sort(res.begin(), res.end());
res.erase(std::unique(res.begin(), res.end()), res.end());
return res;
}
};
\ No newline at end of file
This diff is collapsed. Click to expand it.
#include <iostream>
#include "Expansion.h"
int ExpNode::infer(const std::vector<int>& values) {
switch(this->type) {
case OP_BIT:
return values[this->inputs[0]];
case OP_AND: {
int res = values[this->inputs[0]];
for(size_t i = 1; i < inputs.size(); ++i) {
res &= values[this->inputs[i]];
}
return res;
}
case OP_OR: {
int res = values[this->inputs[0]];
for(size_t i = 1; i < inputs.size(); ++i) {
res |= values[this->inputs[i]];
}
return res;
}
case OP_XOR: {
int res = values[this->inputs[0]];
for(size_t i = 1; i < inputs.size(); ++i) {
res ^= values[this->inputs[i]];
}
return res;
}
default:
throw std::runtime_error("Unknown operation type");
}
}
uint64_t ExpansionManager::add_bit(const uint64_t var) {
nodes.push_back({OP_BIT, {var}});
return nodes.size() - 1;
}
uint64_t ExpansionManager::add_gate(OpType type, const std::vector<uint64_t>& inputs) {
nodes.push_back({type, inputs});
return nodes.size() - 1;
}
ExpansionManager ExpansionManager::slice(size_t start_index, size_t end_index) const {
ExpansionManager new_mgr;
if (start_index < nodes.size() && start_index <= end_index) {
size_t real_end = std::min(end_index, nodes.size());
// 复制指定范围的节点
new_mgr.nodes.assign(nodes.begin() + start_index, nodes.begin() + real_end);
}
return new_mgr;
}
void ExpansionManager::print_layer(uint64_t index) const {
const auto& node = nodes[index];
std::cout << "Node[" << index << "] Type=" << node.type << " Inputs: ";
for(auto in : node.inputs) {
std::cout << in << " ";
}
std::cout << std::endl;
}
int ExpansionManager::infer_layer(uint64_t index, const std::vector<int>& values) {
if(index >= nodes.size()) {
throw std::runtime_error("Index out of bounds");
}
return nodes[index].infer(values);
}
OpType ExpansionManager::get_layer_op(uint64_t index) {
return nodes[index].type;
}
std::vector<uint64_t> ExpansionManager::trans_layer(uint64_t index, const std::vector<uint64_t>& vars_to_idx) {
std::vector<uint64_t> result(nodes[index].inputs.size());
for(size_t i = 0; i < nodes[index].inputs.size(); i++) {
result[i] = vars_to_idx[nodes[index].inputs[i]];
}
return result;
}
\ No newline at end of file
#include <iostream>
#include "SatSolver.h"
SatSolver::SatSolver(uint64_t num_threads) {
solver.set_num_threads(num_threads);
}
// ------------------------------- 基础功能 -----------------------------------
uint64_t SatSolver::create_vars(uint64_t count) {
uint64_t old_size = nums_var;
nums_var += count;
solver.new_vars(count);
return old_size;
}
uint64_t SatSolver::size() {
return nums_var;
}
void SatSolver::add_limits(const std::vector<CMSat::Lit>& limits) {
solver.add_clause(limits);
}
void SatSolver::add_pos_var(uint64_t var) {
solver.add_clause({CMSat::Lit(var, false)});
}
void SatSolver::add_neg_var(uint64_t var) {
solver.add_clause({CMSat::Lit(var, true)});
}
// ------------------------------- bsd辅助函数 -----------------------------------
void SatSolver::add_c_equal_a_and_b(uint64_t a, uint64_t b, uint64_t c) {
// c == (a and b)
CMSat::Lit var_a = CMSat::Lit(a, false);
CMSat::Lit var_b = CMSat::Lit(b, false);
CMSat::Lit var_c = CMSat::Lit(c, false);
solver.add_clause({~var_a, ~var_b, var_c});
solver.add_clause({~var_c, var_a});
solver.add_clause({~var_c, var_b});
}
void SatSolver::add_c_equal_a_and_not_b(uint64_t a, uint64_t b, uint64_t c) {
// c == (a and !b)
CMSat::Lit var_a = CMSat::Lit(a, false);
CMSat::Lit var_b = CMSat::Lit(b, false);
CMSat::Lit var_c = CMSat::Lit(c, false);
solver.add_clause({~var_a, var_b, var_c});
solver.add_clause({~var_c, var_a});
solver.add_clause({~var_c, ~var_b});
}
// ==========================================
// OR 逻辑 (保留你原本的代码)
// Logic: b <-> (a1 | a2 | ... | an)
// ==========================================
void SatSolver::add_b_equal_or_of_a(const std::vector<uint64_t>& a, uint64_t b) {
CMSat::Lit var_b = CMSat::Lit(b, false);
// 1. Forward: b -> (a0 | a1 ... | an) => ~b | a0 | a1 ...
std::vector<CMSat::Lit> clause1;
clause1.push_back(~var_b);
for (uint64_t a_var : a) {
clause1.push_back(CMSat::Lit(a_var, false));
}
solver.add_clause(clause1);
// 2. Backward: (a0 | ... | an) -> b
// Equivalent to: (a0 -> b) AND (a1 -> b) ...
// Clause: ~ai | b
for (uint64_t a_var : a) {
solver.add_clause({var_b, CMSat::Lit(a_var, true)});
}
}
// ==========================================
// AND 逻辑
// Logic: b <-> (a1 & a2 & ... & an)
// ==========================================
void SatSolver::add_b_equal_and_of_a(const std::vector<uint64_t>& a, uint64_t b) {
CMSat::Lit var_b = CMSat::Lit(b, false);
// 1. Forward: b -> (a1 & a2 ... & an)
// Equivalent to: (b -> a1) AND (b -> a2) ...
// Clause: ~b | ai
for (uint64_t a_var : a) {
// 如果 b 为真,那么每一个 ai 都必须为真
solver.add_clause({~var_b, CMSat::Lit(a_var, false)});
}
// 2. Backward: (a1 & ... & an) -> b
// Equivalent to: ~(a1 & ... & an) | b
// De Morgan: (~a1 | ~a2 | ... | ~an) | b
std::vector<CMSat::Lit> clause_back;
clause_back.push_back(var_b); // b
for (uint64_t a_var : a) {
// 如果所有 ai 都为真,则 b 必须为真
// 也就是:只要有一个 ai 为假,b 就可以不为真(这句解释不严谨,看逻辑式)
// 逻辑式:b OR ~a1 OR ~a2 ...
clause_back.push_back(CMSat::Lit(a_var, true)); // ~ai
}
solver.add_clause(clause_back);
}
// ==========================================
// XOR 逻辑 (使用 CMSat 原生接口)
// Logic: b <-> (a1 ^ a2 ^ ... ^ an)
// ==========================================
void SatSolver::add_b_equal_xor_of_a(const std::vector<uint64_t>& a, uint64_t b) {
// 逻辑转换: b = a1 ^ a2 ...
// 等价于: b ^ a1 ^ a2 ... = 0 (False)
std::vector<uint32_t> xor_clause;
// CMSat 的 add_xor_clause 接受的是变量索引 (uint32_t),不是 Lit
xor_clause.push_back((uint32_t)b);
for (uint64_t a_var : a) {
xor_clause.push_back((uint32_t)a_var);
}
// rhs = false 表示所有变量异或的结果为 0 (偶校验)
// rhs = true 表示所有变量异或的结果为 1 (奇校验)
// 这里因为 b 也在左边,所以 sum 应该是 0
solver.add_xor_clause(xor_clause, false);
}
void SatSolver::add_a_equal_b(uint64_t a, uint64_t b) {
CMSat::Lit var_a = CMSat::Lit(a, false);
CMSat::Lit var_b = CMSat::Lit(b, false);
// a == b 等价于两个蕴含:
// 1. a → b (~a ∨ b)
// 2. b → a (~b ∨ a)
solver.add_clause({~var_a, var_b}); // ~a ∨ b
solver.add_clause({~var_b, var_a}); // ~b ∨ a
}
// ------------------------------- 扩展序辅助函数 -----------------------------------
void SatSolver::add_b_equal_op_of_a(const std::vector<uint64_t>& a, uint64_t b, OpType op) {
switch (op) {
case OP_BIT:
assert(a.size() == 1); // OP_BIT 只能有一个输入变量
add_a_equal_b(a[0], b);
break;
case OP_AND:
add_b_equal_and_of_a(a, b);
break;
case OP_OR:
add_b_equal_or_of_a(a, b);
break;
case OP_XOR:
add_b_equal_xor_of_a(a, b);
break;
default:
std::cerr << "[Error] Unsupported OpType in SatSolver constraint!" << std::endl;
break;
}
}
// ------------------------------- 求解函数 -----------------------------------
std::vector<int> SatSolver::work(const std::vector<uint64_t> &vars) {
CMSat::lbool ret = solver.solve();
std::vector<int> ans;
if (ret == CMSat::l_True) {
// std::cout << "SAT" << std::endl;
// 获取解
for (size_t i = 0; i < vars.size(); i++) {
ans.push_back(solver.get_model()[vars[i]] == CMSat::l_True ? 1 : 0);
}
} else if (ret == CMSat::l_False) {
// std::cout << "UNSAT" << std::endl;
} else {
// std::cout << "Unknown" << std::endl;
}
return ans;
}
\ No newline at end of file
#include "Utils.h"
// ==========================================
// 构造函数 A: 默认模式 (乘法)
// ==========================================
Utils::Utils(int width, uint32_t seed) : input_width(width), rng(seed) {
this->logic = [width](const std::vector<int>& input) -> int {
assert((int)input.size() == width && "The input length must be equal to input_width");
int half = width / 2;
std::vector<int> a(input.begin(), input.begin() + half);
std::vector<int> b(input.begin() + half, input.end());
auto res = Utils::op_multiply(a, b);
return res[half];
};
}
// ==========================================
// 构造函数 B: 自定义模式
// ==========================================
Utils::Utils(int width, BlackBoxFunc custom_logic, uint32_t seed) : input_width(width), rng(seed), logic(custom_logic) { }
// ---------------- 核心功能 ----------------
std::vector<int> Utils::get_random_input() {
uint32_t lower = rng(); // 生成低32位随机数
uint32_t upper = rng(); // 生成高32位随机数
uint32_t llwer = rng(); // 生成低32位随机数
uint32_t uuper = rng(); // 生成高32位随机数
uint64_t randomUInt64 = (static_cast<uint64_t>(upper) << 32) | lower;
uint64_t randomUInt64_2 = (static_cast<uint64_t>(uuper) << 32) | llwer;
std::vector<int> result(input_width), result_2(input_width);
if (input_width > 64) {
result = Utils::dec2bin(randomUInt64, 63);
result_2 = Utils::dec2bin(randomUInt64_2, input_width - 63);
result.insert(result.end(), result_2.begin(), result_2.end());
} else {
result = Utils::dec2bin(randomUInt64, input_width);
}
return result;
}
int Utils::query(const std::vector<int>& input) {
return logic(input);
}
// ---------------- 静态工具实现 ----------------
std::vector<int> Utils::dec2bin(uint64_t val, int width) {
std::vector<int> bin(width);
for (int i = 0; i < width; i++) {
bin[i] = val & 1;
val /= 2;
}
std::reverse(bin.begin(), bin.end()); // 假设 vector[0] 是 MSB
return bin;
}
uint64_t Utils::bin2dec(const std::vector<int>& bin) {
uint64_t decimal = 0;
for (size_t i = 0; i < bin.size(); i++) {
decimal = decimal * 2 + bin[i];
}
return decimal;
}
std::vector<int> Utils::op_multiply(const std::vector<int>& a, const std::vector<int>& b) {
uint64_t val_a = bin2dec(a);
uint64_t val_b = bin2dec(b);
return dec2bin(val_a * val_b, a.size() + b.size());
}
std::vector<int> Utils::op_add(const std::vector<int>& a, const std::vector<int>& b) {
uint64_t val_a = bin2dec(a);
uint64_t val_b = bin2dec(b);
int w = std::max(a.size(), b.size()) + 1;
return dec2bin(val_a + val_b, w);
}
int Utils::op_parity(const std::vector<int>& input) {
int p = 0;
for(int bit : input) p ^= bit;
return p;
}
\ No newline at end of file
#include <bits/stdc++.h>
#include "BSD.h"
int main(int argc, char* argv[]) {
std::cout << "8bit mul test !!!" << std::endl;
// 默认模式下不开启predict的方案更优,因为predict本质是贪心
BSD mul_bsd(16, false);
mul_bsd.work();
// 测试16位加法
std::cout << "16bit Adder Test !!!" << std::endl;
int input_width = 32;
BlackBoxFunc add_logic = [](const std::vector<int>& in) -> int {
int half = in.size() / 2;
std::vector<int> a(in.begin(), in.begin() + half);
std::vector<int> b(in.begin() + half, in.end());
// 调用工具箱里的加法
auto sum = Utils::op_add(a, b);
return sum[0];
};
Utils utils(input_width);
ExpansionManager split_bit_record;
// for (int j = 0; j < input_width / 2 - 1; j++) {
// for (int i = 0; i < input_width / 2 - j; i++) {
// split_bit_record.add_gate(OP_AND, {static_cast<unsigned long>(j + i), static_cast<unsigned long>(input_width - 1 - i)});
// }
// }
// for (int i = 0; i < input_width; i++) {
// for (int j = i + 1; j < input_width; j++) {
// split_bit_record.add_gate(OP_OR, {static_cast<unsigned long>(i), static_cast<unsigned long>(j)});
// split_bit_record.add_gate(OP_XOR, {static_cast<unsigned long>(i), static_cast<unsigned long>(j)});
// split_bit_record.add_gate(OP_AND, {static_cast<unsigned long>(i), static_cast<unsigned long>(j)});
// }
// }
for (int i = 0; i < input_width; i++) {
split_bit_record.add_bit(i);
}
BSD add_bsd(input_width, add_logic, split_bit_record);
add_bsd.work();
return 0;
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment