train & predict
Function
train
Trains the selected model using the provided training data loader. The training process adapts based on the model type: for logistic regression, the number of epochs is managed internally, while for other models, the number of epochs can be specified.
predict
Generates predictions using the trained model on the provided data loader. This function abstracts the prediction process for different model types, ensuring a consistent interface.
Parameters
Name | Type | Description |
---|---|---|
train_loader | DataLoader | The data loader containing the training dataset. |
epochs | int, optional | Number of training epochs (used for non-logistic models). Default is 10 . |
data_loader | DataLoader | The data loader containing the data for which predictions are to be made. |
Return type
train
: Depends on the underlying model'strain
method. Typically returns training metrics or status.predict
: Model-specific prediction output (e.g., numpy array, tensor, or list of predictions).
Returns
- train: Returns the result of the model's training process, which may include training loss, accuracy, or a status indicator.
- predict: Returns the predictions generated by the model for the input data.
Attributes Set
self.model_type
: Sets the type of model being used (e.g., "cnn", "simple_cnn", "logistic").self.model
: The instantiated model object corresponding to the selected type.self.device
: The device (e.g., 'cpu' or 'cuda') on which the model will operate.
Example
from crisgi import CRISGI
import pickle
from torchvision import transforms
from torch.utils.data import DataLoader
from src.cnn.CNNModel import CNNModel
from src.simplecnn.SimpleCNNModel import SimpleCNNModel
from src.logistic.LogisticModel import LogisticModel
from src.util import ImageDataset
import src.plotting_sniee_time as pl
crisgi_obj = pickle.load(open("data/GSE30550_H3N2_crisgi_obj.pk", 'rb'))
# Set up the training data loader
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
dataset = ImageDataset(
image_dir='./out1',
label_csv='./out1/labels.csv',
transform=transform,
return_label=True
)
loader = DataLoader(dataset, batch_size=16, shuffle=True)
# Set the model type and initialize the model
# 3L-CNN
crisgi_obj.set_model_type(model_type='cnn',ae_path="data/model/3L-CNN/GSE30550_H3N2_ae_model.pth", mlp_path="data/model/3L-CNN/GSE30550_H3N2_mlp_model.pth")
# 1L-CNN
crisgi_obj.set_model_type(model_type='simple_cnn')
# Logistic Regression
crisgi_obj.set_model_type(model_type='logistic', model_path="data/model/logistic/GSE30550_H3N2_log_model.pth",device = 'cpu')
# Train the model
crisgi_obj.train(loader, epochs=20)
# Generate predictions using the trained model
predictions = crisgi_obj.predict(loader)
# Print the predictions
print(predictions)