Commit 176f7519 by chengshuyao

read input sample from file

parent b0da060e
......@@ -50,6 +50,7 @@ int GLOBAL_BDD_split_nodes;
int GLOBAL_train_time;
int GLOBAL_program_time;
bool** file_inputs;
//待优化变量
int** variable_order ;
int variable_order_number;
......@@ -136,6 +137,7 @@ int search_partition(int start_order_depth, int* start_order);
bool* default_partition_set;
int* default_start_order;
int main(int argc,char* argv[]){
omp_set_num_threads(parameter_num_threads);
......@@ -160,7 +162,34 @@ int main(int argc,char* argv[]){
// truth_table_name = argv[2];
// }
//
// ifstream truth_table_file(truth_table_name);
ifstream sampling_input_file("sample_input.set");
file_inputs = new bool* [parameter_io_file_lines];
for (int i=0;i<parameter_io_file_lines;i++){
file_inputs[i] = new bool [parameter_input_bit_width];
for(int j=0;j<parameter_input_bit_width;j++){
file_inputs[i][j] = 0;
}
}
std::string line;
int lineCount = 0;
while ((lineCount < parameter_io_file_lines) && std::getline(sampling_input_file, line)){
if (line.length() >= parameter_input_bit_width){
for (int i = 0; i < parameter_input_bit_width; i++) {
if(line[i] == '1'){
file_inputs[lineCount][i] = 1;
}
else if (line[i] == '0'){
file_inputs[lineCount][i] = 0;
}else{
std::cerr << "第 " << lineCount + 1 << " 行第 " << i + 1 << " 个字符无效,应为'0'或'1'" << std::endl;
file_inputs[lineCount][i] = 0; // 默认值
}
}
lineCount ++;
}
}
// string line_data;
// for(int i=0;i<pow(2,20);i++){
// truth_table[i] = 0;
......
......@@ -3,21 +3,31 @@ int BDD_class::set_random_input_data(bool** mask_input_data){
random_device rd;
mt19937 gen(rd());
//#pragma omp parallel for
for (int zj=0;zj<BSD_samples*parameter_input_bit_width;zj++){
int i = int(zj/parameter_input_bit_width);
int j = int(zj%parameter_input_bit_width);
//srand((int)time(0));
int zi = 0;
long randint;
zi = zj%30;
if(zi == 0){
//randint = rand();
randint = gen();
// for (int zj=0;zj<BSD_samples*parameter_input_bit_width;zj++){
// int i = int(zj/parameter_input_bit_width);
// int j = int(zj%parameter_input_bit_width);
for (int i=0;i<BSD_samples;i++){
if(i<parameter_io_file_lines){
for(int j=0;j<parameter_input_bit_width;j++){
mask_input_data[i][j] = file_inputs[i][j];
}
mask_input_data[i][j] = bool((randint >> (zi))%2);
}else{
for(int j=0;j<parameter_input_bit_width;j++){
int zj = i*parameter_input_bit_width + j;
int zi = 0;
long randint;
zi = j%30;
if(zi == 0){
randint = gen();
}
mask_input_data[i][j] = bool((randint >> (zi))%2);
}
}
}
return 0;
};
int BDD_class::mask_random_input_data(int depth,bool* mask,int amount,int* most_influence,bool** mask_input_data){
int i,j;
//#pragma omp parallel for
......
......@@ -14,6 +14,8 @@ extern const double parameter_early_stop_accuracy = 1; //允许的错误率,
//没有特殊需要不要设到<1,会慢一些。
//0.5以下无意义,建议至少设到0.8吧.
extern const bool parameter_early_stop_oneway = 1; //如果设为1,只允许0->1 的错误,不允许1->0的错误
extern const int parameter_io_file_lines = 2; //在sample_input.set文件中,保存了最少parameter_use_io_file行input;保证每次采样,都能采到这些样本。
extern const int parameter_num_threads = 32; //线程数
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment