ExpressionVisitor.cs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Collections.ObjectModel;
  4. using System.Linq.Expressions;
  5. namespace Common.LambdaToSQL
  6. {
  7. internal abstract class ExpressionVisitor
  8. {
  9. protected ExpressionVisitor()
  10. {
  11. }
  12. protected virtual Expression Visit(Expression exp)
  13. {
  14. if (exp == null)
  15. return null;
  16. switch (exp.NodeType)
  17. {
  18. case ExpressionType.Negate:
  19. case ExpressionType.NegateChecked:
  20. case ExpressionType.Not:
  21. case ExpressionType.Convert:
  22. case ExpressionType.ConvertChecked:
  23. case ExpressionType.ArrayLength:
  24. case ExpressionType.Quote:
  25. case ExpressionType.TypeAs:
  26. return this.VisitUnary((UnaryExpression)exp);
  27. case ExpressionType.Add:
  28. case ExpressionType.AddChecked:
  29. case ExpressionType.Subtract:
  30. case ExpressionType.SubtractChecked:
  31. case ExpressionType.Multiply:
  32. case ExpressionType.MultiplyChecked:
  33. case ExpressionType.Divide:
  34. case ExpressionType.Modulo:
  35. case ExpressionType.And:
  36. case ExpressionType.AndAlso:
  37. case ExpressionType.Or:
  38. case ExpressionType.OrElse:
  39. case ExpressionType.LessThan:
  40. case ExpressionType.LessThanOrEqual:
  41. case ExpressionType.GreaterThan:
  42. case ExpressionType.GreaterThanOrEqual:
  43. case ExpressionType.Equal:
  44. case ExpressionType.NotEqual:
  45. case ExpressionType.Coalesce:
  46. case ExpressionType.ArrayIndex:
  47. case ExpressionType.RightShift:
  48. case ExpressionType.LeftShift:
  49. case ExpressionType.ExclusiveOr:
  50. return this.VisitBinary((BinaryExpression)exp);
  51. case ExpressionType.TypeIs:
  52. return this.VisitTypeIs((TypeBinaryExpression)exp);
  53. case ExpressionType.Conditional:
  54. return this.VisitConditional((ConditionalExpression)exp);
  55. case ExpressionType.Constant:
  56. return this.VisitConstant((ConstantExpression)exp);
  57. case ExpressionType.Parameter:
  58. return this.VisitParameter((ParameterExpression)exp);
  59. case ExpressionType.MemberAccess:
  60. return this.VisitMemberAccess((MemberExpression)exp);
  61. case ExpressionType.Call:
  62. return this.VisitMethodCall((MethodCallExpression)exp);
  63. case ExpressionType.Lambda:
  64. return this.VisitLambda((LambdaExpression)exp);
  65. case ExpressionType.New:
  66. return this.VisitNew((NewExpression)exp);
  67. case ExpressionType.NewArrayInit:
  68. case ExpressionType.NewArrayBounds:
  69. return this.VisitNewArray((NewArrayExpression)exp);
  70. case ExpressionType.Invoke:
  71. return this.VisitInvocation((InvocationExpression)exp);
  72. case ExpressionType.MemberInit:
  73. return this.VisitMemberInit((MemberInitExpression)exp);
  74. case ExpressionType.ListInit:
  75. return this.VisitListInit((ListInitExpression)exp);
  76. default:
  77. throw new Exception(string.Format("Unhandled expression type: '{0}'", exp.NodeType));
  78. }
  79. }
  80. protected virtual Expression VisitUnknown(Expression expression)
  81. {
  82. throw new Exception(string.Format("Unhandled expression type: '{0}'", expression.NodeType));
  83. }
  84. protected virtual MemberBinding VisitBinding(MemberBinding binding)
  85. {
  86. switch (binding.BindingType)
  87. {
  88. case MemberBindingType.Assignment:
  89. return this.VisitMemberAssignment((MemberAssignment)binding);
  90. case MemberBindingType.MemberBinding:
  91. return this.VisitMemberMemberBinding((MemberMemberBinding)binding);
  92. case MemberBindingType.ListBinding:
  93. return this.VisitMemberListBinding((MemberListBinding)binding);
  94. default:
  95. throw new Exception(string.Format("Unhandled binding type '{0}'", binding.BindingType));
  96. }
  97. }
  98. protected virtual ElementInit VisitElementInitializer(ElementInit initializer)
  99. {
  100. ReadOnlyCollection<Expression> arguments = this.VisitExpressionList(initializer.Arguments);
  101. if (arguments != initializer.Arguments)
  102. {
  103. return Expression.ElementInit(initializer.AddMethod, arguments);
  104. }
  105. return initializer;
  106. }
  107. protected virtual Expression VisitUnary(UnaryExpression u)
  108. {
  109. Expression operand = this.Visit(u.Operand);
  110. if (operand != u.Operand)
  111. {
  112. return Expression.MakeUnary(u.NodeType, operand, u.Type, u.Method);
  113. }
  114. return u;
  115. }
  116. protected virtual Expression VisitBinary(BinaryExpression b)
  117. {
  118. Expression left = this.Visit(b.Left);
  119. Expression right = this.Visit(b.Right);
  120. Expression conversion = this.Visit(b.Conversion);
  121. if (left != b.Left || right != b.Right || conversion != b.Conversion)
  122. {
  123. if (b.NodeType == ExpressionType.Coalesce && b.Conversion != null)
  124. return Expression.Coalesce(left, right, conversion as LambdaExpression);
  125. else
  126. return Expression.MakeBinary(b.NodeType, left, right, b.IsLiftedToNull, b.Method);
  127. }
  128. return b;
  129. }
  130. protected virtual Expression VisitTypeIs(TypeBinaryExpression b)
  131. {
  132. Expression expr = this.Visit(b.Expression);
  133. if (expr != b.Expression)
  134. {
  135. return Expression.TypeIs(expr, b.TypeOperand);
  136. }
  137. return b;
  138. }
  139. protected virtual Expression VisitConstant(ConstantExpression c)
  140. {
  141. return c;
  142. }
  143. protected virtual Expression VisitConditional(ConditionalExpression c)
  144. {
  145. Expression test = this.Visit(c.Test);
  146. Expression ifTrue = this.Visit(c.IfTrue);
  147. Expression ifFalse = this.Visit(c.IfFalse);
  148. if (test != c.Test || ifTrue != c.IfTrue || ifFalse != c.IfFalse)
  149. {
  150. return Expression.Condition(test, ifTrue, ifFalse);
  151. }
  152. return c;
  153. }
  154. protected virtual Expression VisitParameter(ParameterExpression p)
  155. {
  156. return p;
  157. }
  158. protected virtual Expression VisitMemberAccess(MemberExpression m)
  159. {
  160. Expression exp = this.Visit(m.Expression);
  161. if (exp != m.Expression)
  162. {
  163. return Expression.MakeMemberAccess(exp, m.Member);
  164. }
  165. return m;
  166. }
  167. protected virtual Expression VisitMethodCall(MethodCallExpression m)
  168. {
  169. Expression obj = this.Visit(m.Object);
  170. IEnumerable<Expression> args = this.VisitExpressionList(m.Arguments);
  171. if (obj != m.Object || args != m.Arguments)
  172. {
  173. return Expression.Call(obj, m.Method, args);
  174. }
  175. return m;
  176. }
  177. protected virtual ReadOnlyCollection<Expression> VisitExpressionList(ReadOnlyCollection<Expression> original)
  178. {
  179. List<Expression> list = null;
  180. for (int i = 0, n = original.Count; i < n; i++)
  181. {
  182. Expression p = this.Visit(original[i]);
  183. if (list != null)
  184. {
  185. list.Add(p);
  186. }
  187. else if (p != original[i])
  188. {
  189. list = new List<Expression>(n);
  190. for (int j = 0; j < i; j++)
  191. {
  192. list.Add(original[j]);
  193. }
  194. list.Add(p);
  195. }
  196. }
  197. if (list != null)
  198. {
  199. return list.AsReadOnly();
  200. }
  201. return original;
  202. }
  203. protected virtual MemberAssignment VisitMemberAssignment(MemberAssignment assignment)
  204. {
  205. Expression e = this.Visit(assignment.Expression);
  206. if (e != assignment.Expression)
  207. {
  208. return Expression.Bind(assignment.Member, e);
  209. }
  210. return assignment;
  211. }
  212. protected virtual MemberMemberBinding VisitMemberMemberBinding(MemberMemberBinding binding)
  213. {
  214. IEnumerable<MemberBinding> bindings = this.VisitBindingList(binding.Bindings);
  215. if (bindings != binding.Bindings)
  216. {
  217. return Expression.MemberBind(binding.Member, bindings);
  218. }
  219. return binding;
  220. }
  221. protected virtual MemberListBinding VisitMemberListBinding(MemberListBinding binding)
  222. {
  223. IEnumerable<ElementInit> initializers = this.VisitElementInitializerList(binding.Initializers);
  224. if (initializers != binding.Initializers)
  225. {
  226. return Expression.ListBind(binding.Member, initializers);
  227. }
  228. return binding;
  229. }
  230. protected virtual IEnumerable<MemberBinding> VisitBindingList(ReadOnlyCollection<MemberBinding> original)
  231. {
  232. List<MemberBinding> list = null;
  233. for (int i = 0, n = original.Count; i < n; i++)
  234. {
  235. MemberBinding b = this.VisitBinding(original[i]);
  236. if (list != null)
  237. {
  238. list.Add(b);
  239. }
  240. else if (b != original[i])
  241. {
  242. list = new List<MemberBinding>(n);
  243. for (int j = 0; j < i; j++)
  244. {
  245. list.Add(original[j]);
  246. }
  247. list.Add(b);
  248. }
  249. }
  250. if (list != null)
  251. return list;
  252. return original;
  253. }
  254. protected virtual IEnumerable<ElementInit> VisitElementInitializerList(ReadOnlyCollection<ElementInit> original)
  255. {
  256. List<ElementInit> list = null;
  257. for (int i = 0, n = original.Count; i < n; i++)
  258. {
  259. ElementInit init = this.VisitElementInitializer(original[i]);
  260. if (list != null)
  261. {
  262. list.Add(init);
  263. }
  264. else if (init != original[i])
  265. {
  266. list = new List<ElementInit>(n);
  267. for (int j = 0; j < i; j++)
  268. {
  269. list.Add(original[j]);
  270. }
  271. list.Add(init);
  272. }
  273. }
  274. if (list != null)
  275. return list;
  276. return original;
  277. }
  278. protected virtual Expression VisitLambda(LambdaExpression lambda)
  279. {
  280. Expression body = this.Visit(lambda.Body);
  281. if (body != lambda.Body)
  282. {
  283. return Expression.Lambda(lambda.Type, body, lambda.Parameters);
  284. }
  285. return lambda;
  286. }
  287. protected virtual NewExpression VisitNew(NewExpression nex)
  288. {
  289. IEnumerable<Expression> args = this.VisitExpressionList(nex.Arguments);
  290. if (args != nex.Arguments)
  291. {
  292. if (nex.Members != null)
  293. return Expression.New(nex.Constructor, args, nex.Members);
  294. else
  295. return Expression.New(nex.Constructor, args);
  296. }
  297. return nex;
  298. }
  299. protected virtual Expression VisitMemberInit(MemberInitExpression init)
  300. {
  301. NewExpression n = this.VisitNew(init.NewExpression);
  302. IEnumerable<MemberBinding> bindings = this.VisitBindingList(init.Bindings);
  303. if (n != init.NewExpression || bindings != init.Bindings)
  304. {
  305. return Expression.MemberInit(n, bindings);
  306. }
  307. return init;
  308. }
  309. protected virtual Expression VisitListInit(ListInitExpression init)
  310. {
  311. NewExpression n = this.VisitNew(init.NewExpression);
  312. IEnumerable<ElementInit> initializers = this.VisitElementInitializerList(init.Initializers);
  313. if (n != init.NewExpression || initializers != init.Initializers)
  314. {
  315. return Expression.ListInit(n, initializers);
  316. }
  317. return init;
  318. }
  319. protected virtual Expression VisitNewArray(NewArrayExpression na)
  320. {
  321. IEnumerable<Expression> exprs = this.VisitExpressionList(na.Expressions);
  322. if (exprs != na.Expressions)
  323. {
  324. if (na.NodeType == ExpressionType.NewArrayInit)
  325. {
  326. return Expression.NewArrayInit(na.Type.GetElementType() ?? throw new InvalidOperationException(), exprs);
  327. }
  328. else
  329. {
  330. return Expression.NewArrayBounds(na.Type.GetElementType() ?? throw new InvalidOperationException(), exprs);
  331. }
  332. }
  333. return na;
  334. }
  335. protected virtual Expression VisitInvocation(InvocationExpression iv)
  336. {
  337. IEnumerable<Expression> args = this.VisitExpressionList(iv.Arguments);
  338. Expression expr = this.Visit(iv.Expression);
  339. if (args != iv.Arguments || expr != iv.Expression)
  340. {
  341. return Expression.Invoke(expr, args);
  342. }
  343. return iv;
  344. }
  345. }
  346. }