【翻译】创建IQUERYABLE提供器系列文章

  翻译自http://blogs.msdn.com/mattwar/archive/2007/07/30/linq-building-an-iqueryable-provider-part-i.aspx,持续更新

可重用的IQueryable基类

 

       很久就想开始一个使用IQueryable介绍创建LINQ提供器的系列文章了。一直有人通过微软内部邮件或论坛提问问我相关的建议。当然,我也一直回答他们说,我正在做一个简单示例,很快就会让你们知道一切。然而,我希望一步一步来深入并解释一切,而不是一下子给你们一个完整的示例,让你们自己去探索。

       首先,我应该指出的是在Beta2IQueryable有改变。它不再只一个接口,而是分成了两个:IQueryableIQueryProvider。在实现它们之前,让我们先来看看。

       如果你使用Visual Studio的“转到定义”,会看到如下代码:

    public interface IQueryable : IEnumerable {       

        Type ElementType { get; }

        Expression Expression { get; }

        IQueryProvider Provider { get; }

    }

    public interface IQueryable<T> : IEnumerable<T>, IQueryable, IEnumerable {

    }

       当然,IQueryable不再那么有趣,好东西都被移到了IQueryProvider新接口中。在介绍它之前,仍然有必要先看看IQueryable。可以看到,IQueryable只有三个只读接口。第一个给出元素类型(或IQueryable<T>中的T)。需要知道,所有实现IQueryable的类必须为某个T实现IQueryable<T>,反之亦然。泛型IQueryable<T>在方法签名等中非常常用。开始的非泛型在动态查询创建的应用中给出了弱类型的入口点。

       第二个属性给出了查询对应的表达式。这是IQueryable的本质。IQueryable的幕后英雄其实是表示LINQ查询运算符或方法调用的表达式。提供器必须包含这部分才能进行有用的操作。如果我们继续深入的话会看到整个IQueryable构架(包括LINQ标准查询运算符的System.Linq.Queryable版本)只是一种自动构建表达式树的机制。当我们使用Queryable.Where方法来为IQueryable应用过滤的时候,它只是为我们创建了一个新的IQueryable,在其树的顶部增加了一个表示Queryable.Where调用的方法调用表达式节点。不相信的话可以自己尝试一下。

       现在只剩下最后一个属性了,它返回新的IQueryProvider接口的实例。实现构建新的IQueryable以及执行它们的方法被移到了独立的接口中。

    public interface IQueryProvider {

        IQueryable CreateQuery(Expression expression);

        IQueryable<TElement> CreateQuery<TElement>(Expression expression);

        object Execute(Expression expression);

        TResult Execute<TResult>(Expression expression);

    }

       看到IQueryProvider接口你可能会想怎么这么多方法,其实只有两个操作,CreateQueryExecute,两者都有泛型和非泛型的形式。在编程语言中直接写查询的时候,泛型形式就佷常用了,并且由于不需要使用反射来构建实例也就性能更好。

       CreateQuery方法做的工作和它的名字一样。它根据某个表达式树创建IQueryable查询的实例。我们调用方法的时候其实是在让提供器创建一个新的IQueryable实例,在枚举的时候会调用我们的查询提供器并处理某个查询表达式。标准查询运算符的Queryable形式使用这个方法来构建新的和提供器相关的IQueryable。注意,调用者可以把任何表达式树传给API。对于我们的提供器,这可能不是一个合法的查询。然而,有一个是肯定的,那就是表达式本身必须返回/产生正确的IQueryable类型。我们知道,IQueryable包含了表示一段代码的表达式,转化成实际代码并执行后会重新构建为相同的IQueryable(或等价形式)。

       Execute方法是提供器的入口点,用于实际执行查询表达式。例如,查询“myquery.Count()”返回单个整数。查询的表达式树是对一个返回整数的Count方法的方法调用。Queryable.Count(以及其它一些相似的聚合)使用这个方法来立即执行查询。

There, that doesn’t seem to frightening does it?   You could implement all those methods easily, right? Sure you could, but why bother.  I’ll do it for you.  Well all except for the execute method.  I’ll show you how to do that in a later post.

       看起来不是佷吓人吧。我们是不是可以轻松实现这些方法?是的,但是为什么要这么麻烦呢。我会为你实现这些。当然,除了Execute方法之外。我会在之后的文章中介绍如何实现。

       首先,让我们从IQuerayble开始。由于这个接口被分成了两个,我们就可以只实现IQueryable部分一次,然后任何提供器都可以进行重用。我会实现一个叫做Query<T>的类,它实现IQueryable<T>等接口。

    public class Query<T> : IQueryable<T>, IQueryable, IEnumerable<T>, IEnumerable, IOrderedQueryable<T>, IOrderedQueryable {

        QueryProvider provider;

        Expression expression;

 

        public Query(QueryProvider provider) {

            if (provider == null) {

                throw new ArgumentNullException("provider");

            }

            this.provider = provider;

            this.expression = Expression.Constant(this);

        }

 

        public Query(QueryProvider provider, Expression expression) {

            if (provider == null) {

                throw new ArgumentNullException("provider");

            }

            if (expression == null) {

                throw new ArgumentNullException("expression");

            }

            if (!typeof(IQueryable<T>).IsAssignableFrom(expression.Type)) {

                throw new ArgumentOutOfRangeException("expression");

            }

            this.provider = provider;

            this.expression = expression;

        }

 

        Expression IQueryable.Expression {

            get { return this.expression; }

        }

 

        Type IQueryable.ElementType {

            get { return typeof(T); }

        }

 

        IQueryProvider IQueryable.Provider {

            get { return this.provider; }

        }

 

        public IEnumerator<T> GetEnumerator() {

            return ((IEnumerable<T>)this.provider.Execute(this.expression)).GetEnumerator();

        }

 

        IEnumerator IEnumerable.GetEnumerator() {

            return ((IEnumerable)this.provider.Execute(this.expression)).GetEnumerator();

        }

 

        public override string ToString() {

            return this.provider.GetQueryText(this.expression);

        }

    }

 

       我们可以看到,IQueryable的实现是非常简单明了的。这个小对象只是保存了一个表达式树和提供器实例的引用。提供器就有趣多了。

       那么,让我们来看个提供器。我实现了一个之前Query<T>用到的叫做QueryProvider的基类。一个实际的提供器可以从这个类继承并实现Execute方法。

    public abstract class QueryProvider : IQueryProvider {

        protected QueryProvider() {

        }

 

        IQueryable<S> IQueryProvider.CreateQuery<S>(Expression expression) {

            return new Query<S>(this, expression);

        }

 

        IQueryable IQueryProvider.CreateQuery(Expression expression) {

            Type elementType = TypeSystem.GetElementType(expression.Type);

            try {

                return (IQueryable)Activator.CreateInstance(typeof(Query<>).MakeGenericType(elementType), new object[] { this, expression });

            }

            catch (TargetInvocationException tie) {

                throw tie.InnerException;

            }

        }

 

        S IQueryProvider.Execute<S>(Expression expression) {

            return (S)this.Execute(expression);

        }

 

        object IQueryProvider.Execute(Expression expression) {

            return this.Execute(expression);

        }

 

        public abstract string GetQueryText(Expression expression);

        public abstract object Execute(Expression expression);

    }

 

       在我的QueryProvider基类中实现了IQueryProvider接口。CreateQuery方法创建了Query<T>的新实例并且Execute方法把执行转交给未实现的Execute方法。

 

       我觉得你可以把这段代码看作是开始创建LINQ IQueryable提供器的模板。真的操作在Execute方法内部执行。在那个时候,提供器通过分析表达式树来理解查询。下回会继续分解。

 

更新:

       似乎我忘记定义了类实现中使用的帮助类,下面是代码:

    internal static class TypeSystem {

        internal static Type GetElementType(Type seqType) {

            Type ienum = FindIEnumerable(seqType);

            if (ienum == null) return seqType;

            return ienum.GetGenericArguments()[0];

        }

        private static Type FindIEnumerable(Type seqType) {

            if (seqType == null || seqType == typeof(string))

                return null;

            if (seqType.IsArray)

                return typeof(IEnumerable<>).MakeGenericType(seqType.GetElementType());

            if (seqType.IsGenericType) {

                foreach (Type arg in seqType.GetGenericArguments()) {

                    Type ienum = typeof(IEnumerable<>).MakeGenericType(arg);

                    if (ienum.IsAssignableFrom(seqType)) {

                        return ienum;

                    }

                }

            }

            Type[] ifaces = seqType.GetInterfaces();

            if (ifaces != null && ifaces.Length > 0) {

                foreach (Type iface in ifaces) {

                    Type ienum = FindIEnumerable(iface);

                    if (ienum != null) return ienum;

                }

            }

            if (seqType.BaseType != null && seqType.BaseType != typeof(object)) {

                return FindIEnumerable(seqType.BaseType);

            }

            return null;

        }

    }

 
 

Where以及可重用的表达式访问器

 

       既然我已经定义了叫做Query<T>QueryProvider的可重用版本的IQueryableIQueryProvider。我将会创建有实际操作的提供器。我之前说过,查询表达式真正做的只是执行一小段定义为表达式树而不是实际IL的“代码”。当然,从传统意义上说不一定要真正执行。例如,LINQ to SQL把查询表达式翻译为SQL并且发送给服务器来执行。

       我下面的示例将要做和LINQ to SQL差不多的事情,它翻译并通过ADO提供器执行查询。然而,我要声明这个示例不是完整的。我只会处理Where运算符,并且不会尝试进行任何比允许谓词包含字段引用以及一些简单运算符更复杂的事情。我会在将来扩展这个提供器,但是现在只是用于演示。请不要剪切粘帖代码或期望它用于产品。

       这个提供器将会做两个事情:1、翻译查询为SQL命名文本;2、翻译命令的执行结果为对象。

QueryTranslator

       QueryTranslator访问查询表达式树中的每一个节点,并且使用StringBuilder把提供的运算符翻译为文本。为了简单期间,我们假设有一个叫做ExpressionVisitor的类,它定义了表达式节点的访问者模式。

internal class QueryTranslator : ExpressionVisitor {

    StringBuilder sb;

 

    internal QueryTranslator() {

    }

 

    internal string Translate(Expression expression) {

        this.sb = new StringBuilder();

        this.Visit(expression);

        return this.sb.ToString();

    }

 

    private static Expression StripQuotes(Expression e) {

        while (e.NodeType == ExpressionType.Quote) {

            e = ((UnaryExpression)e).Operand;

        }

        return e;

    }

 

    protected override Expression VisitMethodCall(MethodCallExpression m) {

        if (m.Method.DeclaringType == typeof(Queryable) && m.Method.Name == "Where") {

            sb.Append("SELECT * FROM (");

            this.Visit(m.Arguments[0]);

            sb.Append(") AS T WHERE ");

            LambdaExpression lambda = (LambdaExpression)StripQuotes(m.Arguments[1]);

            this.Visit(lambda.Body);

            return m;

        }

        throw new NotSupportedException(string.Format("The method '{0}' is not supported", m.Method.Name));

    }

 

    protected override Expression VisitUnary(UnaryExpression u) {

        switch (u.NodeType) {

            case ExpressionType.Not:

                sb.Append(" NOT ");

                this.Visit(u.Operand);

                break;

            default:

                throw new NotSupportedException(string.Format("The unary operator '{0}' is not supported", u.NodeType));

        }

        return u;

    }

 

    protected override Expression VisitBinary(BinaryExpression b) {

        sb.Append("(");

        this.Visit(b.Left);

        switch (b.NodeType) {

            case ExpressionType.And:

                sb.Append(" AND ");

                break;

            case ExpressionType.Or:

                sb.Append(" OR");

                break;

            case ExpressionType.Equal:

                sb.Append(" = ");

                break;

            case ExpressionType.NotEqual:

                sb.Append(" <> ");

                break;

            case ExpressionType.LessThan:

                sb.Append(" < ");

                break;

            case ExpressionType.LessThanOrEqual:

                sb.Append(" <= ");

                break;

            case ExpressionType.GreaterThan:

                sb.Append(" > ");

                break;

            case ExpressionType.GreaterThanOrEqual:

                sb.Append(" >= ");

                break;

            default:

                throw new NotSupportedException(string.Format("The binary operator '{0}' is not supported", b.NodeType));

        }

        this.Visit(b.Right);

        sb.Append(")");

        return b;

    }

 

    protected override Expression VisitConstant(ConstantExpression c) {

        IQueryable q = c.Value as IQueryable;

        if (q != null) {

            // assume constant nodes w/ IQueryables are table references

            sb.Append("SELECT * FROM ");

            sb.Append(q.ElementType.Name);

        }

        else if (c.Value == null) {

            sb.Append("NULL");

        }

        else {

            switch (Type.GetTypeCode(c.Value.GetType())) {

                case TypeCode.Boolean:

                    sb.Append(((bool)c.Value) ? 1 : 0);

                    break;

                case TypeCode.String:

                    sb.Append("'");

                    sb.Append(c.Value);

                    sb.Append("'");

                    break;

                case TypeCode.Object:

                    throw new NotSupportedException(string.Format("The constant for '{0}' is not supported", c.Value));

                default:

                    sb.Append(c.Value);

                    break;

            }

        }

        return c;

    }

 

    protected override Expression VisitMemberAccess(MemberExpression m) {

        if (m.Expression != null && m.Expression.NodeType == ExpressionType.Parameter) {

            sb.Append(m.Member.Name);

            return m;

        }

        throw new NotSupportedException(string.Format("The member '{0}' is not supported", m.Member.Name));

    }

}

 

       可以看到,代码不是很多,但也确实有点复杂。我希望在表达式树中获取的是一个方法调用节点,它的参数指向源(参数0)和谓词(参数1)。看一下之前的VisitMethodCall方法。在这里,我显式处理Queryable.Where方法,生成“SELECT * FROM(递归访问源代并且追加)AS T WHERE”,然后访问谓词。这就允许了源表达式中其它查询运算符作为嵌套子查询。我没有处理其它运算符,但是如果Where被调用多次,我就会进行优雅的处理。至于表别名使用什么没有关系(我使用T),因为我不会再对它生成引用。当然,一个更完整的提供器可能会需要这么做。

       有一个叫做StripQutotes的帮助方法。它的工作就是方法参数中的任何ExpressionType.Quote节点,这样我就能从中获取lambda表达式。

       VisitUnaryVisitBinary方法简单明了。它们仅仅是为某个我们支持的一元或二元运算符插入相应的文本。VisitConstant方法中的翻译就有点有趣了。我们看到,在这里表根IQueryable的引用就是表的引用。我假设保存Query<T>实例的节点就是表示根表,因此操我追加“SELECT * FROM”和表的名字,也就是查询元素类型名。常量节点的其余翻译代码只是处理实际的常量。注意,这些常量被当作字面值加入到了命令文本中。代码对于注入攻击没有采取任何措施,在实际提供器中需要处理。

       最后,VisitMemberAccess假设所有字段或属性就是命名文本中的列引用。没有作任何检测来证明这个假设。字段或属性名被假设认为是匹配数据库中列名的。

       假设有一个“Customers”的类,它的字段匹配Northwind示例数据库中的列名,QueryTranslator会生成类似下面的查询:

对于查询:

Query<Customers> customers = ...;

IQueryable<Customers> q = customers.Where(c => c.City == "London");

 

è

“SELECT * FROM (SELECT *FROM Customers) AS T WHERE (city = ‘London’)”

 

ObjectReader

       ObjectReader的工作就是把SQL查询的结果转换为对象。我会创建一个简单的类,它采用DbDataReader以及类型T,我还会让它实现IEnumerable<T>。实现没有什么特别需要注意的地方,我只是通过反射来写入类字段。字段名必须匹配读取器中列的名,并且类型必须匹配DataReader认为的正确类型。

internal class ObjectReader<T> : IEnumerable<T>, IEnumerable where T : class, new() {

    Enumerator enumerator;

 

    internal ObjectReader(DbDataReader reader) {

        this.enumerator = new Enumerator(reader);

    }

 

    public IEnumerator<T> GetEnumerator() {

        Enumerator e = this.enumerator;

        if (e == null) {

            throw new InvalidOperationException("Cannot enumerate more than once");

        }

        this.enumerator = null;

        return e;

    }

 

    IEnumerator IEnumerable.GetEnumerator() {

        return this.GetEnumerator();

    }

 

    class Enumerator : IEnumerator<T>, IEnumerator, IDisposable {

        DbDataReader reader;

        FieldInfo[] fields;

        int[] fieldLookup;

        T current;

 

        internal Enumerator(DbDataReader reader) {

            this.reader = reader;

            this.fields = typeof(T).GetFields();

        }

 

        public T Current {

            get { return this.current; }

        }

 

        object IEnumerator.Current {

            get { return this.current; }

        }

 

        public bool MoveNext() {

            if (this.reader.Read()) {

                if (this.fieldLookup == null) {

                    this.InitFieldLookup();

                }

                T instance = new T();

                for (int i = 0, n = this.fields.Length; i < n; i++) {

                    int index = this.fieldLookup[i];

                    if (index >= 0) {

                        FieldInfo fi = this.fields[i];

                        if (this.reader.IsDBNull(index)) {

                            fi.SetValue(instance, null);

                        }

                        else {

                            fi.SetValue(instance, this.reader.GetValue(index));

                        }

                    }

                }

                this.current = instance;

                return true;

            }

            return false;

        }

 

        public void Reset() {

        }

 

        public void Dispose() {

            this.reader.Dispose();

        }

 

        private void InitFieldLookup() {

            Dictionary<string, int> map = new Dictionary<string, int>(StringComparer.InvariantCultureIgnoreCase);

            for (int i = 0, n = this.reader.FieldCount; i < n; i++) {

                map.Add(this.reader.GetName(i), i);

            }

            this.fieldLookup = new int[this.fields.Length];

            for (int i = 0, n = this.fields.Length; i < n; i++) {

                int index;

                if (map.TryGetValue(this.fields[i].Name, out index)) {

                    this.fieldLookup[i] = index;

                }

                else {

                    this.fieldLookup[i] = -1;

                }

            }

        }

    }

}

 

       ObjectReader创建了类型T的新实例,每行都通过DbDataReader读取。它使用FieldInfo.SetValue反射API来为对象的每一个字段进行赋值。当ObjectReader被首次创建的时候,它实例化嵌套枚举数类的实例。枚举数在GetEnumerator方法被调用的时候使用。由于DataReader不能重置重新执行,所以枚举数只能调用一次,如果GetEnumerator被第二次调用就会抛出异常。

       ObjectReader没有规定字段的排序。由于QueryTranslator使用“SELECT *”构建查询,代码也就不知道结果中首先出现的列是哪个。注意,在生产代码中使用“SELECT *”是不推荐的。要记住,这只是一个用于演示的示例。要允许列的不同序列,就需要在DataReader读取第一行的运行时精确找出序列。InitFieldLookup函数构建了从列名到列序号的映射,然后构建了一个叫做“fieldLookup”的映射对象字段和字段顺序的查找表。

 

提供器

       既然我们已经有了这些代码,把它们组成一个实际的IQueryable LINQ提供器就佷简单了。

public class DbQueryProvider : QueryProvider {

    DbConnection connection;

 

    public DbQueryProvider(DbConnection connection) {

        this.connection = connection;

    }

 

    public override string GetQueryText(Expression expression) {

        return this.Translate(expression);

    }

 

    public override object Execute(Expression expression) {

        DbCommand cmd = this.connection.CreateCommand();

        cmd.CommandText = this.Translate(expression);

        DbDataReader reader = cmd.ExecuteReader();

        Type elementType = TypeSystem.GetElementType(expression.Type);

        return Activator.CreateInstance(

            typeof(ObjectReader<>).MakeGenericType(elementType),

            BindingFlags.Instance | BindingFlags.NonPublic, null,

            new object[] { reader },

            null);

    }

 

    private string Translate(Expression expression) {

        return new QueryTranslator().Translate(expression);

    }

}

       我们看到,创建提供器就是把这些代码组装在一起。GetQueryText只需要使用QueryTranslator来获取命名文本。Execute方法使用QueryTranslatorObjectReader来创建DbCommand对象,执行它并把结果以Ienumerable返回。

 

试用一下

       既然我们已经有了自己的提供器,那么就来测试一下。我会按照LINQ to SQL模型来定一个用于Customers表的类以及一个“上下文”来保存表(根查询),和一些试用它们的程序。

public class Customers {

    public string CustomerID;

    public string ContactName;

    public string Phone;

    public string City;

    public string Country;

}

 

public class Orders {

    public int OrderID;

    public string CustomerID;

    public DateTime OrderDate;

}

 

public class Northwind {

    public Query<Customers> Customers;

    public Query<Orders> Orders;

 

    public Northwind(DbConnection connection) {

        QueryProvider provider = new DbQueryProvider(connection);

        this.Customers = new Query<Customers>(provider);

        this.Orders = new Query<Orders>(provider);

    }

}


class Program {

    static void Main(string[] args) {

        string constr = @"…";

        using (SqlConnection con = new SqlConnection(constr)) {

            con.Open();

            Northwind db = new Northwind(con);

 

            IQueryable<Customers> query =
                 db.Customers.Where(c => c.City == "London");

 

            Console.WriteLine("Query:\n{0}\n", query);

 

            var list = query.ToList();

            foreach (var item in list) {

                Console.WriteLine("Name: {0}", item.ContactName);

            }

 

            Console.ReadLine();

        }

    }
}

 

       如果运行代码会得到如下输出:(注意,我们需要为上面的程序增加自己的连接字符串)

Query:
SELECT * FROM (SELECT * FROM Customers) AS T WHERE (City = 'London')

Name: Thomas Hardy
Name: Victoria Ashworth
Name: Elizabeth Brown
Name: Ann Devon
Name: Simon Crowther
Name: Hari Kumar

       不错,就是我想要的。这只是最简单,你可以做更多。

 

附录——ExpressionVisitor

       我想我收到有关这个类的很多请求来帮助构建查询表达式。在System.Linq.Expressions中有一个ExpressionVisitor类,由于它是internal的,所以我们不能直接使用。如果大家都希望使用的话在下次版本中我们可能会把它作为public的。

       ExpressionVisitor基于经典的访问者模式。每一个节点类型有自己的方法,例如,所有二元运算符都由VisitBinary方法处理。节点本身不直接参与访问过程。它们被当作数据来处理。这样做的原因是访问者的数量是可扩展的。我们可以写自己的。XXX节点的默认访问行为发回基类版本的VisitXXXVisitXXX方法返回节点。表达式树节点是不变的。要改变树就必须构建一个新的。默认的VisitXXX方法在子树改变的时候会构建一个新的节点。如果没有改变则返回原来的节点。

代码如下:

public abstract class ExpressionVisitor {

    protected ExpressionVisitor() {

    }

 

    protected virtual Expression Visit(Expression exp) {

        if (exp == null)

            return exp;

        switch (exp.NodeType) {

            case ExpressionType.Negate:

            case ExpressionType.NegateChecked:

            case ExpressionType.Not:

            case ExpressionType.Convert:

            case ExpressionType.ConvertChecked:

            case ExpressionType.ArrayLength:

            case ExpressionType.Quote:

            case ExpressionType.TypeAs:

                return this.VisitUnary((UnaryExpression)exp);

            case ExpressionType.Add:

            case ExpressionType.AddChecked:

            case ExpressionType.Subtract:

            case ExpressionType.SubtractChecked:

            case ExpressionType.Multiply:

            case ExpressionType.MultiplyChecked:

            case ExpressionType.Divide:

            case ExpressionType.Modulo:

            case ExpressionType.And:

            case ExpressionType.AndAlso:

            case ExpressionType.Or:

            case ExpressionType.OrElse:

            case ExpressionType.LessThan:

            case ExpressionType.LessThanOrEqual:

            case ExpressionType.GreaterThan:

            case ExpressionType.GreaterThanOrEqual:

            case ExpressionType.Equal:

            case ExpressionType.NotEqual:

            case ExpressionType.Coalesce:

            case ExpressionType.ArrayIndex:

            case ExpressionType.RightShift:

            case ExpressionType.LeftShift:

            case ExpressionType.ExclusiveOr:

                return this.VisitBinary((BinaryExpression)exp);

            case ExpressionType.TypeIs:

                return this.VisitTypeIs((TypeBinaryExpression)exp);

            case ExpressionType.Conditional:

                return this.VisitConditional((ConditionalExpression)exp);

            case ExpressionType.Constant:

                return this.VisitConstant((ConstantExpression)exp);

            case ExpressionType.Parameter:

                return this.VisitParameter((ParameterExpression)exp);

            case ExpressionType.MemberAccess:

                return this.VisitMemberAccess((MemberExpression)exp);

            case ExpressionType.Call:

                return this.VisitMethodCall((MethodCallExpression)exp);

            case ExpressionType.Lambda:

                return this.VisitLambda((LambdaExpression)exp);

            case ExpressionType.New:

                return this.VisitNew((NewExpression)exp);

            case ExpressionType.NewArrayInit:

            case ExpressionType.NewArrayBounds:

                return this.VisitNewArray((NewArrayExpression)exp);

            case ExpressionType.Invoke:

                return this.VisitInvocation((InvocationExpression)exp);

            case ExpressionType.MemberInit:

                return this.VisitMemberInit((MemberInitExpression)exp);

            case ExpressionType.ListInit:

                return this.VisitListInit((ListInitExpression)exp);

            default:

                throw new Exception(string.Format("Unhandled expression type: '{0}'", exp.NodeType));

        }

    }

 

    protected virtual MemberBinding VisitBinding(MemberBinding binding) {

        switch (binding.BindingType) {

            case MemberBindingType.Assignment:

                return this.VisitMemberAssignment((MemberAssignment)binding);

            case MemberBindingType.MemberBinding:

                return this.VisitMemberMemberBinding((MemberMemberBinding)binding);

            case MemberBindingType.ListBinding:

                return this.VisitMemberListBinding((MemberListBinding)binding);

            default:

                throw new Exception(string.Format("Unhandled binding type '{0}'", binding.BindingType));

        }

    }

 

    protected virtual ElementInit VisitElementInitializer(ElementInit initializer) {

        ReadOnlyCollection<Expression> arguments = this.VisitExpressionList(initializer.Arguments);

        if (arguments != initializer.Arguments) {

            return Expression.ElementInit(initializer.AddMethod, arguments);

        }

        return initializer;

    }

 

    protected virtual Expression VisitUnary(UnaryExpression u) {

        Expression operand = this.Visit(u.Operand);

        if (operand != u.Operand) {

            return Expression.MakeUnary(u.NodeType, operand, u.Type, u.Method);

        }

        return u;

    }

 

    protected virtual Expression VisitBinary(BinaryExpression b) {

        Expression left = this.Visit(b.Left);

        Expression right = this.Visit(b.Right);

        Expression conversion = this.Visit(b.Conversion);

        if (left != b.Left || right != b.Right || conversion != b.Conversion) {

            if (b.NodeType == ExpressionType.Coalesce && b.Conversion != null)

                return Expression.Coalesce(left, right, conversion as LambdaExpression);

            else

                return Expression.MakeBinary(b.NodeType, left, right, b.IsLiftedToNull, b.Method);

        }

        return b;

    }

 

    protected virtual Expression VisitTypeIs(TypeBinaryExpression b) {

        Expression expr = this.Visit(b.Expression);

        if (expr != b.Expression) {

            return Expression.TypeIs(expr, b.TypeOperand);

        }

        return b;

    }

 

    protected virtual Expression VisitConstant(ConstantExpression c) {

        return c;

    }

 

    protected virtual Expression VisitConditional(ConditionalExpression c) {

        Expression test = this.Visit(c.Test);

        Expression ifTrue = this.Visit(c.IfTrue);

        Expression ifFalse = this.Visit(c.IfFalse);

        if (test != c.Test || ifTrue != c.IfTrue || ifFalse != c.IfFalse) {

            return Expression.Condition(test, ifTrue, ifFalse);

        }

        return c;

    }

 

    protected virtual Expression VisitParameter(ParameterExpression p) {

        return p;

    }

 

    protected virtual Expression VisitMemberAccess(MemberExpression m) {

        Expression exp = this.Visit(m.Expression);

        if (exp != m.Expression) {

            return Expression.MakeMemberAccess(exp, m.Member);

        }

        return m;

    }

 

    protected virtual Expression VisitMethodCall(MethodCallExpression m) {

        Expression obj = this.Visit(m.Object);

        IEnumerable<Expression> args = this.VisitExpressionList(m.Arguments);

        if (obj != m.Object || args != m.Arguments) {

            return Expression.Call(obj, m.Method, args);

        }

        return m;

    }

 

    protected virtual ReadOnlyCollection<Expression> VisitExpressionList(ReadOnlyCollection<Expression> original) {

        List<Expression> list = null;

        for (int i = 0, n = original.Count; i < n; i++) {

            Expression p = this.Visit(original[i]);

            if (list != null) {

                list.Add(p);

            }

            else if (p != original[i]) {

                list = new List<Expression>(n);

                for (int j = 0; j < i; j++) {

                    list.Add(original[j]);

                }

                list.Add(p);

            }

        }

        if (list != null) {

            return list.AsReadOnly();

        }

        return original;

    }

 

    protected virtual MemberAssignment VisitMemberAssignment(MemberAssignment assignment) {

        Expression e = this.Visit(assignment.Expression);

        if (e != assignment.Expression) {

            return Expression.Bind(assignment.Member, e);

        }

        return assignment;

    }

 

    protected virtual MemberMemberBinding VisitMemberMemberBinding(MemberMemberBinding binding) {

        IEnumerable<MemberBinding> bindings = this.VisitBindingList(binding.Bindings);

        if (bindings != binding.Bindings) {

            return Expression.MemberBind(binding.Member, bindings);

        }

        return binding;

    }

 

    protected virtual MemberListBinding VisitMemberListBinding(MemberListBinding binding) {

        IEnumerable<ElementInit> initializers = this.VisitElementInitializerList(binding.Initializers);

        if (initializers != binding.Initializers) {

            return Expression.ListBind(binding.Member, initializers);

        }

        return binding;

    }

 

    protected virtual IEnumerable<MemberBinding> VisitBindingList(ReadOnlyCollection<MemberBinding> original) {

        List<MemberBinding> list = null;

        for (int i = 0, n = original.Count; i < n; i++) {

            MemberBinding b = this.VisitBinding(original[i]);

            if (list != null) {

                list.Add(b);

            }

            else if (b != original[i]) {

                list = new List<MemberBinding>(n);

                for (int j = 0; j < i; j++) {

                    list.Add(original[j]);

                }

                list.Add(b);

            }

        }

        if (list != null)

            return list;

        return original;

    }

 

    protected virtual IEnumerable<ElementInit> VisitElementInitializerList(ReadOnlyCollection<ElementInit> original) {

        List<ElementInit> list = null;

        for (int i = 0, n = original.Count; i < n; i++) {

            ElementInit init = this.VisitElementInitializer(original[i]);

            if (list != null) {

                list.Add(init);

            }

            else if (init != original[i]) {

                list = new List<ElementInit>(n);

                for (int j = 0; j < i; j++) {

                    list.Add(original[j]);

                }

                list.Add(init);

            }

        }

        if (list != null)

            return list;

        return original;

    }

 

    protected virtual Expression VisitLambda(LambdaExpression lambda) {

        Expression body = this.Visit(lambda.Body);

        if (body != lambda.Body) {

            return Expression.Lambda(lambda.Type, body, lambda.Parameters);

        }

        return lambda;

    }

 

    protected virtual NewExpression VisitNew(NewExpression nex) {

        IEnumerable<Expression> args = this.VisitExpressionList(nex.Arguments);

        if (args != nex.Arguments) {

            if (nex.Members != null)

                return Expression.New(nex.Constructor, args, nex.Members);

            else

                return Expression.New(nex.Constructor, args);

        }

        return nex;

    }

 

    protected virtual Expression VisitMemberInit(MemberInitExpression init) {

        NewExpression n = this.VisitNew(init.NewExpression);

        IEnumerable<MemberBinding> bindings = this.VisitBindingList(init.Bindings);

        if (n != init.NewExpression || bindings != init.Bindings) {

            return Expression.MemberInit(n, bindings);

        }

        return init;

    }

 

    protected virtual Expression VisitListInit(ListInitExpression init) {

        NewExpression n = this.VisitNew(init.NewExpression);

        IEnumerable<ElementInit> initializers = this.VisitElementInitializerList(init.Initializers);

        if (n != init.NewExpression || initializers != init.Initializers) {

            return Expression.ListInit(n, initializers);

        }

        return init;

    }

 

    protected virtual Expression VisitNewArray(NewArrayExpression na) {

        IEnumerable<Expression> exprs = this.VisitExpressionList(na.Expressions);

        if (exprs != na.Expressions) {

            if (na.NodeType == ExpressionType.NewArrayInit) {

                return Expression.NewArrayInit(na.Type.GetElementType(), exprs);

            }

            else {

                return Expression.NewArrayBounds(na.Type.GetElementType(), exprs);

            }

        }

        return na;

    }

 

    protected virtual Expression VisitInvocation(InvocationExpression iv) {

        IEnumerable<Expression> args = this.VisitExpressionList(iv.Arguments);

        Expression expr = this.Visit(iv.Expression);

        if (args != iv.Arguments || expr != iv.Expression) {

            return Expression.Invoke(expr, args);

        }

        return iv;

    }

}

 

posted @ 2008-03-28 15:16  lovecherry  阅读(4977)  评论(10编辑  收藏  举报