[C++]一份Linq to object的C++实现

几个月的构想+0.5小时的设计+4小时的linq.h编码+3小时的测试编码。

大量使用C++11的特性,在GCC 4.7.2下编译通过。

 

关于实现相关的描述就不说了,我表达能力差,恐怕讲清楚还需要好几个小时。具体使用参见测试码。

上代码:

(1) linq.h

View Code
#ifndef LINQ_H
#define LINQ_H

#include <cassert>

#include <utility>
#include <functional>
#include <memory>
#include <algorithm>

#include <set>
#include <vector>
#include <map>

/*
 * ISSUES:
 *  1. an non-delay action will break of the delay list: dataSource changed,
 *  but query doesn't know. (see @ non-delay)
 */

template<typename T>
struct DeclareType
{
    typedef 
        typename std::remove_cv<
        typename std::remove_reference<
        typename std::remove_cv<T>::type>::type>::type Type;
};

template<typename T> class Enumerable;
template<typename T>
auto range(T end) -> Enumerable<T>;
template<typename ContainerT>
auto from(const ContainerT& c) -> Enumerable<
    typename DeclareType<decltype(*std::begin(c))>::Type >;

template<typename T>
class Enumerable
{
private:
    typedef std::function<std::pair<bool, T>()> ClosureT;
public:
    struct iterator
    {
    public:
        typedef std::forward_iterator_tag iterator_category;
        typedef T value_type;
        typedef int difference_type;
        typedef T* pointer;
        typedef T& reference;
    public:
        iterator(): m_advanced(false){}
        iterator(const ClosureT& c): m_closure(c), m_advanced(true)
        { 
            assert(m_closure);
        }
        iterator& operator ++ ()
        {
            _doAdvance();
            assert(m_closure && !m_advanced);
            m_advanced = true;
            return *this;
        }
        iterator operator ++ (int)
        {
            iterator r(*this);
            ++*this;
            return r;
        }
        const T& operator * () const
        {
            _doAdvance();
            return m_v;
        }
        bool operator == (const iterator& o) const
        {
            _doAdvance();
            o._doAdvance();
            // just for exp: begin == end
            return m_closure == nullptr && o.m_closure == nullptr;
        }
        bool operator != (const iterator& o) const
        {
            return !(*this == o);
        }

    private:
        void _doAdvance() const
        {
            const_cast<iterator*>(this)->_doAdvance();
        }
        void _doAdvance()
        {
            if (!m_advanced) return;
            m_advanced = false;
            assert(m_closure);
            auto r = m_closure();
            if (!r.first) m_closure = nullptr;
            else m_v = r.second;
        }

        ClosureT m_closure;
        T m_v;
        bool m_advanced;
    };

public:
    Enumerable(
            const ClosureT& c):
        m_closure(c)
    { }

    Enumerable() = default;

public:
    iterator begin() const
    {
        return iterator(m_closure);
    }
    iterator end() const
    {
        return iterator();
    }

public:
    template<typename FuncT>
    auto select(const FuncT &f) const -> Enumerable<typename DeclareType<decltype(f(*(T*)0))>::Type> 
    {
        typedef typename DeclareType<decltype(f(*(T*)0))>::Type RType;
        auto iter = this->begin(), end = this->end();
        return Enumerable<RType>([iter, end, f]() mutable
        {
            if (iter == end) return std::make_pair(false, RType());
            return std::make_pair(true, f(*iter++));
        });
    }

    template<typename FuncT>
    auto where(const FuncT& f) const -> Enumerable
    {
        auto iter = this->begin(), end = this->end();
        return Enumerable([iter, end, f]() mutable
        {
            while (iter != end && !f(*iter)) ++iter;
            if (iter == end) return std::make_pair(false, T());
            return std::make_pair(true, *iter++);
        });
    }

    template<typename FuncT>
    auto all(const FuncT& f) const -> bool
    {
        for (auto i : *this) {
            if (!f(i)) return false;
        }
        return true;
    }

    template<typename FuncT>
    auto any(const FuncT& f) const -> bool
    {
        for (auto i : *this) {
            if (f(i)) return true;
        }
        return false;
    }

    template<typename DestT>
    auto cast() const -> Enumerable<DestT>
    {
        auto iter = this->begin(), end = this->end();
        return Enumerable<DestT>([iter, end]() mutable
        {
            if (iter == end) return std::make_pair(false, DestT());
            return std::make_pair(true, DestT(*iter++));
        });
    }

    auto average() const -> T
    {
        T v = T();
        int n = 0;
        for (auto i : *this) {
            v += i;
            ++n;
        }
        assert(n > 0);
        return v / n;
    }

    auto contain(const T& v) const -> bool
    {
        for (auto i : *this) {
            if (i == v) return true;
        }
        return false;
    }

    auto count(const T& v) const -> int
    {
        return count([v](T i){ return i == v;});
    }

    template<typename FuncT>
    auto count(const FuncT& f, typename std::enable_if<!std::is_convertible<FuncT, T>::value>::type* = 0) const -> int
    {
        int n = 0;
        for (auto i : *this) {
            if (f(i)) ++n;
        }
        return n;
    }

    auto first() const -> T
    {
        return *this->begin();
    }

    auto last() const -> T
    {
        T v;
        for (auto i : *this) v = i;
        return v;
    }

    auto head(int n) const -> Enumerable
    {
        auto iter = this->begin(), end = this->end();
        return Enumerable([iter, end, n]() mutable
        {
            if (--n < 0 || iter == end) return std::make_pair(false, T());
            return std::make_pair(true, *iter++);
        });
    }

    auto tail(int n) const -> Enumerable
    {
        int sz = (int)std::vector<T>(this->begin(), this->end()).size();
        n = std::min(n, sz);
        auto iter = this->begin(), end = this->end();
        std::advance(iter, sz - n);
        return Enumerable([iter, end]() mutable
        {
            if (iter == end) return std::make_pair(false, T());
            return std::make_pair(true, *iter++);
        });
    }

    // @ non-delay
    template<typename FuncT>
    auto groupBy(const FuncT &f) const -> Enumerable<
        std::pair<typename DeclareType<decltype(f(*(T*)0))>::Type, Enumerable>> 
    {
        typedef typename DeclareType<decltype(f(*(T*)0))>::Type RType;
        typedef std::pair<RType, Enumerable> RPair;

        std::map<RType, std::vector<T>> m;
        for (auto i : *this) {
            m[f(i)].push_back(i);
        }

        std::shared_ptr<std::map<RType, Enumerable>> m2(new std::map<RType, Enumerable>());
        for (auto i : m) {
            (*m2)[i.first] = from(i.second).reserve();
        }

        auto iter = m2->begin();
        return Enumerable<RPair>([iter, m2]() mutable
        {
            if (iter == m2->end()) return std::make_pair(false, RPair());
            return std::make_pair(true, RPair(*iter++));
        });
    }

    template<typename FuncT>
    auto takeUntil(const FuncT& f) const -> Enumerable
    {
        auto iter = this->begin(), end = this->end();
        return Enumerable([iter, end, f]() mutable
        {
            if (iter == end) return std::make_pair(false, T());
            T r = *iter++;
            if (f(r)) return std::make_pair(false, T());
            return std::make_pair(true, r);
        });
    }

    template<typename FuncT>
    auto skipUntil(const FuncT& f) const -> Enumerable
    {
        auto iter = this->begin(), end = this->end();
        while (iter != end && !f(*iter)) ++iter;
        return Enumerable([iter, end]() mutable
        {
            if (iter == end) return std::make_pair(false, T());
            return std::make_pair(true, *iter++);
        });
    }

    auto max() const -> T
    {
        auto iter = this->begin(), end = this->end();
        assert(iter != end);
        T v = *iter++;
        while (iter != end) v = std::max(v, *iter++);
        return v;
    }

    auto min() const -> T
    {
        auto iter = this->begin(), end = this->end();
        assert(iter != end);
        T v = *iter++;
        while (iter != end) v = std::min(v, *iter++);
        return v;
    }

    template<typename FuncT>
    auto reduce(const FuncT& f, T v = T()) const -> T
    {
        for (auto i : *this) v = f(v, i);
        return v;
    }

    // @ non-delay
    auto reverse() const -> Enumerable
    {
        std::shared_ptr<std::vector<T>> v(new std::vector<T>(this->begin(), this->end()));
        auto iter = v->rbegin();
        return Enumerable([iter, v]() mutable
        {
            if (iter == v->rend()) return std::make_pair(false, T());
            return std::make_pair(true, *iter++);
        });
    }

    // @ non-delay
    auto reserve() const -> Enumerable
    {
        std::shared_ptr<std::vector<T>> v(new std::vector<T>(this->begin(), this->end()));
        auto iter = v->begin();
        return Enumerable([iter, v]() mutable
        {
            if (iter == v->end()) return std::make_pair(false, T());
            return std::make_pair(true, *iter++);
        });
    }

    auto sort() const -> Enumerable
    {
        return sort(std::less<T>());
    }

    // @ non-delay
    template<typename FuncT>
    auto sort(const FuncT& f) const -> Enumerable
    {
        std::shared_ptr<std::vector<T>> v(new std::vector<T>(this->begin(), this->end()));
        std::sort(v->begin(), v->end(), f);
        auto iter = v->begin();
        return Enumerable([iter, v]() mutable
        {
            if (iter == v->end()) return std::make_pair(false, T());
            return std::make_pair(true, *iter++);
        });
    }
    
    // @ non-delay
    auto intersect(const Enumerable& o) const -> Enumerable
    {
        std::shared_ptr<std::set<T>> s1(new std::set<T>(this->begin(), this->end()));
        std::shared_ptr<std::set<T>> s2(new std::set<T>(o.begin(), o.end()));
        auto iter = s1->begin();
        return Enumerable([iter, s1, s2]() mutable
        {
            while (iter != s1->end() && s2->count(*iter) == 0) ++iter;
            if (iter == s1->end()) return std::make_pair(false, T());
            return std::make_pair(true, *iter++);
        });
    }

    // @ non-delay
    auto _union(const Enumerable& o) const -> Enumerable
    {
        std::shared_ptr<std::set<T>> s(new std::set<T>(this->begin(), this->end()));
        s->insert(o.begin(), o.end());
        auto iter = s->begin();
        return Enumerable([iter, s]() mutable
        {
            if (iter == s->end()) return std::make_pair(false, T());
            return std::make_pair(true, *iter++);
        });
    }

    auto unique() const -> Enumerable 
    {
        std::set<T> s;
        auto iter = this->begin(), end = this->end();
        return Enumerable([iter, end, s]() mutable
        {
            while (iter != end && s.count(*iter) > 0) ++iter;
            if (iter == end) return std::make_pair(false, T());
            s.insert(*iter);
            return std::make_pair(true, *iter++);
        });
    }

    // @ non-delay
    auto random() const -> Enumerable
    {
        std::shared_ptr<std::vector<T>> v(new std::vector<T>(this->begin(), this->end()));
        std::random_shuffle(v->begin(), v->end());
        auto iter = v->begin();
        return Enumerable([iter, v]() mutable
        {
            if (iter == v->end()) return std::make_pair(false, T());
            return std::make_pair(true, *iter++);
        });
    }

private:
    ClosureT m_closure;
};

template<typename ContainerT>
auto from(const ContainerT& c) -> Enumerable<
    typename DeclareType<decltype(*std::begin(c))>::Type >
{
    typedef typename DeclareType<decltype(*std::begin(c))>::Type RType;
    bool init = false;
    auto iter = std::end(c);
    return Enumerable<RType>([init, iter, &c]() mutable
    {
        if (!init) { 
            init = true;
            iter = std::begin(c);
        }
        if (iter == std::end(c)) return std::make_pair(false, RType());
        return std::make_pair(true, *iter++);
    });
}

template<typename T>
auto range(T begin, T end, T step = 1) -> Enumerable<T>
{
    T cur = begin;
    return Enumerable<T>([cur, end, step]() mutable
    {
        if ((step > 0 && cur >= end) || (step < 0 && cur <= end)) {
            return std::make_pair(false, T());
        }
        T r = cur;
        cur += step;
        return std::make_pair(true, r);
    });
}

template<typename T>
auto range(T end) -> Enumerable<T>
{
    return range(T(), end);
}

#endif

 

(2) 测试代码main.cpp (比我的代码更烂的是我的英语)

#include "pch.h" 

#include "linq.h"

#include <vector>
#include <string>
#include <algorithm>

template<typename T>
void printC(const T& v)
{
    for (auto i : v) cout << i << ',';
    cout << endl;
}

template<typename T>
void print(const T& v)
{
    cout << v << endl;
}

bool startsWith(const std::string& s, const std::string& prefix)
{
    return s.find(prefix) == 0;
}

void featureTest()
{
    // 1. standard
    {
        auto query = range(10)
            .where([](int i){ return i % 2; })
            .select([](int i){ return i + 1; });
        auto ref = {2, 4, 6, 8, 10};
        assert(std::equal(ref.begin(), ref.end(), query.begin()));
    }
    // 2. deferred range
    {
        assert(range(123LL, 1000000000000LL).first() == 123);
    }
    // 3. deferred action
    {
        int selectCnt = 0;

        auto query = range(1, 1000)
            .where([](int i){ return i % 2 == 0; })
            .select([&selectCnt](int i)
                { 
                    ++selectCnt;
                    return i; 
                })
            .where([](int i) { return i % 4 == 0; })
            .head(2);
        auto query2 = query;

        for (auto i : query) {}
        assert(selectCnt == 4);

        for (auto i : query2) {}
        assert(selectCnt == 8);
    }
    // 4. copy semantic
    {
        auto query = range(10).head(5);
        auto query2 = query;

        auto ref = {0, 1, 2, 3, 4};
        assert(std::equal(ref.begin(), ref.end(), query.begin()));
        assert(std::equal(ref.begin(), ref.end(), query2.begin()));

        auto iter = query.begin();
        ++iter;
        auto iter2 = iter;

        ref = {1, 2, 3, 4};
        assert(std::equal(ref.begin(), ref.end(), iter));
        assert(std::equal(ref.begin(), ref.end(), iter2));
    }
    // 5. always reference the neweast data of dataSource
    {
        std::vector<std::string> dataSrc{"A_abc", "A_def", "B_abc", "B_def"};

        auto query = from(dataSrc)
            .head(3)
            .where([](const std::string& s) { return startsWith(s, "B_"); });

        auto ref = {"B_abc"};
        assert(std::equal(ref.begin(), ref.end(), query.begin()));

        dataSrc.clear();
        dataSrc.shrink_to_fit();
        dataSrc = {"B#_abc", "B_123", "B_1234", "B_321", "B_111"};
        ref = {"B_123", "B_1234"};
        assert(std::equal(ref.begin(), ref.end(), query.begin()));
    }
    // 6. invoke the operator new as least as possible
    {
    }
    // 7. you can use query after the dataSource has been destroyed, by the use of 'reserve'
    {
        Enumerable<int> query;
        {
            std::vector<int> v{1, 2, 3, 4};
            // query = from(v).select([](int i){ return i % 2; });
            query = from(v).reserve().select([](int i){ return i % 2; });

            auto ref = {1, 0, 1, 0};
            assert(std::equal(ref.begin(), ref.end(), query.begin()));
        }

        auto ref = {1, 0, 1, 0};
        assert(std::equal(ref.begin(), ref.end(), query.begin()));
    }
    // 8. add action to an exist query
    {
        auto query = range(10).where([](int i){ return i < 5;});
        auto ref = {0, 1, 2, 3, 4};
        assert(std::equal(ref.begin(), ref.end(), query.begin()));

        auto query2 = query.select([](int i){ return i * i; });
        ref = {0, 1, 4, 9, 16};
        assert(std::equal(ref.begin(), ref.end(), query2.begin()));
    }
}

void functionTest()
{
    // 1. from, select, where, cast
    {
        int a[]{5, 6, 7, 8, 9};
        auto query = from(a)
            .where([](int i){ return i % 3; })
            .select([](int i) { return i * i;});
        auto ref = {25, 49, 64};
        assert(std::equal(ref.begin(), ref.end(), query.begin()));
    }
    // 2. range, all, any
    {
        assert(range(10).all([](int i){ return i >= 0;}));
        assert(!range(10).all([](int i){ return i > 0;}));
        assert(from(std::vector<std::string>{"_a", "b"})
                .any([](const std::string& s){ return startsWith(s, "_"); }));
        assert(!from(std::vector<std::string>{"@a", "b"})
                .any([](const std::string& s){ return startsWith(s, "_"); }));
    }
    // 3. cast, average
    {
        assert(range(1, 5).average() == 2);
        assert(range(1, 5).cast<float>().average() == 2.5);
    }
    // 4. contain, count
    {
        int a[]{1, 2, 1, 1, 3, 2, };
        assert(from(a).contain(3));
        assert(!from(a).contain(4));
        assert(from(a).count(1) == 3);
        assert(from(a).count([](int i) { return i % 2; }) == 4);
    }
    // 5. first, last, head, tail
    {
        int a[]{3, 5, 7, 9, 11};
        assert(from(a).first() == 3);
        assert(from(a).last() == 11);

        auto ref = {3, 5};
        auto query = from(a).head(2);
        assert(std::equal(ref.begin(), ref.end(), query.begin()));

        ref = {7, 9, 11};
        query = from(a).tail(3);
        assert(std::equal(ref.begin(), ref.end(), query.begin()));
    }
    // 6. groupBy
    {
        auto query = range(10).groupBy([](int i) { return i % 3;});
        int refs[][4] = {
            {0, 3, 6, 9},
            {1, 4, 7,},
            {2, 5, 8,},
        };
        int n = 0;
        for (auto i : query) {
            assert(i.first == refs[n][0]);
            assert(std::equal(i.second.begin(), i.second.end(), refs[n]));
            ++n;
        }
        assert(n == 3);
    }
    // 7. takeUntil, skipUntil
    {
        auto query = range(10).takeUntil([](int i){ return i > 5; });
        auto ref = {0, 1, 2, 3, 4, 5};
        assert(std::equal(ref.begin(), ref.end(), query.begin()));

        query = range(10).skipUntil([](int i){ return i > 5; });
        ref = { 6, 7, 8, 9};
        assert(std::equal(ref.begin(), ref.end(), query.begin()));
    }
    // 8. max, min
    {
        int a[]{3, 2, 5, 8, 10, -3};
        assert(from(a).min() == -3);
        assert(from(a).max() == 10);
    }
    // 9. reduce
    {
        assert(range(1, 11).reduce([](int a, int b){ return a + b; }) == 55);
        assert(range(1, 11).reduce([](int a, int b){ return a * b; }, 0) == 0);
        assert(range(1, 11).reduce([](int a, int b){ return a * b; }, 1) == 3628800);
    }
    // 10. unique, sort, random
    {
        int a[]{3, 5, 5, 4, 2, 1, 2};

        auto query = from(a).unique();
        auto ref = {3, 5, 4, 2, 1};
        assert(std::equal(ref.begin(), ref.end(), query.begin()));

        ref = {5, 4, 3, 2, 1};
        query = query.sort().sort([](int a, int b){ return a > b; });
        assert(std::equal(ref.begin(), ref.end(), query.begin()));
        query = query.random();
        assert(!std::equal(ref.begin(), ref.end(), query.begin()));
    }
    // 11. intersect, _union
    {
        int a[]{3, 5, 11};
        auto query = range(10).intersect(from(a));
        auto ref = {3, 5};
        assert(std::equal(ref.begin(), ref.end(), query.begin()));

        ref = {3, 4, 5, 6};
        query = query._union(range(4, 7));
        assert(std::equal(ref.begin(), ref.end(), query.begin()));
    }
}

int main()
{
    featureTest();
    functionTest();
}

为什么不把它提交到git hub之类的专门代码仓库?一则我没有用过,二则,这种代码是我的write-only构想实践码,不提供后续维护的:)

posted @ 2012-10-20 23:28 Scan. 阅读(...) 评论(...) 编辑 收藏