[csu1605]数独(精确覆盖问题)

题意 :给定数独的某些初始值,规定每个格子的得分,求得分最大的数独的解。

思路:这是某年的noip的原题,高中时就写过,位运算也就是那个时候学会的--。这题明显是暴搜,但是需要注意两点,一是需要加一些常数优化,也就是位运算,一个是剪枝,填完某个数后发现某个格子无解了则换个数填,并且那些可填的数的种数少的格子尽量先填,因为这样尽可能让矛盾在靠近根的地方出现。今天粗略学了一下舞蹈链--DLX,这个算法(准确来说是一个结构)可以比较高效的解决一些精确覆盖问题,对于重复覆盖问题稍作修改也适用。用DLX写了一遍数独,发现效率比位运算略高一点,但不明显。

位运算:

  1 #pragma comment(linker, "/STACK:10240000,10240000")
  2 
  3 #include <iostream>
  4 #include <cstdio>
  5 #include <algorithm>
  6 #include <cstdlib>
  7 #include <cstring>
  8 #include <map>
  9 #include <queue>
 10 #include <deque>
 11 #include <cmath>
 12 #include <vector>
 13 #include <ctime>
 14 #include <cctype>
 15 #include <set>
 16 #include <bitset>
 17 #include <functional>
 18 #include <numeric>
 19 #include <stdexcept>
 20 #include <utility>
 21 
 22 using namespace std;
 23 
 24 #define mem0(a) memset(a, 0, sizeof(a))
 25 #define mem_1(a) memset(a, -1, sizeof(a))
 26 #define lson l, m, rt << 1
 27 #define rson m + 1, r, rt << 1 | 1
 28 #define define_m int m = (l + r) >> 1
 29 #define rep_up0(a, b) for (int a = 0; a < (b); a++)
 30 #define rep_up1(a, b) for (int a = 1; a <= (b); a++)
 31 #define rep_down0(a, b) for (int a = b - 1; a >= 0; a--)
 32 #define rep_down1(a, b) for (int a = b; a > 0; a--)
 33 #define all(a) (a).begin(), (a).end()
 34 #define lowbit(x) ((x) & (-(x)))
 35 #define constructInt4(name, a, b, c, d) name(int a = 0, int b = 0, int c = 0, int d = 0): a(a), b(b), c(c), d(d) {}
 36 #define constructInt3(name, a, b, c) name(int a = 0, int b = 0, int c = 0): a(a), b(b), c(c) {}
 37 #define constructInt2(name, a, b) name(int a = 0, int b = 0): a(a), b(b) {}
 38 #define pchr(a) putchar(a)
 39 #define pstr(a) printf("%s", a)
 40 #define sstr(a) scanf("%s", a)
 41 #define sint(a) scanf("%d", &a)
 42 #define sint2(a, b) scanf("%d%d", &a, &b)
 43 #define sint3(a, b, c) scanf("%d%d%d", &a, &b, &c)
 44 #define pint(a) printf("%d\n", a)
 45 #define test_print1(a) cout << "var1 = " << a << endl
 46 #define test_print2(a, b) cout << "var1 = " << a << ", var2 = " << b << endl
 47 #define test_print3(a, b, c) cout << "var1 = " << a << ", var2 = " << b << ", var3 = " << c << endl
 48 
 49 typedef long long LL;
 50 typedef pair<int, int> pii;
 51 typedef vector<int> vi;
 52 
 53 const int dx[8] = {0, 0, -1, 1, 1, 1, -1, -1};
 54 const int dy[8] = {-1, 1, 0, 0, 1, -1, 1, -1 };
 55 const int maxn = 3e4 + 7;
 56 const int md = 10007;
 57 const int inf = 1e9 + 7;
 58 const LL inf_L = 1e18 + 7;
 59 const double pi = acos(-1.0);
 60 const double eps = 1e-6;
 61 
 62 template<class T>T gcd(T a, T b){return b==0?a:gcd(b,a%b);}
 63 template<class T>bool max_update(T &a,const T &b){if(b>a){a = b; return true;}return false;}
 64 template<class T>bool min_update(T &a,const T &b){if(b<a){a = b; return true;}return false;}
 65 template<class T>T condition(bool f, T a, T b){return f?a:b;}
 66 template<class T>void copy_arr(T a[], T b[], int n){rep_up0(i,n)a[i]=b[i];}
 67 int make_id(int x, int y, int n) { return x * n + y; }
 68 
 69 int ans, a[10][10], f[1 << 13], row[10], col[10], block[10], sp[1 << 13];
 70 
 71 int getScore(int i, int j) {
 72     return min(min(i, 8 - i), min(j, 8 - j)) + 6;
 73 }
 74 
 75 void init() {
 76     rep_up0(i, 12) {
 77         f[1 << i] = i;
 78     }
 79 }
 80 
 81 void dfs(int k, int score) {
 82     if (k >= 81) {
 83         max_update(ans, score);
 84         return ;
 85     }
 86     int x, y, c = 10;
 87     rep_up0(i, 9) {
 88         bool ok = false;
 89         rep_up0(j, 9) {
 90             if (a[i][j]) continue;
 91             int tmp = row[i] | col[j] | block[make_id(i / 3, j / 3, 3)];
 92             int tot = 0x3fe ^ tmp;
 93             int cnt = 0;
 94             if (tot == 0) {
 95                 ok = true;
 96                 c = 0;
 97                 break;
 98             }
 99             cnt = sp[tot];
100             if (cnt < c) {
101                 x = i;
102                 y = j;
103                 c = cnt;
104             }
105         }
106         if (ok) break;
107     }
108     if (c == 0 || c == 10) return ;
109     int i = x, j = y;
110     int tmp = row[i] | col[j] | block[make_id(i / 3, j / 3, 3)];
111     int tot = 0x3fe ^ tmp;
112     while (tot) {
113         tmp = lowbit(tot);
114         row[i] ^= 1 << f[tmp];
115         col[j] ^= 1 << f[tmp];
116         block[make_id(i / 3, j / 3, 3)] ^= 1 << f[tmp];
117         a[i][j] = f[tmp];
118         dfs(k + 1, score + f[tmp] * getScore(i, j));
119         row[i] ^= 1 << f[tmp];
120         col[j] ^= 1 << f[tmp];
121         block[make_id(i / 3, j / 3, 3)] ^= 1 << f[tmp];
122         a[i][j] = 0;
123         tot -= tmp;
124     }
125 }
126 
127 int main() {
128     //freopen("in.txt", "r", stdin);
129     sp[0] = 0;
130     rep_up1(i, 1 << 10) {
131         sp[i] = sp[i - lowbit(i)] + 1;
132     }
133     int T;
134     init();
135     cin >> T;
136     while (T --) {
137         int sum = 0, cnt = 0, ok = true;
138         mem0(col);
139         mem0(row);
140         mem0(block);
141         rep_up0(i, 9) {
142             rep_up0(j, 9) {
143                 sint(a[i][j]);
144                 sum += a[i][j] * getScore(i, j);
145                 if (a[i][j]) {
146                     cnt ++;
147                     if (col[j] & (1 << a[i][j])) ok = false;
148                     if (row[i] & (1 << a[i][j])) ok = false;
149                     if (block[make_id(i / 3, j / 3, 3)] & (1 << a[i][j])) ok = false;
150                     col[j] |= 1 << a[i][j];
151                     row[i] |= 1 << a[i][j];
152                     block[make_id(i / 3, j / 3, 3)] |= 1 << a[i][j];
153                 }
154             }
155         }
156         ans = -1;
157         if (ok) dfs(cnt, sum);
158         cout << ans << endl;
159     }
160 }
View Code

DLX(模板):

  1 #pragma comment(linker, "/STACK:102400000,102400000")
  2 
  3 #include <iostream>
  4 #include <cstdio>
  5 #include <algorithm>
  6 #include <cstdlib>
  7 #include <cstring>
  8 #include <map>
  9 #include <queue>
 10 #include <deque>
 11 #include <cmath>
 12 #include <vector>
 13 #include <ctime>
 14 #include <cctype>
 15 #include <set>
 16 #include <bitset>
 17 #include <functional>
 18 #include <numeric>
 19 #include <stdexcept>
 20 #include <utility>
 21 
 22 using namespace std;
 23 
 24 #define mem0(a) memset(a, 0, sizeof(a))
 25 #define mem_1(a) memset(a, -1, sizeof(a))
 26 #define lson l, m, rt << 1
 27 #define rson m + 1, r, rt << 1 | 1
 28 #define define_m int m = (l + r) >> 1
 29 #define rep_up0(a, b) for (int a = 0; a < (b); a++)
 30 #define rep_up1(a, b) for (int a = 1; a <= (b); a++)
 31 #define rep_down0(a, b) for (int a = b - 1; a >= 0; a--)
 32 #define rep_down1(a, b) for (int a = b; a > 0; a--)
 33 #define all(a) (a).begin(), (a).end()
 34 #define lowbit(x) ((x) & (-(x)))
 35 #define constructInt4(name, a, b, c, d) name(int a = 0, int b = 0, int c = 0, int d = 0): a(a), b(b), c(c), d(d) {}
 36 #define constructInt3(name, a, b, c) name(int a = 0, int b = 0, int c = 0): a(a), b(b), c(c) {}
 37 #define constructInt2(name, a, b) name(int a = 0, int b = 0): a(a), b(b) {}
 38 #define pchr(a) putchar(a)
 39 #define pstr(a) printf("%s", a)
 40 #define sstr(a) scanf("%s", a)
 41 #define sint(a) scanf("%d", &a)
 42 #define sint2(a, b) scanf("%d%d", &a, &b)
 43 #define sint3(a, b, c) scanf("%d%d%d", &a, &b, &c)
 44 #define pint(a) printf("%d\n", a)
 45 #define test_print1(a) cout << "var1 = " << a << endl
 46 #define test_print2(a, b) cout << "var1 = " << a << ", var2 = " << b << endl
 47 #define test_print3(a, b, c) cout << "var1 = " << a << ", var2 = " << b << ", var3 = " << c << endl
 48 
 49 typedef long long LL;
 50 typedef pair<int, int> pii;
 51 typedef vector<int> vi;
 52 
 53 const int dx[8] = {0, 0, -1, 1, 1, 1, -1, -1};
 54 const int dy[8] = {-1, 1, 0, 0, 1, -1, 1, -1 };
 55 const int maxn = 1e5 + 7;
 56 const int md = 10007;
 57 const int inf = 1e9 + 7;
 58 const LL inf_L = 1e18 + 7;
 59 const double pi = acos(-1.0);
 60 const double eps = 1e-6;
 61 
 62 template<class T>T gcd(T a, T b){return b==0?a:gcd(b,a%b);}
 63 template<class T>bool max_update(T &a,const T &b){if(b>a){a = b; return true;}return false;}
 64 template<class T>bool min_update(T &a,const T &b){if(b<a){a = b; return true;}return false;}
 65 template<class T>T condition(bool f, T a, T b){return f?a:b;}
 66 template<class T>void copy_arr(T a[], T b[], int n){rep_up0(i,n)a[i]=b[i];}
 67 int make_id(int x, int y, int n) { return x * n + y; }
 68 
 69 ///行编号从1开始,列编号1~n,结点0是表头结点,结点1~n是各列顶部的虚拟结点
 70 int result;
 71 int b[10][10];
 72 
 73 int encode(int a, int b, int c) {
 74     return a * 81 + b * 9 + c + 1;
 75 }
 76 void decode(int code, int &a, int &b, int &c) {
 77     code --;
 78     c = code % 9; code /= 9;
 79     b = code % 9; code /= 9;
 80     a = code;
 81 }
 82 
 83 struct DLX
 84 {
 85     const static int maxn = 1050;
 86     const static int maxnode = 100007;
 87     int n , sz;                                                 // 行数,节点总数
 88     int S[maxn];                                                // 各列节点总数
 89     int row[maxnode],col[maxnode];                              // 各节点行列编号
 90     int L[maxnode],R[maxnode],U[maxnode],D[maxnode];            // 十字链表
 91 
 92     int ansd,ans[maxn];                                         //
 93 
 94     void init(int n )
 95     {
 96         this->n = n ;
 97         for(int i = 0 ; i <= n; i++ )
 98             {
 99               U[i] = i ;
100               D[i] = i ;
101               L[i] = i - 1;
102               R[i] = i + 1;
103         }
104         R[n] = 0 ;
105         L[0] = n;
106         sz = n + 1 ;
107         memset(S,0,sizeof(S));
108     }
109     void addRow(int r,vector<int> c1)
110     {
111         int first = sz;
112         for(int i = 0 ; i < c1.size(); i++ ){
113             int c = c1[i];
114             L[sz] = sz - 1 ; R[sz] = sz + 1 ; D[sz] = c ; U[sz] = U[c];
115             D[U[c]] = sz; U[c] = sz;
116             row[sz] = r; col[sz] = c;
117             S[c] ++ ; sz ++ ;
118         }
119         R[sz - 1] = first ; L[first] = sz - 1;
120     }
121     // 顺着链表A,遍历除s外的其他元素
122     #define FOR(i,A,s) for(int i = A[s]; i != s ; i = A[i])
123 
124     void remove(int c) {
125         L[R[c]] = L[c];
126         R[L[c]] = R[c];
127         FOR(i,D,c)
128             FOR(j,R,i) {U[D[j]] = U[j];D[U[j]] = D[j];--S[col[j]];}
129     }
130     void restore(int c) {
131         FOR(i,U,c)
132             FOR(j,L,i) {++S[col[j]];U[D[j]] = j;D[U[j]] = j; }
133         L[R[c]] = c;
134         R[L[c]] = c;
135     }
136     void update() {
137         int score = 0;
138         rep_up0(i, ansd) {
139             int r, c, v;
140             decode(ans[i], r, c, v);
141             score += (v + 1) * b[r][c];
142         }
143         max_update(result, score);
144     }
145     bool dfs(int d) {
146         if(R[0] == 0) {
147           ansd = d;
148           update();
149           return true;
150         }
151         // 找S最小的列c
152         int c = R[0];
153         FOR(i,R,0) if(S[i] < S[c]) c = i;
154 
155         remove(c);
156         FOR(i,D,c) {
157             ans[d] = row[i];
158             FOR(j,R,i) remove(col[j]);
159             //if(dfs(d + 1)) return true;
160             dfs(d + 1);
161             FOR(j,L,i) restore(col[j]);
162         }
163         restore(c);
164 
165         //return false;
166     }
167     bool solve(vector<int> & v) {
168         v.clear();
169         if(!dfs(0)) return false;
170         for(int i = 0 ; i < ansd ;i ++) v.push_back(ans[i]);
171         return true;
172     }
173 };
174 
175 DLX solver;
176 int a[12][12];
177 
178 
179 int main() {
180     //freopen("in.txt", "r", stdin);
181     rep_up0(i, 9) {
182         rep_up0(j, 9) {
183             b[i][j] = 6 + min(min(i, 8 - i), min(j, 8 - j));
184         }
185     }
186     int T, x;
187     cin >> T;
188     while (T --) {
189         solver.init(324);
190         rep_up0(i, 9) {
191             rep_up0(j, 9) {
192                 int x;
193                 sint(x);
194                 rep_up0(k, 9) {
195                     if (x == 0 || x == k + 1) {
196                         vector<int> col;
197                         col.push_back(encode(0, i, j));
198                         col.push_back(encode(1, i, k));
199                         col.push_back(encode(2, j, k));
200                         col.push_back(encode(3, make_id(i / 3, j / 3, 3), k));
201                         solver.addRow(encode(i, j, k), col);
202                     }
203                 }
204             }
205         }
206         result = -1;
207         solver.dfs(0);
208         cout << result << endl;
209     }
210     return 0;
211 }
View Code

 

posted @ 2015-05-05 05:19  jklongint  阅读(276)  评论(0编辑  收藏  举报