Haskell中的记忆化搜索

Memoization是动态规划(Dynamic Programming)中自顶向下处理问题采用的策略, 其基本想法是通过将子问题的解保存起来避免重复计算来优化算法. 这个概念本身很简单, 在其他有明显mutable语义的语言中, 实现起来也非常简单. 但是在Haskell中问题就变的复杂了不少, 对于一个原始的函数f :: a -> b你如果要用ref, 比如说IORef, 你必须要把它放到IO monad中, 你的memoize函数就变成了... -> IO (a -> b). 我们希望是能够找到一个memoize :: ... -> (a -> b), 这样memoize之后得到的和原函数类型是一致的. 为了讨论的方便, 我们主要关注两个例子的memoization, 一个是经典的Fibonacci数列:

fib :: Int -> Integer
fib 0 = 0
fib 1 = 1
fib n = fib (n - 2) + fib (n - 1)

另一个则是动态规划(自底向上)中典型的最小编辑距离的问题, 所谓的最小编辑距离就是一个字符串通过增加, 删除, 替换的操作得到另一个字符串所需要的操作次数:

minEditDist :: String -> String -> Int
minEditDist []     []     = 0
minEditDist s      []     = length s
minEditDist []     s      = length s
minEditDist (x:xs) (y:ys) | x == y    = minEditDist xs ys
                          | otherwise = 1 + minimum [minEditDist xs ys, minEditDist xs (y:ys), minEditDist (x:xs) ys]

Memoizing with specific problem

首先来看fib的问题, wiki给出了一个非常elegant的解(就fib本身而言, 还有更经典的解, fib = (fibs !!) where fibs = 0 : 1 : zipWith (+) fibs (tail fibs)):

import Data.Function (fix)

memoize :: (Int -> a) -> (Int -> a)
memoize f = (map f [0..] !!)

fib :: (Int -> Integer) -> Int -> Integer
fib f 0 = 1
fib f 1 = 1
fib f n = f (n - 1) + f (n - 2)

fibMemo :: Int -> Integer
fibMemo = fix (memoize . fib)

虽然这个memoize和我们想要的(a -> b) -> a -> b有点差距, 但仍然值得分析一下. 首先来看fix, fix的定义很简单:

fix :: (a -> a) -> a
fix f = let x = f x in x

关于fix的详细解释这里略去, 简单而言, 可以将fix理解为一个构建递归的函数. 例如, fix (1:)按定义展开后就是1:(1:(1:(...))), 很容易看到是一个元素为1的无限列表. 这里的fibMemo = fix (memoize . fib)同样我们按定义展开:

fibMemo = fix (memoize . fib)
        -- fix定义
        = let x = (memoize . fib) x in x
        = (memoize . fib) fibMemo
        = memoize (fib fibMemo)
        -- memoize定义
        = (map (fib fibMemo) [0..] !!)

-- 等价于
fibMemo = (map fib [0..] !!) where
  fib 0 = 0
  fib 1 = 1
  fib n = fibMemo (n - 2) + fibMemo (n - 1)

这种memoization实现利用了Haskell的laziness, fibMemo变成了从一个无限的列表里面取值, 我们已经构建好了每一个元素的表达式, 在需要的时候计算, 这样那些已经计算过的元素就保存在列表里面. 更详细的讲, 我们在定义完fibMemo时其结构为:

fibMemo = ([0, 1,
            fibMemo 0 + fibMemo 1,
            fibMemo 1 + fibMemo 2..] !!)

在调用fibMemo 3之后其结构变为:

fibMemo = ([0, 1, 1, 2, fib 2 + fib 3..] !!)

可以看到fibMemo 2的结果已经被保存了, 这就实现了memoization.

我们再来看最小编辑距离的问题, 我们显然没法把fib中的memoize直接拿过来. 因为在这个问题上, 我们希望保存的是任意两个子串的最小编辑距离, 从之前fib的memoization借鉴, 开始我们的第一次尝试:

minEditDistMemo :: String -> String -> Int
minEditDistMemo s1 s2 = lookupS s1 s2
  where lookupS x1 x2 = maybe undefined id $ lookup (x1, x2) ds
        ds            = map g [(x1, x2) | x1 <- tails s1, x2 <- tails s2]
        g (s1, s2)    = ((s1, s2), f s1 s2)
        f [] []       = 0
        f s []        = length s
        f [] s        = length s
        f (x:xs) (y:ys) | x == y    = minEditDistMemo xs ys
                        | otherwise = 1 + minimum [minEditDistMemo xs ys, minEditDistMemo xs (y:ys), minEditDistMemo (x:xs) ys]

可以看到, 每次递归调用minEditDistMemo, 它都会构建一个新的ds, 而这是有问题的. 当然这也很容易解决, 只要把每次递归调用minEditDistMemo的地方换成lookupS就行:

minEditDistMemo :: String -> String -> Int
minEditDistMemo s1 s2 = lookupS s1 s2
  where lookupS x1 x2 = maybe undefined id $ lookup (x1, x2) ds
        ds            = map g [(x1, x2) | x1 <- tails s1, x2 <- tails s2]
        g (s1, s2)    = ((s1, s2), f s1 s2)
        f []     []   = 0
        f s      []   = length s
        f []     s    = length s
        f (x:xs) (y:ys) | x == y    = lookupS xs ys
                        | otherwise =1 + minimum [lookupS xs ys, lookupS xs (y:ys), lookupS (x:xs) ys]

generic memoization

通过上面的分析, 可以看到, 我们总是可以根据特定的问题构建特定的数据结构来实现memoization. 也就是说, 对于任意的一个函数f :: a -> b(如果f有多个参数, 可以先uncurry), 我们希望能够用一个数据结构来保存计算结果, 也就是(a, b), 显然, Map就是最理想的数据结构. 问题是Haskell的Map是immutable, 我们没法像imperative programming那样方便的修改, 这个时候就需要用到State, State能够帮助我们解决共享状态的问题(以下实现来源于Memoizing function in Haskell):

import qualified Data.Map as M
import Control.Monad.State

type MemoState a b = State (M.Map a b) b

memorize :: Ord a => ((a -> MemoState a b) -> (a -> MemoState a b)) -> a -> b
memorize t x = evalState (f x) M.empty where
  f x = get >>= \m -> maybe (g x) return (M.lookup x m)
  g x = do
        y <- t f x
        m <- get
        put $ M.insert x y m
        return y

这里t就是我们要memoized的函数, xt的参数. memorize从一个empty的Map开始运行f x :: MemoState a b并返回它的值. 而f首先用get拿到了当前的状态(也就是Map), 随后检查是否计算过参数为x的结果, 如果是则返回包含结果的MemoState a b, 否则返回g x :: MemoState a b. g的话, 它首先计算参数为x的值, 注意到这个t的类型是(a -> MemoState a b) -> (a -> MemoState a b), 这和我们之前讨论利用fix的函数类似, 都不递归调用自身, 而是调用额外的函数. 随后, 用get拿到了当前的状态(Map), 再用put更新状态(Map), 最后返回了一个包含结果和新状态的MemoState a b.

注意到这个t的类型, 意味着我们要改写原函数, 我们原先的minEditDist需要改为:

-- minEditDistM :: ((String, String) -> MemoState (String, String) Int) -> (String, String) -> MemoState (String, String) Int
minEditDistM :: Monad m => ((String, String) -> m Int) -> (String, String) -> m Int
minEditDistM f ([],     [])     = return 0
minEditDistM f (s,      [])     = return $ length s
minEditDistM f ([],     s)      = return $ length s
minEditDistM f ((x:xs), (y:ys)) | x == y    = f (xs, ys)
                                | otherwise = (+1) . minimum <$> (sequenceA $ f <$> [(xs, ys), (xs, (y:ys)), ((x:xs), ys)])

所幸的是, 我们可以把minEditDistM, 也就是t的类型定义的更generic. 这样一来, 我们的minEditDist就可以实现为:

-- memoized version
minEditDist :: String -> String -> Int
minEditDist s1 s2 = memorize minEditDistM (s1, s2)

至此, 我们就得到了泛用的memorize, 我们要做的仅仅是改写原先的函数, 即:

origin :: a1 -> a2 ... -> b
-- 1. uncurry所有参数, (a1, a2, ...) -> b
-- 2. 添加额外的f, 替换调用自身的情况, ((a1, a2, ...) -> b) -> (a1, a2, ...) -> b
-- 3. 修改返回值为monad
modified :: Monad m => ((a1, a2, ...) -> m b) -> (a1, a2, ...) -> m b

-- memoized version
originMemo a1 a2 ... = memorize modified (a1, a2, ...)

总结

本文讨论了Haskell中两种memoization的手段, 一种根据具体问题具体的分析, 构建需要的数据结构来保存子问题的结果; 另外一种则利用一个泛用的memoize函数, 按特定的模式修改原函数即可实现memoization. 总体而言, 两种方式各有优劣, 第一种方法需要更精致能够得到更适合问题的解, 第二种方法则提供了泛用性.

参考

  1. Memoization
  2. Dynamic programming
  3. Memoizing function in Haskell
  4. Lazy Dynamic Programming
  5. Haskell/Understanding monads/State
posted @ 2021-03-27 16:27  Christophe1997  阅读(226)  评论(0编辑  收藏  举报