tsai is an open source deep learning package built on top of Pytorch and fastai focused on tasks like classification, regression, forecasting, imputation and many more. The aim of the blog post is to provide solution in performing inference of tsai trained classification model. The training method is simple, but I was facing problem in inferencing because of the version issue. I want to write a detailed blog on resolving and inferencing using the trained model.
Python version 3.8.18
pip install tsai==0.3.9
pip install torch==2.2.0
pip install matplotlib==3.7.1 (This version should match or the inference load_learner will throw error)
The data that I am using here is a time series classification task. This is the data represented in a table format.
Feature1 | Feature2 | Feature3 | Label |
---|---|---|---|
1 | 56 | 67 | 1 |
2 | 67 | 33 | 2 |
3 | 45 | 21 | 1 |
from sklearn.model_selection import train_test_split
import pandas as pd
''' loading the csv file '''
df = pd.read_csv('/content/detrend_df.csv')
''' spliting the features and labels into seperate variable '''
X = df.iloc[:, :200].values
y = df['label'].values
''' splitting data '''
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.10)
''' reshaping '''
X_train = X_train.reshape((X_train.shape[0], 1, X.shape[1]))
y_train = y_train.reshape((y_train.shape[0], 1))
X_test = X_test.reshape((X_test.shape[0], 1, X.shape[1]))
y_test = y_test.reshape((y_test.shape[0], 1))
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)
from tsai.all import *
import numpy as np
tfms = [None, TSClassification()]
batch_tfms = TSStandardize()
clf = TSClassifier(X_train,
y_train,
path='models',
arch="InceptionTimePlus",
tfms=tfms,
batch_tfms=batch_tfms,
metrics=accuracy,
cbs=ShowGraph(),
train_metrics=True)
clf.fit_one_cycle(100, 3e-4)
clf.export("clf.pkl")
In this section I will share the issues that I faced while performing inference with the trained exported model.
from tsai.inference import load_learner
learn = load_learner("clf.pkl")
/Users/karthikrajedran/anaconda3/envs/segmentationModel/lib/python3.8/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/karthikrajedran/anaconda3/envs/segmentationModel/lib/python3.8/site-packages/fastai/learner.py", line 446, in load_learner
try: res = torch.load(fname, map_location=map_loc, pickle_module=pickle_module)
File "/Users/karthikrajedran/anaconda3/envs/segmentationModel/lib/python3.8/site-packages/torch/serialization.py", line 1026, in load
return _load(opened_zipfile,
File "/Users/karthikrajedran/anaconda3/envs/segmentationModel/lib/python3.8/site-packages/torch/serialization.py", line 1438, in _load
result = unpickler.load()
File "/Users/karthikrajedran/anaconda3/envs/segmentationModel/lib/python3.8/site-packages/matplotlib/cbook/__init__.py", line 229, in __setstate__
self._cid_gen = itertools.count(cid_count)
TypeError: a number is required
import torch
torch.load("clf.pkl", map_location=torch.device('cpu'))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/karthikrajedran/anaconda3/envs/segmentationModel/lib/python3.8/site-packages/torch/serialization.py", line 1026, in load
return _load(opened_zipfile,
File "/Users/karthikrajedran/anaconda3/envs/segmentationModel/lib/python3.8/site-packages/torch/serialization.py", line 1438, in _load
result = unpickler.load()
File "/Users/karthikrajedran/anaconda3/envs/segmentationModel/lib/python3.8/site-packages/matplotlib/cbook/__init__.py", line 229, in __setstate__
self._cid_gen = itertools.count(cid_count)
TypeError: a number is required
pip install matplotlib==3.7.1
Then I tried the inference method
from tsai.inference import load_learner
learn = load_learner("clf.pkl")
Then the model was loaded successfully.
I wanted to save the learner after training in a pickle format and load it for inference. But this is an inefficient method.
This is a blog post to show on how I fixed the inference issue in tsai package. There were no solid solution to fix this problem, hence I am sharing the method that I followed to fix this issue in this blog post.
Written on April 3rd , 2024 by Karthik