需要检查代码是否包含特定标识符

8

我将使用Roslyn动态编译和执行代码,类似下面的示例。我想确保代码不违反我的某些规则,例如:

  • 不使用Reflection
  • 不使用HttpClient或WebClient
  • 不使用System.IO命名空间中的File或Directory类
  • 不使用Source Generators
  • 不调用非托管代码

在下面的代码中,我应该在哪里插入我的规则/检查,以及如何实现它们?

using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Emit;
using System.Reflection;
using System.Runtime.CompilerServices;

string code = @"using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using System.IO;

namespace Customization
{
    public class Script
    {
        public async Task<object?> RunAsync(object? data)
        {
            //The following should not be allowed
            File.Delete(@""C:\Temp\log.txt"");

            return await Task.FromResult(data);
        }
    }
}";

var compilation = Compile(code);
var bytes = Build(compilation);

Console.WriteLine("Done");

CSharpCompilation Compile(string code)
{
    SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(code);

    string? dotNetCoreDirectoryPath = Path.GetDirectoryName(typeof(object).GetTypeInfo().Assembly.Location);
    if (String.IsNullOrWhiteSpace(dotNetCoreDirectoryPath))
    {
        throw new ArgumentNullException("Cannot determine path to current assembly.");
    }

    string assemblyName = Path.GetRandomFileName();
    List<MetadataReference> references = new();
    references.Add(MetadataReference.CreateFromFile(typeof(object).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Enumerable).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Console).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Dictionary<,>).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Task).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(Path.Combine(dotNetCoreDirectoryPath, "System.Runtime.dll")));

    CSharpCompilation compilation = CSharpCompilation.Create(
        assemblyName,
        syntaxTrees: new[] { syntaxTree },
        references: references,
        options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));


    SemanticModel model = compilation.GetSemanticModel(syntaxTree);
    CompilationUnitSyntax root = (CompilationUnitSyntax)syntaxTree.GetRoot();

    //TODO: Check the code for use classes that are not allowed such as File in the System.IO namespace.
    //Not exactly sure how to walk through identifiers.
    IEnumerable<IdentifierNameSyntax> identifiers = root.DescendantNodes()
        .Where(s => s is IdentifierNameSyntax)
        .Cast<IdentifierNameSyntax>();


    return compilation;
}

[MethodImpl(MethodImplOptions.NoInlining)]
byte[] Build(CSharpCompilation compilation)
{
    using (MemoryStream ms = new())
    {
        //Emit to catch build errors
        EmitResult emitResult = compilation.Emit(ms);

        if (!emitResult.Success)
        {
            Diagnostic? firstError =
                emitResult
                    .Diagnostics
                    .FirstOrDefault
                    (
                        diagnostic => diagnostic.IsWarningAsError ||
                            diagnostic.Severity == DiagnosticSeverity.Error
                    );

            throw new Exception(firstError?.GetMessage());
        }

        return ms.ToArray();
    }
}
2个回答

4

在检查特定类的使用时,您可以使用OfType<>()方法查找IdentifierNameSyntax类型节点,并通过类名过滤结果:

var names = root.DescendantNodes()
    .OfType<IdentifierNameSyntax>()
    .Where(i => string.Equals(i.Identifier.ValueText, className, StringComparison.OrdinalIgnoreCase));


您可以使用SemanticModel检查类的命名空间:
foreach (var name in names)
{
    var typeInfo = model.GetTypeInfo(name);
    if (string.Equals(typeInfo.Type?.ContainingNamespace?.ToString(), containingNamespace, StringComparison.OrdinalIgnoreCase))
    {
        throw new Exception($"Class {containingNamespace}.{className} is not allowed.");
    }
}

要检查是否使用了反射或非托管代码,您可以检查相关的usings System.ReflectionSystem.Runtime.InteropServices

if (root.Usings.Any(u => string.Equals(u.Name.ToString(), disallowedNamespace, StringComparison.OrdinalIgnoreCase)))
{
    throw new Exception($"Namespace {disallowedNamespace} is not allowed.");
}

这样可以检测到未使用 usings 的情况,即没有实际反射或非托管代码,但这似乎是一个可接受的折衷方案。

我不确定该如何处理源代码生成器的检查,因为它们通常作为项目引用包含在内,所以我不知道它们会如何针对动态编译的代码运行。

在同一位置保留检查并更新您的代码如下:

using System.Reflection;
using System.Runtime.CompilerServices;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Emit;

string code = @"using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using System.IO;
using System;
using System.Net.Http;
using System.Reflection;
using System.Runtime.InteropServices

namespace Customization
{
    public class Script
    {
        static readonly HttpClient client = new HttpClient();

        public async Task<object?> RunAsync(object? data)
        {
            //The following should not be allowed
            File.Delete(@""C:\Temp\log.txt"");

            return await Task.FromResult(data);
        }
    }
}";

var compilation = Compile(code);

var bytes = Build(compilation);
Console.WriteLine("Done");


CSharpCompilation Compile(string code)
{
    SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(code);

    string? dotNetCoreDirectoryPath = Path.GetDirectoryName(typeof(object).GetTypeInfo().Assembly.Location);
    if (String.IsNullOrWhiteSpace(dotNetCoreDirectoryPath))
    {
        throw new InvalidOperationException("Cannot determine path to current assembly.");
    }

    string assemblyName = Path.GetRandomFileName();
    List<MetadataReference> references = new();
    references.Add(MetadataReference.CreateFromFile(typeof(object).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Enumerable).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Console).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Dictionary<,>).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(Task).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(typeof(HttpClient).Assembly.Location));
    references.Add(MetadataReference.CreateFromFile(Path.Combine(dotNetCoreDirectoryPath, "System.Runtime.dll")));

    CSharpCompilation compilation = CSharpCompilation.Create(
        assemblyName,
        syntaxTrees: new[] { syntaxTree },
        references: references,
        options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));


    SemanticModel model = compilation.GetSemanticModel(syntaxTree);
    CompilationUnitSyntax root = (CompilationUnitSyntax)syntaxTree.GetRoot();

    ThrowOnDisallowedClass("File", "System.IO", root, model);
    ThrowOnDisallowedClass("HttpClient", "System.Net.Http", root, model);
    ThrowOnDisallowedNamespace("System.Reflection", root);
    ThrowOnDisallowedNamespace("System.Runtime.InteropServices", root);

    return compilation;
}

[MethodImpl(MethodImplOptions.NoInlining)]
byte[] Build(CSharpCompilation compilation)
{
    using (MemoryStream ms = new())
    {
        //Emit to catch build errors
        EmitResult emitResult = compilation.Emit(ms);

        if (!emitResult.Success)
        {
            Diagnostic? firstError =
                emitResult
                    .Diagnostics
                    .FirstOrDefault
                    (
                        diagnostic => diagnostic.IsWarningAsError ||
                            diagnostic.Severity == DiagnosticSeverity.Error
                    );

            throw new Exception(firstError?.GetMessage());
        }

        return ms.ToArray();
    }
}

void ThrowOnDisallowedClass(string className, string containingNamespace, CompilationUnitSyntax root, SemanticModel model)
{
    var names = root.DescendantNodes()
                    .OfType<IdentifierNameSyntax>()
                    .Where(i => string.Equals(i.Identifier.ValueText, className, StringComparison.OrdinalIgnoreCase));

    foreach (var name in names)
    {
        var typeInfo = model.GetTypeInfo(name);
        if (string.Equals(typeInfo.Type?.ContainingNamespace?.ToString(), containingNamespace, StringComparison.OrdinalIgnoreCase))
        {
            throw new Exception($"Class {containingNamespace}.{className} is not allowed.");
        }
    }
}

void ThrowOnDisallowedNamespace(string disallowedNamespace, CompilationUnitSyntax root)
{
    if (root.Usings.Any(u => string.Equals(u.Name.ToString(), disallowedNamespace, StringComparison.OrdinalIgnoreCase)))
    {
        throw new Exception($"Namespace {disallowedNamespace} is not allowed.");
    }
}

在这里,我使用了 throw 来表示违规情况,这将意味着多个违规情况不会同时报告,因此您可能需要调整它,使其更加高效。


我已经把悬赏奖励授予了您。谢谢。我还发布了一个使用SymbolInfo类的答案。 - Sandy

3

SymbolInfo类提供创建规则以限制某些代码使用所需的一些元数据。这是我目前想到的,欢迎提出改进建议。

//Check for banned namespaces
string[] namespaceBlacklist = new string[] { "System.Net", "System.IO" };

foreach (IdentifierNameSyntax identifier in identifiers)
{
    SymbolInfo symbolInfo = semanticModel.GetSymbolInfo(identifier);

    if (symbolInfo.Symbol is { })
    {
        if (symbolInfo.Symbol.Kind == SymbolKind.Namespace)
        {
            if (namespaceBlacklist.Any(ns => ns == symbolInfo.Symbol.ToDisplayString()))
            {
                throw new Exception($"Declaration of namespace '{symbolInfo.Symbol.ToDisplayString()}' is not allowed.");
            }
        }
        else if (symbolInfo.Symbol.Kind == SymbolKind.NamedType)
        {
            if (namespaceBlacklist.Any(ns => symbolInfo.Symbol.ToDisplayString().StartsWith(ns + ".")))
            {
                throw new Exception($"Use of namespace '{identifier.Identifier.ValueText}' is not allowed.");
            }
        }
    }
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接