使用Pandas DataFrame的线性回归

14

我有一个用pandas创建的dataframe,用于生成散点图,并想为该图包含一个回归线。目前我正在尝试使用polyfit来实现这一目标。

以下是我的代码:

import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from numpy import *

table1 = pd.DataFrame.from_csv('upregulated_genes.txt', sep='\t', header=0, index_col=0)
table2 = pd.DataFrame.from_csv('misson_genes.txt', sep='\t', header=0, index_col=0)
table1 = table1.join(table2, how='outer')

table1 = table1.dropna(how='any')
table1 = table1.replace('#DIV/0!', 0)

# scatterplot
plt.scatter(table1['log2 fold change misson'], table1['log2 fold change'])
plt.ylabel('log2 expression fold change')
plt.xlabel('log2 expression fold change Misson et al. 2005')
plt.title('Root Early Upregulated Genes')
plt.axis([0,12,-5,12])

# this is the part I'm unsure about
regres = polyfit(table1['log2 fold change misson'], table1['log2 fold change'], 1)

plt.show()

但是我遇到了以下错误:
TypeError: cannot concatenate 'str' and 'float' objects

请问有没有人知道我在这里做错了什么?我也不确定如何将回归线添加到我的图表中。如果对我的代码有其他一般性的评论,那将非常赞赏,因为我还是初学者。


你在哪一行出现了错误? - usethedeathstar
@usethedeathstar regres = polyfit(table1['log2 fold change misson'], table1['log2 fold change'], 1) @usethedeathstar regres = polyfit(table1 ['log2折变misson'],table1 ['log2折变'],1) - TimStuart
你确定表格中没有NaN值吗?因为pylab.scatter函数不会绘制x或y为NaN的点(也就是说它不会报错),但是polyfit函数可能不知道这一点。(只是猜测问题可能出在哪里-你的csv文件中如何存储非数字值?) - usethedeathstar
没有 NaN 值。唯一的非数字值是“#DIV/0!”,我已经将其删除。 - TimStuart
table1['log2 fold change misson']和table1['log2 fold change']的类型是什么?(据我所知,它们应该是具有浮点数dtype的numpy.array(并且两者应该具有相同的形状)) - usethedeathstar
它们都是Pandas系列。 - TimStuart
1个回答

28

不要手动替换 '#DIV/0!',强制数据为数字。这样做可以同时实现两个目标:确保结果是数字类型(而非字符串),并用 NaN 替换任何无法解析为数字的条目。示例:

In [5]: Series([1, 2, 'blah', '#DIV/0!']).convert_objects(convert_numeric=True)
Out[5]: 
0     1
1     2
2   NaN
3   NaN
dtype: float64

这应该可以解决你的错误。但是,关于拟合数据线的一般主题,我掌握了两种比polyfit更好的方法。其中第二种更加健壮(并且可能返回更详细的统计信息),但需要使用statsmodels。

from scipy.stats import linregress
def fit_line1(x, y):
    """Return slope, intercept of best fit line."""
    # Remove entries where either x or y is NaN.
    clean_data = pd.concat([x, y], 1).dropna(0) # row-wise
    (_, x), (_, y) = clean_data.iteritems()
    slope, intercept, r, p, stderr = linregress(x, y)
    return slope, intercept # could also return stderr

import statsmodels.api as sm
def fit_line2(x, y):
    """Return slope, intercept of best fit line."""
    X = sm.add_constant(x)
    model = sm.OLS(y, X, missing='drop') # ignores entires where x or y is NaN
    fit = model.fit()
    return fit.params[1], fit.params[0] # could also return stderr in each via fit.bse

为了绘制它,可以像这样进行操作

m, b = fit_line2(x, y)
N = 100 # could be just 2 if you are only drawing a straight line...
points = np.linspace(x.min(), x.max(), N)
plt.plot(points, m*points + b)

谢谢!将数据强制转换为数字已经解决了我遇到的错误,但是从polyfit和你建议的代码中我得到了NaN输出... 有什么想法吗? - TimStuart
一些NaN,还是全部NaN?您能否使用数据的一个小子集来重现问题,并在此共享? - Dan Allan
抱歉,那只是我的一个错误,现在已经可以工作了。你知道我该怎么将这个添加为散点图中的一行吗? - TimStuart
如果底部的答案没有显示在同一图中,请尝试在“plot”中添加关键字参数“ax=plt.gca()”。 - Dan Allan

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接