解决mxnet gluonts出现错误ImportError: cannot import name 'Trainer' from 'gluonts'
Published on Aug. 22, 2023, 12:10 p.m.
解决mxnet gluonts出现错误
ImportError: cannot import name ‘Trainer’ from ‘gluonts’ (/opt/conda/lib/python3.7/site-packages/gluonts/init.py)
这还是官方示例
from gluonts.dataset import common
from gluonts.model import deepar
from gluonts.trainer import Trainer
import pandas as pd
url = "https://raw.githubusercontent.com/numenta/NAB/master/data/realTweets/Twitter_volume_AMZN.csv"
df = pd.read_csv(url, header=0, index_col=0)
data = common.ListDataset([{
"start": df.index[0],
"target": df.value[:"2015-04-05 00:00:00"]
}],
freq="5min")
trainer = Trainer(epochs=10)
estimator = deepar.DeepAREstimator(
freq="5min", prediction_length=12, trainer=trainer)
predictor = estimator.train(training_data=data)
prediction = next(predictor.predict(data))
print(prediction.mean)
prediction.plot(output_file='graph.png')
改成
from gluonts.dataset import common
from gluonts.model import deepar
# from gluonts.trainer import Trainer
from gluonts.mx.trainer import Trainer
import pandas as pd
url = "https://raw.githubusercontent.com/numenta/NAB/master/data/realTweets/Twitter_volume_AMZN.csv"
df = pd.read_csv(url, header=0, index_col=0)
data = common.ListDataset([{
"start": df.index[0],
"target": df.value[:"2015-04-05 00:00:00"]
}],
freq="5min")
trainer = Trainer(epochs=10)
estimator = deepar.DeepAREstimator(
freq="5min", prediction_length=12, trainer=trainer)
predictor = estimator.train(training_data=data)
prediction = next(predictor.predict(data))
print(prediction.mean)
prediction.plot(output_file='graph.png')