反射技术用于识别扩展方法

82

在C#中,是否有使用反射来确定方法是否已作为扩展方法添加到类中的技术?

如果有一个像下面所示的扩展方法,是否可以确定Reverse()已被添加到字符串类中?

public static class StringExtensions
{
    public static string Reverse(this string value)
    {
        char[] cArray = value.ToCharArray();
        Array.Reverse(cArray);
        return new string(cArray);
    }
}

我们正在寻找一种机制,在单元测试中确定开发人员是否正确添加了扩展方法。尝试这样做的一个原因是开发人员可能会向实际类中添加类似的方法,如果添加了该方法,编译器将会使用它。

7个回答

126

你需要查看所有可能定义扩展方法的程序集。

查找带有ExtensionAttribute的类,然后在该类中查找同样被装饰上ExtensionAttribute的方法。接着检查第一个参数的类型是否与你感兴趣的类型匹配。

以下是完整的代码。它可能不够严谨(没有检查类型是否嵌套或至少有一个参数),但应该可以帮到你。

using System;
using System.Runtime.CompilerServices;
using System.Reflection;
using System.Linq;
using System.Collections.Generic;

public static class FirstExtensions
{
    public static void Foo(this string x) { }
    public static void Bar(string x) { } // Not an ext. method
    public static void Baz(this int x) { } // Not on string
}

public static class SecondExtensions
{
    public static void Quux(this string x) { }
}

public class Test
{
    static void Main()
    {
        Assembly thisAssembly = typeof(Test).Assembly;
        foreach (MethodInfo method in GetExtensionMethods(thisAssembly, typeof(string)))
        {
            Console.WriteLine(method);
        }
    }
    static IEnumerable<MethodInfo> GetExtensionMethods(Assembly assembly, Type extendedType)
    {
        var isGenericTypeDefinition = extendedType.IsGenericType && extendedType.IsTypeDefinition;
        var query = from type in assembly.GetTypes()
            where type.IsSealed && !type.IsGenericType && !type.IsNested
            from method in type.GetMethods(BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic)
            where method.IsDefined(typeof(ExtensionAttribute), false)
            where isGenericTypeDefinition
                ? method.GetParameters()[0].ParameterType.IsGenericType && method.GetParameters()[0].ParameterType.GetGenericTypeDefinition() == extendedType
                : method.GetParameters()[0].ParameterType == extendedType
            select method;
        return query;
    }
}

5
不错的代码。你可以利用这个事实来排除一堆方法:扩展方法必须在非泛型静态类中定义,其中 !type.IsGenericType && type.IsSealed。 - Amy B
4
@Seb:是的,如果要让它适用于通用方法,需要付出更多的努力。这是可行的,但比较棘手。 - Jon Skeet
1
@nawfal:C#编译器会进行这种检查——如果不在顶级静态非泛型类中,则它不是扩展方法……任何人都可以将[Extension]属性应用于方法——如果不满足这些条件,那就不是扩展方法。 - Jon Skeet
1
@nawfal:就C#编译器而言,这时它是一个扩展方法——即使它不是从带有“this”参数的代码编译而来,编译器也会将其接受为扩展方法。在IL中识别它与编译器的方式一样接近,这已经是你能得到的最接近的了。 - Jon Skeet
1
@nawfal:有趣-我刚试了一下,你是对的。这是一个新的限制;旧的编译器没有这个问题。但你仍然可以通过其他语言或直接作为IL来构建这种方式的IL。我的观点是,我回答中的代码仍然有效地模仿了C#编译器检测扩展方法的方式。 - Jon Skeet
显示剩余5条评论

13

基于John Skeet的回答,我创建了自己的扩展到System.Type类型。

using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;

namespace System
{
    public static class TypeExtension
    {
        /// <summary>
        /// This Methode extends the System.Type-type to get all extended methods. It searches hereby in all assemblies which are known by the current AppDomain.
        /// </summary>
        /// <remarks>
        /// Insired by Jon Skeet from his answer on https://dev59.com/ZHVC5IYBdhLWcg3wbQfA
        /// </remarks>
        /// <returns>returns MethodInfo[] with the extended Method</returns>

        public static MethodInfo[] GetExtensionMethods(this Type t)
        {
            List<Type> AssTypes = new List<Type>();

            foreach (Assembly item in AppDomain.CurrentDomain.GetAssemblies())
            {
                AssTypes.AddRange(item.GetTypes());
            }

            var query = from type in AssTypes
                where type.IsSealed && !type.IsGenericType && !type.IsNested
                from method in type.GetMethods(BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic)
                where method.IsDefined(typeof(ExtensionAttribute), false)
                where method.GetParameters()[0].ParameterType == t
                select method;
            return query.ToArray<MethodInfo>();
        }

        /// <summary>
        /// Extends the System.Type-type to search for a given extended MethodeName.
        /// </summary>
        /// <param name="MethodeName">Name of the Methode</param>
        /// <returns>the found Methode or null</returns>
        public static MethodInfo GetExtensionMethod(this Type t, string MethodeName)
        {
            var mi = from methode in t.GetExtensionMethods()
                where methode.Name == MethodeName
                select methode;
            if (mi.Count<MethodInfo>() <= 0)
                return null;
            else
                return mi.First<MethodInfo>();
        }
    }
}

它获取当前AppDomain中的所有程序集,并搜索扩展方法。

使用:

Type t = typeof(Type);
MethodInfo[] extendedMethods = t.GetExtensionMethods();
MethodInfo extendedMethodInfo = t.GetExtensionMethod("GetExtensionMethods");

下一步是扩展System.Type方法,返回所有方法(包括“普通”方法和扩展方法)。

5
这将返回指定类型定义的所有扩展方法列表,包括泛型方法:
public static IEnumerable<KeyValuePair<Type, MethodInfo>> GetExtensionMethodsDefinedInType(this Type t)
{
    if (!t.IsSealed || t.IsGenericType || t.IsNested)
        return Enumerable.Empty<KeyValuePair<Type, MethodInfo>>();

    var methods = t.GetMethods(BindingFlags.Public | BindingFlags.Static)
                   .Where(m => m.IsDefined(typeof(ExtensionAttribute), false));

    List<KeyValuePair<Type, MethodInfo>> pairs = new List<KeyValuePair<Type, MethodInfo>>();
    foreach (var m in methods)
    {
        var parameters = m.GetParameters();
        if (parameters.Length > 0)
        {
            if (parameters[0].ParameterType.IsGenericParameter)
            {
                if (m.ContainsGenericParameters)
                {
                    var genericParameters = m.GetGenericArguments();
                    Type genericParam = genericParameters[parameters[0].ParameterType.GenericParameterPosition];
                    foreach (var constraint in genericParam.GetGenericParameterConstraints())
                        pairs.Add(new KeyValuePair<Type, MethodInfo>(parameters[0].ParameterType, m));
                }
            }
            else
                pairs.Add(new KeyValuePair<Type, MethodInfo>(parameters[0].ParameterType, m));
        }
    }

    return pairs;
}

但这里有一个问题:返回的Type类型与typeof(..)预期的不同,因为它是泛型参数类型。为了查找给定类型的所有扩展方法,您必须比较Type的所有基类型和接口的GUID,例如:

public List<MethodInfo> GetExtensionMethodsOf(Type t)
{
    List<MethodInfo> methods = new List<MethodInfo>();
    Type cur = t;
    while (cur != null)
    {

        TypeInfo tInfo;
        if (typeInfo.TryGetValue(cur.GUID, out tInfo))
            methods.AddRange(tInfo.ExtensionMethods);


        foreach (var iface in cur.GetInterfaces())
        {
            if (typeInfo.TryGetValue(iface.GUID, out tInfo))
                methods.AddRange(tInfo.ExtensionMethods);
        }

        cur = cur.BaseType;
    }
    return methods;
}

完整来说:

我保留了所有程序集的所有类型的迭代,并建立了一个类型信息对象的字典:

private Dictionary<Guid, TypeInfo> typeInfo = new Dictionary<Guid, TypeInfo>();

其中TypeInfo的定义如下:

public class TypeInfo
{
    public TypeInfo()
    {
        ExtensionMethods = new List<MethodInfo>();
    }

    public List<ConstructorInfo> Constructors { get; set; }

    public List<FieldInfo> Fields { get; set; }
    public List<PropertyInfo> Properties { get; set; }
    public List<MethodInfo> Methods { get; set; }

    public List<MethodInfo> ExtensionMethods { get; set; }
}

目前你的代码为多个约束条件添加了多个相同的KVPs,而“显而易见”的修复方法(实际上对于单个约束条件更正确)是添加(constraint, m}并不正确,因为约束条件是“与”而不是“或”。 - Mark Hurd
实际上,从编译器的角度来看,因为它不认为约束条件是区分因素,所以当前的代码有一定的优点:只需删除foreach约束行。但这对我们通过反射实现扩展方法没有帮助 :-( - Mark Hurd

3
为了澄清 Jon 没有详细说明的一点... "添加" 扩展方法并不会以任何方式改变类,这只是 C# 编译器执行的一个小技巧。
因此,使用您的示例,您可以编写:
string rev = myStr.Reverse();

但是写入程序集的中间语言代码与您手动编写的代码完全相同:

string rev = StringExtensions.Reverse(myStr);

编译器只是让你自以为在调用String的一个方法。

3
是的,我完全意识到编译器正在进行一些“魔法”来隐藏细节。这是我们对于检测单元测试中的方法是否为扩展方法感兴趣的原因之一。 - Mike Chess

2
尝试进行这项工作的一个原因是,开发人员可能会添加类似的方法,如果添加了该方法,编译器将捕获该方法。
假设定义了扩展方法 void Foo(this Customer someCustomer)。 假设还修改了Customer并添加了方法void Foo()。 然后,Customer上的新方法将覆盖/隐藏扩展方法。
在那时调用旧的Foo方法的唯一方法是:
CustomerExtension.Foo(myCustomer);

0
void Main()
{
    var test = new Test();
    var testWithMethod = new TestWithExtensionMethod();
    Tools.IsExtensionMethodCall(() => test.Method()).Dump();
    Tools.IsExtensionMethodCall(() => testWithMethod.Method()).Dump();
}

public class Test 
{
    public void Method() { }
}

public class TestWithExtensionMethod
{
}

public static class Extensions
{
    public static void Method(this TestWithExtensionMethod test) { }
}

public static class Tools
{
    public static MethodInfo GetCalledMethodInfo(Expression<Action> expr)
    {
        var methodCall = expr.Body as MethodCallExpression;
        return methodCall.Method;
    }

    public static bool IsExtensionMethodCall(Expression<Action> expr)
    {
        var methodInfo = GetCalledMethodInfo(expr);
        return methodInfo.IsStatic;
    }
}

输出:


0

这是基于@Jon Skeet的答案,使用LINQ方法语法而不是查询语法的解决方案。

public static IEnumerable<MethodInfo> GetExtensionMethods(Assembly assembly, Type extendedType)
{
    var methods = assembly.GetTypes()
        .Where(type => type.IsSealed && !type.IsGenericType && !type.IsNested)
        .SelectMany(type => type.GetMethods(BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic))
        .Where(method => method.IsDefined(typeof(ExtensionAttribute), false) && 
                         method.GetParameters()[0].ParameterType == extendedType);
    return methods;
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接