Skip to content

Machine Learning API

ML utilities including dataloaders and tokenizers.

DataLoaders

slaf.ml.dataloaders

Classes

SLAFDataLoader

High-performance DataLoader for SLAF data optimized for ML training.

SLAFDataLoader provides efficient streaming of single-cell data for machine learning applications with multiple loading strategies for different use cases. It uses async batch processing and provides device-agnostic CPU tensor output for maximum training flexibility.

Key Features
  • Multiple tokenization strategies (GeneFormer, scGPT)
  • Multiple loading modes for different entropy requirements:
    • Mixture of Scanners (MoS): Maximum entropy, best randomization (default)
    • Fragment-based loading: Higher entropy, moderate performance
    • Sequential loading: Fastest, lowest entropy
  • Pre-tokenized sequences for maximum performance (tokenized mode)
  • Raw data output for external processing (raw mode)
  • Device-agnostic CPU tensor output
  • Async batch processing with background prefetching
  • Memory-efficient streaming
  • Multi-epoch training support
  • Comprehensive error handling and validation
Loading Modes
  1. Mixture of Scanners (default): Randomly samples from multiple fragment generators for maximum entropy and randomization (88% of random entropy)
  2. Fragment-based: Loads complete Lance fragments for higher data entropy
  3. Sequential: Loads contiguous Lance batches for maximum throughput

Examples:

>>> # Basic usage with default settings (MoS loading)
>>> slaf_array = SLAFArray("path/to/data.slaf")
>>> dataloader = SLAFDataLoader(slaf_array)
>>> for batch in dataloader:
...     print(f"Batch shape: {batch['input_ids'].shape}")
...     print(f"Cell IDs: {batch['cell_ids']}")
...     break
Batch shape: torch.Size([32, 2048])
Cell IDs: tensor([0, 1, 2, ..., 29, 30, 31])
>>> print(f"MoS enabled: {dataloader.use_mixture_of_scanners}")
MoS enabled: True
>>> # Sequential loading for maximum throughput
>>> dataloader = SLAFDataLoader(
...     slaf_array=slaf_array,
...     use_mixture_of_scanners=False,
...     by_fragment=False
... )
>>> print(f"Sequential loading: {not dataloader.use_mixture_of_scanners}")
Sequential loading: True
>>> # Fragment-based loading for higher entropy
>>> dataloader = SLAFDataLoader(
...     slaf_array=slaf_array,
...     use_mixture_of_scanners=False,
...     by_fragment=True
... )
>>> print(f"Fragment-based loading: {dataloader.by_fragment}")
Fragment-based loading: True
>>> # Raw mode for external processing
>>> dataloader = SLAFDataLoader(
...     slaf_array=slaf_array,
...     raw_mode=True
... )
>>> for batch in dataloader:
...     print(f"Raw data type: {type(batch['x'])}")
...     break
Raw data type: <class 'polars.dataframe.frame.DataFrame'>
>>> # Multi-epoch training
>>> dataloader = SLAFDataLoader(
...     slaf_array=slaf_array,
...     n_epochs=5
... )
>>> print(f"Number of epochs: {dataloader.n_epochs}")
Number of epochs: 5
>>> # Custom configuration for training
>>> dataloader = SLAFDataLoader(
...     slaf_array=slaf_array,
...     tokenizer_type="scgpt",
...     batch_size=64,
...     max_genes=1024
... )
>>> print(f"Tokenizer type: {dataloader.tokenizer_type}")
Tokenizer type: scgpt
>>> # Training loop example
>>> for batch_idx, batch in enumerate(dataloader):
...     input_ids = batch["input_ids"]
...     attention_mask = batch["attention_mask"]
...     cell_ids = batch["cell_ids"]
...     # Your training code here
...     if batch_idx >= 2:  # Just test first few batches
...         break
>>> print("Training loop completed")
Training loop completed
>>> # Error handling for invalid tokenizer type
>>> try:
...     dataloader = SLAFDataLoader(slaf_array, tokenizer_type="invalid")
... except ValueError as e:
...     print(f"Error: {e}")
Error: Unsupported tokenizer type: invalid
Source code in slaf/ml/dataloaders.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
class SLAFDataLoader:
    """
    High-performance DataLoader for SLAF data optimized for ML training.

    SLAFDataLoader provides efficient streaming of single-cell data for machine learning
    applications with multiple loading strategies for different use cases. It uses async
    batch processing and provides device-agnostic CPU tensor output for maximum training flexibility.

    Key Features:
        - Multiple tokenization strategies (GeneFormer, scGPT)
        - Multiple loading modes for different entropy requirements:
            * Mixture of Scanners (MoS): Maximum entropy, best randomization (default)
            * Fragment-based loading: Higher entropy, moderate performance
            * Sequential loading: Fastest, lowest entropy
        - Pre-tokenized sequences for maximum performance (tokenized mode)
        - Raw data output for external processing (raw mode)
        - Device-agnostic CPU tensor output
        - Async batch processing with background prefetching
        - Memory-efficient streaming
        - Multi-epoch training support
        - Comprehensive error handling and validation

    Loading Modes:
        1. Mixture of Scanners (default): Randomly samples from multiple fragment generators
           for maximum entropy and randomization (88% of random entropy)
        2. Fragment-based: Loads complete Lance fragments for higher data entropy
        3. Sequential: Loads contiguous Lance batches for maximum throughput

    Examples:
        >>> # Basic usage with default settings (MoS loading)
        >>> slaf_array = SLAFArray("path/to/data.slaf")
        >>> dataloader = SLAFDataLoader(slaf_array)
        >>> for batch in dataloader:
        ...     print(f"Batch shape: {batch['input_ids'].shape}")
        ...     print(f"Cell IDs: {batch['cell_ids']}")
        ...     break
        Batch shape: torch.Size([32, 2048])
        Cell IDs: tensor([0, 1, 2, ..., 29, 30, 31])
        >>> print(f"MoS enabled: {dataloader.use_mixture_of_scanners}")
        MoS enabled: True

        >>> # Sequential loading for maximum throughput
        >>> dataloader = SLAFDataLoader(
        ...     slaf_array=slaf_array,
        ...     use_mixture_of_scanners=False,
        ...     by_fragment=False
        ... )
        >>> print(f"Sequential loading: {not dataloader.use_mixture_of_scanners}")
        Sequential loading: True

        >>> # Fragment-based loading for higher entropy
        >>> dataloader = SLAFDataLoader(
        ...     slaf_array=slaf_array,
        ...     use_mixture_of_scanners=False,
        ...     by_fragment=True
        ... )
        >>> print(f"Fragment-based loading: {dataloader.by_fragment}")
        Fragment-based loading: True

        >>> # Raw mode for external processing
        >>> dataloader = SLAFDataLoader(
        ...     slaf_array=slaf_array,
        ...     raw_mode=True
        ... )
        >>> for batch in dataloader:
        ...     print(f"Raw data type: {type(batch['x'])}")
        ...     break
        Raw data type: <class 'polars.dataframe.frame.DataFrame'>

        >>> # Multi-epoch training
        >>> dataloader = SLAFDataLoader(
        ...     slaf_array=slaf_array,
        ...     n_epochs=5
        ... )
        >>> print(f"Number of epochs: {dataloader.n_epochs}")
        Number of epochs: 5

        >>> # Custom configuration for training
        >>> dataloader = SLAFDataLoader(
        ...     slaf_array=slaf_array,
        ...     tokenizer_type="scgpt",
        ...     batch_size=64,
        ...     max_genes=1024
        ... )
        >>> print(f"Tokenizer type: {dataloader.tokenizer_type}")
        Tokenizer type: scgpt

        >>> # Training loop example
        >>> for batch_idx, batch in enumerate(dataloader):
        ...     input_ids = batch["input_ids"]
        ...     attention_mask = batch["attention_mask"]
        ...     cell_ids = batch["cell_ids"]
        ...     # Your training code here
        ...     if batch_idx >= 2:  # Just test first few batches
        ...         break
        >>> print("Training loop completed")
        Training loop completed

        >>> # Error handling for invalid tokenizer type
        >>> try:
        ...     dataloader = SLAFDataLoader(slaf_array, tokenizer_type="invalid")
        ... except ValueError as e:
        ...     print(f"Error: {e}")
        Error: Unsupported tokenizer type: invalid
    """

    device: Optional["torch.device"]  # type: ignore
    tokenizer: Optional["SLAFTokenizer"]  # type: ignore

    def __init__(
        self,
        slaf_array: SLAFArray,
        tokenizer_type: str = "geneformer",
        batch_size: int = 32,
        max_genes: int = 2048,
        vocab_size: int = 50000,
        n_expression_bins: int = 10,
        n_epochs: int = 1,  # Add n_epochs parameter
        raw_mode: bool = False,  # Add raw_mode parameter
        verbose: bool = True,  # Add verbose parameter
        batches_per_chunk: int = 1,  # Default to 1 for MoS (was 50 for sequential)
        by_fragment: bool = True,  # Default to True for MoS (was False for sequential)
        use_mixture_of_scanners: bool = True,  # Default to True for MoS (was False)
        n_scanners: int = 16,  # Add n_scanners parameter for MoS
        prefetch_batch_size: int = 4194304,  # Add prefetch_batch_size parameter for MoS
    ):
        """
        Initialize the SLAF DataLoader with training configuration.

        Args:
            slaf_array: SLAFArray instance containing the single-cell data.
                       Must be a valid SLAFArray with proper Lance dataset structure.

            # Tokenization Configuration
            tokenizer_type: Tokenization strategy to use. Options: "geneformer", "scgpt".
                          Geneformer uses ranked gene sequences, scGPT uses interleaved
                          gene-expression pairs. Ignored when raw_mode=True.
            max_genes: Maximum number of genes to include in each cell's tokenization.
                     For Geneformer: same as sequence length. For scGPT: number of
                     gene-expression pairs (sequence length = 2*max_genes+2).
            vocab_size: Size of the tokenizer vocabulary. Higher values allow more
                       genes but use more memory. Range: 1000-100000, default: 50000.
            n_expression_bins: Number of expression level bins for scGPT discretization.
                             Higher values provide finer expression resolution.
                             Range: 1-1000, default: 10.

            # Training Configuration
            batch_size: Number of cells per batch. Larger batches use more memory
                       but may improve training efficiency. Range: 1-512, default: 32.
            n_epochs: Number of epochs to run. The generator will automatically reset
                     after each epoch, enabling multi-epoch training on small datasets.
                     Default: 1.

            # Output Mode Configuration
            raw_mode: If True, return raw cell × gene data as Polars DataFrames
                     instead of pre-tokenized sequences. This bypasses tokenization
                     and windowing for maximum flexibility. Default: False.

            # Loading Strategy Configuration (MoS is now default)
            batches_per_chunk: Number of Lance batches to load per chunk for sequential loading.
                             Higher values use more memory but may improve throughput.
                             Range: 1-200, default: 1 (optimized for MoS). Only used when by_fragment=False.
            by_fragment: If True, use fragment-based loading instead of batch-based loading.
                        Fragment-based loading provides higher entropy but may be slightly slower.
                        Automatically enabled when use_mixture_of_scanners=True.
                        Default: True (enabled for MoS).
            use_mixture_of_scanners: If True, use mixture of scanners (MoS) approach for higher
                                   entropy by randomly sampling from multiple fragment generators.
                                   This provides the best randomization and is now the default
                                   for foundation model training. Default: True.
            n_scanners: Number of fragment generators to sample from simultaneously when using MoS.
                       Higher values provide better entropy but use more memory.
                       Range: 1-100, default: 16. Only used when use_mixture_of_scanners=True.
            prefetch_batch_size: Target number of rows to load per prefetch batch when using MoS.
                               Higher values improve throughput but use more memory.
                               Range: 1000-10000000, default: 4194304. Only used when
                               use_mixture_of_scanners=True.

            # System Configuration
            verbose: If True, print detailed timing and progress information.
                    If False, suppress all SLAF internal prints for clean output.
                    Default: True.

        Raises:
            ValueError: If tokenizer_type is not supported or parameters are invalid.
            RuntimeError: If PyTorch is not available or datasets module is missing.
            TypeError: If slaf_array is not a valid SLAFArray instance.
            ImportError: If required dependencies are not available.

        Loading Strategy Selection Guide:
            - For foundation model training: Use default settings (MoS provides 88% random entropy)
            - For maximum throughput: Set use_mixture_of_scanners=False, by_fragment=False
            - For external processing: Set raw_mode=True

        Examples:
            >>> # Basic initialization (MoS is now default)
            >>> slaf_array = SLAFArray("path/to/data.slaf")
            >>> dataloader = SLAFDataLoader(slaf_array)
            >>> print(f"Batch size: {dataloader.batch_size}")
            Batch size: 32
            >>> print(f"MoS enabled: {dataloader.use_mixture_of_scanners}")
            MoS enabled: True

            >>> # Sequential loading for maximum throughput
            >>> dataloader = SLAFDataLoader(
            ...     slaf_array=slaf_array,
            ...     use_mixture_of_scanners=False,
            ...     by_fragment=False
            ... )
            >>> print(f"Sequential loading: {not dataloader.use_mixture_of_scanners}")
            Sequential loading: True

            >>> # Fragment-based loading for higher entropy
            >>> dataloader = SLAFDataLoader(
            ...     slaf_array=slaf_array,
            ...     use_mixture_of_scanners=False,
            ...     by_fragment=True
            ... )
            >>> print(f"Fragment-based loading: {dataloader.by_fragment}")
            Fragment-based loading: True

            >>> # Raw mode for external processing
            >>> dataloader = SLAFDataLoader(
            ...     slaf_array=slaf_array,
            ...     raw_mode=True
            ... )
            >>> print(f"Raw mode: {dataloader.raw_mode}")
            Raw mode: True

            >>> # Error handling for invalid parameters
            >>> try:
            ...     dataloader = SLAFDataLoader(slaf_array, n_scanners=0)
            ... except ValueError as e:
            ...     print(f"Error: {e}")
            Error: n_scanners must be at least 1
        """
        self.slaf_array = slaf_array
        self.tokenizer_type = tokenizer_type
        self.batch_size = batch_size
        self.max_genes = max_genes
        self.n_epochs = n_epochs
        self.raw_mode = raw_mode  # Add raw_mode attribute
        self.verbose = verbose  # Add verbose attribute
        self.batches_per_chunk = batches_per_chunk  # Add batches_per_chunk attribute
        self.by_fragment = by_fragment  # Add by_fragment attribute
        self.use_mixture_of_scanners = use_mixture_of_scanners  # Add MoS attribute
        self.n_scanners = n_scanners  # Add n_scanners attribute
        self.prefetch_batch_size = (
            prefetch_batch_size  # Add prefetch_batch_size attribute
        )

        # Validate MoS parameters
        if self.use_mixture_of_scanners:
            if self.n_scanners < 1:
                raise ValueError("n_scanners must be at least 1")
            if self.n_scanners > 100:
                raise ValueError("n_scanners cannot exceed 100")
            if (
                self.prefetch_batch_size < 1000
            ):  # Allow smaller values for warm-up strategy
                raise ValueError("prefetch_batch_size must be at least 1,000")
            if self.prefetch_batch_size > 10000000:
                raise ValueError("prefetch_batch_size cannot exceed 10,000,000")

        # Device-agnostic: always return CPU tensors
        self.device = None

        # Check that required modules are available
        if not DATASETS_AVAILABLE:
            raise ImportError(
                "SLAFIterableDataset is required but not available. Please install required dependencies."
            )

        # Initialize tokenizer (only needed for non-raw mode)
        if not self.raw_mode:
            self.tokenizer = SLAFTokenizer(
                slaf_array=slaf_array,
                tokenizer_type=tokenizer_type,
                vocab_size=vocab_size,
                n_expression_bins=n_expression_bins,
            )

            # Get special tokens from tokenizer
            self.special_tokens = self.tokenizer.special_tokens
        else:
            # For raw mode, we don't need a tokenizer
            self.tokenizer = None
            self.special_tokens = None

        # Use IterableDataset
        self._dataset = SLAFIterableDataset(
            slaf_array=slaf_array,
            tokenizer=self.tokenizer,
            batch_size=batch_size,
            seed=42,  # TODO: make configurable
            max_queue_size=500,
            tokenizer_type=tokenizer_type,
            n_epochs=n_epochs,  # Pass n_epochs to dataset
            raw_mode=raw_mode,  # Pass raw_mode to dataset
            verbose=verbose,  # Pass verbose to dataset
            batches_per_chunk=batches_per_chunk,  # Pass batches_per_chunk to dataset
            by_fragment=by_fragment,  # Pass by_fragment to dataset
            use_mixture_of_scanners=use_mixture_of_scanners,  # Pass MoS to dataset
            n_scanners=n_scanners,  # Pass n_scanners to dataset
            prefetch_batch_size=prefetch_batch_size,  # Pass prefetch_batch_size to dataset
        )

    def __iter__(self):
        """
        Iterate through batches of single-cell data based on the configured mode.

        Yields batches of data suitable for machine learning training. The output format
        depends on the configuration:

        - **Tokenized mode** (default): Yields pre-tokenized sequences with attention masks
        - **Raw mode**: Yields raw Polars DataFrames for external processing
        - **Multi-epoch**: Automatically handles epoch transitions when n_epochs > 1

        The loading strategy (sequential, fragment-based, or Mixture of Scanners) affects
        data entropy and throughput but not the output format.

        Yields:
            dict: Batch dictionary containing:
                - **Tokenized mode** (raw_mode=False):
                    - input_ids: Pre-tokenized gene expression data (torch.Tensor)
                    - attention_mask: Boolean mask indicating valid tokens (torch.Tensor)
                    - cell_ids: Integer IDs of cells in the batch (torch.Tensor)
                - **Raw mode** (raw_mode=True):
                    - x: Raw cell × gene data as Polars DataFrame
                    - cell_ids: List of cell integer IDs in the batch
                - **Multi-epoch** (when n_epochs > 1):
                    - epoch: Current epoch number (int)

        Note:
            All tensors are returned on CPU for device-agnostic training.
            The training loop should handle device transfer as needed.

        Raises:
            ValueError: If the tokenizer type is not supported.
            RuntimeError: If batch processing fails.

        Examples:
            >>> # Basic iteration (tokenized mode)
            >>> slaf_array = SLAFArray("path/to/data.slaf")
            >>> dataloader = SLAFDataLoader(slaf_array, batch_size=16)
            >>> for batch in dataloader:
            ...     print(f"Batch keys: {list(batch.keys())}")
            ...     print(f"Input shape: {batch['input_ids'].shape}")
            ...     print(f"Cell IDs: {batch['cell_ids']}")
            ...     break
            Batch keys: ['input_ids', 'attention_mask', 'cell_ids']
            Input shape: (16, 2048)
            Cell IDs: tensor([0, 1, 2, ..., 13, 14, 15])

            >>> # Raw mode iteration
            >>> dataloader = SLAFDataLoader(slaf_array, raw_mode=True, batch_size=16)
            >>> for batch in dataloader:
            ...     print(f"Raw data type: {type(batch['x'])}")
            ...     print(f"Cell IDs: {batch['cell_ids']}")
            ...     break
            Raw data type: <class 'polars.dataframe.frame.DataFrame'>
            Cell IDs: [0, 1, 2, ..., 13, 14, 15]

            >>> # Multi-epoch training
            >>> dataloader = SLAFDataLoader(slaf_array, n_epochs=3)
            >>> epochs_seen = set()
            >>> for batch in dataloader:
            ...     if 'epoch' in batch:
            ...         epochs_seen.add(batch['epoch'])
            ...     if len(epochs_seen) >= 3:  # Stop after seeing all epochs
            ...         break
            >>> print(f"Epochs completed: {sorted(epochs_seen)}")
            Epochs completed: [0, 1, 2]

            >>> # Training loop with error handling
            >>> for batch_idx, batch in enumerate(dataloader):
            ...     try:
            ...         if 'input_ids' in batch:  # Tokenized mode
            ...             input_ids = batch["input_ids"]
            ...             attention_mask = batch["attention_mask"]
            ...             cell_ids = batch["cell_ids"]
            ...         else:  # Raw mode
            ...             x = batch["x"]
            ...             cell_ids = batch["cell_ids"]
            ...         # Your training code here
            ...         print(f"Processed batch {batch_idx}")
            ...     except Exception as e:
            ...         print(f"Error in batch {batch_idx}: {e}")
            ...         continue
            ...     if batch_idx >= 2:  # Just first few batches
            ...         break
            Processed batch 0
            Processed batch 1
            Processed batch 2

            >>> # Different tokenizer types
            >>> dataloader_geneformer = SLAFDataLoader(slaf_array, tokenizer_type="geneformer")
            >>> dataloader_scgpt = SLAFDataLoader(slaf_array, tokenizer_type="scgpt")
            >>>
            >>> # Compare batch shapes
            >>> for batch in dataloader_geneformer:
            ...     print(f"Geneformer input shape: {batch['input_ids'].shape}")
            ...     break
            Geneformer input shape: (32, 2048)
            >>> for batch in dataloader_scgpt:
            ...     print(f"scGPT input shape: {batch['input_ids'].shape}")
            ...     break
            scGPT input shape: (32, 1024)
        """
        yield from self._dataset

    def __len__(self):
        """
        Return the number of batches in the dataset.

        Note: Since SLAFDataLoader uses an IterableDataset that streams data,
        the exact number of batches is not known in advance. This method
        returns 0 to indicate an unknown length for streaming datasets.

        Returns:
            int: Always returns 0 to indicate unknown length for streaming datasets.

        Examples:
            >>> # Check dataset length
            >>> slaf_array = SLAFArray("path/to/data.slaf")
            >>> dataloader = SLAFDataLoader(slaf_array)
            >>> print(f"Dataset length: {len(dataloader)}")
            Dataset length: 0

            >>> # IterableDataset behavior
            >>> batch_count = 0
            >>> for batch in dataloader:
            ...     batch_count += 1
            ...     if batch_count >= 5:  # Just count first 5 batches
            ...         break
            >>> print(f"Actually processed {batch_count} batches")
            Actually processed 5 batches

            >>> # Length is consistent
            >>> print(f"Length check: {len(dataloader)}")
            Length check: 0
        """
        return 0  # Indicates unknown length

    def __del__(self):
        """
        Cleanup method to stop async prefetching.

        This method is called when the DataLoader object is garbage collected.
        It ensures that the underlying dataset's prefetcher is properly cleaned up
        to prevent resource leaks.

        Examples:
            >>> # DataLoader cleanup happens automatically
            >>> slaf_array = SLAFArray("path/to/data.slaf")
            >>> dataloader = SLAFDataLoader(slaf_array)
            >>> print("DataLoader created")
            DataLoader created
            >>> # When dataloader goes out of scope, __del__ is called automatically
            >>> del dataloader
            >>> print("DataLoader destroyed and cleaned up")
            DataLoader destroyed and cleaned up

            >>> # Manual cleanup (not usually needed)
            >>> dataloader = SLAFDataLoader(slaf_array)
            >>> dataloader.__del__()
            >>> print("Manual cleanup completed")
            Manual cleanup completed
        """
        if hasattr(self, "_dataset"):
            # The SLAFIterableDataset doesn't have a stop method,
            # so we just let it finish its current epoch.
            pass
Functions
__init__(slaf_array: SLAFArray, tokenizer_type: str = 'geneformer', batch_size: int = 32, max_genes: int = 2048, vocab_size: int = 50000, n_expression_bins: int = 10, n_epochs: int = 1, raw_mode: bool = False, verbose: bool = True, batches_per_chunk: int = 1, by_fragment: bool = True, use_mixture_of_scanners: bool = True, n_scanners: int = 16, prefetch_batch_size: int = 4194304)

Initialize the SLAF DataLoader with training configuration.

Parameters:

Name Type Description Default
slaf_array SLAFArray

SLAFArray instance containing the single-cell data. Must be a valid SLAFArray with proper Lance dataset structure.

required
tokenizer_type str

Tokenization strategy to use. Options: "geneformer", "scgpt". Geneformer uses ranked gene sequences, scGPT uses interleaved gene-expression pairs. Ignored when raw_mode=True.

'geneformer'
max_genes int

Maximum number of genes to include in each cell's tokenization. For Geneformer: same as sequence length. For scGPT: number of gene-expression pairs (sequence length = 2*max_genes+2).

2048
vocab_size int

Size of the tokenizer vocabulary. Higher values allow more genes but use more memory. Range: 1000-100000, default: 50000.

50000
n_expression_bins int

Number of expression level bins for scGPT discretization. Higher values provide finer expression resolution. Range: 1-1000, default: 10.

10
batch_size int

Number of cells per batch. Larger batches use more memory but may improve training efficiency. Range: 1-512, default: 32.

32
n_epochs int

Number of epochs to run. The generator will automatically reset after each epoch, enabling multi-epoch training on small datasets. Default: 1.

1
raw_mode bool

If True, return raw cell × gene data as Polars DataFrames instead of pre-tokenized sequences. This bypasses tokenization and windowing for maximum flexibility. Default: False.

False
batches_per_chunk int

Number of Lance batches to load per chunk for sequential loading. Higher values use more memory but may improve throughput. Range: 1-200, default: 1 (optimized for MoS). Only used when by_fragment=False.

1
by_fragment bool

If True, use fragment-based loading instead of batch-based loading. Fragment-based loading provides higher entropy but may be slightly slower. Automatically enabled when use_mixture_of_scanners=True. Default: True (enabled for MoS).

True
use_mixture_of_scanners bool

If True, use mixture of scanners (MoS) approach for higher entropy by randomly sampling from multiple fragment generators. This provides the best randomization and is now the default for foundation model training. Default: True.

True
n_scanners int

Number of fragment generators to sample from simultaneously when using MoS. Higher values provide better entropy but use more memory. Range: 1-100, default: 16. Only used when use_mixture_of_scanners=True.

16
prefetch_batch_size int

Target number of rows to load per prefetch batch when using MoS. Higher values improve throughput but use more memory. Range: 1000-10000000, default: 4194304. Only used when use_mixture_of_scanners=True.

4194304
verbose bool

If True, print detailed timing and progress information. If False, suppress all SLAF internal prints for clean output. Default: True.

True

Raises:

Type Description
ValueError

If tokenizer_type is not supported or parameters are invalid.

RuntimeError

If PyTorch is not available or datasets module is missing.

TypeError

If slaf_array is not a valid SLAFArray instance.

ImportError

If required dependencies are not available.

Loading Strategy Selection Guide
  • For foundation model training: Use default settings (MoS provides 88% random entropy)
  • For maximum throughput: Set use_mixture_of_scanners=False, by_fragment=False
  • For external processing: Set raw_mode=True

Examples:

>>> # Basic initialization (MoS is now default)
>>> slaf_array = SLAFArray("path/to/data.slaf")
>>> dataloader = SLAFDataLoader(slaf_array)
>>> print(f"Batch size: {dataloader.batch_size}")
Batch size: 32
>>> print(f"MoS enabled: {dataloader.use_mixture_of_scanners}")
MoS enabled: True
>>> # Sequential loading for maximum throughput
>>> dataloader = SLAFDataLoader(
...     slaf_array=slaf_array,
...     use_mixture_of_scanners=False,
...     by_fragment=False
... )
>>> print(f"Sequential loading: {not dataloader.use_mixture_of_scanners}")
Sequential loading: True
>>> # Fragment-based loading for higher entropy
>>> dataloader = SLAFDataLoader(
...     slaf_array=slaf_array,
...     use_mixture_of_scanners=False,
...     by_fragment=True
... )
>>> print(f"Fragment-based loading: {dataloader.by_fragment}")
Fragment-based loading: True
>>> # Raw mode for external processing
>>> dataloader = SLAFDataLoader(
...     slaf_array=slaf_array,
...     raw_mode=True
... )
>>> print(f"Raw mode: {dataloader.raw_mode}")
Raw mode: True
>>> # Error handling for invalid parameters
>>> try:
...     dataloader = SLAFDataLoader(slaf_array, n_scanners=0)
... except ValueError as e:
...     print(f"Error: {e}")
Error: n_scanners must be at least 1
Source code in slaf/ml/dataloaders.py
def __init__(
    self,
    slaf_array: SLAFArray,
    tokenizer_type: str = "geneformer",
    batch_size: int = 32,
    max_genes: int = 2048,
    vocab_size: int = 50000,
    n_expression_bins: int = 10,
    n_epochs: int = 1,  # Add n_epochs parameter
    raw_mode: bool = False,  # Add raw_mode parameter
    verbose: bool = True,  # Add verbose parameter
    batches_per_chunk: int = 1,  # Default to 1 for MoS (was 50 for sequential)
    by_fragment: bool = True,  # Default to True for MoS (was False for sequential)
    use_mixture_of_scanners: bool = True,  # Default to True for MoS (was False)
    n_scanners: int = 16,  # Add n_scanners parameter for MoS
    prefetch_batch_size: int = 4194304,  # Add prefetch_batch_size parameter for MoS
):
    """
    Initialize the SLAF DataLoader with training configuration.

    Args:
        slaf_array: SLAFArray instance containing the single-cell data.
                   Must be a valid SLAFArray with proper Lance dataset structure.

        # Tokenization Configuration
        tokenizer_type: Tokenization strategy to use. Options: "geneformer", "scgpt".
                      Geneformer uses ranked gene sequences, scGPT uses interleaved
                      gene-expression pairs. Ignored when raw_mode=True.
        max_genes: Maximum number of genes to include in each cell's tokenization.
                 For Geneformer: same as sequence length. For scGPT: number of
                 gene-expression pairs (sequence length = 2*max_genes+2).
        vocab_size: Size of the tokenizer vocabulary. Higher values allow more
                   genes but use more memory. Range: 1000-100000, default: 50000.
        n_expression_bins: Number of expression level bins for scGPT discretization.
                         Higher values provide finer expression resolution.
                         Range: 1-1000, default: 10.

        # Training Configuration
        batch_size: Number of cells per batch. Larger batches use more memory
                   but may improve training efficiency. Range: 1-512, default: 32.
        n_epochs: Number of epochs to run. The generator will automatically reset
                 after each epoch, enabling multi-epoch training on small datasets.
                 Default: 1.

        # Output Mode Configuration
        raw_mode: If True, return raw cell × gene data as Polars DataFrames
                 instead of pre-tokenized sequences. This bypasses tokenization
                 and windowing for maximum flexibility. Default: False.

        # Loading Strategy Configuration (MoS is now default)
        batches_per_chunk: Number of Lance batches to load per chunk for sequential loading.
                         Higher values use more memory but may improve throughput.
                         Range: 1-200, default: 1 (optimized for MoS). Only used when by_fragment=False.
        by_fragment: If True, use fragment-based loading instead of batch-based loading.
                    Fragment-based loading provides higher entropy but may be slightly slower.
                    Automatically enabled when use_mixture_of_scanners=True.
                    Default: True (enabled for MoS).
        use_mixture_of_scanners: If True, use mixture of scanners (MoS) approach for higher
                               entropy by randomly sampling from multiple fragment generators.
                               This provides the best randomization and is now the default
                               for foundation model training. Default: True.
        n_scanners: Number of fragment generators to sample from simultaneously when using MoS.
                   Higher values provide better entropy but use more memory.
                   Range: 1-100, default: 16. Only used when use_mixture_of_scanners=True.
        prefetch_batch_size: Target number of rows to load per prefetch batch when using MoS.
                           Higher values improve throughput but use more memory.
                           Range: 1000-10000000, default: 4194304. Only used when
                           use_mixture_of_scanners=True.

        # System Configuration
        verbose: If True, print detailed timing and progress information.
                If False, suppress all SLAF internal prints for clean output.
                Default: True.

    Raises:
        ValueError: If tokenizer_type is not supported or parameters are invalid.
        RuntimeError: If PyTorch is not available or datasets module is missing.
        TypeError: If slaf_array is not a valid SLAFArray instance.
        ImportError: If required dependencies are not available.

    Loading Strategy Selection Guide:
        - For foundation model training: Use default settings (MoS provides 88% random entropy)
        - For maximum throughput: Set use_mixture_of_scanners=False, by_fragment=False
        - For external processing: Set raw_mode=True

    Examples:
        >>> # Basic initialization (MoS is now default)
        >>> slaf_array = SLAFArray("path/to/data.slaf")
        >>> dataloader = SLAFDataLoader(slaf_array)
        >>> print(f"Batch size: {dataloader.batch_size}")
        Batch size: 32
        >>> print(f"MoS enabled: {dataloader.use_mixture_of_scanners}")
        MoS enabled: True

        >>> # Sequential loading for maximum throughput
        >>> dataloader = SLAFDataLoader(
        ...     slaf_array=slaf_array,
        ...     use_mixture_of_scanners=False,
        ...     by_fragment=False
        ... )
        >>> print(f"Sequential loading: {not dataloader.use_mixture_of_scanners}")
        Sequential loading: True

        >>> # Fragment-based loading for higher entropy
        >>> dataloader = SLAFDataLoader(
        ...     slaf_array=slaf_array,
        ...     use_mixture_of_scanners=False,
        ...     by_fragment=True
        ... )
        >>> print(f"Fragment-based loading: {dataloader.by_fragment}")
        Fragment-based loading: True

        >>> # Raw mode for external processing
        >>> dataloader = SLAFDataLoader(
        ...     slaf_array=slaf_array,
        ...     raw_mode=True
        ... )
        >>> print(f"Raw mode: {dataloader.raw_mode}")
        Raw mode: True

        >>> # Error handling for invalid parameters
        >>> try:
        ...     dataloader = SLAFDataLoader(slaf_array, n_scanners=0)
        ... except ValueError as e:
        ...     print(f"Error: {e}")
        Error: n_scanners must be at least 1
    """
    self.slaf_array = slaf_array
    self.tokenizer_type = tokenizer_type
    self.batch_size = batch_size
    self.max_genes = max_genes
    self.n_epochs = n_epochs
    self.raw_mode = raw_mode  # Add raw_mode attribute
    self.verbose = verbose  # Add verbose attribute
    self.batches_per_chunk = batches_per_chunk  # Add batches_per_chunk attribute
    self.by_fragment = by_fragment  # Add by_fragment attribute
    self.use_mixture_of_scanners = use_mixture_of_scanners  # Add MoS attribute
    self.n_scanners = n_scanners  # Add n_scanners attribute
    self.prefetch_batch_size = (
        prefetch_batch_size  # Add prefetch_batch_size attribute
    )

    # Validate MoS parameters
    if self.use_mixture_of_scanners:
        if self.n_scanners < 1:
            raise ValueError("n_scanners must be at least 1")
        if self.n_scanners > 100:
            raise ValueError("n_scanners cannot exceed 100")
        if (
            self.prefetch_batch_size < 1000
        ):  # Allow smaller values for warm-up strategy
            raise ValueError("prefetch_batch_size must be at least 1,000")
        if self.prefetch_batch_size > 10000000:
            raise ValueError("prefetch_batch_size cannot exceed 10,000,000")

    # Device-agnostic: always return CPU tensors
    self.device = None

    # Check that required modules are available
    if not DATASETS_AVAILABLE:
        raise ImportError(
            "SLAFIterableDataset is required but not available. Please install required dependencies."
        )

    # Initialize tokenizer (only needed for non-raw mode)
    if not self.raw_mode:
        self.tokenizer = SLAFTokenizer(
            slaf_array=slaf_array,
            tokenizer_type=tokenizer_type,
            vocab_size=vocab_size,
            n_expression_bins=n_expression_bins,
        )

        # Get special tokens from tokenizer
        self.special_tokens = self.tokenizer.special_tokens
    else:
        # For raw mode, we don't need a tokenizer
        self.tokenizer = None
        self.special_tokens = None

    # Use IterableDataset
    self._dataset = SLAFIterableDataset(
        slaf_array=slaf_array,
        tokenizer=self.tokenizer,
        batch_size=batch_size,
        seed=42,  # TODO: make configurable
        max_queue_size=500,
        tokenizer_type=tokenizer_type,
        n_epochs=n_epochs,  # Pass n_epochs to dataset
        raw_mode=raw_mode,  # Pass raw_mode to dataset
        verbose=verbose,  # Pass verbose to dataset
        batches_per_chunk=batches_per_chunk,  # Pass batches_per_chunk to dataset
        by_fragment=by_fragment,  # Pass by_fragment to dataset
        use_mixture_of_scanners=use_mixture_of_scanners,  # Pass MoS to dataset
        n_scanners=n_scanners,  # Pass n_scanners to dataset
        prefetch_batch_size=prefetch_batch_size,  # Pass prefetch_batch_size to dataset
    )

Functions

get_optimal_device()

Get the optimal device for PyTorch operations (CUDA > MPS > CPU).

This function determines the best available device for PyTorch operations by checking for CUDA first, then MPS (Apple Silicon), and falling back to CPU if neither is available.

Returns:

Type Description

torch.device | None: The optimal device, or None if PyTorch is not available.

Examples:

>>> # Check optimal device
>>> device = get_optimal_device()
>>> print(f"Optimal device: {device}")
Optimal device: cuda
>>> # Device priority (CUDA > MPS > CPU)
>>> # If CUDA is available: cuda
>>> # If MPS is available but not CUDA: mps
>>> # If neither: cpu
>>> device = get_optimal_device()
>>> if device.type == "cuda":
...     print("Using CUDA GPU")
... elif device.type == "mps":
...     print("Using Apple Silicon GPU")
... else:
...     print("Using CPU")
Using CUDA GPU
>>> # Handle PyTorch not available
>>> # This would return None if PyTorch is not installed
>>> device = get_optimal_device()
>>> if device is None:
...     print("PyTorch not available")
... else:
...     print(f"Device available: {device}")
Device available: cuda
Source code in slaf/ml/dataloaders.py
def get_optimal_device():
    """
    Get the optimal device for PyTorch operations (CUDA > MPS > CPU).

    This function determines the best available device for PyTorch operations
    by checking for CUDA first, then MPS (Apple Silicon), and falling back
    to CPU if neither is available.

    Returns:
        torch.device | None: The optimal device, or None if PyTorch is not available.

    Examples:
        >>> # Check optimal device
        >>> device = get_optimal_device()
        >>> print(f"Optimal device: {device}")
        Optimal device: cuda

        >>> # Device priority (CUDA > MPS > CPU)
        >>> # If CUDA is available: cuda
        >>> # If MPS is available but not CUDA: mps
        >>> # If neither: cpu
        >>> device = get_optimal_device()
        >>> if device.type == "cuda":
        ...     print("Using CUDA GPU")
        ... elif device.type == "mps":
        ...     print("Using Apple Silicon GPU")
        ... else:
        ...     print("Using CPU")
        Using CUDA GPU

        >>> # Handle PyTorch not available
        >>> # This would return None if PyTorch is not installed
        >>> device = get_optimal_device()
        >>> if device is None:
        ...     print("PyTorch not available")
        ... else:
        ...     print(f"Device available: {device}")
        Device available: cuda
    """
    if not TORCH_AVAILABLE:
        return None

    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")

get_device_info()

Get comprehensive device information for debugging.

This function returns detailed information about the available PyTorch devices, including CUDA and MPS availability, device counts, and capabilities. Useful for debugging device-related issues and understanding the system configuration.

Returns:

Name Type Description
dict

Device information dictionary containing: - torch_available: Whether PyTorch is available - cuda_available: Whether CUDA is available - mps_available: Whether MPS (Apple Silicon) is available - optimal_device: String representation of the optimal device - cuda_device_count: Number of CUDA devices (if CUDA available) - cuda_device_name: Name of the first CUDA device (if available) - cuda_device_capability: Compute capability of first CUDA device

Examples:

>>> # Get device information
>>> info = get_device_info()
>>> print(f"PyTorch available: {info['torch_available']}")
PyTorch available: True
>>> print(f"CUDA available: {info['cuda_available']}")
CUDA available: True
>>> print(f"Optimal device: {info['optimal_device']}")
Optimal device: cuda
>>> # Check CUDA details
>>> if info['cuda_available']:
...     print(f"CUDA devices: {info['cuda_device_count']}")
...     print(f"Device name: {info['cuda_device_name']}")
...     print(f"Capability: {info['cuda_device_capability']}")
CUDA devices: 1
Device name: NVIDIA GeForce RTX 3080
Capability: (8, 6)
>>> # Check MPS availability
>>> print(f"MPS available: {info['mps_available']}")
MPS available: False
>>> # Handle PyTorch not available
>>> # This would show torch_available: False if PyTorch is not installed
>>> info = get_device_info()
>>> if not info['torch_available']:
...     print("PyTorch not available")
... else:
...     print("PyTorch is available")
PyTorch is available
Source code in slaf/ml/dataloaders.py
def get_device_info():
    """
    Get comprehensive device information for debugging.

    This function returns detailed information about the available PyTorch devices,
    including CUDA and MPS availability, device counts, and capabilities.
    Useful for debugging device-related issues and understanding the system
    configuration.

    Returns:
        dict: Device information dictionary containing:
            - torch_available: Whether PyTorch is available
            - cuda_available: Whether CUDA is available
            - mps_available: Whether MPS (Apple Silicon) is available
            - optimal_device: String representation of the optimal device
            - cuda_device_count: Number of CUDA devices (if CUDA available)
            - cuda_device_name: Name of the first CUDA device (if available)
            - cuda_device_capability: Compute capability of first CUDA device

    Examples:
        >>> # Get device information
        >>> info = get_device_info()
        >>> print(f"PyTorch available: {info['torch_available']}")
        PyTorch available: True
        >>> print(f"CUDA available: {info['cuda_available']}")
        CUDA available: True
        >>> print(f"Optimal device: {info['optimal_device']}")
        Optimal device: cuda

        >>> # Check CUDA details
        >>> if info['cuda_available']:
        ...     print(f"CUDA devices: {info['cuda_device_count']}")
        ...     print(f"Device name: {info['cuda_device_name']}")
        ...     print(f"Capability: {info['cuda_device_capability']}")
        CUDA devices: 1
        Device name: NVIDIA GeForce RTX 3080
        Capability: (8, 6)

        >>> # Check MPS availability
        >>> print(f"MPS available: {info['mps_available']}")
        MPS available: False

        >>> # Handle PyTorch not available
        >>> # This would show torch_available: False if PyTorch is not installed
        >>> info = get_device_info()
        >>> if not info['torch_available']:
        ...     print("PyTorch not available")
        ... else:
        ...     print("PyTorch is available")
        PyTorch is available
    """
    if not TORCH_AVAILABLE:
        return {
            "torch_available": False,
            "cuda_available": False,
            "mps_available": False,
            "optimal_device": None,
        }

    info = {
        "torch_available": True,
        "cuda_available": torch.cuda.is_available(),
        "mps_available": torch.backends.mps.is_available(),
        "optimal_device": str(get_optimal_device()),
    }

    if torch.cuda.is_available():
        info["cuda_device_count"] = torch.cuda.device_count()
        info["cuda_device_name"] = torch.cuda.get_device_name(0)
        info["cuda_device_capability"] = torch.cuda.get_device_capability(0)

    return info

slaf.ml.tiledb_dataloaders

TileDB Dataloader for Single-Cell Data

This module provides efficient streaming of single-cell data from TileDB SOMA format using PyTorch IterableDataset and DataLoader. It follows a similar pattern to SLAF's dataloader implementation for consistency and performance comparison.

Classes

TileDBPrefetchBatch dataclass

Container for a batch of TileDB data with metadata.

This dataclass holds a processed batch of TileDB SOMA data along with associated metadata for tracking performance and debugging. It serves as the primary data structure passed between the batch processor and the async prefetcher.

Attributes:

Name Type Description
batch_id int

Unique identifier for this batch within the current epoch.

batch_df DataFrame

Polars DataFrame containing the cell-gene expression data with columns: cell_integer_id, gene_integer_id, value.

cell_integer_ids list[int]

List of unique cell IDs present in this batch.

process_time float

Time taken to process this batch (in seconds).

memory_mb float

Memory usage at the time of batch creation (in MB).

Examples:

>>> # Create a batch container
>>> batch = TileDBPrefetchBatch(
...     batch_id=0,
...     batch_df=df,
...     cell_integer_ids=[0, 1, 2, 3],
...     process_time=0.1,
...     memory_mb=128.5
... )
>>> print(f"Batch {batch.batch_id} has {len(batch.cell_integer_ids)} cells")
Batch 0 has 4 cells
Source code in slaf/ml/tiledb_dataloaders.py
@dataclass
class TileDBPrefetchBatch:
    """
    Container for a batch of TileDB data with metadata.

    This dataclass holds a processed batch of TileDB SOMA data along with
    associated metadata for tracking performance and debugging. It serves as
    the primary data structure passed between the batch processor and the
    async prefetcher.

    Attributes:
        batch_id: Unique identifier for this batch within the current epoch.
        batch_df: Polars DataFrame containing the cell-gene expression data
                 with columns: cell_integer_id, gene_integer_id, value.
        cell_integer_ids: List of unique cell IDs present in this batch.
        process_time: Time taken to process this batch (in seconds).
        memory_mb: Memory usage at the time of batch creation (in MB).

    Examples:
        >>> # Create a batch container
        >>> batch = TileDBPrefetchBatch(
        ...     batch_id=0,
        ...     batch_df=df,
        ...     cell_integer_ids=[0, 1, 2, 3],
        ...     process_time=0.1,
        ...     memory_mb=128.5
        ... )
        >>> print(f"Batch {batch.batch_id} has {len(batch.cell_integer_ids)} cells")
        Batch 0 has 4 cells
    """

    batch_id: int
    batch_df: (
        pl.DataFrame
    )  # Polars DataFrame with cell_integer_id, gene_integer_id, value
    cell_integer_ids: list[int]  # List of cell IDs in this batch
    process_time: float
    memory_mb: float

TileDBBatchProcessor

High-performance batch processor for TileDB SOMA data with multiple loading strategies.

TileDBBatchProcessor provides efficient streaming and processing of single-cell data from TileDB SOMA format. It supports multiple loading strategies including Mixture of Scanners (MoS) for maximum entropy and sequential loading for maximum throughput.

Key Features
  • Multiple loading strategies:
    • Mixture of Scanners (MoS): Random sampling from multiple generators for maximum entropy and randomization (default)
    • Sequential loading: Contiguous data loading for maximum throughput
  • Streaming data processing with configurable batch sizes
  • Built-in shuffling strategies for data randomization
  • Multi-epoch training support with automatic epoch transitions
  • Comprehensive timing and memory monitoring
  • Error handling and recovery mechanisms
  • Configurable prefetch batch sizes for different dataset sizes
Loading Strategies
  1. Mixture of Scanners (default): Randomly samples from multiple fragment generators for maximum entropy and randomization
  2. Sequential: Loads contiguous data chunks for maximum throughput

Examples:

>>> # Basic usage with default MoS strategy
>>> processor = TileDBBatchProcessor(
...     tiledb_path="path/to/experiment",
...     batch_size=32,
...     prefetch_batch_size=100
... )
>>> batch = processor.load_prefetch_batch()
>>> print(f"Loaded batch with {len(batch.cell_integer_ids)} cells")
Loaded batch with 100 cells
>>> # Sequential loading for maximum throughput
>>> processor = TileDBBatchProcessor(
...     tiledb_path="path/to/experiment",
...     use_mixture_of_scanners=False,
...     batch_size=64
... )
>>> print(f"MoS enabled: {processor.use_mixture_of_scanners}")
MoS enabled: False
>>> # Multi-epoch training
>>> processor = TileDBBatchProcessor(
...     tiledb_path="path/to/experiment",
...     n_epochs=3
... )
>>> print(f"Number of epochs: {processor.n_epochs}")
Number of epochs: 3
Source code in slaf/ml/tiledb_dataloaders.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
class TileDBBatchProcessor:
    """
    High-performance batch processor for TileDB SOMA data with multiple loading strategies.

    TileDBBatchProcessor provides efficient streaming and processing of single-cell data
    from TileDB SOMA format. It supports multiple loading strategies including Mixture
    of Scanners (MoS) for maximum entropy and sequential loading for maximum throughput.

    Key Features:
        - Multiple loading strategies:
            * Mixture of Scanners (MoS): Random sampling from multiple generators for
              maximum entropy and randomization (default)
            * Sequential loading: Contiguous data loading for maximum throughput
        - Streaming data processing with configurable batch sizes
        - Built-in shuffling strategies for data randomization
        - Multi-epoch training support with automatic epoch transitions
        - Comprehensive timing and memory monitoring
        - Error handling and recovery mechanisms
        - Configurable prefetch batch sizes for different dataset sizes

    Loading Strategies:
        1. Mixture of Scanners (default): Randomly samples from multiple fragment
           generators for maximum entropy and randomization
        2. Sequential: Loads contiguous data chunks for maximum throughput

    Examples:
        >>> # Basic usage with default MoS strategy
        >>> processor = TileDBBatchProcessor(
        ...     tiledb_path="path/to/experiment",
        ...     batch_size=32,
        ...     prefetch_batch_size=100
        ... )
        >>> batch = processor.load_prefetch_batch()
        >>> print(f"Loaded batch with {len(batch.cell_integer_ids)} cells")
        Loaded batch with 100 cells

        >>> # Sequential loading for maximum throughput
        >>> processor = TileDBBatchProcessor(
        ...     tiledb_path="path/to/experiment",
        ...     use_mixture_of_scanners=False,
        ...     batch_size=64
        ... )
        >>> print(f"MoS enabled: {processor.use_mixture_of_scanners}")
        MoS enabled: False

        >>> # Multi-epoch training
        >>> processor = TileDBBatchProcessor(
        ...     tiledb_path="path/to/experiment",
        ...     n_epochs=3
        ... )
        >>> print(f"Number of epochs: {processor.n_epochs}")
        Number of epochs: 3
    """

    def __init__(
        self,
        tiledb_path: str,
        batch_size: int = 32,
        prefetch_batch_size: int = 100,
        seed: int = 42,
        n_epochs: int = 1,
        verbose: bool = True,
        log_metrics: bool = False,
        use_mixture_of_scanners: bool = True,
        n_readers: int = 50,
        n_scanners: int = 8,
    ):
        """
        Initialize the TileDB batch processor with training configuration.

        Args:
            tiledb_path: Path to the TileDB SOMA experiment directory.
                         Must contain a valid SOMA experiment with RNA measurement data.
            batch_size: Number of cells per training batch. Larger batches use more
                       memory but may improve training efficiency. Range: 1-512, default: 32.
            prefetch_batch_size: Number of cells to load per prefetch batch from TileDB.
                               Higher values improve throughput but use more memory.
                               Range: 10-10000, default: 100.
            seed: Random seed for reproducible shuffling and MoS sampling.
                  Used for consistent data ordering across runs. Default: 42.
            n_epochs: Number of epochs to run. The processor will automatically reset
                     after each epoch, enabling multi-epoch training. Default: 1.
            verbose: If True, print detailed timing and progress information.
                    If False, suppress all internal prints for clean output. Default: True.
            log_metrics: If True, collect detailed timing metrics for performance analysis.
                        Metrics include loading time, shuffle time, and memory usage.
                        Default: False.
            use_mixture_of_scanners: If True, use MoS strategy for higher entropy by
                                   randomly sampling from multiple fragment generators.
                                   Provides better randomization for foundation model training.
                                   Default: True.
            n_readers: Total number of fragment generators to create when using MoS.
                      Higher values provide better entropy but use more memory.
                      Range: 1-1000, default: 50.
            n_scanners: Number of active scanners to sample from simultaneously when using MoS.
                       Higher values provide better entropy but use more memory.
                       Range: 1-100, default: 8.

        Raises:
            ImportError: If TileDB SOMA is not available.
            ValueError: If MoS parameters are invalid (n_readers < 1, n_scanners < 1,
                       or n_scanners > n_readers).
            RuntimeError: If the TileDB experiment cannot be opened or is invalid.

        Examples:
            >>> # Basic initialization with default MoS strategy
            >>> processor = TileDBBatchProcessor(
            ...     tiledb_path="path/to/experiment",
            ...     batch_size=32,
            ...     prefetch_batch_size=100
            ... )
            >>> print(f"Total cells: {processor.total_cells}")
            Total cells: 50000

            >>> # Sequential loading for maximum throughput
            >>> processor = TileDBBatchProcessor(
            ...     tiledb_path="path/to/experiment",
            ...     use_mixture_of_scanners=False,
            ...     batch_size=64
            ... )
            >>> print(f"MoS enabled: {processor.use_mixture_of_scanners}")
            MoS enabled: False

            >>> # High-entropy MoS configuration
            >>> processor = TileDBBatchProcessor(
            ...     tiledb_path="path/to/experiment",
            ...     n_readers=100,
            ...     n_scanners=16
            ... )
            >>> print(f"MoS readers: {processor.n_readers}, scanners: {processor.n_scanners}")
            MoS readers: 100, scanners: 16
        """
        if not TILEDB_AVAILABLE:
            raise ImportError("TileDB SOMA is required but not available")

        self.tiledb_path = tiledb_path
        self.batch_size = batch_size
        self.prefetch_batch_size = prefetch_batch_size
        self.seed = seed
        self.n_epochs = n_epochs
        self.verbose = verbose
        self.log_metrics = log_metrics
        self.use_mixture_of_scanners = use_mixture_of_scanners
        self.n_readers = n_readers
        self.n_scanners = n_scanners

        # Validate MoS parameters
        if self.use_mixture_of_scanners:
            if self.n_readers < 1:
                raise ValueError("n_readers must be at least 1")
            if self.n_scanners < 1:
                raise ValueError("n_scanners must be at least 1")
            if self.n_scanners > self.n_readers:
                raise ValueError("n_scanners cannot exceed n_readers")

        # Initialize state
        self.batch_id = 0
        self.current_epoch = 0
        self.total_cells = 0

        # Open TileDB experiment
        self.experiment = tiledbsoma.Experiment.open(tiledb_path)
        self.X = self.experiment.ms["RNA"].X["data"]

        # Get total number of cells
        self.total_cells = self.X.shape[0]

        # Initialize shuffling strategy (similar to SLAF)
        from slaf.ml.samplers import ShuffleType, create_shuffle

        self.shuffle = create_shuffle(ShuffleType.RANDOM)

        # Initialize MoS generators if enabled
        if self.use_mixture_of_scanners:
            self._initialize_mos_generators()

        # Initialize timing metrics for benchmarking
        self._timing_metrics: dict[str, list[float]] | None
        if self.log_metrics:
            self._timing_metrics = {
                "tiledb_loading": [],
                "shuffle": [],
                "total": [],
                "cells_processed": [],
            }
        else:
            self._timing_metrics = None

        # Initialize timing variables for consolidated reporting
        self._last_load_time = 0.0
        self._last_memory_mb = 0.0

    def _initialize_mos_generators(self):
        """Initialize MoS generators with evenly distributed scan ranges."""
        # Calculate scan ranges for each generator
        cells_per_reader = self.total_cells // self.n_readers
        remainder = self.total_cells % self.n_readers

        self.generators = []
        current_position = 0

        for i in range(self.n_readers):
            # Distribute remainder cells among first few readers
            reader_cell_count = cells_per_reader + (1 if i < remainder else 0)

            generator = {
                "generator_id": i,
                "start_position": current_position,
                "current_position": current_position,
                "end_position": current_position + reader_cell_count,
                "is_active": True,
            }

            self.generators.append(generator)
            current_position += reader_cell_count

        if self.verbose:
            print_prefetch(
                f"TileDB MoS initialized: {self.n_readers} generators, "
                f"{self.n_scanners} active scanners, "
                f"prefetch_batch_size={self.prefetch_batch_size}",
                self.verbose,
            )

    def reset_for_epoch(self, epoch: int) -> None:
        """
        Reset the processor for a new epoch.

        This method resets the batch processor state to start a new epoch,
        including resetting batch counters, MoS generator positions, and
        shuffling seeds. It is called automatically during multi-epoch training.

        Args:
            epoch: The epoch number to start (0-based indexing).
                  Must be 0 <= epoch < n_epochs.

        Raises:
            ValueError: If epoch is invalid (negative or >= n_epochs).

        Examples:
            >>> # Reset for epoch 1
            >>> processor = TileDBBatchProcessor("path/to/experiment", n_epochs=3)
            >>> processor.reset_for_epoch(1)
            >>> print(f"Current epoch: {processor.current_epoch}")
            Current epoch: 1

            >>> # Invalid epoch raises error
            >>> try:
            ...     processor.reset_for_epoch(5)  # n_epochs=3
            ... except ValueError as e:
            ...     print(f"Error: {e}")
            Error: Invalid epoch 5. Must be 0 <= epoch < 3
        """
        if epoch < 0 or epoch >= self.n_epochs:
            raise ValueError(
                f"Invalid epoch {epoch}. Must be 0 <= epoch < {self.n_epochs}"
            )

        self.current_epoch = epoch
        self.batch_id = 0

        # Reset MoS generators if enabled
        if self.use_mixture_of_scanners:
            for generator in self.generators:
                generator["current_position"] = generator["start_position"]
                generator["is_active"] = True

        if self.verbose:
            print(f"🔄 Reset TileDB processor for epoch {epoch}")

    def _record_timing(self, step: str, duration: float, cells_processed: int = 0):
        """Record timing for a processing step."""
        if not self.log_metrics or self._timing_metrics is None:
            return

        if step in self._timing_metrics:
            self._timing_metrics[step].append(duration)

        if cells_processed > 0:
            self._timing_metrics["cells_processed"].append(cells_processed)

    def load_prefetch_batch(self) -> TileDBPrefetchBatch:
        """
        Load and process a chunk of TileDB data into batches using configured strategy.

        This method loads a batch of data from TileDB SOMA format, applies shuffling,
        and returns a processed batch ready for training. It supports both MoS and
        sequential loading strategies and handles epoch transitions automatically.

        The method performs the following steps:
        1. Load data from TileDB using the configured strategy (MoS or sequential)
        2. Convert Arrow data to Polars DataFrame
        3. Apply shuffling strategy for data randomization
        4. Return processed batch with metadata

        Returns:
            TileDBPrefetchBatch: Processed batch containing:
                - batch_df: Polars DataFrame with cell-gene expression data
                - cell_integer_ids: List of unique cell IDs in the batch
                - process_time: Time taken to process the batch
                - memory_mb: Memory usage at batch creation time

        Raises:
            StopIteration: When all epochs are completed and no more data is available.
            RuntimeError: If TileDB data loading fails.

        Examples:
            >>> # Load a batch with MoS strategy
            >>> processor = TileDBBatchProcessor(
            ...     tiledb_path="path/to/experiment",
            ...     use_mixture_of_scanners=True
            ... )
            >>> batch = processor.load_prefetch_batch()
            >>> print(f"Batch {batch.batch_id} has {len(batch.cell_integer_ids)} cells")
            Batch 0 has 100 cells

            >>> # Load a batch with sequential strategy
            >>> processor = TileDBBatchProcessor(
            ...     tiledb_path="path/to/experiment",
            ...     use_mixture_of_scanners=False
            ... )
            >>> batch = processor.load_prefetch_batch()
            >>> print(f"Sequential batch shape: {batch.batch_df.shape}")
            Sequential batch shape: (100, 3)

            >>> # Handle epoch completion
            >>> processor = TileDBBatchProcessor("path/to/experiment", n_epochs=1)
            >>> try:
            ...     while True:
            ...         batch = processor.load_prefetch_batch()
            ...         print(f"Processed batch {batch.batch_id}")
            ... except StopIteration:
            ...     print("All epochs completed")
            All epochs completed
        """
        # Iterative approach to handle epoch transitions
        while True:
            start_time = time.time()

            if self.use_mixture_of_scanners:
                # MoS approach: randomly sample from active generators
                import numpy as np

                # Get indices of currently active generators
                active_generators = [g for g in self.generators if g["is_active"]]

                if not active_generators:
                    # Check if we should start a new epoch
                    if self.current_epoch + 1 < self.n_epochs:
                        if self.verbose:
                            print(
                                f"🔄 Epoch {self.current_epoch} complete, starting epoch {self.current_epoch + 1}"
                            )
                        self.reset_for_epoch(self.current_epoch + 1)
                        continue
                    else:
                        raise StopIteration("No more epochs available") from None

                # Randomly sample from active generators
                n_to_sample = min(self.n_scanners, len(active_generators))
                selected_generators = np.random.choice(
                    active_generators, size=n_to_sample, replace=False
                )

                if self.verbose and self.batch_id % 10 == 0:
                    print_prefetch(
                        f"TileDB MoS sampling: {len(active_generators)} active generators, "
                        f"sampling from {n_to_sample} generators",
                        self.verbose,
                    )

                # Load data from selected generators
                load_start = time.time()
                batch_dfs = []

                for generator in selected_generators:
                    try:
                        start_cell = generator["current_position"]
                        end_cell = min(
                            start_cell + self.prefetch_batch_size,
                            generator["end_position"],
                        )

                        if start_cell >= generator["end_position"]:
                            # Generator exhausted
                            generator["is_active"] = False
                            continue

                        # Read slice from TileDB
                        arrow_data = (
                            self.X.read((slice(start_cell, end_cell),))
                            .tables()
                            .concat()
                        )

                        # Convert Arrow table to Polars DataFrame
                        df = pl.from_arrow(arrow_data)  # type: ignore[assignment]
                        if not isinstance(df, pl.DataFrame):
                            raise TypeError("Expected DataFrame from Arrow table")

                        # Rename SOMA columns to expected names
                        df = df.rename(
                            {
                                "soma_dim_0": "cell_integer_id",
                                "soma_dim_1": "gene_integer_id",
                                "soma_data": "value",
                            }
                        )

                        batch_dfs.append(df)

                        # Update generator position
                        generator["current_position"] = end_cell

                        # Mark as inactive if exhausted
                        if generator["current_position"] >= generator["end_position"]:
                            generator["is_active"] = False

                    except Exception as e:
                        logger.error(
                            f"Error loading TileDB data from generator {generator['generator_id']}: {e}"
                        )
                        generator["is_active"] = False
                        continue

                if not batch_dfs:
                    # All selected generators are exhausted, continue to next iteration
                    continue

                # Combine all batches
                combined_df_mos = pl.concat(batch_dfs, how="vertical")
            else:
                # Sequential approach (original implementation)
                current_position = (
                    self.batch_id * self.prefetch_batch_size
                ) % self.total_cells

                # Only check for epoch transitions when we actually wrap around
                if self.batch_id > 0:
                    prev_position = (
                        (self.batch_id - 1) * self.prefetch_batch_size
                    ) % self.total_cells
                    if current_position < prev_position:  # We wrapped around
                        if self.current_epoch + 1 < self.n_epochs:
                            if self.verbose:
                                print(
                                    f"🔄 Epoch {self.current_epoch} complete, starting epoch {self.current_epoch + 1}"
                                )
                            self.reset_for_epoch(self.current_epoch + 1)
                            continue
                        else:
                            raise StopIteration("No more epochs available") from None

                # Load data from TileDB
                load_start = time.time()
                try:
                    # Read slice from TileDB as Arrow table
                    arrow_data = (
                        self.X.read(
                            (
                                slice(
                                    current_position,
                                    current_position + self.prefetch_batch_size,
                                ),
                            )
                        )
                        .tables()
                        .concat()
                    )

                    # Convert Arrow table to Polars DataFrame
                    combined_df = pl.from_arrow(arrow_data)  # type: ignore[assignment]
                    if not isinstance(combined_df, pl.DataFrame):
                        raise TypeError("Expected DataFrame from Arrow table")

                    # Rename SOMA columns to expected names
                    combined_df = combined_df.rename(
                        {
                            "soma_dim_0": "cell_integer_id",
                            "soma_dim_1": "gene_integer_id",
                            "soma_data": "value",
                        }
                    )

                except Exception as e:
                    logger.error(f"Error loading TileDB data: {e}")
                    raise StopIteration(f"Failed to load TileDB data: {e}") from e

            load_time = time.time() - load_start
            self._record_timing("tiledb_loading", load_time)

            # Print detailed loading breakdown every 10 batches
            if self.batch_id % 10 == 0:
                import psutil

                process = psutil.Process()
                memory_mb = process.memory_info().rss / 1024 / 1024

                # Store timing info for consolidated report
                self._last_load_time = load_time
                self._last_memory_mb = memory_mb

            # Apply shuffling strategy
            shuffle_start = time.time()

            # Apply shuffling with chunking
            if self.use_mixture_of_scanners:
                shuffled_chunks = self.shuffle.apply(
                    combined_df_mos,  # type: ignore[arg-type]
                    self.seed + self.batch_id + self.current_epoch * 10000,
                    batch_size=self.batch_size,
                )
            else:
                shuffled_chunks = self.shuffle.apply(
                    combined_df,  # type: ignore[arg-type]
                    self.seed + self.batch_id + self.current_epoch * 10000,
                    batch_size=self.batch_size,
                )

            shuffle_time = time.time() - shuffle_start
            total_time = time.time() - start_time

            # Record timing metrics
            self._record_timing("shuffle", shuffle_time)
            self._record_timing("total", total_time)

            # Count total cells across all chunks
            total_cells_in_chunks = sum(
                len(chunk.get_column("cell_integer_id").unique())
                for chunk in shuffled_chunks
                if isinstance(chunk, pl.DataFrame)
            )

            # Record cells processed
            self._record_timing("cells_processed", 0, total_cells_in_chunks)

            # Print consolidated prefetch batch reporting
            if self.batch_id % 10 == 0:
                strategy_name = "MoS" if self.use_mixture_of_scanners else "sequential"
                prefetch_report = f"TileDB {strategy_name} prefetch batch {self.batch_id} (epoch {self.current_epoch}):\n"
                prefetch_report += f"   TileDB loading: {self._last_load_time * 1000:.1f}ms ({self.prefetch_batch_size} cells)\n"
                prefetch_report += (
                    f"   Processing: {shuffle_time * 1000:.1f}ms shuffle\n"
                )
                prefetch_report += f"   Total: {total_time * 1000:.1f}ms, {len(shuffled_chunks)} chunks, {total_cells_in_chunks} cells, {self._last_memory_mb:.1f} MB"

                print_prefetch(prefetch_report, self.verbose)

            self.batch_id += 1

            # Return the first chunk as a batch (we'll handle multiple chunks in the dataloader)
            if shuffled_chunks:
                first_chunk = shuffled_chunks[0]
                return TileDBPrefetchBatch(
                    batch_id=self.batch_id - 1,
                    batch_df=first_chunk,
                    cell_integer_ids=first_chunk["cell_integer_id"].unique().to_list(),
                    process_time=shuffle_time,
                    memory_mb=self._last_memory_mb,
                )
            else:
                # No data in this chunk, continue to next iteration
                continue
Functions
__init__(tiledb_path: str, batch_size: int = 32, prefetch_batch_size: int = 100, seed: int = 42, n_epochs: int = 1, verbose: bool = True, log_metrics: bool = False, use_mixture_of_scanners: bool = True, n_readers: int = 50, n_scanners: int = 8)

Initialize the TileDB batch processor with training configuration.

Parameters:

Name Type Description Default
tiledb_path str

Path to the TileDB SOMA experiment directory. Must contain a valid SOMA experiment with RNA measurement data.

required
batch_size int

Number of cells per training batch. Larger batches use more memory but may improve training efficiency. Range: 1-512, default: 32.

32
prefetch_batch_size int

Number of cells to load per prefetch batch from TileDB. Higher values improve throughput but use more memory. Range: 10-10000, default: 100.

100
seed int

Random seed for reproducible shuffling and MoS sampling. Used for consistent data ordering across runs. Default: 42.

42
n_epochs int

Number of epochs to run. The processor will automatically reset after each epoch, enabling multi-epoch training. Default: 1.

1
verbose bool

If True, print detailed timing and progress information. If False, suppress all internal prints for clean output. Default: True.

True
log_metrics bool

If True, collect detailed timing metrics for performance analysis. Metrics include loading time, shuffle time, and memory usage. Default: False.

False
use_mixture_of_scanners bool

If True, use MoS strategy for higher entropy by randomly sampling from multiple fragment generators. Provides better randomization for foundation model training. Default: True.

True
n_readers int

Total number of fragment generators to create when using MoS. Higher values provide better entropy but use more memory. Range: 1-1000, default: 50.

50
n_scanners int

Number of active scanners to sample from simultaneously when using MoS. Higher values provide better entropy but use more memory. Range: 1-100, default: 8.

8

Raises:

Type Description
ImportError

If TileDB SOMA is not available.

ValueError

If MoS parameters are invalid (n_readers < 1, n_scanners < 1, or n_scanners > n_readers).

RuntimeError

If the TileDB experiment cannot be opened or is invalid.

Examples:

>>> # Basic initialization with default MoS strategy
>>> processor = TileDBBatchProcessor(
...     tiledb_path="path/to/experiment",
...     batch_size=32,
...     prefetch_batch_size=100
... )
>>> print(f"Total cells: {processor.total_cells}")
Total cells: 50000
>>> # Sequential loading for maximum throughput
>>> processor = TileDBBatchProcessor(
...     tiledb_path="path/to/experiment",
...     use_mixture_of_scanners=False,
...     batch_size=64
... )
>>> print(f"MoS enabled: {processor.use_mixture_of_scanners}")
MoS enabled: False
>>> # High-entropy MoS configuration
>>> processor = TileDBBatchProcessor(
...     tiledb_path="path/to/experiment",
...     n_readers=100,
...     n_scanners=16
... )
>>> print(f"MoS readers: {processor.n_readers}, scanners: {processor.n_scanners}")
MoS readers: 100, scanners: 16
Source code in slaf/ml/tiledb_dataloaders.py
def __init__(
    self,
    tiledb_path: str,
    batch_size: int = 32,
    prefetch_batch_size: int = 100,
    seed: int = 42,
    n_epochs: int = 1,
    verbose: bool = True,
    log_metrics: bool = False,
    use_mixture_of_scanners: bool = True,
    n_readers: int = 50,
    n_scanners: int = 8,
):
    """
    Initialize the TileDB batch processor with training configuration.

    Args:
        tiledb_path: Path to the TileDB SOMA experiment directory.
                     Must contain a valid SOMA experiment with RNA measurement data.
        batch_size: Number of cells per training batch. Larger batches use more
                   memory but may improve training efficiency. Range: 1-512, default: 32.
        prefetch_batch_size: Number of cells to load per prefetch batch from TileDB.
                           Higher values improve throughput but use more memory.
                           Range: 10-10000, default: 100.
        seed: Random seed for reproducible shuffling and MoS sampling.
              Used for consistent data ordering across runs. Default: 42.
        n_epochs: Number of epochs to run. The processor will automatically reset
                 after each epoch, enabling multi-epoch training. Default: 1.
        verbose: If True, print detailed timing and progress information.
                If False, suppress all internal prints for clean output. Default: True.
        log_metrics: If True, collect detailed timing metrics for performance analysis.
                    Metrics include loading time, shuffle time, and memory usage.
                    Default: False.
        use_mixture_of_scanners: If True, use MoS strategy for higher entropy by
                               randomly sampling from multiple fragment generators.
                               Provides better randomization for foundation model training.
                               Default: True.
        n_readers: Total number of fragment generators to create when using MoS.
                  Higher values provide better entropy but use more memory.
                  Range: 1-1000, default: 50.
        n_scanners: Number of active scanners to sample from simultaneously when using MoS.
                   Higher values provide better entropy but use more memory.
                   Range: 1-100, default: 8.

    Raises:
        ImportError: If TileDB SOMA is not available.
        ValueError: If MoS parameters are invalid (n_readers < 1, n_scanners < 1,
                   or n_scanners > n_readers).
        RuntimeError: If the TileDB experiment cannot be opened or is invalid.

    Examples:
        >>> # Basic initialization with default MoS strategy
        >>> processor = TileDBBatchProcessor(
        ...     tiledb_path="path/to/experiment",
        ...     batch_size=32,
        ...     prefetch_batch_size=100
        ... )
        >>> print(f"Total cells: {processor.total_cells}")
        Total cells: 50000

        >>> # Sequential loading for maximum throughput
        >>> processor = TileDBBatchProcessor(
        ...     tiledb_path="path/to/experiment",
        ...     use_mixture_of_scanners=False,
        ...     batch_size=64
        ... )
        >>> print(f"MoS enabled: {processor.use_mixture_of_scanners}")
        MoS enabled: False

        >>> # High-entropy MoS configuration
        >>> processor = TileDBBatchProcessor(
        ...     tiledb_path="path/to/experiment",
        ...     n_readers=100,
        ...     n_scanners=16
        ... )
        >>> print(f"MoS readers: {processor.n_readers}, scanners: {processor.n_scanners}")
        MoS readers: 100, scanners: 16
    """
    if not TILEDB_AVAILABLE:
        raise ImportError("TileDB SOMA is required but not available")

    self.tiledb_path = tiledb_path
    self.batch_size = batch_size
    self.prefetch_batch_size = prefetch_batch_size
    self.seed = seed
    self.n_epochs = n_epochs
    self.verbose = verbose
    self.log_metrics = log_metrics
    self.use_mixture_of_scanners = use_mixture_of_scanners
    self.n_readers = n_readers
    self.n_scanners = n_scanners

    # Validate MoS parameters
    if self.use_mixture_of_scanners:
        if self.n_readers < 1:
            raise ValueError("n_readers must be at least 1")
        if self.n_scanners < 1:
            raise ValueError("n_scanners must be at least 1")
        if self.n_scanners > self.n_readers:
            raise ValueError("n_scanners cannot exceed n_readers")

    # Initialize state
    self.batch_id = 0
    self.current_epoch = 0
    self.total_cells = 0

    # Open TileDB experiment
    self.experiment = tiledbsoma.Experiment.open(tiledb_path)
    self.X = self.experiment.ms["RNA"].X["data"]

    # Get total number of cells
    self.total_cells = self.X.shape[0]

    # Initialize shuffling strategy (similar to SLAF)
    from slaf.ml.samplers import ShuffleType, create_shuffle

    self.shuffle = create_shuffle(ShuffleType.RANDOM)

    # Initialize MoS generators if enabled
    if self.use_mixture_of_scanners:
        self._initialize_mos_generators()

    # Initialize timing metrics for benchmarking
    self._timing_metrics: dict[str, list[float]] | None
    if self.log_metrics:
        self._timing_metrics = {
            "tiledb_loading": [],
            "shuffle": [],
            "total": [],
            "cells_processed": [],
        }
    else:
        self._timing_metrics = None

    # Initialize timing variables for consolidated reporting
    self._last_load_time = 0.0
    self._last_memory_mb = 0.0
reset_for_epoch(epoch: int) -> None

Reset the processor for a new epoch.

This method resets the batch processor state to start a new epoch, including resetting batch counters, MoS generator positions, and shuffling seeds. It is called automatically during multi-epoch training.

Parameters:

Name Type Description Default
epoch int

The epoch number to start (0-based indexing). Must be 0 <= epoch < n_epochs.

required

Raises:

Type Description
ValueError

If epoch is invalid (negative or >= n_epochs).

Examples:

>>> # Reset for epoch 1
>>> processor = TileDBBatchProcessor("path/to/experiment", n_epochs=3)
>>> processor.reset_for_epoch(1)
>>> print(f"Current epoch: {processor.current_epoch}")
Current epoch: 1
>>> # Invalid epoch raises error
>>> try:
...     processor.reset_for_epoch(5)  # n_epochs=3
... except ValueError as e:
...     print(f"Error: {e}")
Error: Invalid epoch 5. Must be 0 <= epoch < 3
Source code in slaf/ml/tiledb_dataloaders.py
def reset_for_epoch(self, epoch: int) -> None:
    """
    Reset the processor for a new epoch.

    This method resets the batch processor state to start a new epoch,
    including resetting batch counters, MoS generator positions, and
    shuffling seeds. It is called automatically during multi-epoch training.

    Args:
        epoch: The epoch number to start (0-based indexing).
              Must be 0 <= epoch < n_epochs.

    Raises:
        ValueError: If epoch is invalid (negative or >= n_epochs).

    Examples:
        >>> # Reset for epoch 1
        >>> processor = TileDBBatchProcessor("path/to/experiment", n_epochs=3)
        >>> processor.reset_for_epoch(1)
        >>> print(f"Current epoch: {processor.current_epoch}")
        Current epoch: 1

        >>> # Invalid epoch raises error
        >>> try:
        ...     processor.reset_for_epoch(5)  # n_epochs=3
        ... except ValueError as e:
        ...     print(f"Error: {e}")
        Error: Invalid epoch 5. Must be 0 <= epoch < 3
    """
    if epoch < 0 or epoch >= self.n_epochs:
        raise ValueError(
            f"Invalid epoch {epoch}. Must be 0 <= epoch < {self.n_epochs}"
        )

    self.current_epoch = epoch
    self.batch_id = 0

    # Reset MoS generators if enabled
    if self.use_mixture_of_scanners:
        for generator in self.generators:
            generator["current_position"] = generator["start_position"]
            generator["is_active"] = True

    if self.verbose:
        print(f"🔄 Reset TileDB processor for epoch {epoch}")
load_prefetch_batch() -> TileDBPrefetchBatch

Load and process a chunk of TileDB data into batches using configured strategy.

This method loads a batch of data from TileDB SOMA format, applies shuffling, and returns a processed batch ready for training. It supports both MoS and sequential loading strategies and handles epoch transitions automatically.

The method performs the following steps: 1. Load data from TileDB using the configured strategy (MoS or sequential) 2. Convert Arrow data to Polars DataFrame 3. Apply shuffling strategy for data randomization 4. Return processed batch with metadata

Returns:

Name Type Description
TileDBPrefetchBatch TileDBPrefetchBatch

Processed batch containing: - batch_df: Polars DataFrame with cell-gene expression data - cell_integer_ids: List of unique cell IDs in the batch - process_time: Time taken to process the batch - memory_mb: Memory usage at batch creation time

Raises:

Type Description
StopIteration

When all epochs are completed and no more data is available.

RuntimeError

If TileDB data loading fails.

Examples:

>>> # Load a batch with MoS strategy
>>> processor = TileDBBatchProcessor(
...     tiledb_path="path/to/experiment",
...     use_mixture_of_scanners=True
... )
>>> batch = processor.load_prefetch_batch()
>>> print(f"Batch {batch.batch_id} has {len(batch.cell_integer_ids)} cells")
Batch 0 has 100 cells
>>> # Load a batch with sequential strategy
>>> processor = TileDBBatchProcessor(
...     tiledb_path="path/to/experiment",
...     use_mixture_of_scanners=False
... )
>>> batch = processor.load_prefetch_batch()
>>> print(f"Sequential batch shape: {batch.batch_df.shape}")
Sequential batch shape: (100, 3)
>>> # Handle epoch completion
>>> processor = TileDBBatchProcessor("path/to/experiment", n_epochs=1)
>>> try:
...     while True:
...         batch = processor.load_prefetch_batch()
...         print(f"Processed batch {batch.batch_id}")
... except StopIteration:
...     print("All epochs completed")
All epochs completed
Source code in slaf/ml/tiledb_dataloaders.py
def load_prefetch_batch(self) -> TileDBPrefetchBatch:
    """
    Load and process a chunk of TileDB data into batches using configured strategy.

    This method loads a batch of data from TileDB SOMA format, applies shuffling,
    and returns a processed batch ready for training. It supports both MoS and
    sequential loading strategies and handles epoch transitions automatically.

    The method performs the following steps:
    1. Load data from TileDB using the configured strategy (MoS or sequential)
    2. Convert Arrow data to Polars DataFrame
    3. Apply shuffling strategy for data randomization
    4. Return processed batch with metadata

    Returns:
        TileDBPrefetchBatch: Processed batch containing:
            - batch_df: Polars DataFrame with cell-gene expression data
            - cell_integer_ids: List of unique cell IDs in the batch
            - process_time: Time taken to process the batch
            - memory_mb: Memory usage at batch creation time

    Raises:
        StopIteration: When all epochs are completed and no more data is available.
        RuntimeError: If TileDB data loading fails.

    Examples:
        >>> # Load a batch with MoS strategy
        >>> processor = TileDBBatchProcessor(
        ...     tiledb_path="path/to/experiment",
        ...     use_mixture_of_scanners=True
        ... )
        >>> batch = processor.load_prefetch_batch()
        >>> print(f"Batch {batch.batch_id} has {len(batch.cell_integer_ids)} cells")
        Batch 0 has 100 cells

        >>> # Load a batch with sequential strategy
        >>> processor = TileDBBatchProcessor(
        ...     tiledb_path="path/to/experiment",
        ...     use_mixture_of_scanners=False
        ... )
        >>> batch = processor.load_prefetch_batch()
        >>> print(f"Sequential batch shape: {batch.batch_df.shape}")
        Sequential batch shape: (100, 3)

        >>> # Handle epoch completion
        >>> processor = TileDBBatchProcessor("path/to/experiment", n_epochs=1)
        >>> try:
        ...     while True:
        ...         batch = processor.load_prefetch_batch()
        ...         print(f"Processed batch {batch.batch_id}")
        ... except StopIteration:
        ...     print("All epochs completed")
        All epochs completed
    """
    # Iterative approach to handle epoch transitions
    while True:
        start_time = time.time()

        if self.use_mixture_of_scanners:
            # MoS approach: randomly sample from active generators
            import numpy as np

            # Get indices of currently active generators
            active_generators = [g for g in self.generators if g["is_active"]]

            if not active_generators:
                # Check if we should start a new epoch
                if self.current_epoch + 1 < self.n_epochs:
                    if self.verbose:
                        print(
                            f"🔄 Epoch {self.current_epoch} complete, starting epoch {self.current_epoch + 1}"
                        )
                    self.reset_for_epoch(self.current_epoch + 1)
                    continue
                else:
                    raise StopIteration("No more epochs available") from None

            # Randomly sample from active generators
            n_to_sample = min(self.n_scanners, len(active_generators))
            selected_generators = np.random.choice(
                active_generators, size=n_to_sample, replace=False
            )

            if self.verbose and self.batch_id % 10 == 0:
                print_prefetch(
                    f"TileDB MoS sampling: {len(active_generators)} active generators, "
                    f"sampling from {n_to_sample} generators",
                    self.verbose,
                )

            # Load data from selected generators
            load_start = time.time()
            batch_dfs = []

            for generator in selected_generators:
                try:
                    start_cell = generator["current_position"]
                    end_cell = min(
                        start_cell + self.prefetch_batch_size,
                        generator["end_position"],
                    )

                    if start_cell >= generator["end_position"]:
                        # Generator exhausted
                        generator["is_active"] = False
                        continue

                    # Read slice from TileDB
                    arrow_data = (
                        self.X.read((slice(start_cell, end_cell),))
                        .tables()
                        .concat()
                    )

                    # Convert Arrow table to Polars DataFrame
                    df = pl.from_arrow(arrow_data)  # type: ignore[assignment]
                    if not isinstance(df, pl.DataFrame):
                        raise TypeError("Expected DataFrame from Arrow table")

                    # Rename SOMA columns to expected names
                    df = df.rename(
                        {
                            "soma_dim_0": "cell_integer_id",
                            "soma_dim_1": "gene_integer_id",
                            "soma_data": "value",
                        }
                    )

                    batch_dfs.append(df)

                    # Update generator position
                    generator["current_position"] = end_cell

                    # Mark as inactive if exhausted
                    if generator["current_position"] >= generator["end_position"]:
                        generator["is_active"] = False

                except Exception as e:
                    logger.error(
                        f"Error loading TileDB data from generator {generator['generator_id']}: {e}"
                    )
                    generator["is_active"] = False
                    continue

            if not batch_dfs:
                # All selected generators are exhausted, continue to next iteration
                continue

            # Combine all batches
            combined_df_mos = pl.concat(batch_dfs, how="vertical")
        else:
            # Sequential approach (original implementation)
            current_position = (
                self.batch_id * self.prefetch_batch_size
            ) % self.total_cells

            # Only check for epoch transitions when we actually wrap around
            if self.batch_id > 0:
                prev_position = (
                    (self.batch_id - 1) * self.prefetch_batch_size
                ) % self.total_cells
                if current_position < prev_position:  # We wrapped around
                    if self.current_epoch + 1 < self.n_epochs:
                        if self.verbose:
                            print(
                                f"🔄 Epoch {self.current_epoch} complete, starting epoch {self.current_epoch + 1}"
                            )
                        self.reset_for_epoch(self.current_epoch + 1)
                        continue
                    else:
                        raise StopIteration("No more epochs available") from None

            # Load data from TileDB
            load_start = time.time()
            try:
                # Read slice from TileDB as Arrow table
                arrow_data = (
                    self.X.read(
                        (
                            slice(
                                current_position,
                                current_position + self.prefetch_batch_size,
                            ),
                        )
                    )
                    .tables()
                    .concat()
                )

                # Convert Arrow table to Polars DataFrame
                combined_df = pl.from_arrow(arrow_data)  # type: ignore[assignment]
                if not isinstance(combined_df, pl.DataFrame):
                    raise TypeError("Expected DataFrame from Arrow table")

                # Rename SOMA columns to expected names
                combined_df = combined_df.rename(
                    {
                        "soma_dim_0": "cell_integer_id",
                        "soma_dim_1": "gene_integer_id",
                        "soma_data": "value",
                    }
                )

            except Exception as e:
                logger.error(f"Error loading TileDB data: {e}")
                raise StopIteration(f"Failed to load TileDB data: {e}") from e

        load_time = time.time() - load_start
        self._record_timing("tiledb_loading", load_time)

        # Print detailed loading breakdown every 10 batches
        if self.batch_id % 10 == 0:
            import psutil

            process = psutil.Process()
            memory_mb = process.memory_info().rss / 1024 / 1024

            # Store timing info for consolidated report
            self._last_load_time = load_time
            self._last_memory_mb = memory_mb

        # Apply shuffling strategy
        shuffle_start = time.time()

        # Apply shuffling with chunking
        if self.use_mixture_of_scanners:
            shuffled_chunks = self.shuffle.apply(
                combined_df_mos,  # type: ignore[arg-type]
                self.seed + self.batch_id + self.current_epoch * 10000,
                batch_size=self.batch_size,
            )
        else:
            shuffled_chunks = self.shuffle.apply(
                combined_df,  # type: ignore[arg-type]
                self.seed + self.batch_id + self.current_epoch * 10000,
                batch_size=self.batch_size,
            )

        shuffle_time = time.time() - shuffle_start
        total_time = time.time() - start_time

        # Record timing metrics
        self._record_timing("shuffle", shuffle_time)
        self._record_timing("total", total_time)

        # Count total cells across all chunks
        total_cells_in_chunks = sum(
            len(chunk.get_column("cell_integer_id").unique())
            for chunk in shuffled_chunks
            if isinstance(chunk, pl.DataFrame)
        )

        # Record cells processed
        self._record_timing("cells_processed", 0, total_cells_in_chunks)

        # Print consolidated prefetch batch reporting
        if self.batch_id % 10 == 0:
            strategy_name = "MoS" if self.use_mixture_of_scanners else "sequential"
            prefetch_report = f"TileDB {strategy_name} prefetch batch {self.batch_id} (epoch {self.current_epoch}):\n"
            prefetch_report += f"   TileDB loading: {self._last_load_time * 1000:.1f}ms ({self.prefetch_batch_size} cells)\n"
            prefetch_report += (
                f"   Processing: {shuffle_time * 1000:.1f}ms shuffle\n"
            )
            prefetch_report += f"   Total: {total_time * 1000:.1f}ms, {len(shuffled_chunks)} chunks, {total_cells_in_chunks} cells, {self._last_memory_mb:.1f} MB"

            print_prefetch(prefetch_report, self.verbose)

        self.batch_id += 1

        # Return the first chunk as a batch (we'll handle multiple chunks in the dataloader)
        if shuffled_chunks:
            first_chunk = shuffled_chunks[0]
            return TileDBPrefetchBatch(
                batch_id=self.batch_id - 1,
                batch_df=first_chunk,
                cell_integer_ids=first_chunk["cell_integer_id"].unique().to_list(),
                process_time=shuffle_time,
                memory_mb=self._last_memory_mb,
            )
        else:
            # No data in this chunk, continue to next iteration
            continue

TileDBAsyncPrefetcher

Asynchronous prefetcher for TileDB batch processing with background loading.

TileDBAsyncPrefetcher provides background batch processing and prefetching for TileDB data to improve training throughput. It runs a separate worker thread that continuously loads and processes batches while the main training loop consumes pre-processed data.

Key Features
  • Background batch processing in separate worker thread
  • Configurable queue size for memory management
  • Comprehensive performance monitoring and statistics
  • Automatic epoch transition handling
  • Graceful shutdown and cleanup
  • Real-time rate monitoring and reporting
  • Error handling and recovery

The prefetcher maintains a queue of pre-processed batches and provides statistics about loading rates, memory usage, and processing times.

Examples:

>>> # Create prefetcher with a batch processor
>>> processor = TileDBBatchProcessor("path/to/experiment")
>>> prefetcher = TileDBAsyncPrefetcher(processor, max_queue_size=100)
>>>
>>> # Start background processing
>>> prefetcher.start()
>>>
>>> # Get pre-processed batches
>>> batch = prefetcher.get_batch()
>>> if batch:
...     print(f"Got batch {batch.batch_id} with {len(batch.cell_integer_ids)} cells")
>>>
>>> # Check performance statistics
>>> stats = prefetcher.get_stats()
>>> print(f"Loading rate: {stats['cells_per_sec']:.1f} cells/sec")
>>>
>>> # Stop background processing
>>> prefetcher.stop()
Source code in slaf/ml/tiledb_dataloaders.py
class TileDBAsyncPrefetcher:
    """
    Asynchronous prefetcher for TileDB batch processing with background loading.

    TileDBAsyncPrefetcher provides background batch processing and prefetching
    for TileDB data to improve training throughput. It runs a separate worker
    thread that continuously loads and processes batches while the main training
    loop consumes pre-processed data.

    Key Features:
        - Background batch processing in separate worker thread
        - Configurable queue size for memory management
        - Comprehensive performance monitoring and statistics
        - Automatic epoch transition handling
        - Graceful shutdown and cleanup
        - Real-time rate monitoring and reporting
        - Error handling and recovery

    The prefetcher maintains a queue of pre-processed batches and provides
    statistics about loading rates, memory usage, and processing times.

    Examples:
            >>> # Create prefetcher with a batch processor
            >>> processor = TileDBBatchProcessor("path/to/experiment")
            >>> prefetcher = TileDBAsyncPrefetcher(processor, max_queue_size=100)
            >>>
            >>> # Start background processing
            >>> prefetcher.start()
            >>>
            >>> # Get pre-processed batches
            >>> batch = prefetcher.get_batch()
            >>> if batch:
            ...     print(f"Got batch {batch.batch_id} with {len(batch.cell_integer_ids)} cells")
            >>>
            >>> # Check performance statistics
            >>> stats = prefetcher.get_stats()
            >>> print(f"Loading rate: {stats['cells_per_sec']:.1f} cells/sec")
            >>>
            >>> # Stop background processing
            >>> prefetcher.stop()
    """

    def __init__(
        self, batch_processor: TileDBBatchProcessor, max_queue_size: int = 500
    ):
        """
        Initialize the TileDB async prefetcher with background processing.

        Args:
            batch_processor: TileDBBatchProcessor instance to use for loading batches.
                           Must be properly initialized with TileDB path and configuration.
            max_queue_size: Maximum number of pre-processed batches to keep in queue.
                          Higher values use more memory but provide better buffering.
                          Range: 10-10000, default: 500.

        Examples:
            >>> # Create prefetcher with default queue size
            >>> processor = TileDBBatchProcessor("path/to/experiment")
            >>> prefetcher = TileDBAsyncPrefetcher(processor)
            >>> print(f"Max queue size: {prefetcher.max_queue_size}")
            Max queue size: 500

            >>> # Create prefetcher with custom queue size
            >>> prefetcher = TileDBAsyncPrefetcher(processor, max_queue_size=1000)
            >>> print(f"Custom queue size: {prefetcher.max_queue_size}")
            Custom queue size: 1000
        """
        self.batch_processor = batch_processor
        self.max_queue_size = max_queue_size
        self.queue: Queue[TileDBPrefetchBatch] = Queue(maxsize=max_queue_size)
        self.worker_thread = None
        self.should_stop = False

        # Monitoring stats
        self.total_cells_added = 0
        self.start_time = None
        self.last_rate_print = 0
        self.total_process_time = 0.0
        self.process_count = 0
        self.current_epoch = 0

    def start(self):
        """
        Start the prefetching worker thread for background batch processing.

        This method starts a background worker thread that continuously loads
        and processes batches from the TileDB batch processor. The worker thread
        runs as a daemon thread and will automatically stop when the main
        process exits.

        The prefetcher will begin loading batches immediately after starting.
        Use get_batch() to retrieve pre-processed batches from the queue.

        Examples:
            >>> # Start background prefetching
            >>> processor = TileDBBatchProcessor("path/to/experiment")
            >>> prefetcher = TileDBAsyncPrefetcher(processor)
            >>> prefetcher.start()
            >>> print("Prefetcher started")
            Prefetcher started

            >>> # Check if prefetcher is ready
            >>> import time
            >>> time.sleep(1)  # Wait for first batch
            >>> if prefetcher.has_batch():
            ...     print("Prefetcher is ready with data")
            ... else:
            ...     print("Prefetcher not ready yet")
            Prefetcher is ready with data
        """
        if self.worker_thread is None or not self.worker_thread.is_alive():
            self.should_stop = False
            self.start_time = time.time()
            self.worker_thread = threading.Thread(
                target=self._prefetch_worker, daemon=True
            )
            self.worker_thread.start()

    def stop(self):
        """
        Stop the prefetching worker thread and clean up resources.

        This method gracefully stops the background worker thread and waits
        for it to finish processing the current batch. It sets the stop flag
        and joins the thread with a timeout to prevent hanging.

        After calling stop(), the prefetcher will no longer load new batches.
        Any remaining batches in the queue can still be retrieved with get_batch().

        Examples:
            >>> # Stop the prefetcher
            >>> processor = TileDBBatchProcessor("path/to/experiment")
            >>> prefetcher = TileDBAsyncPrefetcher(processor)
            >>> prefetcher.start()
            >>>
            >>> # Do some work...
            >>> batch = prefetcher.get_batch()
            >>>
            >>> # Stop when done
            >>> prefetcher.stop()
            >>> print("Prefetcher stopped")
            Prefetcher stopped
        """
        self.should_stop = True
        if self.worker_thread and self.worker_thread.is_alive():
            self.worker_thread.join(timeout=1.0)

    def _prefetch_worker(self):
        """Worker thread that loads batches in background"""
        while not self.should_stop:
            try:
                # Load batch chunk
                batch = self.batch_processor.load_prefetch_batch()

                # Update monitoring stats
                self.total_cells_added += len(batch.cell_integer_ids)
                self.total_process_time += batch.process_time
                self.process_count += 1
                self.current_epoch = self.batch_processor.current_epoch

                elapsed = time.time() - (self.start_time or 0)
                rate = self.total_cells_added / elapsed if elapsed > 0 else 0

                # Print rate every 10 batches
                if batch.batch_id % 10 == 0 and batch.batch_id > self.last_rate_print:
                    avg_process_ms = (
                        self.total_process_time / self.process_count
                    ) * 1000
                    rate_report = f"TileDB prefetch rate: {rate:.1f} cells/sec (epoch {self.current_epoch}, total: {self.total_cells_added} cells, avg process: {avg_process_ms:.1f}ms)"
                    print_prefetch(rate_report, self.batch_processor.verbose)
                    self.last_rate_print = batch.batch_id

                # Put in queue
                try:
                    self.queue.put_nowait(batch)
                except queue.Full:
                    # Queue is full, wait a bit
                    time.sleep(0.1)

            except StopIteration as e:
                if "No more epochs available" in str(e):
                    if self.batch_processor.verbose:
                        print(
                            f"✅ All {self.batch_processor.n_epochs} epochs completed"
                        )
                else:
                    logger.info("Reached end of batches")
                break
            except Exception as e:
                logger.info(f"Error loading TileDB batch: {e}")
                break

    def get_batch(self) -> TileDBPrefetchBatch | None:
        """
        Get the next pre-processed batch from the queue.

        This method retrieves a pre-processed batch from the internal queue.
        If no batch is available, it returns None. The method has a timeout
        to prevent blocking indefinitely.

        Returns:
            TileDBPrefetchBatch | None: The next available batch, or None if
                                       no batch is available within the timeout.

        Examples:
            >>> # Get a batch from the prefetcher
            >>> processor = TileDBBatchProcessor("path/to/experiment")
            >>> prefetcher = TileDBAsyncPrefetcher(processor)
            >>> prefetcher.start()
            >>>
            >>> # Wait for a batch
            >>> batch = prefetcher.get_batch()
            >>> if batch:
            ...     print(f"Got batch {batch.batch_id} with {len(batch.cell_integer_ids)} cells")
            ... else:
            ...     print("No batch available")
            Got batch 0 with 100 cells

            >>> # Check for multiple batches
            >>> batches = []
            >>> for _ in range(3):
            ...     batch = prefetcher.get_batch()
            ...     if batch:
            ...         batches.append(batch)
            ...     else:
            ...         break
            >>> print(f"Retrieved {len(batches)} batches")
            Retrieved 3 batches
        """
        try:
            return self.queue.get(timeout=1.0)
        except queue.Empty:
            return None

    def has_batch(self) -> bool:
        """
        Check if a pre-processed batch is available in the queue.

        This method provides a non-blocking way to check if batches are
        available for immediate retrieval. It returns True if the queue
        contains at least one batch, False otherwise.

        Returns:
            bool: True if at least one batch is available, False otherwise.

        Examples:
            >>> # Check for available batches
            >>> processor = TileDBBatchProcessor("path/to/experiment")
            >>> prefetcher = TileDBAsyncPrefetcher(processor)
            >>> prefetcher.start()
            >>>
            >>> # Check if batches are ready
            >>> if prefetcher.has_batch():
            ...     batch = prefetcher.get_batch()
            ...     print(f"Processing batch {batch.batch_id}")
            ... else:
            ...     print("No batches ready yet")
            No batches ready yet

            >>> # Wait and check again
            >>> import time
            >>> time.sleep(2)
            >>> if prefetcher.has_batch():
            ...     print("Batches are now available")
            ... else:
            ...     print("Still waiting for batches")
            Batches are now available
        """
        return not self.queue.empty()

    def get_stats(self) -> dict:
        """
        Get comprehensive statistics about the prefetcher's performance.

        This method returns detailed performance statistics including loading
        rates, memory usage, queue status, and processing times. Useful for
        monitoring and debugging the prefetcher's performance.

        Returns:
            dict: Performance statistics dictionary containing:
                - total_cells: Total number of cells processed
                - elapsed_time: Total time since prefetcher started (seconds)
                - cells_per_sec: Average loading rate (cells per second)
                - queue_size: Current number of batches in queue
                - queue_full: Whether the queue is at maximum capacity
                - total_process_time: Total time spent processing batches
                - process_count: Number of batches processed
                - avg_process_time_ms: Average processing time per batch (ms)
                - current_epoch: Current epoch number
                - n_epochs: Total number of epochs configured

        Examples:
            >>> # Get performance statistics
            >>> processor = TileDBBatchProcessor("path/to/experiment")
            >>> prefetcher = TileDBAsyncPrefetcher(processor)
            >>> prefetcher.start()
            >>>
            >>> # Wait for some processing
            >>> import time
            >>> time.sleep(5)
            >>>
            >>> # Get and display stats
            >>> stats = prefetcher.get_stats()
            >>> print(f"Loading rate: {stats['cells_per_sec']:.1f} cells/sec")
            >>> print(f"Queue size: {stats['queue_size']}/{prefetcher.max_queue_size}")
            >>> print(f"Current epoch: {stats['current_epoch']}/{stats['n_epochs']}")
            Loading rate: 1250.5 cells/sec
            Queue size: 45/500
            Current epoch: 0/1

            >>> # Monitor performance over time
            >>> for i in range(3):
            ...     stats = prefetcher.get_stats()
            ...     print(f"Check {i+1}: {stats['cells_per_sec']:.1f} cells/sec")
            ...     time.sleep(2)
            Check 1: 1200.3 cells/sec
            Check 2: 1250.5 cells/sec
            Check 3: 1180.7 cells/sec
        """
        elapsed = time.time() - (self.start_time or 0)
        rate = self.total_cells_added / elapsed if elapsed > 0 else 0
        avg_process_time = (
            self.total_process_time / self.process_count
            if self.process_count > 0
            else 0
        )
        return {
            "total_cells": self.total_cells_added,
            "elapsed_time": elapsed,
            "cells_per_sec": rate,
            "queue_size": self.queue.qsize(),
            "queue_full": self.queue.full(),
            "total_process_time": self.total_process_time,
            "process_count": self.process_count,
            "avg_process_time_ms": avg_process_time * 1000,
            "current_epoch": self.current_epoch,
            "n_epochs": self.batch_processor.n_epochs,
        }
Functions
__init__(batch_processor: TileDBBatchProcessor, max_queue_size: int = 500)

Initialize the TileDB async prefetcher with background processing.

Parameters:

Name Type Description Default
batch_processor TileDBBatchProcessor

TileDBBatchProcessor instance to use for loading batches. Must be properly initialized with TileDB path and configuration.

required
max_queue_size int

Maximum number of pre-processed batches to keep in queue. Higher values use more memory but provide better buffering. Range: 10-10000, default: 500.

500

Examples:

>>> # Create prefetcher with default queue size
>>> processor = TileDBBatchProcessor("path/to/experiment")
>>> prefetcher = TileDBAsyncPrefetcher(processor)
>>> print(f"Max queue size: {prefetcher.max_queue_size}")
Max queue size: 500
>>> # Create prefetcher with custom queue size
>>> prefetcher = TileDBAsyncPrefetcher(processor, max_queue_size=1000)
>>> print(f"Custom queue size: {prefetcher.max_queue_size}")
Custom queue size: 1000
Source code in slaf/ml/tiledb_dataloaders.py
def __init__(
    self, batch_processor: TileDBBatchProcessor, max_queue_size: int = 500
):
    """
    Initialize the TileDB async prefetcher with background processing.

    Args:
        batch_processor: TileDBBatchProcessor instance to use for loading batches.
                       Must be properly initialized with TileDB path and configuration.
        max_queue_size: Maximum number of pre-processed batches to keep in queue.
                      Higher values use more memory but provide better buffering.
                      Range: 10-10000, default: 500.

    Examples:
        >>> # Create prefetcher with default queue size
        >>> processor = TileDBBatchProcessor("path/to/experiment")
        >>> prefetcher = TileDBAsyncPrefetcher(processor)
        >>> print(f"Max queue size: {prefetcher.max_queue_size}")
        Max queue size: 500

        >>> # Create prefetcher with custom queue size
        >>> prefetcher = TileDBAsyncPrefetcher(processor, max_queue_size=1000)
        >>> print(f"Custom queue size: {prefetcher.max_queue_size}")
        Custom queue size: 1000
    """
    self.batch_processor = batch_processor
    self.max_queue_size = max_queue_size
    self.queue: Queue[TileDBPrefetchBatch] = Queue(maxsize=max_queue_size)
    self.worker_thread = None
    self.should_stop = False

    # Monitoring stats
    self.total_cells_added = 0
    self.start_time = None
    self.last_rate_print = 0
    self.total_process_time = 0.0
    self.process_count = 0
    self.current_epoch = 0
start()

Start the prefetching worker thread for background batch processing.

This method starts a background worker thread that continuously loads and processes batches from the TileDB batch processor. The worker thread runs as a daemon thread and will automatically stop when the main process exits.

The prefetcher will begin loading batches immediately after starting. Use get_batch() to retrieve pre-processed batches from the queue.

Examples:

>>> # Start background prefetching
>>> processor = TileDBBatchProcessor("path/to/experiment")
>>> prefetcher = TileDBAsyncPrefetcher(processor)
>>> prefetcher.start()
>>> print("Prefetcher started")
Prefetcher started
>>> # Check if prefetcher is ready
>>> import time
>>> time.sleep(1)  # Wait for first batch
>>> if prefetcher.has_batch():
...     print("Prefetcher is ready with data")
... else:
...     print("Prefetcher not ready yet")
Prefetcher is ready with data
Source code in slaf/ml/tiledb_dataloaders.py
def start(self):
    """
    Start the prefetching worker thread for background batch processing.

    This method starts a background worker thread that continuously loads
    and processes batches from the TileDB batch processor. The worker thread
    runs as a daemon thread and will automatically stop when the main
    process exits.

    The prefetcher will begin loading batches immediately after starting.
    Use get_batch() to retrieve pre-processed batches from the queue.

    Examples:
        >>> # Start background prefetching
        >>> processor = TileDBBatchProcessor("path/to/experiment")
        >>> prefetcher = TileDBAsyncPrefetcher(processor)
        >>> prefetcher.start()
        >>> print("Prefetcher started")
        Prefetcher started

        >>> # Check if prefetcher is ready
        >>> import time
        >>> time.sleep(1)  # Wait for first batch
        >>> if prefetcher.has_batch():
        ...     print("Prefetcher is ready with data")
        ... else:
        ...     print("Prefetcher not ready yet")
        Prefetcher is ready with data
    """
    if self.worker_thread is None or not self.worker_thread.is_alive():
        self.should_stop = False
        self.start_time = time.time()
        self.worker_thread = threading.Thread(
            target=self._prefetch_worker, daemon=True
        )
        self.worker_thread.start()
stop()

Stop the prefetching worker thread and clean up resources.

This method gracefully stops the background worker thread and waits for it to finish processing the current batch. It sets the stop flag and joins the thread with a timeout to prevent hanging.

After calling stop(), the prefetcher will no longer load new batches. Any remaining batches in the queue can still be retrieved with get_batch().

Examples:

>>> # Stop the prefetcher
>>> processor = TileDBBatchProcessor("path/to/experiment")
>>> prefetcher = TileDBAsyncPrefetcher(processor)
>>> prefetcher.start()
>>>
>>> # Do some work...
>>> batch = prefetcher.get_batch()
>>>
>>> # Stop when done
>>> prefetcher.stop()
>>> print("Prefetcher stopped")
Prefetcher stopped
Source code in slaf/ml/tiledb_dataloaders.py
def stop(self):
    """
    Stop the prefetching worker thread and clean up resources.

    This method gracefully stops the background worker thread and waits
    for it to finish processing the current batch. It sets the stop flag
    and joins the thread with a timeout to prevent hanging.

    After calling stop(), the prefetcher will no longer load new batches.
    Any remaining batches in the queue can still be retrieved with get_batch().

    Examples:
        >>> # Stop the prefetcher
        >>> processor = TileDBBatchProcessor("path/to/experiment")
        >>> prefetcher = TileDBAsyncPrefetcher(processor)
        >>> prefetcher.start()
        >>>
        >>> # Do some work...
        >>> batch = prefetcher.get_batch()
        >>>
        >>> # Stop when done
        >>> prefetcher.stop()
        >>> print("Prefetcher stopped")
        Prefetcher stopped
    """
    self.should_stop = True
    if self.worker_thread and self.worker_thread.is_alive():
        self.worker_thread.join(timeout=1.0)
get_batch() -> TileDBPrefetchBatch | None

Get the next pre-processed batch from the queue.

This method retrieves a pre-processed batch from the internal queue. If no batch is available, it returns None. The method has a timeout to prevent blocking indefinitely.

Returns:

Type Description
TileDBPrefetchBatch | None

TileDBPrefetchBatch | None: The next available batch, or None if no batch is available within the timeout.

Examples:

>>> # Get a batch from the prefetcher
>>> processor = TileDBBatchProcessor("path/to/experiment")
>>> prefetcher = TileDBAsyncPrefetcher(processor)
>>> prefetcher.start()
>>>
>>> # Wait for a batch
>>> batch = prefetcher.get_batch()
>>> if batch:
...     print(f"Got batch {batch.batch_id} with {len(batch.cell_integer_ids)} cells")
... else:
...     print("No batch available")
Got batch 0 with 100 cells
>>> # Check for multiple batches
>>> batches = []
>>> for _ in range(3):
...     batch = prefetcher.get_batch()
...     if batch:
...         batches.append(batch)
...     else:
...         break
>>> print(f"Retrieved {len(batches)} batches")
Retrieved 3 batches
Source code in slaf/ml/tiledb_dataloaders.py
def get_batch(self) -> TileDBPrefetchBatch | None:
    """
    Get the next pre-processed batch from the queue.

    This method retrieves a pre-processed batch from the internal queue.
    If no batch is available, it returns None. The method has a timeout
    to prevent blocking indefinitely.

    Returns:
        TileDBPrefetchBatch | None: The next available batch, or None if
                                   no batch is available within the timeout.

    Examples:
        >>> # Get a batch from the prefetcher
        >>> processor = TileDBBatchProcessor("path/to/experiment")
        >>> prefetcher = TileDBAsyncPrefetcher(processor)
        >>> prefetcher.start()
        >>>
        >>> # Wait for a batch
        >>> batch = prefetcher.get_batch()
        >>> if batch:
        ...     print(f"Got batch {batch.batch_id} with {len(batch.cell_integer_ids)} cells")
        ... else:
        ...     print("No batch available")
        Got batch 0 with 100 cells

        >>> # Check for multiple batches
        >>> batches = []
        >>> for _ in range(3):
        ...     batch = prefetcher.get_batch()
        ...     if batch:
        ...         batches.append(batch)
        ...     else:
        ...         break
        >>> print(f"Retrieved {len(batches)} batches")
        Retrieved 3 batches
    """
    try:
        return self.queue.get(timeout=1.0)
    except queue.Empty:
        return None
has_batch() -> bool

Check if a pre-processed batch is available in the queue.

This method provides a non-blocking way to check if batches are available for immediate retrieval. It returns True if the queue contains at least one batch, False otherwise.

Returns:

Name Type Description
bool bool

True if at least one batch is available, False otherwise.

Examples:

>>> # Check for available batches
>>> processor = TileDBBatchProcessor("path/to/experiment")
>>> prefetcher = TileDBAsyncPrefetcher(processor)
>>> prefetcher.start()
>>>
>>> # Check if batches are ready
>>> if prefetcher.has_batch():
...     batch = prefetcher.get_batch()
...     print(f"Processing batch {batch.batch_id}")
... else:
...     print("No batches ready yet")
No batches ready yet
>>> # Wait and check again
>>> import time
>>> time.sleep(2)
>>> if prefetcher.has_batch():
...     print("Batches are now available")
... else:
...     print("Still waiting for batches")
Batches are now available
Source code in slaf/ml/tiledb_dataloaders.py
def has_batch(self) -> bool:
    """
    Check if a pre-processed batch is available in the queue.

    This method provides a non-blocking way to check if batches are
    available for immediate retrieval. It returns True if the queue
    contains at least one batch, False otherwise.

    Returns:
        bool: True if at least one batch is available, False otherwise.

    Examples:
        >>> # Check for available batches
        >>> processor = TileDBBatchProcessor("path/to/experiment")
        >>> prefetcher = TileDBAsyncPrefetcher(processor)
        >>> prefetcher.start()
        >>>
        >>> # Check if batches are ready
        >>> if prefetcher.has_batch():
        ...     batch = prefetcher.get_batch()
        ...     print(f"Processing batch {batch.batch_id}")
        ... else:
        ...     print("No batches ready yet")
        No batches ready yet

        >>> # Wait and check again
        >>> import time
        >>> time.sleep(2)
        >>> if prefetcher.has_batch():
        ...     print("Batches are now available")
        ... else:
        ...     print("Still waiting for batches")
        Batches are now available
    """
    return not self.queue.empty()
get_stats() -> dict

Get comprehensive statistics about the prefetcher's performance.

This method returns detailed performance statistics including loading rates, memory usage, queue status, and processing times. Useful for monitoring and debugging the prefetcher's performance.

Returns:

Name Type Description
dict dict

Performance statistics dictionary containing: - total_cells: Total number of cells processed - elapsed_time: Total time since prefetcher started (seconds) - cells_per_sec: Average loading rate (cells per second) - queue_size: Current number of batches in queue - queue_full: Whether the queue is at maximum capacity - total_process_time: Total time spent processing batches - process_count: Number of batches processed - avg_process_time_ms: Average processing time per batch (ms) - current_epoch: Current epoch number - n_epochs: Total number of epochs configured

Examples:

>>> # Get performance statistics
>>> processor = TileDBBatchProcessor("path/to/experiment")
>>> prefetcher = TileDBAsyncPrefetcher(processor)
>>> prefetcher.start()
>>>
>>> # Wait for some processing
>>> import time
>>> time.sleep(5)
>>>
>>> # Get and display stats
>>> stats = prefetcher.get_stats()
>>> print(f"Loading rate: {stats['cells_per_sec']:.1f} cells/sec")
>>> print(f"Queue size: {stats['queue_size']}/{prefetcher.max_queue_size}")
>>> print(f"Current epoch: {stats['current_epoch']}/{stats['n_epochs']}")
Loading rate: 1250.5 cells/sec
Queue size: 45/500
Current epoch: 0/1
>>> # Monitor performance over time
>>> for i in range(3):
...     stats = prefetcher.get_stats()
...     print(f"Check {i+1}: {stats['cells_per_sec']:.1f} cells/sec")
...     time.sleep(2)
Check 1: 1200.3 cells/sec
Check 2: 1250.5 cells/sec
Check 3: 1180.7 cells/sec
Source code in slaf/ml/tiledb_dataloaders.py
def get_stats(self) -> dict:
    """
    Get comprehensive statistics about the prefetcher's performance.

    This method returns detailed performance statistics including loading
    rates, memory usage, queue status, and processing times. Useful for
    monitoring and debugging the prefetcher's performance.

    Returns:
        dict: Performance statistics dictionary containing:
            - total_cells: Total number of cells processed
            - elapsed_time: Total time since prefetcher started (seconds)
            - cells_per_sec: Average loading rate (cells per second)
            - queue_size: Current number of batches in queue
            - queue_full: Whether the queue is at maximum capacity
            - total_process_time: Total time spent processing batches
            - process_count: Number of batches processed
            - avg_process_time_ms: Average processing time per batch (ms)
            - current_epoch: Current epoch number
            - n_epochs: Total number of epochs configured

    Examples:
        >>> # Get performance statistics
        >>> processor = TileDBBatchProcessor("path/to/experiment")
        >>> prefetcher = TileDBAsyncPrefetcher(processor)
        >>> prefetcher.start()
        >>>
        >>> # Wait for some processing
        >>> import time
        >>> time.sleep(5)
        >>>
        >>> # Get and display stats
        >>> stats = prefetcher.get_stats()
        >>> print(f"Loading rate: {stats['cells_per_sec']:.1f} cells/sec")
        >>> print(f"Queue size: {stats['queue_size']}/{prefetcher.max_queue_size}")
        >>> print(f"Current epoch: {stats['current_epoch']}/{stats['n_epochs']}")
        Loading rate: 1250.5 cells/sec
        Queue size: 45/500
        Current epoch: 0/1

        >>> # Monitor performance over time
        >>> for i in range(3):
        ...     stats = prefetcher.get_stats()
        ...     print(f"Check {i+1}: {stats['cells_per_sec']:.1f} cells/sec")
        ...     time.sleep(2)
        Check 1: 1200.3 cells/sec
        Check 2: 1250.5 cells/sec
        Check 3: 1180.7 cells/sec
    """
    elapsed = time.time() - (self.start_time or 0)
    rate = self.total_cells_added / elapsed if elapsed > 0 else 0
    avg_process_time = (
        self.total_process_time / self.process_count
        if self.process_count > 0
        else 0
    )
    return {
        "total_cells": self.total_cells_added,
        "elapsed_time": elapsed,
        "cells_per_sec": rate,
        "queue_size": self.queue.qsize(),
        "queue_full": self.queue.full(),
        "total_process_time": self.total_process_time,
        "process_count": self.process_count,
        "avg_process_time_ms": avg_process_time * 1000,
        "current_epoch": self.current_epoch,
        "n_epochs": self.batch_processor.n_epochs,
    }

TileDBIterableDataset

Bases: IterableDataset

PyTorch IterableDataset for streaming TileDB SOMA data with async prefetching.

TileDBIterableDataset provides a PyTorch-compatible interface for streaming single-cell data from TileDB SOMA format. It combines the TileDBBatchProcessor and TileDBAsyncPrefetcher to provide efficient, asynchronous data loading for machine learning training.

Key Features
  • PyTorch IterableDataset compatibility
  • Asynchronous background prefetching for improved throughput
  • Multiple loading strategies (MoS and sequential)
  • Multi-epoch training support
  • Automatic epoch transition handling
  • Memory-efficient streaming
  • Comprehensive error handling
  • Configurable batch and prefetch sizes

The dataset automatically manages background prefetching and provides seamless iteration over batches of TileDB data. It handles epoch transitions and provides detailed timing information for performance monitoring.

Examples:

>>> # Create dataset with default MoS strategy
>>> dataset = TileDBIterableDataset(
...     tiledb_path="path/to/experiment",
...     batch_size=32,
...     prefetch_batch_size=100
... )
>>>
>>> # Iterate through batches
>>> for batch in dataset:
...     print(f"Batch keys: {list(batch.keys())}")
...     print(f"Cell IDs: {batch['cell_ids']}")
...     break
Batch keys: ['X', 'cell_ids']
Cell IDs: [0, 1, 2, ..., 29, 30, 31]

>>> # Sequential loading for maximum throughput
>>> dataset = TileDBIterableDataset(
...     tiledb_path="path/to/experiment",
...     use_mixture_of_scanners=False,
...     batch_size=64
... )
>>> print(f"MoS enabled: {dataset.use_mixture_of_scanners}")
MoS enabled: False

>>> # Multi-epoch training
>>> dataset = TileDBIterableDataset(
...     tiledb_path="path/to/experiment",
...     n_epochs=3
... )
>>> epochs_seen = set()
>>> for batch in dataset:
...     if 'epoch' in batch:
...         epochs_seen.add(batch['epoch'])
...     if len(epochs_seen) >= 3:
...         break
>>> print(f"Epochs completed: {sorted(epochs_seen)}")
Epochs completed: [0, 1, 2]
Source code in slaf/ml/tiledb_dataloaders.py
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
class TileDBIterableDataset(IterableDataset):
    """
    PyTorch IterableDataset for streaming TileDB SOMA data with async prefetching.

    TileDBIterableDataset provides a PyTorch-compatible interface for streaming
    single-cell data from TileDB SOMA format. It combines the TileDBBatchProcessor
    and TileDBAsyncPrefetcher to provide efficient, asynchronous data loading
    for machine learning training.

    Key Features:
        - PyTorch IterableDataset compatibility
        - Asynchronous background prefetching for improved throughput
        - Multiple loading strategies (MoS and sequential)
        - Multi-epoch training support
        - Automatic epoch transition handling
        - Memory-efficient streaming
        - Comprehensive error handling
        - Configurable batch and prefetch sizes

    The dataset automatically manages background prefetching and provides
    seamless iteration over batches of TileDB data. It handles epoch
    transitions and provides detailed timing information for performance
    monitoring.

    Examples:
            >>> # Create dataset with default MoS strategy
            >>> dataset = TileDBIterableDataset(
            ...     tiledb_path="path/to/experiment",
            ...     batch_size=32,
            ...     prefetch_batch_size=100
            ... )
            >>>
            >>> # Iterate through batches
            >>> for batch in dataset:
            ...     print(f"Batch keys: {list(batch.keys())}")
            ...     print(f"Cell IDs: {batch['cell_ids']}")
            ...     break
        Batch keys: ['X', 'cell_ids']
        Cell IDs: [0, 1, 2, ..., 29, 30, 31]

        >>> # Sequential loading for maximum throughput
        >>> dataset = TileDBIterableDataset(
        ...     tiledb_path="path/to/experiment",
        ...     use_mixture_of_scanners=False,
        ...     batch_size=64
        ... )
        >>> print(f"MoS enabled: {dataset.use_mixture_of_scanners}")
        MoS enabled: False

        >>> # Multi-epoch training
        >>> dataset = TileDBIterableDataset(
        ...     tiledb_path="path/to/experiment",
        ...     n_epochs=3
        ... )
        >>> epochs_seen = set()
        >>> for batch in dataset:
        ...     if 'epoch' in batch:
        ...         epochs_seen.add(batch['epoch'])
        ...     if len(epochs_seen) >= 3:
        ...         break
        >>> print(f"Epochs completed: {sorted(epochs_seen)}")
        Epochs completed: [0, 1, 2]
    """

    def __init__(
        self,
        tiledb_path: str,
        batch_size: int = 32,
        prefetch_batch_size: int = 100,
        seed: int = 42,
        max_queue_size: int = 500,
        n_epochs: int = 1,
        verbose: bool = True,
        use_mixture_of_scanners: bool = True,
        n_readers: int = 50,
        n_scanners: int = 8,
    ):
        """
        Initialize the TileDB IterableDataset with async prefetching.

        Args:
            tiledb_path: Path to the TileDB SOMA experiment directory.
                         Must contain a valid SOMA experiment with RNA measurement data.
            batch_size: Number of cells per training batch. Larger batches use more
                       memory but may improve training efficiency. Range: 1-512, default: 32.
            prefetch_batch_size: Number of cells to load per prefetch batch from TileDB.
                               Higher values improve throughput but use more memory.
                               Range: 10-10000, default: 100.
            seed: Random seed for reproducible shuffling and MoS sampling.
                  Used for consistent data ordering across runs. Default: 42.
            max_queue_size: Maximum number of pre-processed batches to keep in queue.
                          Higher values use more memory but provide better buffering.
                          Range: 10-10000, default: 500.
            n_epochs: Number of epochs to run. The dataset will automatically reset
                     after each epoch, enabling multi-epoch training. Default: 1.
            verbose: If True, print detailed timing and progress information.
                    If False, suppress all internal prints for clean output. Default: True.
            use_mixture_of_scanners: If True, use MoS strategy for higher entropy by
                                   randomly sampling from multiple fragment generators.
                                   Provides better randomization for foundation model training.
                                   Default: True.
            n_readers: Total number of fragment generators to create when using MoS.
                      Higher values provide better entropy but use more memory.
                      Range: 1-1000, default: 50.
            n_scanners: Number of active scanners to sample from simultaneously when using MoS.
                       Higher values provide better entropy but use more memory.
                       Range: 1-100, default: 8.

        Raises:
            ImportError: If TileDB SOMA is not available.
            ValueError: If MoS parameters are invalid.
            RuntimeError: If the TileDB experiment cannot be opened or is invalid.

        Examples:
            >>> # Basic initialization with default MoS strategy
            >>> dataset = TileDBIterableDataset(
            ...     tiledb_path="path/to/experiment",
            ...     batch_size=32,
            ...     prefetch_batch_size=100
            ... )
            >>> print(f"MoS enabled: {dataset.use_mixture_of_scanners}")
            MoS enabled: True

            >>> # Sequential loading for maximum throughput
            >>> dataset = TileDBIterableDataset(
            ...     tiledb_path="path/to/experiment",
            ...     use_mixture_of_scanners=False,
            ...     batch_size=64
            ... )
            >>> print(f"Sequential loading: {not dataset.use_mixture_of_scanners}")
            Sequential loading: True

            >>> # High-entropy MoS configuration
            >>> dataset = TileDBIterableDataset(
            ...     tiledb_path="path/to/experiment",
            ...     n_readers=100,
            ...     n_scanners=16
            ... )
            >>> print(f"MoS readers: {dataset.n_readers}, scanners: {dataset.n_scanners}")
            MoS readers: 100, scanners: 16
        """
        super().__init__()
        self.tiledb_path = tiledb_path
        self.batch_size = batch_size
        self.prefetch_batch_size = prefetch_batch_size
        self.seed = seed
        self.max_queue_size = max_queue_size
        self.n_epochs = n_epochs
        self.verbose = verbose
        self.use_mixture_of_scanners = use_mixture_of_scanners
        self.n_readers = n_readers
        self.n_scanners = n_scanners

        # Initialize batch processor
        self.batch_processor = TileDBBatchProcessor(
            tiledb_path=tiledb_path,
            batch_size=batch_size,
            prefetch_batch_size=prefetch_batch_size,
            seed=seed,
            n_epochs=n_epochs,
            verbose=verbose,
            log_metrics=False,
            use_mixture_of_scanners=use_mixture_of_scanners,
            n_readers=n_readers,
            n_scanners=n_scanners,
        )

        # Initialize async prefetcher
        self.prefetcher = TileDBAsyncPrefetcher(
            batch_processor=self.batch_processor,
            max_queue_size=max_queue_size,
        )

        # Start async prefetching
        self.prefetcher.start()

        # Wait for prefetcher to initialize
        self._wait_for_prefetcher_ready()

    def _wait_for_prefetcher_ready(self, timeout: float = 10.0):
        """Wait for the prefetcher to be ready with data."""
        start_time = time.time()
        while time.time() - start_time < timeout:
            if self.prefetcher.has_batch():
                if self.verbose:
                    print(
                        f"✅ TileDB prefetcher ready after {time.time() - start_time:.2f}s"
                    )
                return
            time.sleep(0.1)

        if self.verbose:
            print(
                f"⚠️ TileDB prefetcher not ready after {timeout}s, proceeding anyway..."
            )

    def __iter__(self) -> Iterator[dict]:
        """
        Iterate through batches of TileDB data with async prefetching.

        This method provides an iterator over batches of TileDB data, automatically
        handling background prefetching, epoch transitions, and error recovery.
        It yields dictionaries containing the batch data and metadata.

        The iterator automatically manages:
        - Background prefetching for improved throughput
        - Epoch transitions for multi-epoch training
        - Error handling and recovery
        - Performance monitoring and reporting

        Yields:
            dict: Batch dictionary containing:
                - X: Polars DataFrame with cell-gene expression data
                - cell_ids: List of unique cell IDs in the batch
                - epoch: Current epoch number (when n_epochs > 1)

        Examples:
            >>> # Basic iteration
            >>> dataset = TileDBIterableDataset("path/to/experiment")
            >>> for batch in dataset:
            ...     print(f"Batch keys: {list(batch.keys())}")
            ...     print(f"Cell IDs: {batch['cell_ids']}")
            ...     break
            Batch keys: ['X', 'cell_ids']
            Cell IDs: [0, 1, 2, ..., 29, 30, 31]

            >>> # Multi-epoch training
            >>> dataset = TileDBIterableDataset("path/to/experiment", n_epochs=3)
            >>> epochs_seen = set()
            >>> for batch in dataset:
            ...     if 'epoch' in batch:
            ...         epochs_seen.add(batch['epoch'])
            ...     if len(epochs_seen) >= 3:
            ...         break
            >>> print(f"Epochs completed: {sorted(epochs_seen)}")
            Epochs completed: [0, 1, 2]

            >>> # Training loop with error handling
            >>> dataset = TileDBIterableDataset("path/to/experiment")
            >>> for batch_idx, batch in enumerate(dataset):
            ...     try:
            ...         x = batch["X"]
            ...         cell_ids = batch["cell_ids"]
            ...         print(f"Processed batch {batch_idx} with {len(cell_ids)} cells")
            ...     except Exception as e:
            ...         print(f"Error in batch {batch_idx}: {e}")
            ...         continue
            ...     if batch_idx >= 2:  # Just first few batches
            ...         break
            Processed batch 0 with 32 cells
            Processed batch 1 with 32 cells
            Processed batch 2 with 32 cells
        """
        batches_yielded = 0
        current_epoch = 0
        last_epoch = -1

        while True:
            # Get data from prefetcher
            data_start = time.time()
            data = self.prefetcher.get_batch()
            data_time = time.time() - data_start

            if data is None:
                # Check if prefetcher has finished all epochs
                stats = self.prefetcher.get_stats()
                if stats["current_epoch"] >= stats["n_epochs"]:
                    if self.verbose:
                        print(
                            f"✅ Dataset iteration complete: all {stats['n_epochs']} epochs finished"
                        )
                    break

                # Wait for more data with timeout
                wait_start = time.time()
                while not self.prefetcher.has_batch():
                    time.sleep(0.1)
                    # Timeout after 5 seconds to avoid infinite wait
                    if time.time() - wait_start > 5.0:
                        if self.verbose:
                            print("⚠️ Timeout waiting for prefetcher data")
                        break

                data = self.prefetcher.get_batch()
                if data is None:
                    # Double-check if prefetcher is done
                    stats = self.prefetcher.get_stats()
                    if stats["current_epoch"] >= stats["n_epochs"]:
                        if self.verbose:
                            print(
                                f"✅ Dataset iteration complete: all {stats['n_epochs']} epochs finished"
                            )
                        break
                    else:
                        if self.verbose:
                            print("⚠️ No data available from prefetcher")
                        break

            # Track epoch transitions
            current_epoch = self.batch_processor.current_epoch
            if current_epoch != last_epoch:
                if self.verbose:
                    print(
                        f"🔄 Epoch transition detected: {last_epoch} -> {current_epoch}"
                    )
                last_epoch = current_epoch

            # Process batch data
            batch_df = data.batch_df

            # Time the overall batch processing
            batch_start_time = time.time()

            # Get unique cell IDs in this batch
            batch_cell_ids = batch_df["cell_integer_id"].unique().to_list()

            # Calculate total batch processing time
            total_batch_time = time.time() - batch_start_time

            # Create batch dictionary
            batch_dict = {
                "X": batch_df,  # Polars DataFrame with CSR-like structure
                "cell_ids": batch_cell_ids,
            }

            # Add epoch info if multi-epoch training
            if self.n_epochs > 1:
                batch_dict["epoch"] = current_epoch

            batches_yielded += 1

            # Print detailed timing every 100 batches
            if batches_yielded % 100 == 0:
                # Consolidate training batch reporting
                training_report = f"  TileDB training batch {batches_yielded} (epoch {current_epoch}) processing:\n"
                training_report += f"     Data retrieval: {data_time * 1000:.1f}ms\n"
                training_report += (
                    f"     Total batch time: {total_batch_time * 1000:.1f}ms\n"
                )
                training_report += "     Raw data (polars DataFrame)"

                print_training(training_report, self.verbose)

            yield batch_dict

    def __del__(self):
        """
        Cleanup when dataset is destroyed.

        This method is called when the dataset object is garbage collected.
        It ensures that the underlying prefetcher is properly stopped to
        prevent resource leaks and background thread issues.

        Examples:
            >>> # Dataset cleanup happens automatically
            >>> dataset = TileDBIterableDataset("path/to/experiment")
            >>> print("Dataset created")
            Dataset created
            >>> # When dataset goes out of scope, __del__ is called automatically
            >>> del dataset
            >>> print("Dataset destroyed and cleaned up")
            Dataset destroyed and cleaned up
        """
        self.prefetcher.stop()
Functions
__init__(tiledb_path: str, batch_size: int = 32, prefetch_batch_size: int = 100, seed: int = 42, max_queue_size: int = 500, n_epochs: int = 1, verbose: bool = True, use_mixture_of_scanners: bool = True, n_readers: int = 50, n_scanners: int = 8)

Initialize the TileDB IterableDataset with async prefetching.

Parameters:

Name Type Description Default
tiledb_path str

Path to the TileDB SOMA experiment directory. Must contain a valid SOMA experiment with RNA measurement data.

required
batch_size int

Number of cells per training batch. Larger batches use more memory but may improve training efficiency. Range: 1-512, default: 32.

32
prefetch_batch_size int

Number of cells to load per prefetch batch from TileDB. Higher values improve throughput but use more memory. Range: 10-10000, default: 100.

100
seed int

Random seed for reproducible shuffling and MoS sampling. Used for consistent data ordering across runs. Default: 42.

42
max_queue_size int

Maximum number of pre-processed batches to keep in queue. Higher values use more memory but provide better buffering. Range: 10-10000, default: 500.

500
n_epochs int

Number of epochs to run. The dataset will automatically reset after each epoch, enabling multi-epoch training. Default: 1.

1
verbose bool

If True, print detailed timing and progress information. If False, suppress all internal prints for clean output. Default: True.

True
use_mixture_of_scanners bool

If True, use MoS strategy for higher entropy by randomly sampling from multiple fragment generators. Provides better randomization for foundation model training. Default: True.

True
n_readers int

Total number of fragment generators to create when using MoS. Higher values provide better entropy but use more memory. Range: 1-1000, default: 50.

50
n_scanners int

Number of active scanners to sample from simultaneously when using MoS. Higher values provide better entropy but use more memory. Range: 1-100, default: 8.

8

Raises:

Type Description
ImportError

If TileDB SOMA is not available.

ValueError

If MoS parameters are invalid.

RuntimeError

If the TileDB experiment cannot be opened or is invalid.

Examples:

>>> # Basic initialization with default MoS strategy
>>> dataset = TileDBIterableDataset(
...     tiledb_path="path/to/experiment",
...     batch_size=32,
...     prefetch_batch_size=100
... )
>>> print(f"MoS enabled: {dataset.use_mixture_of_scanners}")
MoS enabled: True
>>> # Sequential loading for maximum throughput
>>> dataset = TileDBIterableDataset(
...     tiledb_path="path/to/experiment",
...     use_mixture_of_scanners=False,
...     batch_size=64
... )
>>> print(f"Sequential loading: {not dataset.use_mixture_of_scanners}")
Sequential loading: True
>>> # High-entropy MoS configuration
>>> dataset = TileDBIterableDataset(
...     tiledb_path="path/to/experiment",
...     n_readers=100,
...     n_scanners=16
... )
>>> print(f"MoS readers: {dataset.n_readers}, scanners: {dataset.n_scanners}")
MoS readers: 100, scanners: 16
Source code in slaf/ml/tiledb_dataloaders.py
def __init__(
    self,
    tiledb_path: str,
    batch_size: int = 32,
    prefetch_batch_size: int = 100,
    seed: int = 42,
    max_queue_size: int = 500,
    n_epochs: int = 1,
    verbose: bool = True,
    use_mixture_of_scanners: bool = True,
    n_readers: int = 50,
    n_scanners: int = 8,
):
    """
    Initialize the TileDB IterableDataset with async prefetching.

    Args:
        tiledb_path: Path to the TileDB SOMA experiment directory.
                     Must contain a valid SOMA experiment with RNA measurement data.
        batch_size: Number of cells per training batch. Larger batches use more
                   memory but may improve training efficiency. Range: 1-512, default: 32.
        prefetch_batch_size: Number of cells to load per prefetch batch from TileDB.
                           Higher values improve throughput but use more memory.
                           Range: 10-10000, default: 100.
        seed: Random seed for reproducible shuffling and MoS sampling.
              Used for consistent data ordering across runs. Default: 42.
        max_queue_size: Maximum number of pre-processed batches to keep in queue.
                      Higher values use more memory but provide better buffering.
                      Range: 10-10000, default: 500.
        n_epochs: Number of epochs to run. The dataset will automatically reset
                 after each epoch, enabling multi-epoch training. Default: 1.
        verbose: If True, print detailed timing and progress information.
                If False, suppress all internal prints for clean output. Default: True.
        use_mixture_of_scanners: If True, use MoS strategy for higher entropy by
                               randomly sampling from multiple fragment generators.
                               Provides better randomization for foundation model training.
                               Default: True.
        n_readers: Total number of fragment generators to create when using MoS.
                  Higher values provide better entropy but use more memory.
                  Range: 1-1000, default: 50.
        n_scanners: Number of active scanners to sample from simultaneously when using MoS.
                   Higher values provide better entropy but use more memory.
                   Range: 1-100, default: 8.

    Raises:
        ImportError: If TileDB SOMA is not available.
        ValueError: If MoS parameters are invalid.
        RuntimeError: If the TileDB experiment cannot be opened or is invalid.

    Examples:
        >>> # Basic initialization with default MoS strategy
        >>> dataset = TileDBIterableDataset(
        ...     tiledb_path="path/to/experiment",
        ...     batch_size=32,
        ...     prefetch_batch_size=100
        ... )
        >>> print(f"MoS enabled: {dataset.use_mixture_of_scanners}")
        MoS enabled: True

        >>> # Sequential loading for maximum throughput
        >>> dataset = TileDBIterableDataset(
        ...     tiledb_path="path/to/experiment",
        ...     use_mixture_of_scanners=False,
        ...     batch_size=64
        ... )
        >>> print(f"Sequential loading: {not dataset.use_mixture_of_scanners}")
        Sequential loading: True

        >>> # High-entropy MoS configuration
        >>> dataset = TileDBIterableDataset(
        ...     tiledb_path="path/to/experiment",
        ...     n_readers=100,
        ...     n_scanners=16
        ... )
        >>> print(f"MoS readers: {dataset.n_readers}, scanners: {dataset.n_scanners}")
        MoS readers: 100, scanners: 16
    """
    super().__init__()
    self.tiledb_path = tiledb_path
    self.batch_size = batch_size
    self.prefetch_batch_size = prefetch_batch_size
    self.seed = seed
    self.max_queue_size = max_queue_size
    self.n_epochs = n_epochs
    self.verbose = verbose
    self.use_mixture_of_scanners = use_mixture_of_scanners
    self.n_readers = n_readers
    self.n_scanners = n_scanners

    # Initialize batch processor
    self.batch_processor = TileDBBatchProcessor(
        tiledb_path=tiledb_path,
        batch_size=batch_size,
        prefetch_batch_size=prefetch_batch_size,
        seed=seed,
        n_epochs=n_epochs,
        verbose=verbose,
        log_metrics=False,
        use_mixture_of_scanners=use_mixture_of_scanners,
        n_readers=n_readers,
        n_scanners=n_scanners,
    )

    # Initialize async prefetcher
    self.prefetcher = TileDBAsyncPrefetcher(
        batch_processor=self.batch_processor,
        max_queue_size=max_queue_size,
    )

    # Start async prefetching
    self.prefetcher.start()

    # Wait for prefetcher to initialize
    self._wait_for_prefetcher_ready()

TileDBDataLoader

High-performance DataLoader for TileDB SOMA data optimized for ML training.

TileDBDataLoader provides efficient streaming of single-cell data from TileDB SOMA format for machine learning applications. It uses async batch processing and provides multiple loading strategies for different use cases.

Key Features
  • Multiple loading strategies for different entropy requirements:
    • Mixture of Scanners (MoS): Maximum entropy, best randomization (default)
    • Sequential loading: Fastest, lowest entropy
  • Asynchronous background prefetching for improved throughput
  • Multi-epoch training support with automatic epoch transitions
  • Memory-efficient streaming with configurable batch sizes
  • Comprehensive error handling and validation
  • Performance monitoring and statistics
  • PyTorch IterableDataset compatibility
Loading Strategies
  1. Mixture of Scanners (default): Randomly samples from multiple generators for maximum entropy and randomization
  2. Sequential: Loads contiguous data chunks for maximum throughput

Examples:

>>> # Basic usage with default MoS strategy
>>> dataloader = TileDBDataLoader(
...     tiledb_path="path/to/experiment",
...     batch_size=32,
...     prefetch_batch_size=100
... )
>>> for batch in dataloader:
...     print(f"Batch keys: {list(batch.keys())}")
...     print(f"Cell IDs: {batch['cell_ids']}")
...     break
Batch keys: ['X', 'cell_ids']
Cell IDs: [0, 1, 2, ..., 29, 30, 31]
>>> # Sequential loading for maximum throughput
>>> dataloader = TileDBDataLoader(
...     tiledb_path="path/to/experiment",
...     use_mixture_of_scanners=False,
...     batch_size=64
... )
>>> print(f"MoS enabled: {dataloader.use_mixture_of_scanners}")
MoS enabled: False
>>> # Multi-epoch training
>>> dataloader = TileDBDataLoader(
...     tiledb_path="path/to/experiment",
...     n_epochs=3
... )
>>> print(f"Number of epochs: {dataloader.n_epochs}")
Number of epochs: 3
>>> # Custom MoS configuration
>>> dataloader = TileDBDataLoader(
...     tiledb_path="path/to/experiment",
...     n_readers=100,
...     n_scanners=16
... )
>>> print(f"MoS readers: {dataloader.n_readers}, scanners: {dataloader.n_scanners}")
MoS readers: 100, scanners: 16
Source code in slaf/ml/tiledb_dataloaders.py
class TileDBDataLoader:
    """
    High-performance DataLoader for TileDB SOMA data optimized for ML training.

    TileDBDataLoader provides efficient streaming of single-cell data from TileDB
    SOMA format for machine learning applications. It uses async batch processing
    and provides multiple loading strategies for different use cases.

    Key Features:
        - Multiple loading strategies for different entropy requirements:
            * Mixture of Scanners (MoS): Maximum entropy, best randomization (default)
            * Sequential loading: Fastest, lowest entropy
        - Asynchronous background prefetching for improved throughput
        - Multi-epoch training support with automatic epoch transitions
        - Memory-efficient streaming with configurable batch sizes
        - Comprehensive error handling and validation
        - Performance monitoring and statistics
        - PyTorch IterableDataset compatibility

    Loading Strategies:
        1. Mixture of Scanners (default): Randomly samples from multiple generators
           for maximum entropy and randomization
        2. Sequential: Loads contiguous data chunks for maximum throughput

    Examples:
        >>> # Basic usage with default MoS strategy
        >>> dataloader = TileDBDataLoader(
        ...     tiledb_path="path/to/experiment",
        ...     batch_size=32,
        ...     prefetch_batch_size=100
        ... )
        >>> for batch in dataloader:
        ...     print(f"Batch keys: {list(batch.keys())}")
        ...     print(f"Cell IDs: {batch['cell_ids']}")
        ...     break
        Batch keys: ['X', 'cell_ids']
        Cell IDs: [0, 1, 2, ..., 29, 30, 31]

        >>> # Sequential loading for maximum throughput
        >>> dataloader = TileDBDataLoader(
        ...     tiledb_path="path/to/experiment",
        ...     use_mixture_of_scanners=False,
        ...     batch_size=64
        ... )
        >>> print(f"MoS enabled: {dataloader.use_mixture_of_scanners}")
        MoS enabled: False

        >>> # Multi-epoch training
        >>> dataloader = TileDBDataLoader(
        ...     tiledb_path="path/to/experiment",
        ...     n_epochs=3
        ... )
        >>> print(f"Number of epochs: {dataloader.n_epochs}")
        Number of epochs: 3

        >>> # Custom MoS configuration
        >>> dataloader = TileDBDataLoader(
        ...     tiledb_path="path/to/experiment",
        ...     n_readers=100,
        ...     n_scanners=16
        ... )
        >>> print(f"MoS readers: {dataloader.n_readers}, scanners: {dataloader.n_scanners}")
        MoS readers: 100, scanners: 16
    """

    def __init__(
        self,
        tiledb_path: str,
        batch_size: int = 32,
        prefetch_batch_size: int = 100,
        seed: int = 42,
        n_epochs: int = 1,
        verbose: bool = True,
        max_queue_size: int = 500,
        use_mixture_of_scanners: bool = True,
        n_readers: int = 50,
        n_scanners: int = 8,
    ):
        """
        Initialize the TileDB DataLoader with training configuration.

        Args:
            tiledb_path: Path to the TileDB SOMA experiment directory.
                         Must contain a valid SOMA experiment with RNA measurement data.
            batch_size: Number of cells per training batch. Larger batches use more
                       memory but may improve training efficiency. Range: 1-512, default: 32.
            prefetch_batch_size: Number of cells to prefetch from TileDB per batch.
                               Higher values improve throughput but use more memory.
                               Range: 10-10000, default: 100.
            seed: Random seed for reproducible shuffling and MoS sampling.
                  Used for consistent data ordering across runs. Default: 42.
            n_epochs: Number of epochs to run. The dataloader will automatically reset
                     after each epoch, enabling multi-epoch training. Default: 1.
            verbose: If True, print detailed timing and progress information.
                    If False, suppress all internal prints for clean output. Default: True.
            max_queue_size: Maximum number of pre-processed batches to keep in queue.
                          Higher values use more memory but provide better buffering.
                          Range: 10-10000, default: 500.
            use_mixture_of_scanners: If True, use MoS strategy for higher entropy by
                                   randomly sampling from multiple fragment generators.
                                   Provides better randomization for foundation model training.
                                   Default: True.
            n_readers: Total number of fragment generators to create when using MoS.
                      Higher values provide better entropy but use more memory.
                      Range: 1-1000, default: 50.
            n_scanners: Number of active scanners to sample from simultaneously when using MoS.
                       Higher values provide better entropy but use more memory.
                       Range: 1-100, default: 8.

        Raises:
            ImportError: If TileDB SOMA is not available.
            ValueError: If MoS parameters are invalid.
            RuntimeError: If the TileDB experiment cannot be opened or is invalid.

        Examples:
            >>> # Basic initialization with default MoS strategy
            >>> dataloader = TileDBDataLoader(
            ...     tiledb_path="path/to/experiment",
            ...     batch_size=32,
            ...     prefetch_batch_size=100
            ... )
            >>> print(f"MoS enabled: {dataloader.use_mixture_of_scanners}")
            MoS enabled: True

            >>> # Sequential loading for maximum throughput
            >>> dataloader = TileDBDataLoader(
            ...     tiledb_path="path/to/experiment",
            ...     use_mixture_of_scanners=False,
            ...     batch_size=64
            ... )
            >>> print(f"Sequential loading: {not dataloader.use_mixture_of_scanners}")
            Sequential loading: True

            >>> # High-entropy MoS configuration
            >>> dataloader = TileDBDataLoader(
            ...     tiledb_path="path/to/experiment",
            ...     n_readers=100,
            ...     n_scanners=16
            ... )
            >>> print(f"MoS readers: {dataloader.n_readers}, scanners: {dataloader.n_scanners}")
            MoS readers: 100, scanners: 16
        """
        self.tiledb_path = tiledb_path
        self.batch_size = batch_size
        self.prefetch_batch_size = prefetch_batch_size
        self.seed = seed
        self.n_epochs = n_epochs
        self.verbose = verbose
        self.max_queue_size = max_queue_size
        self.use_mixture_of_scanners = use_mixture_of_scanners
        self.n_readers = n_readers
        self.n_scanners = n_scanners

        # Check that required modules are available
        if not TILEDB_AVAILABLE:
            raise ImportError("TileDB SOMA is required but not available")

        # Use IterableDataset
        self._dataset = TileDBIterableDataset(
            tiledb_path=tiledb_path,
            batch_size=batch_size,
            prefetch_batch_size=prefetch_batch_size,
            seed=seed,
            max_queue_size=max_queue_size,
            n_epochs=n_epochs,
            verbose=verbose,
            use_mixture_of_scanners=use_mixture_of_scanners,
            n_readers=n_readers,
            n_scanners=n_scanners,
        )

    def __iter__(self):
        """
        Iterate through batches of TileDB data with async prefetching.

        This method provides an iterator over batches of TileDB data, automatically
        handling background prefetching, epoch transitions, and error recovery.
        It yields dictionaries containing the batch data and metadata.

        The iterator automatically manages:
        - Background prefetching for improved throughput
        - Epoch transitions for multi-epoch training
        - Error handling and recovery
        - Performance monitoring and reporting

        Yields:
            dict: Batch dictionary containing:
                - X: Polars DataFrame with cell-gene expression data
                - cell_ids: List of unique cell IDs in the batch
                - epoch: Current epoch number (when n_epochs > 1)

        Examples:
            >>> # Basic iteration
            >>> dataloader = TileDBDataLoader("path/to/experiment")
            >>> for batch in dataloader:
            ...     print(f"Batch keys: {list(batch.keys())}")
            ...     print(f"Cell IDs: {batch['cell_ids']}")
            ...     break
            Batch keys: ['X', 'cell_ids']
            Cell IDs: [0, 1, 2, ..., 29, 30, 31]

            >>> # Multi-epoch training
            >>> dataloader = TileDBDataLoader("path/to/experiment", n_epochs=3)
            >>> epochs_seen = set()
            >>> for batch in dataloader:
            ...     if 'epoch' in batch:
            ...         epochs_seen.add(batch['epoch'])
            ...     if len(epochs_seen) >= 3:
            ...         break
            >>> print(f"Epochs completed: {sorted(epochs_seen)}")
            Epochs completed: [0, 1, 2]

            >>> # Training loop with error handling
            >>> dataloader = TileDBDataLoader("path/to/experiment")
            >>> for batch_idx, batch in enumerate(dataloader):
            ...     try:
            ...         x = batch["X"]
            ...         cell_ids = batch["cell_ids"]
            ...         print(f"Processed batch {batch_idx} with {len(cell_ids)} cells")
            ...     except Exception as e:
            ...         print(f"Error in batch {batch_idx}: {e}")
            ...         continue
            ...     if batch_idx >= 2:  # Just first few batches
            ...         break
            Processed batch 0 with 32 cells
            Processed batch 1 with 32 cells
            Processed batch 2 with 32 cells
        """
        yield from self._dataset

    def __len__(self):
        """
        Return the number of batches in the dataset.

        Note: Since TileDBDataLoader uses an IterableDataset that streams data,
        the exact number of batches is not known in advance. This method
        returns 0 to indicate an unknown length for streaming datasets.

        Returns:
            int: Always returns 0 to indicate unknown length for streaming datasets.

        Examples:
            >>> # Check dataset length
            >>> dataloader = TileDBDataLoader("path/to/experiment")
            >>> print(f"Dataset length: {len(dataloader)}")
            Dataset length: 0

            >>> # IterableDataset behavior
            >>> batch_count = 0
            >>> for batch in dataloader:
            ...     batch_count += 1
            ...     if batch_count >= 5:  # Just count first 5 batches
            ...         break
            >>> print(f"Actually processed {batch_count} batches")
            Actually processed 5 batches

            >>> # Length is consistent
            >>> print(f"Length check: {len(dataloader)}")
            Length check: 0
        """
        return 0  # Indicates unknown length for streaming datasets

    def __del__(self):
        """
        Cleanup method to stop async prefetching.

        This method is called when the DataLoader object is garbage collected.
        It ensures that the underlying dataset's prefetcher is properly cleaned up
        to prevent resource leaks.

        Examples:
            >>> # DataLoader cleanup happens automatically
            >>> dataloader = TileDBDataLoader("path/to/experiment")
            >>> print("DataLoader created")
            DataLoader created
            >>> # When dataloader goes out of scope, __del__ is called automatically
            >>> del dataloader
            >>> print("DataLoader destroyed and cleaned up")
            DataLoader destroyed and cleaned up

            >>> # Manual cleanup (not usually needed)
            >>> dataloader = TileDBDataLoader("path/to/experiment")
            >>> dataloader.__del__()
            >>> print("Manual cleanup completed")
            Manual cleanup completed
        """
        if hasattr(self, "_dataset"):
            # The TileDBIterableDataset doesn't have a stop method,
            # so we just let it finish its current epoch.
            pass
Functions
__init__(tiledb_path: str, batch_size: int = 32, prefetch_batch_size: int = 100, seed: int = 42, n_epochs: int = 1, verbose: bool = True, max_queue_size: int = 500, use_mixture_of_scanners: bool = True, n_readers: int = 50, n_scanners: int = 8)

Initialize the TileDB DataLoader with training configuration.

Parameters:

Name Type Description Default
tiledb_path str

Path to the TileDB SOMA experiment directory. Must contain a valid SOMA experiment with RNA measurement data.

required
batch_size int

Number of cells per training batch. Larger batches use more memory but may improve training efficiency. Range: 1-512, default: 32.

32
prefetch_batch_size int

Number of cells to prefetch from TileDB per batch. Higher values improve throughput but use more memory. Range: 10-10000, default: 100.

100
seed int

Random seed for reproducible shuffling and MoS sampling. Used for consistent data ordering across runs. Default: 42.

42
n_epochs int

Number of epochs to run. The dataloader will automatically reset after each epoch, enabling multi-epoch training. Default: 1.

1
verbose bool

If True, print detailed timing and progress information. If False, suppress all internal prints for clean output. Default: True.

True
max_queue_size int

Maximum number of pre-processed batches to keep in queue. Higher values use more memory but provide better buffering. Range: 10-10000, default: 500.

500
use_mixture_of_scanners bool

If True, use MoS strategy for higher entropy by randomly sampling from multiple fragment generators. Provides better randomization for foundation model training. Default: True.

True
n_readers int

Total number of fragment generators to create when using MoS. Higher values provide better entropy but use more memory. Range: 1-1000, default: 50.

50
n_scanners int

Number of active scanners to sample from simultaneously when using MoS. Higher values provide better entropy but use more memory. Range: 1-100, default: 8.

8

Raises:

Type Description
ImportError

If TileDB SOMA is not available.

ValueError

If MoS parameters are invalid.

RuntimeError

If the TileDB experiment cannot be opened or is invalid.

Examples:

>>> # Basic initialization with default MoS strategy
>>> dataloader = TileDBDataLoader(
...     tiledb_path="path/to/experiment",
...     batch_size=32,
...     prefetch_batch_size=100
... )
>>> print(f"MoS enabled: {dataloader.use_mixture_of_scanners}")
MoS enabled: True
>>> # Sequential loading for maximum throughput
>>> dataloader = TileDBDataLoader(
...     tiledb_path="path/to/experiment",
...     use_mixture_of_scanners=False,
...     batch_size=64
... )
>>> print(f"Sequential loading: {not dataloader.use_mixture_of_scanners}")
Sequential loading: True
>>> # High-entropy MoS configuration
>>> dataloader = TileDBDataLoader(
...     tiledb_path="path/to/experiment",
...     n_readers=100,
...     n_scanners=16
... )
>>> print(f"MoS readers: {dataloader.n_readers}, scanners: {dataloader.n_scanners}")
MoS readers: 100, scanners: 16
Source code in slaf/ml/tiledb_dataloaders.py
def __init__(
    self,
    tiledb_path: str,
    batch_size: int = 32,
    prefetch_batch_size: int = 100,
    seed: int = 42,
    n_epochs: int = 1,
    verbose: bool = True,
    max_queue_size: int = 500,
    use_mixture_of_scanners: bool = True,
    n_readers: int = 50,
    n_scanners: int = 8,
):
    """
    Initialize the TileDB DataLoader with training configuration.

    Args:
        tiledb_path: Path to the TileDB SOMA experiment directory.
                     Must contain a valid SOMA experiment with RNA measurement data.
        batch_size: Number of cells per training batch. Larger batches use more
                   memory but may improve training efficiency. Range: 1-512, default: 32.
        prefetch_batch_size: Number of cells to prefetch from TileDB per batch.
                           Higher values improve throughput but use more memory.
                           Range: 10-10000, default: 100.
        seed: Random seed for reproducible shuffling and MoS sampling.
              Used for consistent data ordering across runs. Default: 42.
        n_epochs: Number of epochs to run. The dataloader will automatically reset
                 after each epoch, enabling multi-epoch training. Default: 1.
        verbose: If True, print detailed timing and progress information.
                If False, suppress all internal prints for clean output. Default: True.
        max_queue_size: Maximum number of pre-processed batches to keep in queue.
                      Higher values use more memory but provide better buffering.
                      Range: 10-10000, default: 500.
        use_mixture_of_scanners: If True, use MoS strategy for higher entropy by
                               randomly sampling from multiple fragment generators.
                               Provides better randomization for foundation model training.
                               Default: True.
        n_readers: Total number of fragment generators to create when using MoS.
                  Higher values provide better entropy but use more memory.
                  Range: 1-1000, default: 50.
        n_scanners: Number of active scanners to sample from simultaneously when using MoS.
                   Higher values provide better entropy but use more memory.
                   Range: 1-100, default: 8.

    Raises:
        ImportError: If TileDB SOMA is not available.
        ValueError: If MoS parameters are invalid.
        RuntimeError: If the TileDB experiment cannot be opened or is invalid.

    Examples:
        >>> # Basic initialization with default MoS strategy
        >>> dataloader = TileDBDataLoader(
        ...     tiledb_path="path/to/experiment",
        ...     batch_size=32,
        ...     prefetch_batch_size=100
        ... )
        >>> print(f"MoS enabled: {dataloader.use_mixture_of_scanners}")
        MoS enabled: True

        >>> # Sequential loading for maximum throughput
        >>> dataloader = TileDBDataLoader(
        ...     tiledb_path="path/to/experiment",
        ...     use_mixture_of_scanners=False,
        ...     batch_size=64
        ... )
        >>> print(f"Sequential loading: {not dataloader.use_mixture_of_scanners}")
        Sequential loading: True

        >>> # High-entropy MoS configuration
        >>> dataloader = TileDBDataLoader(
        ...     tiledb_path="path/to/experiment",
        ...     n_readers=100,
        ...     n_scanners=16
        ... )
        >>> print(f"MoS readers: {dataloader.n_readers}, scanners: {dataloader.n_scanners}")
        MoS readers: 100, scanners: 16
    """
    self.tiledb_path = tiledb_path
    self.batch_size = batch_size
    self.prefetch_batch_size = prefetch_batch_size
    self.seed = seed
    self.n_epochs = n_epochs
    self.verbose = verbose
    self.max_queue_size = max_queue_size
    self.use_mixture_of_scanners = use_mixture_of_scanners
    self.n_readers = n_readers
    self.n_scanners = n_scanners

    # Check that required modules are available
    if not TILEDB_AVAILABLE:
        raise ImportError("TileDB SOMA is required but not available")

    # Use IterableDataset
    self._dataset = TileDBIterableDataset(
        tiledb_path=tiledb_path,
        batch_size=batch_size,
        prefetch_batch_size=prefetch_batch_size,
        seed=seed,
        max_queue_size=max_queue_size,
        n_epochs=n_epochs,
        verbose=verbose,
        use_mixture_of_scanners=use_mixture_of_scanners,
        n_readers=n_readers,
        n_scanners=n_scanners,
    )

Functions

print_prefetch(message: str, verbose: bool = True)

Print prefetch-related messages with colored formatting.

This function prints prefetch-related messages using rich console formatting when available, or falls back to loguru logging. Messages are displayed in cyan-colored panels for better visual distinction during training.

Parameters:

Name Type Description Default
message str

The message to print.

required
verbose bool

If True, print the message. If False, suppress output.

True

Examples:

>>> # Print a prefetch message
>>> print_prefetch("Loading batch 1 of 100")
>>> # Suppress output
>>> print_prefetch("Loading batch 1 of 100", verbose=False)
Source code in slaf/ml/tiledb_dataloaders.py
def print_prefetch(message: str, verbose: bool = True):
    """
    Print prefetch-related messages with colored formatting.

    This function prints prefetch-related messages using rich console formatting
    when available, or falls back to loguru logging. Messages are displayed in
    cyan-colored panels for better visual distinction during training.

    Args:
        message: The message to print.
        verbose: If True, print the message. If False, suppress output.

    Examples:
        >>> # Print a prefetch message
        >>> print_prefetch("Loading batch 1 of 100")
        >>> # Suppress output
        >>> print_prefetch("Loading batch 1 of 100", verbose=False)
    """
    if not verbose:
        return

    if RICH_AVAILABLE and console is not None:
        console.print(Panel(message, border_style="cyan"))
    else:
        logger.info(f"🔍 {message}")

print_training(message: str, verbose: bool = True)

Print training-related messages with colored formatting.

This function prints training-related messages using rich console formatting when available, or falls back to loguru logging. Messages are displayed in green-colored panels for better visual distinction during training.

Parameters:

Name Type Description Default
message str

The message to print.

required
verbose bool

If True, print the message. If False, suppress output.

True

Examples:

>>> # Print a training message
>>> print_training("Processing batch with 32 cells")
>>> # Suppress output
>>> print_training("Processing batch with 32 cells", verbose=False)
Source code in slaf/ml/tiledb_dataloaders.py
def print_training(message: str, verbose: bool = True):
    """
    Print training-related messages with colored formatting.

    This function prints training-related messages using rich console formatting
    when available, or falls back to loguru logging. Messages are displayed in
    green-colored panels for better visual distinction during training.

    Args:
        message: The message to print.
        verbose: If True, print the message. If False, suppress output.

    Examples:
        >>> # Print a training message
        >>> print_training("Processing batch with 32 cells")
        >>> # Suppress output
        >>> print_training("Processing batch with 32 cells", verbose=False)
    """
    if not verbose:
        return

    if RICH_AVAILABLE and console is not None:
        console.print(Panel(message, border_style="green"))
    else:
        logger.info(f"📊 {message}")

Tokenizers

slaf.ml.tokenizers

Classes

TokenizerType

Bases: str, Enum

Tokenizer types

Source code in slaf/ml/tokenizers.py
class TokenizerType(str, Enum):
    """Tokenizer types"""

    GENEFORMER = "geneformer"
    SCPGPT = "scgpt"

SLAFTokenizer

Tokenizer for single-cell RNA-seq data in SLAF format.

SLAFTokenizer converts single-cell gene expression data into token sequences suitable for machine learning models. It supports multiple tokenization strategies including GeneFormer and scGPT formats with optimized vectorized operations.

Key Features
  • Multiple tokenization strategies (GeneFormer, scGPT)
  • Vectorized tokenization for high performance
  • Expression binning for scGPT format
  • Device-agnostic CPU tensor output
  • Memory-efficient processing
  • Comprehensive vocabulary management

Examples:

>>> # Basic usage with GeneFormer
>>> slaf_array = SLAFArray("path/to/data.slaf")
>>> tokenizer = SLAFTokenizer(slaf_array, tokenizer_type="geneformer")
>>> gene_sequences = [[1, 2, 3], [4, 5, 6]]
>>> input_ids, attention_mask = tokenizer.tokenize(gene_sequences)
>>> print(f"Input shape: {input_ids.shape}")
Input shape: torch.Size([2, 2048])
>>> # scGPT with expression sequences
>>> tokenizer = SLAFTokenizer(slaf_array, tokenizer_type="scgpt")
>>> gene_sequences = [[1, 2, 3], [4, 5, 6]]
>>> expr_sequences = [[0.5, 0.8, 0.2], [0.9, 0.1, 0.7]]
>>> input_ids, attention_mask = tokenizer.tokenize(
...     gene_sequences, expr_sequences
... )
>>> print(f"Input shape: {input_ids.shape}")
Input shape: torch.Size([2, 2050])
>>> # Error handling for invalid tokenizer type
>>> try:
...     tokenizer = SLAFTokenizer(slaf_array, tokenizer_type="invalid")
... except ValueError as e:
...     print(f"Error: {e}")
Error: Unsupported tokenizer type: invalid. Supported types: ['geneformer', 'scgpt']
>>> # Vocabulary information
>>> vocab_info = tokenizer.get_vocab_info()
>>> print(f"Vocabulary size: {vocab_info['vocab_size']}")
Vocabulary size: 50000
Source code in slaf/ml/tokenizers.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
class SLAFTokenizer:
    """
    Tokenizer for single-cell RNA-seq data in SLAF format.

    SLAFTokenizer converts single-cell gene expression data into token sequences
    suitable for machine learning models. It supports multiple tokenization strategies
    including GeneFormer and scGPT formats with optimized vectorized operations.

    Key Features:
        - Multiple tokenization strategies (GeneFormer, scGPT)
        - Vectorized tokenization for high performance
        - Expression binning for scGPT format
        - Device-agnostic CPU tensor output
        - Memory-efficient processing
        - Comprehensive vocabulary management

    Examples:
        >>> # Basic usage with GeneFormer
        >>> slaf_array = SLAFArray("path/to/data.slaf")
        >>> tokenizer = SLAFTokenizer(slaf_array, tokenizer_type="geneformer")
        >>> gene_sequences = [[1, 2, 3], [4, 5, 6]]
        >>> input_ids, attention_mask = tokenizer.tokenize(gene_sequences)
        >>> print(f"Input shape: {input_ids.shape}")
        Input shape: torch.Size([2, 2048])

        >>> # scGPT with expression sequences
        >>> tokenizer = SLAFTokenizer(slaf_array, tokenizer_type="scgpt")
        >>> gene_sequences = [[1, 2, 3], [4, 5, 6]]
        >>> expr_sequences = [[0.5, 0.8, 0.2], [0.9, 0.1, 0.7]]
        >>> input_ids, attention_mask = tokenizer.tokenize(
        ...     gene_sequences, expr_sequences
        ... )
        >>> print(f"Input shape: {input_ids.shape}")
        Input shape: torch.Size([2, 2050])

        >>> # Error handling for invalid tokenizer type
        >>> try:
        ...     tokenizer = SLAFTokenizer(slaf_array, tokenizer_type="invalid")
        ... except ValueError as e:
        ...     print(f"Error: {e}")
        Error: Unsupported tokenizer type: invalid. Supported types: ['geneformer', 'scgpt']

        >>> # Vocabulary information
        >>> vocab_info = tokenizer.get_vocab_info()
        >>> print(f"Vocabulary size: {vocab_info['vocab_size']}")
        Vocabulary size: 50000
    """

    def __init__(
        self,
        slaf_array: SLAFArray,
        tokenizer_type: TokenizerType | str = TokenizerType.GENEFORMER,
        vocab_size: int = 50000,
        n_expression_bins: int = 10,
    ):
        """
        Initialize SLAFTokenizer with SLAF array and vocabulary settings.

        Args:
            slaf_array: Initialized SLAFArray instance containing the single-cell data.
                       Used to build the gene vocabulary and access expression data.
                       Must be a valid SLAFArray with proper var DataFrame.
            tokenizer_type: Type of tokenizer to use. Options: "geneformer", "scgpt".
                          Can be passed as string or TokenizerType enum.
            vocab_size: Maximum size of gene vocabulary. Genes beyond this limit
                       are excluded from tokenization. Higher values use more memory.
            n_expression_bins: Number of expression bins for scGPT tokenization.
                             Higher values provide finer expression resolution.
                             Range: 1-1000, default: 10.

        Raises:
            ValueError: If tokenizer_type is not supported or vocab_size is invalid.
            RuntimeError: If SLAF array is not properly initialized.
            TypeError: If slaf_array is not a valid SLAFArray instance.

        Examples:
            >>> # Basic initialization
            >>> slaf_array = SLAFArray("path/to/data.slaf")
            >>> tokenizer = SLAFTokenizer(slaf_array)
            >>> print(f"Tokenizer type: {tokenizer.tokenizer_type}")
            Tokenizer type: TokenizerType.GENEFORMER

            >>> # scGPT with custom settings
            >>> tokenizer = SLAFTokenizer(
            ...     slaf_array=slaf_array,
            ...     tokenizer_type="scgpt",
            ...     vocab_size=30000,
            ...     n_expression_bins=20
            ... )
            >>> print(f"Expression bins: {tokenizer.n_expression_bins}")
            Expression bins: 20

            >>> # Error handling for invalid tokenizer type
            >>> try:
            ...     tokenizer = SLAFTokenizer(slaf_array, tokenizer_type="invalid")
            ... except ValueError as e:
            ...     print(f"Error: {e}")
            Error: Unsupported tokenizer type: invalid. Supported types: ['geneformer', 'scgpt']

            >>> # Error handling for invalid SLAF array
            >>> try:
            ...     tokenizer = SLAFTokenizer(None)
            ... except TypeError as e:
            ...     print(f"Error: {e}")
            Error: slaf_array must be a valid SLAFArray instance
        """
        self.slaf_array = slaf_array
        self.vocab_size = vocab_size
        self.n_expression_bins = n_expression_bins

        # Convert string to enum if needed
        if isinstance(tokenizer_type, str):
            try:
                self.tokenizer_type = TokenizerType(tokenizer_type.lower())
            except ValueError as err:
                raise ValueError(
                    f"Unsupported tokenizer type: {tokenizer_type}. "
                    f"Supported types: {[t.value for t in TokenizerType]}"
                ) from err
        else:
            self.tokenizer_type = tokenizer_type

        # Build vocabulary and special tokens
        self._build_gene_vocabulary()
        self._setup_special_tokens()

    def _build_gene_vocabulary(self):
        """Build gene vocabulary from SLAF var DataFrame."""
        try:
            var_df = self.slaf_array.var.reset_index()

            # Check if we have a real SLAF array or a Mock object
            if (
                hasattr(var_df, "columns")
                and "gene_integer_id" in var_df.columns
                and "gene_id" in var_df.columns
            ):
                # Real SLAF array - build vocabulary from gene data
                gene_vocab = {}

                # Use Polars native iteration
                for row in var_df.iter_rows(named=True):
                    gene_id = row["gene_id"]
                    gene_integer_id = row["gene_integer_id"]

                    # Only include genes within vocab size limit
                    if gene_integer_id < self.vocab_size:
                        gene_vocab[gene_id] = gene_integer_id

                self.gene_vocab = gene_vocab
                # Account for the +4 offset used in tokenization
                self.token_to_gene = {v + 4: k for k, v in self.gene_vocab.items()}

                # Pre-build vectorized mapping array for fast lookup
                max_gene_id = (
                    max(
                        int(k) if isinstance(k, str) else k
                        for k in self.gene_vocab.keys()
                    )
                    if self.gene_vocab
                    else 0
                )
                self.vocab_mapping = np.full(max_gene_id + 1, -1, dtype=int)

                # Fill the mapping array once
                for gene_id, token_id in self.gene_vocab.items():
                    try:
                        gene_id_int = (
                            int(gene_id) if isinstance(gene_id, str) else gene_id
                        )
                        self.vocab_mapping[gene_id_int] = token_id
                    except (ValueError, TypeError):
                        continue
            else:
                # Mock object - create dummy vocabulary
                self.gene_vocab = {f"gene_{i}": i for i in range(1000)}
                # Account for the +4 offset used in tokenization
                self.token_to_gene = {v + 4: k for k, v in self.gene_vocab.items()}
                # No mapping array needed for mock objects

        except Exception:
            # Fallback for testing or error cases
            self.gene_vocab = {f"gene_{i}": i for i in range(1000)}
            # Account for the +4 offset used in tokenization
            self.token_to_gene = {v + 4: k for k, v in self.gene_vocab.items()}
            # No mapping array needed for fallback

    def _setup_special_tokens(self):
        """Setup special tokens for tokenization."""
        # Special token IDs
        self.special_tokens = {
            "PAD": 0,
            "CLS": 1,
            "SEP": 2,
            "MASK": 3,
        }

        # Expression binning setup for scGPT
        self.expr_bin_start = self.vocab_size
        self.expr_bin_size = 1.0 / self.n_expression_bins

    def _expression_to_bin(self, expression_value: float) -> int:
        """Convert expression value to bin token ID"""
        if expression_value <= 0:
            return self.special_tokens["PAD"]

        # Bin the expression value
        bin_id = min(
            int(expression_value / self.expr_bin_size), self.n_expression_bins - 1
        )
        return self.expr_bin_start + bin_id

    def _expression_to_bin_vectorized(
        self, expression_values: np.ndarray
    ) -> np.ndarray:
        """Vectorized version of expression binning"""
        # Handle edge cases
        if len(expression_values) == 0:
            return np.array([], dtype=int)

        # Create bins
        bins = np.clip(
            (expression_values / self.expr_bin_size).astype(int),
            0,
            self.n_expression_bins - 1,
        )

        # Convert to token IDs
        result = np.where(
            expression_values > 0,
            self.expr_bin_start + bins,
            self.special_tokens["PAD"],
        )

        return result.astype(int)

    def _map_gene_ids_to_tokens_vectorized(self, gene_ids) -> np.ndarray:
        """Vectorized mapping of gene IDs to token IDs"""
        # Handle edge cases
        if hasattr(gene_ids, "is_empty") and gene_ids.is_empty():
            return np.array([], dtype=int)
        elif hasattr(gene_ids, "__len__") and len(gene_ids) == 0:
            return np.array([], dtype=int)

        # Convert to numpy array for vectorized operations
        gene_ids_array = np.array(gene_ids, dtype=int)

        # Check if we have a real SLAF array or a Mock object
        try:
            # Try to access the DataFrame properly
            var_df = self.slaf_array.var.reset_index()

            # Check if it's actually a DataFrame with the expected columns
            if (
                hasattr(var_df, "columns")
                and "gene_integer_id" in var_df.columns
                and "gene_id" in var_df.columns
            ):
                # Real SLAF array - use pre-built vectorized mapping
                # Vectorized lookup using pre-built mapping array
                tokens = self.vocab_mapping[gene_ids_array]

                # Filter out missing genes (-1 values)
                valid_mask = tokens != -1
                return tokens[valid_mask]
            else:
                # Mock object - direct mapping (same as original test)
                return gene_ids_array + 4  # Simple offset like original test
        except Exception:
            # Fallback for testing - direct mapping with offset
            return gene_ids_array + 4  # Simple offset like original test

    def tokenize(
        self,
        gene_sequences: list[list[int] | list[tuple[int, float]]],
        expr_sequences: list[list[float]] | None = None,
        max_genes: int | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Tokenize gene expression sequences into model-ready tensors.

        This method converts gene and expression sequences into tokenized tensors
        suitable for machine learning models. It supports both GeneFormer and scGPT
        tokenization strategies with optimized vectorized operations.

        Args:
            gene_sequences: List of gene ID sequences for each cell
            expr_sequences: List of expression value sequences for each cell (required for scGPT)
            max_genes: Maximum number of genes per cell (defaults based on tokenizer type)

        Returns:
            tuple: (input_ids, attention_mask) tensors
                - input_ids: Tokenized sequences with padding
                - attention_mask: Boolean mask indicating valid tokens

        Raises:
            ValueError: If gene_sequences is empty

        Examples:
            >>> # GeneFormer tokenization
            >>> gene_sequences = [[1, 2, 3], [4, 5, 6]]
            >>> input_ids, attention_mask = tokenizer.tokenize(gene_sequences)
            >>> print(f"Shape: {input_ids.shape}")
            Shape: torch.Size([2, 2048])

            >>> # scGPT tokenization
            >>> gene_sequences = [[1, 2, 3], [4, 5, 6]]
            >>> expr_sequences = [[0.5, 0.8, 0.2], [0.9, 0.1, 0.7]]
            >>> input_ids, attention_mask = tokenizer.tokenize(gene_sequences, expr_sequences)
            >>> print(f"Shape: {input_ids.shape}")
            Shape: torch.Size([2, 2050])
        """
        if not gene_sequences:
            raise ValueError("Gene sequences cannot be empty")

        # Set default max_genes based on tokenizer type
        if max_genes is None:
            if self.tokenizer_type == TokenizerType.GENEFORMER:
                max_genes = 2048
            else:
                # For scGPT: CLS + (gene,expr)*n + SEP = 2*n + 2
                # So if we want max_genes total tokens, n = (max_genes - 2) / 2
                max_genes = 1024  # This is the number of gene-expression pairs

        # Always define max_sequence_length based on tokenizer type
        if self.tokenizer_type == TokenizerType.GENEFORMER:
            max_sequence_length = max_genes  # For Geneformer, same as max_genes
        else:
            # For scGPT: CLS + (gene,expr)*n + SEP = 2*n + 2
            max_sequence_length = 2 * max_genes + 2  # Total sequence length

        # For scGPT, gene_sequences now contains struct pairs [(gene, expr), ...]
        # so we don't need separate expr_sequences validation

        batch_size = len(gene_sequences)

        # Use fast numpy-based approach (same as original test)
        import numpy as np

        # Pre-allocate numpy array with correct dimensions
        if self.tokenizer_type == TokenizerType.SCPGPT:
            # For scGPT: use max_sequence_length (2*max_genes+2)
            array_width = max_sequence_length
        else:
            # For Geneformer: use max_genes
            array_width = max_genes

        token_array = np.full(
            (batch_size, array_width), self.special_tokens["PAD"], dtype=np.int64
        )

        if self.tokenizer_type == TokenizerType.SCPGPT:
            # scGPT format: [CLS] gene1 expr1 gene2 expr2 ... [SEP]
            for i, (gene_sequence, expr_sequence) in enumerate(
                zip(gene_sequences, expr_sequences or [], strict=False)
            ):
                # For scGPT, we now use separate gene_sequence and expr_sequence columns
                if gene_sequence and len(gene_sequence) > 0:
                    # Fast path: separate gene_sequence and expr_sequence columns
                    genes = gene_sequence
                    exprs = expr_sequence if expr_sequence else []

                    # Vectorized operations - use simple +4 for performance
                    gene_tokens = np.array(genes, dtype=np.int64) + 4

                    # Handle expression tokens - don't bin if already binned by window function
                    if len(exprs) > 0 and isinstance(exprs[0], int | np.integer):
                        # Already binned by window function - just convert to tokens
                        expr_tokens = (
                            np.array(exprs, dtype=np.int64) + self.expr_bin_start
                        )
                    else:
                        # Raw values - need to bin them
                        expr_tokens = self._expression_to_bin_vectorized(
                            np.array(exprs, dtype=np.float32)
                        )

                    # Vectorized interleaving (much faster than Python loop)
                    if len(gene_tokens) > 0:
                        # Pre-allocate full sequence: CLS + (gene,expr)*n + SEP
                        sequence_length = 1 + 2 * len(gene_tokens) + 1
                        tokens = np.full(
                            sequence_length, self.special_tokens["PAD"], dtype=np.int64
                        )

                        # Set CLS token
                        tokens[0] = self.special_tokens["CLS"]

                        # Vectorized interleaving
                        tokens[1::2][: len(gene_tokens)] = gene_tokens  # type: ignore[assignment]
                        tokens[2::2][: len(expr_tokens)] = expr_tokens  # type: ignore[assignment]

                        tokens[1 + 2 * len(gene_tokens)] = self.special_tokens["SEP"]
                    else:
                        # Empty sequence case
                        tokens = np.array(
                            [self.special_tokens["CLS"], self.special_tokens["SEP"]],
                            dtype=np.int64,
                        )  # type: ignore[assignment]
                else:
                    # Empty sequence case
                    tokens = np.array(
                        [self.special_tokens["CLS"], self.special_tokens["SEP"]],
                        dtype=np.int64,
                    )  # type: ignore[assignment]

                # Pad/truncate to correct sequence length
                if self.tokenizer_type == TokenizerType.SCPGPT:
                    # For scGPT: use max_sequence_length (2*max_genes+2)
                    target_length = max_sequence_length
                else:
                    # For Geneformer: use max_genes
                    target_length = max_genes

                tokens = tokens[:target_length]  # type: ignore[assignment]
                if len(tokens) < target_length:
                    padding = np.full(
                        target_length - len(tokens),
                        self.special_tokens["PAD"],
                        dtype=np.int64,
                    )
                    tokens = np.concatenate([tokens, padding])  # type: ignore[assignment]

                # Fill array
                token_array[i, :] = tokens  # type: ignore[assignment]

        else:
            # Geneformer format: [CLS] gene1 gene2 gene3 ... [SEP]
            for i, gene_sequence in enumerate(gene_sequences):
                # Convert gene IDs to tokens (fast mapping)
                gene_tokens = np.array(gene_sequence, dtype=np.int64) + 4

                # Vectorized sequence building: use concatenation for speed
                if len(gene_tokens) > 0:
                    # Use concatenation: CLS + genes + SEP
                    tokens = np.concatenate(
                        [
                            [self.special_tokens["CLS"]],
                            gene_tokens,
                            [self.special_tokens["SEP"]],
                        ]
                    )  # type: ignore[assignment]
                else:
                    # Empty sequence case
                    tokens = np.array(
                        [self.special_tokens["CLS"], self.special_tokens["SEP"]],
                        dtype=np.int64,
                    )  # type: ignore[assignment]

                # Pad/truncate to max_genes
                tokens = tokens[:max_genes]  # type: ignore[assignment]
                if len(tokens) < max_genes:
                    padding = np.full(
                        max_genes - len(tokens),
                        self.special_tokens["PAD"],
                        dtype=np.int64,
                    )
                    tokens = np.concatenate([tokens, padding])  # type: ignore[assignment]

                # Fill array
                token_array[i, :] = tokens  # type: ignore[assignment]

        # Convert to tensors in one operation
        input_ids = torch.from_numpy(token_array)
        attention_mask = input_ids != self.special_tokens["PAD"]

        return input_ids, attention_mask

    def get_vocab_info(self) -> dict[str, Any]:
        """
        Get vocabulary information for debugging and analysis.

        Returns:
            dict: Vocabulary information including size, special tokens, etc.

        Examples:
            >>> vocab_info = tokenizer.get_vocab_info()
            >>> print(f"Vocabulary size: {vocab_info['vocab_size']}")
            >>> print(f"Special tokens: {vocab_info['special_tokens']}")
            Vocabulary size: 50000
            Special tokens: {'PAD': 0, 'CLS': 1, 'SEP': 2, 'MASK': 3}
        """
        return {
            "vocab_size": self.vocab_size,
            "tokenizer_type": self.tokenizer_type.value,
            "special_tokens": self.special_tokens,
            "n_expression_bins": self.n_expression_bins,
            "gene_vocab_size": len(self.gene_vocab),
        }

    def decode_tokens(self, tokens: list[int]) -> dict[str, Any]:
        """
        Decode token sequence back to gene information.

        Args:
            tokens: List of token IDs to decode

        Returns:
            dict: Decoded information including genes, expressions, etc.

        Examples:
            >>> # Decode a token sequence
            >>> tokens = [1, 100, 50050, 200, 50060, 2]  # CLS, gene1, expr1, gene2, expr2, SEP
            >>> decoded = tokenizer.decode_tokens(tokens)
            >>> print(f"Genes: {decoded['genes']}")
            >>> print(f"Expressions: {decoded['expressions']}")
            Genes: ['gene_100', 'gene_200']
            Expressions: [0.5, 0.6]
        """
        if not tokens:
            return {"genes": [], "expressions": [], "special_tokens": []}

        genes = []
        expressions = []
        special_tokens = []

        i = 0
        while i < len(tokens):
            token = tokens[i]

            if token == self.special_tokens["CLS"]:
                special_tokens.append("CLS")
                i += 1
            elif token == self.special_tokens["SEP"]:
                special_tokens.append("SEP")
                i += 1
            elif token == self.special_tokens["PAD"]:
                special_tokens.append("PAD")
                i += 1
            elif token == self.special_tokens["MASK"]:
                special_tokens.append("MASK")
                i += 1
            elif (
                self.tokenizer_type == TokenizerType.SCPGPT
                and token >= self.expr_bin_start
            ):
                # Expression token
                bin_id = token - self.expr_bin_start
                expr_value = bin_id * self.expr_bin_size
                expressions.append(expr_value)
                i += 1
            else:
                # Gene token
                if token in self.token_to_gene:
                    genes.append(self.token_to_gene[token])
                else:
                    genes.append(f"unknown_gene_{token}")
                i += 1

        return {
            "genes": genes,
            "expressions": expressions,
            "special_tokens": special_tokens,
        }
Functions
__init__(slaf_array: SLAFArray, tokenizer_type: TokenizerType | str = TokenizerType.GENEFORMER, vocab_size: int = 50000, n_expression_bins: int = 10)

Initialize SLAFTokenizer with SLAF array and vocabulary settings.

Parameters:

Name Type Description Default
slaf_array SLAFArray

Initialized SLAFArray instance containing the single-cell data. Used to build the gene vocabulary and access expression data. Must be a valid SLAFArray with proper var DataFrame.

required
tokenizer_type TokenizerType | str

Type of tokenizer to use. Options: "geneformer", "scgpt". Can be passed as string or TokenizerType enum.

GENEFORMER
vocab_size int

Maximum size of gene vocabulary. Genes beyond this limit are excluded from tokenization. Higher values use more memory.

50000
n_expression_bins int

Number of expression bins for scGPT tokenization. Higher values provide finer expression resolution. Range: 1-1000, default: 10.

10

Raises:

Type Description
ValueError

If tokenizer_type is not supported or vocab_size is invalid.

RuntimeError

If SLAF array is not properly initialized.

TypeError

If slaf_array is not a valid SLAFArray instance.

Examples:

>>> # Basic initialization
>>> slaf_array = SLAFArray("path/to/data.slaf")
>>> tokenizer = SLAFTokenizer(slaf_array)
>>> print(f"Tokenizer type: {tokenizer.tokenizer_type}")
Tokenizer type: TokenizerType.GENEFORMER
>>> # scGPT with custom settings
>>> tokenizer = SLAFTokenizer(
...     slaf_array=slaf_array,
...     tokenizer_type="scgpt",
...     vocab_size=30000,
...     n_expression_bins=20
... )
>>> print(f"Expression bins: {tokenizer.n_expression_bins}")
Expression bins: 20
>>> # Error handling for invalid tokenizer type
>>> try:
...     tokenizer = SLAFTokenizer(slaf_array, tokenizer_type="invalid")
... except ValueError as e:
...     print(f"Error: {e}")
Error: Unsupported tokenizer type: invalid. Supported types: ['geneformer', 'scgpt']
>>> # Error handling for invalid SLAF array
>>> try:
...     tokenizer = SLAFTokenizer(None)
... except TypeError as e:
...     print(f"Error: {e}")
Error: slaf_array must be a valid SLAFArray instance
Source code in slaf/ml/tokenizers.py
def __init__(
    self,
    slaf_array: SLAFArray,
    tokenizer_type: TokenizerType | str = TokenizerType.GENEFORMER,
    vocab_size: int = 50000,
    n_expression_bins: int = 10,
):
    """
    Initialize SLAFTokenizer with SLAF array and vocabulary settings.

    Args:
        slaf_array: Initialized SLAFArray instance containing the single-cell data.
                   Used to build the gene vocabulary and access expression data.
                   Must be a valid SLAFArray with proper var DataFrame.
        tokenizer_type: Type of tokenizer to use. Options: "geneformer", "scgpt".
                      Can be passed as string or TokenizerType enum.
        vocab_size: Maximum size of gene vocabulary. Genes beyond this limit
                   are excluded from tokenization. Higher values use more memory.
        n_expression_bins: Number of expression bins for scGPT tokenization.
                         Higher values provide finer expression resolution.
                         Range: 1-1000, default: 10.

    Raises:
        ValueError: If tokenizer_type is not supported or vocab_size is invalid.
        RuntimeError: If SLAF array is not properly initialized.
        TypeError: If slaf_array is not a valid SLAFArray instance.

    Examples:
        >>> # Basic initialization
        >>> slaf_array = SLAFArray("path/to/data.slaf")
        >>> tokenizer = SLAFTokenizer(slaf_array)
        >>> print(f"Tokenizer type: {tokenizer.tokenizer_type}")
        Tokenizer type: TokenizerType.GENEFORMER

        >>> # scGPT with custom settings
        >>> tokenizer = SLAFTokenizer(
        ...     slaf_array=slaf_array,
        ...     tokenizer_type="scgpt",
        ...     vocab_size=30000,
        ...     n_expression_bins=20
        ... )
        >>> print(f"Expression bins: {tokenizer.n_expression_bins}")
        Expression bins: 20

        >>> # Error handling for invalid tokenizer type
        >>> try:
        ...     tokenizer = SLAFTokenizer(slaf_array, tokenizer_type="invalid")
        ... except ValueError as e:
        ...     print(f"Error: {e}")
        Error: Unsupported tokenizer type: invalid. Supported types: ['geneformer', 'scgpt']

        >>> # Error handling for invalid SLAF array
        >>> try:
        ...     tokenizer = SLAFTokenizer(None)
        ... except TypeError as e:
        ...     print(f"Error: {e}")
        Error: slaf_array must be a valid SLAFArray instance
    """
    self.slaf_array = slaf_array
    self.vocab_size = vocab_size
    self.n_expression_bins = n_expression_bins

    # Convert string to enum if needed
    if isinstance(tokenizer_type, str):
        try:
            self.tokenizer_type = TokenizerType(tokenizer_type.lower())
        except ValueError as err:
            raise ValueError(
                f"Unsupported tokenizer type: {tokenizer_type}. "
                f"Supported types: {[t.value for t in TokenizerType]}"
            ) from err
    else:
        self.tokenizer_type = tokenizer_type

    # Build vocabulary and special tokens
    self._build_gene_vocabulary()
    self._setup_special_tokens()
tokenize(gene_sequences: list[list[int] | list[tuple[int, float]]], expr_sequences: list[list[float]] | None = None, max_genes: int | None = None) -> tuple[torch.Tensor, torch.Tensor]

Tokenize gene expression sequences into model-ready tensors.

This method converts gene and expression sequences into tokenized tensors suitable for machine learning models. It supports both GeneFormer and scGPT tokenization strategies with optimized vectorized operations.

Parameters:

Name Type Description Default
gene_sequences list[list[int] | list[tuple[int, float]]]

List of gene ID sequences for each cell

required
expr_sequences list[list[float]] | None

List of expression value sequences for each cell (required for scGPT)

None
max_genes int | None

Maximum number of genes per cell (defaults based on tokenizer type)

None

Returns:

Name Type Description
tuple tuple[Tensor, Tensor]

(input_ids, attention_mask) tensors - input_ids: Tokenized sequences with padding - attention_mask: Boolean mask indicating valid tokens

Raises:

Type Description
ValueError

If gene_sequences is empty

Examples:

>>> # GeneFormer tokenization
>>> gene_sequences = [[1, 2, 3], [4, 5, 6]]
>>> input_ids, attention_mask = tokenizer.tokenize(gene_sequences)
>>> print(f"Shape: {input_ids.shape}")
Shape: torch.Size([2, 2048])
>>> # scGPT tokenization
>>> gene_sequences = [[1, 2, 3], [4, 5, 6]]
>>> expr_sequences = [[0.5, 0.8, 0.2], [0.9, 0.1, 0.7]]
>>> input_ids, attention_mask = tokenizer.tokenize(gene_sequences, expr_sequences)
>>> print(f"Shape: {input_ids.shape}")
Shape: torch.Size([2, 2050])
Source code in slaf/ml/tokenizers.py
def tokenize(
    self,
    gene_sequences: list[list[int] | list[tuple[int, float]]],
    expr_sequences: list[list[float]] | None = None,
    max_genes: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Tokenize gene expression sequences into model-ready tensors.

    This method converts gene and expression sequences into tokenized tensors
    suitable for machine learning models. It supports both GeneFormer and scGPT
    tokenization strategies with optimized vectorized operations.

    Args:
        gene_sequences: List of gene ID sequences for each cell
        expr_sequences: List of expression value sequences for each cell (required for scGPT)
        max_genes: Maximum number of genes per cell (defaults based on tokenizer type)

    Returns:
        tuple: (input_ids, attention_mask) tensors
            - input_ids: Tokenized sequences with padding
            - attention_mask: Boolean mask indicating valid tokens

    Raises:
        ValueError: If gene_sequences is empty

    Examples:
        >>> # GeneFormer tokenization
        >>> gene_sequences = [[1, 2, 3], [4, 5, 6]]
        >>> input_ids, attention_mask = tokenizer.tokenize(gene_sequences)
        >>> print(f"Shape: {input_ids.shape}")
        Shape: torch.Size([2, 2048])

        >>> # scGPT tokenization
        >>> gene_sequences = [[1, 2, 3], [4, 5, 6]]
        >>> expr_sequences = [[0.5, 0.8, 0.2], [0.9, 0.1, 0.7]]
        >>> input_ids, attention_mask = tokenizer.tokenize(gene_sequences, expr_sequences)
        >>> print(f"Shape: {input_ids.shape}")
        Shape: torch.Size([2, 2050])
    """
    if not gene_sequences:
        raise ValueError("Gene sequences cannot be empty")

    # Set default max_genes based on tokenizer type
    if max_genes is None:
        if self.tokenizer_type == TokenizerType.GENEFORMER:
            max_genes = 2048
        else:
            # For scGPT: CLS + (gene,expr)*n + SEP = 2*n + 2
            # So if we want max_genes total tokens, n = (max_genes - 2) / 2
            max_genes = 1024  # This is the number of gene-expression pairs

    # Always define max_sequence_length based on tokenizer type
    if self.tokenizer_type == TokenizerType.GENEFORMER:
        max_sequence_length = max_genes  # For Geneformer, same as max_genes
    else:
        # For scGPT: CLS + (gene,expr)*n + SEP = 2*n + 2
        max_sequence_length = 2 * max_genes + 2  # Total sequence length

    # For scGPT, gene_sequences now contains struct pairs [(gene, expr), ...]
    # so we don't need separate expr_sequences validation

    batch_size = len(gene_sequences)

    # Use fast numpy-based approach (same as original test)
    import numpy as np

    # Pre-allocate numpy array with correct dimensions
    if self.tokenizer_type == TokenizerType.SCPGPT:
        # For scGPT: use max_sequence_length (2*max_genes+2)
        array_width = max_sequence_length
    else:
        # For Geneformer: use max_genes
        array_width = max_genes

    token_array = np.full(
        (batch_size, array_width), self.special_tokens["PAD"], dtype=np.int64
    )

    if self.tokenizer_type == TokenizerType.SCPGPT:
        # scGPT format: [CLS] gene1 expr1 gene2 expr2 ... [SEP]
        for i, (gene_sequence, expr_sequence) in enumerate(
            zip(gene_sequences, expr_sequences or [], strict=False)
        ):
            # For scGPT, we now use separate gene_sequence and expr_sequence columns
            if gene_sequence and len(gene_sequence) > 0:
                # Fast path: separate gene_sequence and expr_sequence columns
                genes = gene_sequence
                exprs = expr_sequence if expr_sequence else []

                # Vectorized operations - use simple +4 for performance
                gene_tokens = np.array(genes, dtype=np.int64) + 4

                # Handle expression tokens - don't bin if already binned by window function
                if len(exprs) > 0 and isinstance(exprs[0], int | np.integer):
                    # Already binned by window function - just convert to tokens
                    expr_tokens = (
                        np.array(exprs, dtype=np.int64) + self.expr_bin_start
                    )
                else:
                    # Raw values - need to bin them
                    expr_tokens = self._expression_to_bin_vectorized(
                        np.array(exprs, dtype=np.float32)
                    )

                # Vectorized interleaving (much faster than Python loop)
                if len(gene_tokens) > 0:
                    # Pre-allocate full sequence: CLS + (gene,expr)*n + SEP
                    sequence_length = 1 + 2 * len(gene_tokens) + 1
                    tokens = np.full(
                        sequence_length, self.special_tokens["PAD"], dtype=np.int64
                    )

                    # Set CLS token
                    tokens[0] = self.special_tokens["CLS"]

                    # Vectorized interleaving
                    tokens[1::2][: len(gene_tokens)] = gene_tokens  # type: ignore[assignment]
                    tokens[2::2][: len(expr_tokens)] = expr_tokens  # type: ignore[assignment]

                    tokens[1 + 2 * len(gene_tokens)] = self.special_tokens["SEP"]
                else:
                    # Empty sequence case
                    tokens = np.array(
                        [self.special_tokens["CLS"], self.special_tokens["SEP"]],
                        dtype=np.int64,
                    )  # type: ignore[assignment]
            else:
                # Empty sequence case
                tokens = np.array(
                    [self.special_tokens["CLS"], self.special_tokens["SEP"]],
                    dtype=np.int64,
                )  # type: ignore[assignment]

            # Pad/truncate to correct sequence length
            if self.tokenizer_type == TokenizerType.SCPGPT:
                # For scGPT: use max_sequence_length (2*max_genes+2)
                target_length = max_sequence_length
            else:
                # For Geneformer: use max_genes
                target_length = max_genes

            tokens = tokens[:target_length]  # type: ignore[assignment]
            if len(tokens) < target_length:
                padding = np.full(
                    target_length - len(tokens),
                    self.special_tokens["PAD"],
                    dtype=np.int64,
                )
                tokens = np.concatenate([tokens, padding])  # type: ignore[assignment]

            # Fill array
            token_array[i, :] = tokens  # type: ignore[assignment]

    else:
        # Geneformer format: [CLS] gene1 gene2 gene3 ... [SEP]
        for i, gene_sequence in enumerate(gene_sequences):
            # Convert gene IDs to tokens (fast mapping)
            gene_tokens = np.array(gene_sequence, dtype=np.int64) + 4

            # Vectorized sequence building: use concatenation for speed
            if len(gene_tokens) > 0:
                # Use concatenation: CLS + genes + SEP
                tokens = np.concatenate(
                    [
                        [self.special_tokens["CLS"]],
                        gene_tokens,
                        [self.special_tokens["SEP"]],
                    ]
                )  # type: ignore[assignment]
            else:
                # Empty sequence case
                tokens = np.array(
                    [self.special_tokens["CLS"], self.special_tokens["SEP"]],
                    dtype=np.int64,
                )  # type: ignore[assignment]

            # Pad/truncate to max_genes
            tokens = tokens[:max_genes]  # type: ignore[assignment]
            if len(tokens) < max_genes:
                padding = np.full(
                    max_genes - len(tokens),
                    self.special_tokens["PAD"],
                    dtype=np.int64,
                )
                tokens = np.concatenate([tokens, padding])  # type: ignore[assignment]

            # Fill array
            token_array[i, :] = tokens  # type: ignore[assignment]

    # Convert to tensors in one operation
    input_ids = torch.from_numpy(token_array)
    attention_mask = input_ids != self.special_tokens["PAD"]

    return input_ids, attention_mask
get_vocab_info() -> dict[str, Any]

Get vocabulary information for debugging and analysis.

Returns:

Name Type Description
dict dict[str, Any]

Vocabulary information including size, special tokens, etc.

Examples:

>>> vocab_info = tokenizer.get_vocab_info()
>>> print(f"Vocabulary size: {vocab_info['vocab_size']}")
>>> print(f"Special tokens: {vocab_info['special_tokens']}")
Vocabulary size: 50000
Special tokens: {'PAD': 0, 'CLS': 1, 'SEP': 2, 'MASK': 3}
Source code in slaf/ml/tokenizers.py
def get_vocab_info(self) -> dict[str, Any]:
    """
    Get vocabulary information for debugging and analysis.

    Returns:
        dict: Vocabulary information including size, special tokens, etc.

    Examples:
        >>> vocab_info = tokenizer.get_vocab_info()
        >>> print(f"Vocabulary size: {vocab_info['vocab_size']}")
        >>> print(f"Special tokens: {vocab_info['special_tokens']}")
        Vocabulary size: 50000
        Special tokens: {'PAD': 0, 'CLS': 1, 'SEP': 2, 'MASK': 3}
    """
    return {
        "vocab_size": self.vocab_size,
        "tokenizer_type": self.tokenizer_type.value,
        "special_tokens": self.special_tokens,
        "n_expression_bins": self.n_expression_bins,
        "gene_vocab_size": len(self.gene_vocab),
    }
decode_tokens(tokens: list[int]) -> dict[str, Any]

Decode token sequence back to gene information.

Parameters:

Name Type Description Default
tokens list[int]

List of token IDs to decode

required

Returns:

Name Type Description
dict dict[str, Any]

Decoded information including genes, expressions, etc.

Examples:

>>> # Decode a token sequence
>>> tokens = [1, 100, 50050, 200, 50060, 2]  # CLS, gene1, expr1, gene2, expr2, SEP
>>> decoded = tokenizer.decode_tokens(tokens)
>>> print(f"Genes: {decoded['genes']}")
>>> print(f"Expressions: {decoded['expressions']}")
Genes: ['gene_100', 'gene_200']
Expressions: [0.5, 0.6]
Source code in slaf/ml/tokenizers.py
def decode_tokens(self, tokens: list[int]) -> dict[str, Any]:
    """
    Decode token sequence back to gene information.

    Args:
        tokens: List of token IDs to decode

    Returns:
        dict: Decoded information including genes, expressions, etc.

    Examples:
        >>> # Decode a token sequence
        >>> tokens = [1, 100, 50050, 200, 50060, 2]  # CLS, gene1, expr1, gene2, expr2, SEP
        >>> decoded = tokenizer.decode_tokens(tokens)
        >>> print(f"Genes: {decoded['genes']}")
        >>> print(f"Expressions: {decoded['expressions']}")
        Genes: ['gene_100', 'gene_200']
        Expressions: [0.5, 0.6]
    """
    if not tokens:
        return {"genes": [], "expressions": [], "special_tokens": []}

    genes = []
    expressions = []
    special_tokens = []

    i = 0
    while i < len(tokens):
        token = tokens[i]

        if token == self.special_tokens["CLS"]:
            special_tokens.append("CLS")
            i += 1
        elif token == self.special_tokens["SEP"]:
            special_tokens.append("SEP")
            i += 1
        elif token == self.special_tokens["PAD"]:
            special_tokens.append("PAD")
            i += 1
        elif token == self.special_tokens["MASK"]:
            special_tokens.append("MASK")
            i += 1
        elif (
            self.tokenizer_type == TokenizerType.SCPGPT
            and token >= self.expr_bin_start
        ):
            # Expression token
            bin_id = token - self.expr_bin_start
            expr_value = bin_id * self.expr_bin_size
            expressions.append(expr_value)
            i += 1
        else:
            # Gene token
            if token in self.token_to_gene:
                genes.append(self.token_to_gene[token])
            else:
                genes.append(f"unknown_gene_{token}")
            i += 1

    return {
        "genes": genes,
        "expressions": expressions,
        "special_tokens": special_tokens,
    }