Python N-body模拟代码给出了错误的答案。

4
我已经编写了一些Python代码,使用欧拉方法来解决N体问题。代码运行没有问题,似乎给出了一个合理的答案(例如,如果有两个粒子,则它们开始相互靠近)。但是,当我在大量迭代中运行此模拟时,我发现粒子(比如我用两个粒子运行)彼此经过(我不考虑碰撞),并且在它们各自的方向上无限地继续前进。这违反了能量守恒定律,因此我的代码必须有缺陷,但我无法找到它。请问是否有人可以找到并解释我的错误。
谢谢。
感谢@samgak指出我更新了粒子两次。我现在已经修复了这个问题,但问题仍然存在。我还复制了当我在(0,0)和(1,0)处以1秒和100000次迭代运行这个模拟时得到的输出: 粒子质量为1,位置为[234.8268420043934, 0.0],速度为[0.011249111128594091, 0.0]。

具有质量为1,位置为[-233.82684200439311, 0.0],速度为[-0.011249111128594091, 0.0]的粒子。

同时感谢@PM2Ring指出的一些我可以进行的优化和使用欧拉方法的危险。

代码:

import math
class Particle:
    """
    Class to represent a single particle
    """
    def __init__(self,mass,position,velocity):
        """
        Initialize the particle
        """
        self.G = 6.67408*10**-11 #fixed throughout the simulation
        self.time_interval = 10**0 #fixed throughout the simulation, gives the interval between updates
        self.mass = mass
        self.position = position #should be a list
        self.velocity = velocity #should be a list
        self.updated_position = position
        self.updated_velocity = velocity
    def __str__(self):
        """
        String representation of particle
        """
        return "Particle with mass: " + str(self.mass) + " and position: " + str(self.position) + " and velocity: " + str(self.velocity)
    def get_mass(self):
        """
        Returns the mass of the particle
        """
        return self.mass
    def get_position(self):
        """
        returns the position of the particle
        """
        return self.position
    def get_velocity(self):
        """
        returns the velocity of the particle
        """
        return self.velocity
    def get_updated_position(self):
        """
        calculates the future position of the particle
        """
        for i in range(len(self.position)):
            self.updated_position[i] = self.updated_position[i] + self.time_interval*self.velocity[i]
    def update_position(self):
        """
        updates the position of the particle
        """
        self.position = self.updated_position.copy()
    def get_distance(self,other_particle):
        """
        returns the distance between the particle and another given particle
        """
        tot = 0
        other = other_particle.get_position()
        for i in range(len(self.position)):
            tot += (self.position[i]-other[i])**2
        return math.sqrt(tot)
    def get_updated_velocity(self,other_particle):
        """
        updates the future velocity of the particle due to the acceleration
        by another particle
        """
        distance_vector = []
        other = other_particle.get_position()
        for i in range(len(self.position)):
            distance_vector.append(self.position[i]-other[i])
        distance_squared = 0
        for item in distance_vector:
            distance_squared += item**2
        distance = math.sqrt(distance_squared)
        force = -self.G*self.mass*other_particle.get_mass()/(distance_squared)
        for i in range(len(self.velocity)):
            self.updated_velocity[i] = self.updated_velocity[i]+self.time_interval*force*(distance_vector[i])/(self.mass*(distance))
    def update_velocity(self):
        """
        updates the velocity of the particle
        """
        self.velocity = self.updated_velocity.copy()
def update_particles(particle_list):
    """
    updates the position of all the particles
    """
    for i in range(len(particle_list)):
        for j in range(i+1,len(particle_list)):
            particle_list[i].get_updated_velocity(particle_list[j])
            particle_list[j].get_updated_velocity(particle_list[i])
    for i in range(len(particle_list)):
        particle_list[i].update_velocity()
        particle_list[i].get_updated_position()
    for i in range(len(particle_list)):
        particle_list[i].update_position()      
#the list of particles
partList = [Particle(1,[0,0],[0,0]),Particle(1,[1,0],[0,0])]
#how many iterations I perform
for i in range(100000):
    update_particles(partList)
#prints out the final position of all the particles
for item in partList:
    print(item)

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------进一步编辑:

我决定实现Leapfrog方法,并编写了一些代码,再次运行并且似乎在命令行中工作得很好。但是当我添加了绘图功能并进行分析时,似乎又出现了另一个问题。系统似乎又走得太远了,能量再次无限增加。我已经附上了输出图片以展示问题。如果我只有两个相等质量的粒子,它们会再次互相穿过,然后继续相互远离而不停止。因此,我的代码中肯定存在我没有发现的错误。

如果有人能帮忙,将不胜感激。

我的代码:

import math
import matplotlib.pyplot as plt

class Particle:
    """
    Represents a single particle
    """
    def __init__(self,mass,position,velocity):
        """
        Initialize the particle
        """
        self.G = 6.67408*10**-11
        self.time_step = 10**2
        self.mass = mass
        self.dimensions = len(position)
        self.position = position
        self.velocity = velocity
        self.acceleration = [0 for i in range(len(position))]
        self.next_position = position
        self.next_velocity = velocity
        self.next_acceleration = [0 for i in range(len(position))]
    def __str__(self):
        """
        A string representation of the particle
        """
        return "A Particle with mass: " + str(self.mass) + " and position: " + str(self.position) + " and velocity:" + str(self.velocity) 
    def get_mass(self):
        return self.mass
    def get_position(self):
        return self.position
    def get_velocity(self):
        return self.velocity
    def get_acceleration(self):
        return self.acceleration
    def get_next_position(self):
        return self.next_position
    def put_next_position(self):
        for i in range(self.dimensions):
            self.next_position[i] = self.position[i] + self.time_step*self.velocity[i]+0.5*self.time_step**2*self.acceleration[i]
    def put_next_velocity(self):
        for i in range(self.dimensions):
            self.next_velocity[i] = self.velocity[i] + 0.5*self.time_step*(self.acceleration[i]+self.next_acceleration[i])
    def update_position(self):
        self.position = self.next_position.copy()
    def update_velocity(self):
        self.velocity = self.next_velocity.copy()  
    def update_acceleration(self):
        self.acceleration = self.next_acceleration.copy()
    def reset_acceleration(self):
        self.acceleration = [0 for i in range(self.dimensions)]
    def reset_future_acceleration(self):
        self.next_acceleration = [0 for i in range(self.dimensions)]
    def calculate_acceleration(self,other_particle):
        """
        Increments the acceleration of the particle due to the force from 
        a single other particle
        """
        distances = []
        other = other_particle.get_position()
        distance_squared = 0
        for i in range(self.dimensions):
            distance_squared += (self.position[i]-other[i])**2
            distances.append(self.position[i]-other[i])
        distance = math.sqrt(distance_squared)
        force = -self.G*self.mass*other_particle.get_mass()/distance_squared
        acc = []
        for i in range(self.dimensions):
            acc.append(force*distances[i]/(distance*self.mass))
        for i in range(self.dimensions):
            self.acceleration[i] += acc[i]
    def calculate_future_acceleration(self,other_particle):
        """
        Increments the future acceleration of the particle due to the force from 
        a single other particle
        """
        distances = []
        other = other_particle.get_next_position()
        distance_squared = 0
        for i in range(self.dimensions):
            distance_squared += (self.next_position[i]-other[i])**2
            distances.append(self.next_position[i]-other[i])
        distance = math.sqrt(distance_squared)
        force = -self.G*self.mass*other_particle.get_mass()/distance_squared
        acc = []
        for i in range(self.dimensions):
            acc.append(force*distances[i]/(distance*self.mass))
        for i in range(self.dimensions):
            self.next_acceleration[i] += acc[i]

def update_all(particleList):
    for i in range(len(particleList)):
        particleList[i].reset_acceleration()
        for j in range(len(particleList)):
            if i != j:
                particleList[i].calculate_acceleration(particleList[j])
    for i in range(len(particleList)):
        particleList[i].put_next_position()
    for i in range(len(particleList)):
        particleList[i].reset_future_acceleration()
        for j in range(len(particleList)):
            if i != j:
                particleList[i].calculate_future_acceleration(particleList[j])
    for i in range(len(particleList)):
        particleList[i].put_next_velocity()
    for i in range(len(particleList)):
        particleList[i].update_position()
        particleList[i].update_velocity()
partList = [Particle(1,[0,0],[0,0]),Particle(1,[1,0],[0,0])]

Alist = [[],[]]
Blist = [[],[]]
for i in range(10000):
    Alist[0].append(partList[0].get_position()[0])
    Alist[1].append(partList[0].get_position()[1])
    Blist[0].append(partList[1].get_position()[0])
    Blist[1].append(partList[1].get_position()[1])
    update_all(partList)

plt.scatter(Alist[0],Alist[1],color="r")
plt.scatter(Blist[0],Blist[1],color="b")
plt.grid() 
plt.show()
for item in partList:
    print(item)

A zoomed in plot

请问有人能告诉我代码中出现了哪些错误吗?


2
欧拉积分的主要缺陷是由于误差累积导致能量不守恒。如果您使用足够小的时间步长,可以模拟一切都具有近似圆形轨道的小系统,这样就没问题了。否则,您需要使用辛积分器,它将确保能量守恒(更准确地说,它会保持与系统真实能量密切相关的哈密顿量不变)。 - PM 2Ring
一个流行的辛波立克积分器是Verlet积分,但我个人最喜欢的是同步跳跃法 - PM 2Ring
2
你正在两次更新每个粒子对。请将条件从 i != j 改为 i < j(或者从 i+1 开始 j 循环)。 - samgak
1
使用 Decimal 可以减少近似误差,但 1000 位可能有些过头了。然而,你正在调用 math.sqrt,它只返回一个 float。相反,你需要使用 Decimal.sqrt 方法。此外,在计算力时,可以通过直接使用平方距离来减少误差,而不是对平方距离取平方根再平方。 - PM 2Ring
@PM2Ring 我会尝试使用Leapfrog并在这里发布结果。 - Hadi Khan
显示剩余12条评论
1个回答

1
代码的主要问题在于它使用了欧拉方法,随着迭代次数的增加,它变得不太准确(只有O(h),而其他方法可以达到O(h^4)甚至更好)。要解决这个问题需要对代码进行基本重构,因此我认为这个代码对于N体模拟来说并不真正准确(对于2个粒子,它可以运行,但是添加越来越多的粒子,误差只会增加)。
感谢@samgak和@PM2Ring帮助我消除错误并优化我的代码,但总体而言,这个代码是无法使用的...
编辑:我从头开始实现了评论中提到的leapfrog方法,并发现它完美地工作。它非常简单易懂,易于实现,而且还有效!
进一步编辑:我以为我已经让leapfrog方法起作用了。结果发现当我添加GUI功能时,还有另一个错误。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接