Dataset and DataLoader¶
Efficiently managing data is essential for training neural networks. Sorix provides the Dataset and DataLoader classes to simplify the process of loading, batching, and shuffling data.
Sorix's Data API is designed to be API-compatible with PyTorch, making it easy for users to migrate their knowledge.
#Uncomment the next line and run this cell to install sorix
#!pip install 'sorix @ git+https://github.com/Mitchell-Mirano/sorix.git@main'
import sorix
from sorix.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
1. Tensors by Default and Transforms¶
In Sorix, the DataLoader automatically converts your batches into sorix.tensor objects. Additionally, the Dataset class supports transform and target_transform functions, allowing you to modify your data on-the-fly.
# A simple transform to normalize data
def normalize(x):
return x / 255.0
# Create data
X = np.array([[0, 127, 255], [64, 128, 192]], dtype=np.float32)
y = np.array([0, 1])
# Dataset with transform
dataset = Dataset(X, y, transform=normalize)
# DataLoader automatically gives you Tensors
loader = DataLoader(dataset, batch_size=2)
for bx, by in loader:
print(f"Batch X type: {type(bx)}")
print(f"Batch X data (normalized):\n{bx}")
print(f"Batch y: {by}")
Batch X type: <class 'sorix.tensor.Tensor'>
Batch X data (normalized):
tensor([[0. , 0.49803922, 1. ],
[0.2509804 , 0.5019608 , 0.7529412 ]])
Batch y: tensor([0, 1])
2. Working with CSV Data¶
One of the most common tasks is loading data from a CSV file. You can easily integrate pandas with the Dataset class.
# Create a dummy CSV for this example
df = pd.DataFrame({
'feature1': np.random.rand(10),
'feature2': np.random.rand(10),
'label': np.random.randint(0, 2, 10)
})
df.to_csv('data.csv', index=False)
class CSVDataset(Dataset):
def __init__(self, csv_file):
self.df = pd.read_csv(csv_file)
# Extract features and labels
X = self.df[['feature1', 'feature2']].values
y = self.df['label'].values.reshape(-1, 1)
super().__init__(X, y)
csv_ds = CSVDataset('data.csv')
csv_loader = DataLoader(csv_ds, batch_size=5)
for bx, by in csv_loader:
print(f"CSV Batch shape: {bx.shape}")
CSV Batch shape: sorix.Size([5, 2]) CSV Batch shape: sorix.Size([5, 2])
3. Working with Images¶
For image datasets, you typically load the pixel data and apply augmentations using the transform argument.
class ImageDataset(Dataset):
def __init__(self, num_images, transform=None):
# In a real scenario, this would be a list of file paths
self.image_data = [np.random.randint(0, 256, (3, 32, 32), dtype=np.uint8) for _ in range(num_images)]
self.labels = np.random.randint(0, 10, num_images)
super().__init__(self.image_data, self.labels, transform=transform)
def to_float(img):
return img.astype(np.float32) / 255.0
image_ds = ImageDataset(num_images=20, transform=to_float)
image_loader = DataLoader(image_ds, batch_size=4, shuffle=True)
for images, labels in image_loader:
print(f"Image batch shape: {images.shape}")
print(f"Image data range: [{images.data.min():.2f}, {images.data.max():.2f}]")
break
Image batch shape: sorix.Size([4, 3, 32, 32]) Image data range: [0.00, 1.00]
4. Custom Collate Function¶
Sometimes you need complex logic when merging samples into a batch. Just like in PyTorch, you can provide a collate_fn to the DataLoader.
def my_collate(samples):
# samples is a list of what Dataset[idx] returns
batch_x = []
batch_y = []
for x, y in samples:
batch_x.append(x)
batch_y.append(y)
# Custom logic: maybe pad sequences, or add metadata
return sorix.tensor(batch_x), sorix.tensor(batch_y)
custom_loader = DataLoader(dataset, batch_size=2, collate_fn=my_collate)
for bx, by in custom_loader:
print("Custom collate worked!")
break
Custom collate worked!
Summary¶
- PyTorch API Parity:
Datasetsupports transforms, andDataLoadersupports batching, shuffling, and custom collation. - Automatic Tensors:
DataLoaderyieldssorix.tensorobjects by default. - Practical Applications: Easily handle CSVs, Images, and custom data formats by extending the
Datasetclass.