Jerry @DOA&INPAC, SJTU

Working out everything from the first principles.

导航

玩玩24点(上)

《玩玩24点》系列:

 

最近班里开始玩24点了。起因是一个在计算器上两人比赛24点的程序,但计算器判断一组数据是否有解需要15秒,于是这个程序就没有判定有解这一功能。

这么慢的速度我当然看不下去,但去优化那个BASIC程序是不可能的,我就开始写自己的24点程序。正好之前的算法课中递归一章提到过24点,我就理所当然地用开始写递归求解算法。

 

第一个版本是一个运行在PC上的非常复杂的C++程序,用上了十多个头文件。由于它太烂了,我把它注释掉以后又删掉了。这个程序最后算出1820组数据中有1362组有解数据,与网上查到的数字是一致的,不过算得很慢,要二十几秒。

这个时候我的想法还是单片机上的程序通过标准库随机数函数产生数据然后跑一遍求解算法。于是我就把这个程序修改了一下,标准库容器换成数组替换掉了,排序就随便写了个冒泡。PC上所有数据遍历一遍需要十几秒。

  1 #include <iostream>
  2 #include <cstdint>
  3 
  4 using Integer = std::uint16_t;
  5 
  6 template <typename T>
  7 inline void swap(T& lhs, T& rhs)
  8 {
  9     auto temp = lhs;
 10     lhs = rhs;
 11     rhs = temp;
 12 }
 13 
 14 template <typename I>
 15 inline void sort(I begin, I end)
 16 {
 17     for (auto pass_end = end - 1; pass_end != begin; --pass_end)
 18     {
 19         bool changed = false;
 20         for (auto iter = begin; iter != pass_end; ++iter)
 21             if (*(iter + 1) < *iter)
 22             {
 23                 swap(*iter, *(iter + 1));
 24                 changed = true;
 25             }
 26         if (!changed)
 27             break;
 28     }
 29 }
 30 
 31 int divide_count = 0, modulo_count = 0;
 32 
 33 class Rational
 34 {
 35 public:
 36     Integer num, den;
 37     Rational(Integer num = 0, Integer den = 1)
 38         : num(num), den(den)
 39     {
 40         // make every object reduced
 41         reduce();
 42     }
 43     Rational& operator=(Integer i)
 44     {
 45         num = i;
 46         den = 1;
 47         return *this;
 48     }
 49     Rational operator+(const Rational& rhs) const
 50     {
 51         // assume it won't overflow
 52         return Rational(num * rhs.den + rhs.num * den, den * rhs.den);
 53     }
 54     Rational operator-(const Rational& rhs) const
 55     {
 56         // assume *this >= rhs
 57         return Rational(num * rhs.den - rhs.num * den, den * rhs.den);
 58     }
 59     Rational operator*(const Rational& rhs) const
 60     {
 61         return Rational(num * rhs.num, den * rhs.den);
 62     }
 63     Rational operator/(const Rational& rhs) const
 64     {
 65         // assume rhs != 0
 66         return Rational(num * rhs.den, den * rhs.num);
 67     }
 68     bool operator==(const Rational& rhs) const
 69     {
 70         return num == rhs.num && den == rhs.den;
 71     }
 72     bool operator==(Integer rhs) const
 73     {
 74         return num == rhs && den == 1;
 75     }
 76     bool operator<(const Rational& rhs) const
 77     {
 78         return num * rhs.den < rhs.num * den;
 79     }
 80     explicit operator bool()
 81     {
 82         return num;
 83     }
 84 private:
 85     void reduce()
 86     {
 87         if (num == 1 || den == 1)
 88             return;
 89         if (num == 0)
 90         {
 91             den = 1;
 92             return;
 93         }
 94         Integer gcd = 1;
 95         auto a = num, b = den;
 96         while (1)
 97         {
 98             if (a == 0 || a == b)
 99             {
100                 gcd = b;
101                 break;
102             }
103             if (b == 0)
104             {
105                 gcd = a;
106                 break;
107             }
108             if (a > b)
109             {
110                 ++modulo_count;
111                 a %= b;
112             }
113             else
114             {
115                 ++modulo_count;
116                 b %= a;
117             }
118         }
119         if (gcd > 1)
120         {
121             divide_count += 2;
122             num /= gcd;
123             den /= gcd;
124         }
125     }
126 };
127 
128 template <typename S>
129 S& operator<<(S& lhs, const Rational& rhs)
130 {
131     lhs << rhs.num;
132     if (rhs.den > 1)
133         lhs << '/' << rhs.den;
134     return lhs;
135 }
136 
137 struct Expression
138 {
139     Expression() = default;
140     Expression(const Rational& lhs, char op, const Rational& rhs,
141         const Rational& res)
142         : lhs(lhs), rhs(rhs), res(res), op(op) { }
143     char op = ' ';
144     Rational lhs, rhs, res;
145 };
146 
147 template <typename S>
148 S& operator<<(S& lhs, const Expression& rhs)
149 {
150     lhs << rhs.lhs << ' ' << rhs.op << ' ' << rhs.rhs << " = " << rhs.res;
151     return lhs;
152 }
153 
154 constexpr Integer target = 24;
155 constexpr Integer max_count = 4;
156 
157 bool solve(Integer count, const Rational* data, Expression* expr)
158 {
159     // assume data is ordered
160     if (count == 1)
161         return *data == target;
162     auto end = data + count;
163     auto before_end = end - 1;
164     --count;
165     Rational new_data[max_count - 1];
166     auto new_end = new_data + count;
167     for (auto lhs = data; lhs != before_end; ++lhs)
168         for (auto rhs = lhs + 1; rhs != end; ++rhs)
169         {
170             auto dst = new_data;
171             for (auto src = data; src != end; ++src)
172                 if (src != lhs && src != rhs)
173                     *dst++ = *src;
174             *dst = *lhs + *rhs;
175             Expression temp(*lhs, '+', *rhs, *dst);
176             sort(new_data, new_end);
177             if (solve(count, new_data, expr + 1))
178             {
179                 *expr = temp;
180                 return true;
181             }
182         }
183     for (auto lhs = data + 1; lhs != end; ++lhs)
184         for (auto rhs = data; rhs != lhs; ++rhs)
185         {
186             auto dst = new_data;
187             for (auto src = data; src != end; ++src)
188                 if (src != lhs && src != rhs)
189                     *dst++ = *src;
190             *dst = *lhs - *rhs;
191             Expression temp(*lhs, '-', *rhs, *dst);
192             sort(new_data, new_end);
193             if (solve(count, new_data, expr + 1))
194             {
195                 *expr = temp;
196                 return true;
197             }
198         }
199     for (auto lhs = data; lhs != before_end; ++lhs)
200         for (auto rhs = lhs + 1; rhs != end; ++rhs)
201         {
202             auto dst = new_data;
203             for (auto src = data; src != end; ++src)
204                 if (src != lhs && src != rhs)
205                     *dst++ = *src;
206             *dst = *lhs * *rhs;
207             Expression temp(*lhs, '*', *rhs, *dst);
208             sort(new_data, new_end);
209             if (solve(count, new_data, expr + 1))
210             {
211                 *expr = temp;
212                 return true;
213             }
214         }
215     for (auto lhs = data; lhs != end; ++lhs)
216         for (auto rhs = data; rhs != end; ++rhs)
217         {
218             if (lhs == rhs || *rhs == Rational(0))
219                 continue;
220             auto dst = new_data;
221             for (auto src = data; src != end; ++src)
222                 if (src != lhs && src != rhs)
223                     *dst++ = *src;
224             *dst = *lhs / *rhs;
225             Expression temp(*lhs, '/', *rhs, *dst);
226             sort(new_data, new_end);
227             if (solve(count, new_data, expr + 1))
228             {
229                 *expr = temp;
230                 return true;
231             }
232         }
233     return false;
234 }
235 
236 bool test(Integer a, Integer b, Integer c, Integer d)
237 {
238     Rational data[6];
239     Expression expr[3];
240     data[0] = a;
241     data[1] = b;
242     data[2] = c;
243     data[3] = d;
244     std::cout << a << ", " << b << ", " << c << ", " << d
245         << ':' << std::endl;
246     bool solved = solve(4, data, expr);
247     if (solved)
248         for (const auto& e : expr)
249             std::cout << '\t' << e << std::endl;
250     else
251         std::cout << "\tno solution" << std::endl;
252     return solved;
253 }
254 
255 int main()
256 {
257     int count = 0;
258     constexpr Integer max_num = 13;
259     for (int a = 1; a <= max_num; ++a)
260         for (int b = a; b <= max_num; ++b)
261             for (int c = b; c <= max_num; ++c)
262                 for (int d = c; d <= max_num; ++d)
263                     if (test(a, b, c, d))
264                         ++count;
265     std::cout << count << ' ' << divide_count << ' '
266         << modulo_count << std::endl;
267 
268     test(1, 3, 4, 6);
269     test(1, 4, 5, 6);
270     test(1, 5, 5, 5);
271     test(1, 6, 11, 13);
272     test(2, 2, 11, 11);
273     test(2, 2, 13, 13);
274     test(2, 7, 7, 10);
275     test(3, 3, 7, 7);
276     test(3, 3, 8, 8);
277     test(3, 7, 9, 13);
278     test(4, 4, 7, 7);
279     test(5, 5, 7, 11);
280 }

稍微解释一下这个程序。先定义了Integer类型别名,本来应该是uint8_t,意思是单片机是8位字长的,但流输入输出中会被当成char处理,而我想要的是整数,就换成了uint16_t。swap和sort用于替换标准库中同名函数,后者是冒泡排序,反正待排序的元素至多4个。然后是Rational类,表示非负有理数,自动约分,用的是辗转相除。由于单片机算除法和取模极慢,我把特殊情况排除掉了,类似于剪枝的思想。另有divide_count和modulo_count两变量用于统计除法和取模次数。Expression类表示表达式,24点的解法由3个表达式组成。以及Rational类和Expression类的流插入运算符重载。

递归求解函数参数有3个:要求解的数字个数count、输入数据,data以及表达式存放位置expr。函数假设输入数据已经有序。递归出口为count == 1,检查唯一的数据是否为24。其他情况下,程序会选两个数相加、相减、相乘、相除(排除除数为0的情况),把这两个数从序列中删除再加入新的运算结果,然后排序传给递归子程序。如果子程序返回true,就把这一步运算存入*expr。

这里还有一个小插曲。写完这个程序后,第一次运行的结果是有1320多组有解,然而答案应该是1362。我试着debug,但递归函数相当难debug,我只能在递归中写输出语句,最后发现是排序算法出了问题。那么简单的冒泡排序我当然不会写错,问题在于*(iter + 1) < *iter这句,一开始写的是*iter > *(iter + 1),然而Rational类并没有重载operator>,实际调用的是两个operator bool。我把比较方向换了过来,又在operator bool前加了explicit关键字以防万一,结果就正确了。这个故事告诉我们重载关系运算符要么乖乖全部写出来,要么写using namespace std::rel_ops。

 

然后移植到了AVR单片机上(没有C++标准库,这就是为什么前一个程序刻意避免了那么好用的标准库),开发板用的是写系列教程那块。输入是硬编码,输出是几个LED,3组数据各跑100遍,通过LED和秒表测运行速度。运行结果是300遍有解的数据求解一共用了15秒,其中最后一组大约用了一半时间。我有充足理由估计对无解的数据跑一遍求解需要至少100毫秒。

如果我要保证提供给用户的数据有解,一定会出现200ms以上的延迟,就算不用保证,也至少要100毫秒计算时间,我认为这是不能接受的。所以这个求解算法太慢了。为了用户体验,需要一种更快的算法,然而我并没有办法把求解算法优化掉一个量级。

那唯一的方法就是不求解了。等等,不求解?那数据怎么来?要手写数据?还是要模板元编程把数据在编译器就存起来?

想多了。让PC端程序数据一定格式的数据作为单片机程序的代码,然后直接读取即可。由于数据量比较大,单片机2KB内存放不下,必须放在flash中(就算数据量不大放在RAM中也是浪费)。数据长成这样:

 1 #ifndef DATA_H
 2 #define DATA_H
 3 
 4 #include <avr/pgmspace.h>
 5 
 6 const uint8_t valid_data[][5] PROGMEM =
 7 {
 8      17, 129,  32, 130, 117,
 9      17, 177,  32,  98, 180,
10      17, 193,  32, 146, 117,
11      17, 209,  32,  75, 180,
12      17,  98,  32, 130, 117,
13     // ...
14     // 1362 lines in all
15     // ...
16     219, 221,  32, 130, 109,
17     204, 204,  32, 130, 109,
18     204, 220,  32,  75, 149,
19     204, 221,  32, 130, 109,
20     220, 221,  32, 122, 172,
21 };
22 
23 const uint8_t invalid_data[][2] PROGMEM =
24 {
25      17,  17,
26      17,  33,
27      17,  49,
28      17,  65,
29      17,  81,
30     // ...
31     // 458 lines in all
32     // ...
33     186, 187,
34     186, 221,
35     187, 187,
36     187, 221,
37     221, 221,
38 };
39 
40 #endif

数据格式是我一拍脑袋定的:前两字节的低4位、高4位分别是从小到大的4个数字;后三字节是3个表达式,最低3位为LHS下标,中间2位为运算符,最高3位为RHS下标。运算符域从0到3分别是加减乘除,LHS和RHS的下标定义为依次存放输入数字和中间结果的数组中相等元素的下标;对于无解的数据,只有前两个字节。

这样单片机就无需对输入数据求解,只需根据压缩成字节码的表达式复原出原来的表达式,涉及到少量分数运算而已。实验证明这样的算法是足够快的,至少没有超过16毫秒的显示屏刷新时间。

 

单片机端的程序框架无非是定时器中断中更新显示,其他如硬件驱动等函数都是现成的。程序用到两个按键,分别用于切换刷新状态与显示答案。左边的按键按一下开始刷新数据,再按一下停止刷新,显示一组数据;对于有解的数据,右边的按键按第一下会显示最后一行,第二下会显示完整答案;对于无解的数据,按右边的按键会显示“no solution”。

我以前玩的24点都是1~10的数字,这次是真实模拟扑克牌环境的A~K。10这个数字如果正常显示需要2位,不美观,因此我用画点和画线的操作组合出了在一个字符空间内画10的操作。

为了增加可玩性,我还加入了无解的数据,概率大约为1/6。一群人围着一道题想了半分钟后发现是“no solution”是最爽的事。

 

实际上这个24点程序还远不完美。单片机经常在屏幕上输出诡异的解法,比如10 * 12 = 120,120 / 5 = 24,这些是不符合人类计算逻辑的,正常人想到的都是10 / 5 = 2,2 * 12 = 24。一个可行的方法是把递归搜索的顺序换一下,先减再加,先除后乘,在除法中优先用最大的数除以最小的数。但还是会出现12 / 5 = 12/5,12/5 * 10 = 24这样的式子,最根本的算法还是根据表达式建立树,在树上调整顺序。也许4个数算24点的情况不需要这么复杂,但这是万能的、具有可扩展性的做法(也有可能是我想多了)。

 

点一下,玩一年。24点这么好玩,我肯定不能止步于4个1~13的数加减乘除算出1个24这种简单的游戏。这句话暗示得很清楚了吧,我们中篇再见。

posted on 2019-10-15 00:48  Jerry_SJTU  阅读(418)  评论(0编辑  收藏  举报