在 scikit-learn 中将分类器保存到磁盘

2025-03-10 08:52:00
admin
原创
67
摘要:问题描述:如何将训练好的朴素贝叶斯分类器保存到磁盘并使用它来预测数据?我有来自 scikit-learn 网站的以下示例程序:from sklearn import datasets iris = datasets.load_iris() from sklearn.naive_bayes import Gau...

问题描述:

如何将训练好的朴素贝叶斯分类器保存到磁盘并使用它来预测数据?

我有来自 scikit-learn 网站的以下示例程序:

from sklearn import datasets
iris = datasets.load_iris()
from sklearn.naive_bayes import GaussianNB
gnb = GaussianNB()
y_pred = gnb.fit(iris.data, iris.target).predict(iris.data)
print "Number of mislabeled points : %d" % (iris.target != y_pred).sum()

解决方案 1:

分类器只是可以像其他对象一样进行 pickle 和 dump 的对象。继续您的示例:

import cPickle
# save the classifier
with open('my_dumped_classifier.pkl', 'wb') as fid:
    cPickle.dump(gnb, fid)    

# load it again
with open('my_dumped_classifier.pkl', 'rb') as fid:
    gnb_loaded = cPickle.load(fid)

解决方案 2:

您还可以使用joblib.dump和joblib.load,它们处理数值数组比默认的 python pickler 效率更高。

Joblib 包含在 scikit-learn 中:

>>> import joblib
>>> from sklearn.datasets import load_digits
>>> from sklearn.linear_model import SGDClassifier

>>> digits = load_digits()
>>> clf = SGDClassifier().fit(digits.data, digits.target)
>>> clf.score(digits.data, digits.target)  # evaluate training error
0.9526989426822482

>>> filename = '/tmp/digits_classifier.joblib.pkl'
>>> _ = joblib.dump(clf, filename, compress=9)

>>> clf2 = joblib.load(filename)
>>> clf2
SGDClassifier(alpha=0.0001, class_weight=None, epsilon=0.1, eta0=0.0,
       fit_intercept=True, learning_rate='optimal', loss='hinge', n_iter=5,
       n_jobs=1, penalty='l2', power_t=0.5, rho=0.85, seed=0,
       shuffle=False, verbose=0, warm_start=False)
>>> clf2.score(digits.data, digits.target)
0.9526989426822482

编辑:在 Python 3.8+ 中,如果您使用 pickle 协议 5(这不是默认协议),现在可以使用 pickle 有效地对具有大型数值数组作为属性的对象进行 pickle 。

解决方案 3:

您正在寻找的内容在 sklearn 词汇中称为模型持久性,并且它记录在简介和模型持久性部分中。

因此,您已经初始化了分类器,并使用

clf = some.classifier()
clf.fit(X, y)

此后你有两个选择:

1)使用 Pickle

import pickle
# now you can save it to a file
with open('filename.pkl', 'wb') as f:
    pickle.dump(clf, f)

# and later you can load it
with open('filename.pkl', 'rb') as f:
    clf = pickle.load(f)

2)使用Joblib

from sklearn.externals import joblib
# now you can save it to a file
joblib.dump(clf, 'filename.pkl') 
# and later you can load it
clf = joblib.load('filename.pkl')

再次阅读上述链接是有帮助的

解决方案 4:

在许多情况下,特别是文本分类时,仅存储分类器是不够的,还需要存储矢量化器,以便将来可以对输入进行矢量化。

import pickle
with open('model.pkl', 'wb') as fout:
  pickle.dump((vectorizer, clf), fout)

未来用例:

with open('model.pkl', 'rb') as fin:
  vectorizer, clf = pickle.load(fin)

X_new = vectorizer.transform(new_samples)
X_new_preds = clf.predict(X_new)

在转储矢量化器之前,可以通过以下方式删除矢量化器的stop_words_属性:

vectorizer.stop_words_ = None

使转储更加高效。此外,如果您的分类器参数是稀疏的(如大多数文本分类示例),您可以将参数从密集转换为稀疏,这将在内存消耗、加载和转储方面产生巨大差异。通过以下方式稀疏化模型:

clf.sparsify()

这将自动适用于SGDClassifier,但如果你知道你的模型是稀疏的(clf.coef_ 中有很多零),那么你可以通过以下方式手动将 clf.coef_ 转换为csr scipy 稀疏矩阵:

clf.coef_ = scipy.sparse.csr_matrix(clf.coef_)

然后您就可以更有效地存储它。

解决方案 5:

sklearn估算器实现了方法,让您可以轻松保存估算器的相关训练属性。一些估算器会__getstate__自己实现方法,但其他估算器(例如GMM仅使用基本实现)来简单地保存对象内部字典:

def __getstate__(self):
    try:
        state = super(BaseEstimator, self).__getstate__()
    except AttributeError:
        state = self.__dict__.copy()

    if type(self).__module__.startswith('sklearn.'):
        return dict(state.items(), _sklearn_version=__version__)
    else:
        return state

将模型保存到磁盘的推荐方法是使用以下pickle模块:

from sklearn import datasets
from sklearn.svm import SVC
iris = datasets.load_iris()
X = iris.data[:100, :2]
y = iris.target[:100]
model = SVC()
model.fit(X,y)
import pickle
with open('mymodel','wb') as f:
    pickle.dump(model,f)

但是,您应该保存额外的数据,以便将来可以重新训练您的模型,否则会遭受可怕的后果(例如被锁定在旧版本的 sklearn 中)

来自文档:

为了使用 scikit-learn 的未来版本重建类似的模型,应该与 pickle 模型一起保存额外的元数据:

训练数据,例如对不可变快照的引用

用于生成模型的python源代码

scikit-learn 及其依赖项的版本

在训练数据上获得的交叉验证分数

对于依赖于tree.pyx用 Cython 编写的模块(例如)的 Ensemble 估算器来说尤其如此IsolationForest,因为它会与实现产生耦合,无法保证在 sklearn 的各个版本之间保持稳定。它过去曾出现过向后不兼容的变化。

如果你的模型变得非常大,并且加载变得麻烦,你也可以使用更高效的joblib。摘自文档:

在 scikit 的特定情况下,使用 joblib 替换pickle( joblib.dump& joblib.load) 可能会更有趣,这对于内部携带大型 numpy 数组的对象更有效,就像拟合的 scikit-learn 估计器的情况一样,但只能腌制到磁盘而不能腌制到字符串:

解决方案 6:

sklearn.externals.joblib从那时起已被弃用0.21并将在以下时间被删除v0.23

/usr/local/lib/python3.7/site-packages/sklearn/externals/joblib/ init .py:15: FutureWarning: sklearn.externals.joblib 在 0.21 中已弃用,并将在 0.23 中删除。请直接从 joblib 导入此功能,可以使用以下命令安装:pip install joblib。如果在加载 pickled 模型时出现此警告,您可能需要使用 scikit-learn 0.21+ 重新序列化这些模型。warnings.warn

(msg, category=FutureWarning)


因此,您需要安装joblib

pip install joblib

最后将模型写入磁盘:

import joblib
from sklearn.datasets import load_digits
from sklearn.linear_model import SGDClassifier


digits = load_digits()
clf = SGDClassifier().fit(digits.data, digits.target)

with open('myClassifier.joblib.pkl', 'wb') as f:
    joblib.dump(clf, f, compress=9)

现在,为了读取转储的文件,您需要运行:

with open('myClassifier.joblib.pkl', 'rb') as f:
    my_clf = joblib.load(f)

解决方案 7:

一般来说,截至 2024 年 2 月,其他选项可用(根据文档:https ://scikit-learn.org/stable/model_persistence.html )

相关推荐
  政府信创国产化的10大政策解读一、信创国产化的背景与意义信创国产化,即信息技术应用创新国产化,是当前中国信息技术领域的一个重要发展方向。其核心在于通过自主研发和创新,实现信息技术应用的自主可控,减少对外部技术的依赖,并规避潜在的技术制裁和风险。随着全球信息技术竞争的加剧,以及某些国家对中国在科技领域的打压,信创国产化显...
工程项目管理   3998  
  为什么项目管理通常仍然耗时且低效?您是否还在反复更新电子表格、淹没在便利贴中并参加每周更新会议?这确实是耗费时间和精力。借助软件工具的帮助,您可以一目了然地全面了解您的项目。如今,国内外有足够多优秀的项目管理软件可以帮助您掌控每个项目。什么是项目管理软件?项目管理软件是广泛行业用于项目规划、资源分配和调度的软件。它使项...
项目管理软件   2749  
  本文介绍了以下10款项目管理软件工具:禅道项目管理软件、Freshdesk、ClickUp、nTask、Hubstaff、Plutio、Productive、Targa、Bonsai、Wrike。在当今快速变化的商业环境中,项目管理已成为企业成功的关键因素之一。然而,许多企业在项目管理过程中面临着诸多痛点,如任务分配不...
项目管理系统   85  
  本文介绍了以下10款项目管理软件工具:禅道项目管理软件、Monday、TeamGantt、Filestage、Chanty、Visor、Smartsheet、Productive、Quire、Planview。在当今快速变化的商业环境中,项目管理已成为企业成功的关键因素之一。然而,许多项目经理和团队在管理复杂项目时,常...
开源项目管理工具   96  
  本文介绍了以下10款项目管理软件工具:禅道项目管理软件、Smartsheet、GanttPRO、Backlog、Visor、ResourceGuru、Productive、Xebrio、Hive、Quire。在当今快节奏的商业环境中,项目管理已成为企业成功的关键因素之一。然而,许多企业在选择项目管理工具时常常面临困惑:...
项目管理系统   83  
热门文章
项目管理软件有哪些?
曾咪二维码

扫码咨询,免费领取项目管理大礼包!

云禅道AD
禅道项目管理软件

云端的项目管理软件

尊享禅道项目软件收费版功能

无需维护,随时随地协同办公

内置subversion和git源码管理

每天备份,随时转为私有部署

免费试用