lightning-flash 预测泰坦尼克号信息

Published on Aug. 22, 2023, 12:08 p.m.

lightning-flash

PyTorchLightning/lightning-flash:用于快速原型设计、基线、微调和解决深度学习问题的任务集合。

https://github.com/PyTorchLightning/lightning-flash

文档在这里
https://lightning-flash.readthedocs.io/en/latest/reference/tabular_classification.html

https://lightning-flash.readthedocs.io/

还是很不错的

from torchmetrics.classification import Accuracy, Precision, Recall
import flash
from flash.core.data.utils import download\_data
from flash.tabular import TabularClassifier, TabularData</p>

<h1>1. Download the data</h1>

download\_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')

<h1>2. Load the data</h1>

datamodule = TabularData.from\_csv(
 ["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
 "Fare",
 target\_fields="Survived",
 train\_file="./data/titanic/titanic.csv",
 test\_file="./data/titanic/test.csv",
 val\_split=0.25,
)

<h1>3. Build the model</h1>

model = TabularClassifier.from\_data(datamodule, metrics=[Accuracy(), Precision(), Recall()])

<h1>4. Create the trainer. Run 10 times on data</h1>

trainer = flash.Trainer(max\_epochs=10)

<h1>5. Train the model</h1>

trainer.fit(model, datamodule=datamodule)

<h1>6. Test model</h1>

trainer.test()

<h1>7. Predict!</h1>

predictions = model.predict("data/titanic/titanic.csv")
print(predictions)

Tags: