使用异步套接字进行TCP端口转发

4
我正在尝试在C#中实现TCP转发器。具体而言,该应用程序:
  1. 监听一个TCP端口并等待客户端连接,
  2. 当客户端连接后,连接到远程主机,
  3. 等待来自两个连接的传入数据并在两个端点之间交换数据(充当代理),
  4. 当另一个端点关闭时,关闭一个连接。
我已经改编了Garcia的Simple TCP Forwader以转发一系列端口。
TCPForwarder.exe 10.1.1.1 192.168.1.100 1000 1100 2000

将在端口1000-1100接收到的任何数据包转发到远程主机192.168.1.100上的端口2000-2100。我使用这种方法来暴露一个位于NAT后面的FTP服务器。

通过运行上述命令,客户端可以连接到FTP服务器,并且控制台输出以下预期模式(参考代码):

0 StartReceive: BeginReceive
1 StartReceive: BeginReceive
1 OnDataReceive: EndReceive
1 OnDataReceive: BeginReceive
1 OnDataReceive: EndReceive
1 OnDataReceive: Close (0 read)
0 OnDataReceive: EndReceive
0 OnDataReceive: Close (exception)

但在成功连接几次后(在Filezilla中按F5),TCPForwarder和FTP服务器没有收到进一步的响应。

我的实现似乎有两个问题,我无法调试:

  1. 在这种情况下,在StartReceive方法中的BeginReceive被调用,但没有从FTP服务器接收到任何数据。我认为这不可能是FTP服务器问题(它是一个著名的ProFTPD服务器)。

  2. 每次建立并关闭连接时,线程数都会增加1。我不认为垃圾回收会解决这个问题。线程数持续增加,强制垃圾回收器运行也无法减少线程数。我认为我的代码中存在一些泄漏,也导致了问题#1。

编辑:

  • 重新启动FTP服务器没有解决问题,因此TCPForwarder中肯定存在错误。

  • @jgauffin指出的一些问题已在以下代码中得到修复。

以下是完整的代码:

using System;
using System.Net;
using System.Net.Sockets;
using System.Collections.Generic;
using System.Threading;

namespace TCPForwarder
{
    class Program
    {
        private class State
        {
            public int ID { get; private set; } // for debugging purposes
            public Socket SourceSocket { get; private set; }
            public Socket DestinationSocket { get; private set; }
            public byte[] Buffer { get; private set; }
            public State(int id, Socket source, Socket destination)
            {
                ID = id;
                SourceSocket = source;
                DestinationSocket = destination;
                Buffer = new byte[8192];
            }
        }

        public class TcpForwarder
        {
            public void Start(IPEndPoint local, IPEndPoint remote)
            {
                Socket MainSocket;
                try
                {
                    MainSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
                    MainSocket.Bind(local);
                    MainSocket.Listen(10);
                }
                catch (Exception exp)
                {
                    Console.WriteLine("Error on listening to " + local.Port + ": " + exp.Message);
                    return;
                }

                while (true)
                {
                    // Accept a new client
                    var socketSrc = MainSocket.Accept();
                    var socketDest = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);

                    try
                    {
                        // Connect to the endpoint
                        socketDest.Connect(remote);
                    }
                    catch
                    {
                        socketSrc.Shutdown(SocketShutdown.Both);
                        socketSrc.Close();
                        Console.WriteLine("Exception in connecting to remote host");
                        continue;
                    }

                    // Wait for data sent from client and forward it to the endpoint
                    StartReceive(0, socketSrc, socketDest);

                    // Also, wait for data sent from endpoint and forward it to the client
                    StartReceive(1, socketDest, socketSrc);
                }
            }

            private static void StartReceive(int id, Socket src, Socket dest)
            {
                var state = new State(id, src, dest);

                Console.WriteLine("{0} StartReceive: BeginReceive", id);
                try
                {
                    src.BeginReceive(state.Buffer, 0, state.Buffer.Length, 0, OnDataReceive, state);
                }
                catch
                {
                    Console.WriteLine("{0} Exception in StartReceive: BeginReceive", id);
                }
            }

            private static void OnDataReceive(IAsyncResult result)
            {
                State state = null;
                try
                {
                    state = (State)result.AsyncState;

                    Console.WriteLine("{0} OnDataReceive: EndReceive", state.ID);
                    var bytesRead = state.SourceSocket.EndReceive(result);
                    if (bytesRead > 0)
                    {
                        state.DestinationSocket.Send(state.Buffer, bytesRead, SocketFlags.None);

                        Console.WriteLine("{0} OnDataReceive: BeginReceive", state.ID);
                        state.SourceSocket.BeginReceive(state.Buffer, 0, state.Buffer.Length, 0, OnDataReceive, state);
                    }
                    else
                    {
                        Console.WriteLine("{0} OnDataReceive: Close (0 read)", state.ID);
                        state.SourceSocket.Shutdown(SocketShutdown.Both);
                        state.DestinationSocket.Shutdown(SocketShutdown.Both);
                        state.DestinationSocket.Close();
                        state.SourceSocket.Close();
                    }
                }
                catch
                {
                    if (state!=null)
                    {
                        Console.WriteLine("{0} OnDataReceive: Close (exception)", state.ID);
                        state.SourceSocket.Shutdown(SocketShutdown.Both);
                        state.DestinationSocket.Shutdown(SocketShutdown.Both);
                        state.DestinationSocket.Close();
                        state.SourceSocket.Close();
                    }
                }
            }
        }

        static void Main(string[] args)
        {
            List<Socket> sockets = new List<Socket>();

            int srcPortStart = int.Parse(args[2]);
            int srcPortEnd = int.Parse(args[3]);
            int destPortStart = int.Parse(args[4]);

            List<Thread> threads = new List<Thread>();
            for (int i = 0; i < srcPortEnd - srcPortStart + 1; i++)
            {
                int srcPort = srcPortStart + i;
                int destPort = destPortStart + i;

                TcpForwarder tcpForwarder = new TcpForwarder();

                Thread t = new Thread(new ThreadStart(() => tcpForwarder.Start(
                    new IPEndPoint(IPAddress.Parse(args[0]), srcPort),
                    new IPEndPoint(IPAddress.Parse(args[1]), destPort))));
                t.Start();

                threads.Add(t);
            }

            foreach (var t in threads)
            {
                t.Join();
            }
            Console.WriteLine("All threads are closed");
        }
    }
}
1个回答

2
第一个问题是,代码在目标套接字(接受循环)连接失败时将继续执行。在try/catch中使用continue;。还不能保证当您调用第一个BeginReceive时套接字仍然正常工作。这些调用也需要被包装。
始终在回调方法中使用try/catch,否则您的应用程序可能会失败(在此情况下为OnDataRecieve)。
解决此问题并开始输出异常。它们肯定会给你关于出了什么问题的提示。

好的,做得好!我在while(true)循环中添加了continue语句以处理异常,并将所有方法都包装在try catch块中。唯一抛出的异常仍然与之前的EndReceive有关。仍然存在线程数量和创建新连接的问题。 - Isaac
重新启动FTP服务器无法解决问题,因此这绝对是TCPForwarder中的一个错误。 - Isaac
异常信息是什么? - jgauffin
你仍然有问题。catch块中的代码也可能会抛出异常(这将关闭您的应用程序)。 - jgauffin
catch 块中的代码可能会抛出异常。这正是问题所在!不幸的是,它没有关闭应用程序,只留下了一个套接字开着。我为每个套接字放置了 Shutdown-Close 对,并将其放在单独的 try/catch 块中解决了该问题。谢谢! - Isaac

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接