lightning-flash 预测泰坦尼克号信息
Published on Aug. 22, 2023, 12:08 p.m.
lightning-flash
PyTorchLightning/lightning-flash:用于快速原型设计、基线、微调和解决深度学习问题的任务集合。
https://github.com/PyTorchLightning/lightning-flash
文档在这里
https://lightning-flash.readthedocs.io/en/latest/reference/tabular_classification.html
https://lightning-flash.readthedocs.io/
还是很不错的
from torchmetrics.classification import Accuracy, Precision, Recall
import flash
from flash.core.data.utils import download\_data
from flash.tabular import TabularClassifier, TabularData</p>
<h1>1. Download the data</h1>
download\_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')
<h1>2. Load the data</h1>
datamodule = TabularData.from\_csv(
["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
"Fare",
target\_fields="Survived",
train\_file="./data/titanic/titanic.csv",
test\_file="./data/titanic/test.csv",
val\_split=0.25,
)
<h1>3. Build the model</h1>
model = TabularClassifier.from\_data(datamodule, metrics=[Accuracy(), Precision(), Recall()])
<h1>4. Create the trainer. Run 10 times on data</h1>
trainer = flash.Trainer(max\_epochs=10)
<h1>5. Train the model</h1>
trainer.fit(model, datamodule=datamodule)
<h1>6. Test model</h1>
trainer.test()
<h1>7. Predict!</h1>
predictions = model.predict("data/titanic/titanic.csv")
print(predictions)