pugh_torch.utils package

Submodules

pugh_torch.utils.batch_index_select module

Based on:

https://discuss.pytorch.org/t/batch-index-select/9115/11

pugh_torch.utils.batch_index_select.batch_index_select(input, dim, index)[source]

batch version of torch.index_select.

Returns a new tensor which indexes the input tensor along dimension dim using the corresponding entries in index which is a LongTensor.

The returned tensor has the same number of dimensions as the original tensor (input). The ``dim``th dimension has the same size as the length of index; other dimensions have the same size as in the original tensor.

Parameters
  • input (torch.Tensor) – (B, ..)the input tensor.

  • dim (int) – the dimension in which we index. Must be >0 since we use the ``0``th index as the batch. May be negative.

  • index (torch.LongTensor) – (B, N) the 1-D tensor containing the indices to index per batch

Returns

(B, …) tensor that matches the input dimensions, except the dim``th dimension now has length ``N.

NOTE: does NOT use the same storage as input Tensor

Return type

torch.Tensor

pugh_torch.utils.io module

exception pugh_torch.utils.io.GDriveDownloadError[source]

Bases: OSError

Failed to download file from google drive

pugh_torch.utils.io.gdrive_download(url, path, progress=True)[source]

Download a file from google drive.

Parameters
  • url (str-like) – Public google drive link.

  • path (str-like) – Path to where the file should be downloaded. If a directory (determined if extensionless), it will be created if necessary and use the name from google drive.

  • progress (bool) – Display download progress

Returns

Path to the downloaded file.

Return type

pathlib.Path

pugh_torch.utils.misc module

pugh_torch.utils.misc.timeit(msg='')[source]

Timing context manager

Example usage:
>>> with timeit('Data Reading'):
...    time.sleep(5)
Data Reading executed in 5.002 seconds.
Parameters

msg (str) – Message to be prepended to ” executed in %.3f seconds.”

pugh_torch.utils.tensorboard module

class pugh_torch.utils.tensorboard.SummaryWriter(*args, rgb_transform=None, **kwargs)[source]

Bases: torch.utils.tensorboard.writer.SummaryWriter

Extension of Summary Writer for convenient common uses.

Parameters

rgb_transform (str, torchvision.transforms.*) – Used when rgb images need to be logged. Transform applied to rgb images when logged. Should convert whatever rgb data your batch is in to a float image in range [0, 1]. Values outside this range will be clipped.

add_depth(tag, rgbs, preds, targets, global_step=None, walltime=None, dataformats='CHW', *, rgb_transform=None, n_images=1)[source]

Add a depth image and its pairing input montage.

self.add_rgb’s documentation applies to the rgbs input here.

Parameters
  • preds

  • preds (torch.Tensor) – (B, H, W) Predicted depth in meters.

  • targets (torch.Tensor) – (B, H, W) Ground truth depth in meters.

add_rgb(tag, rgbs, global_step=None, walltime=None, dataformats='CHW', *, rgb_transform=None, n_images=1, labels=None)[source]

Applies a transform and adds image to log

A common scenario is when you only have a normalized image (say, by ImageNet’s mean and stddev) and you want to log it to tensorboard.

In this case, it may be more convenient to set:

rgb_transform=”imagenet”

in the constructor.

rgbstorch.Tensor

(B, 3, H, W) Image data. See rgb_transform argument.

rgb_transformstr or callable

Transform to apply to the rgb data. If not provided, defaults to the transform provided in __init__.

n_imagesint

Maximum number of images to add.

labelslist or str

Some string to rasterize to text and display under image. If str, the same str will be appied under all images.

add_ss(tag, rgbs, preds, targets, global_step=None, walltime=None, dataformats='CHW', *, rgb_transform=None, n_images=1, palette='ade20k', offset=0, labels=None)[source]

Add a semantic segmentation image and it’s pairing input montage.

self.add_rgb’s documentation applies to the rgbs input here.

TODO: more control over which image to show

Parameters
  • tag (str) – Data identifier

  • rgbs (torch.Tensor) – (B, 3, H, W) Image data. See rgb_transform argument.

  • preds (torch.Tensor) – (B, C, H, W) Predicted semantic segmentation data. This method will argmax over the C dimension.

  • targets (torch.Tensor) – (B, H, W) Indexed ground truth data.

  • rgb_transform (str or callable) – Transform to apply to the rgb data. If not provided, defaults to the transform provided in __init__. Expects data to be in range [0, 1] after transform.

  • n_images (int) – Maximum number of images to add.

  • offset (int) – Add this to the pred and target index into the colormap. A common value might be 1 if your network isn’t using a background class.

  • labels (list or str) – Some string to rasterize to text and display under image. If str, the same str will be appied under all images.

class pugh_torch.utils.tensorboard.TensorBoardLogger(save_dir: str, name: Optional[str] = 'default', version: Union[int, str, None] = None, log_graph: bool = False, default_hp_metric: bool = True, **kwargs)[source]

Bases: pytorch_lightning.loggers.tensorboard.TensorBoardLogger

Same as default PyTorch Lightning TensorBoard Logger, but uses the extended SummaryWriter defined in this file.

property experiment

Actual tensorboard object. To use TensorBoard features in your LightningModule do the following. Example:

self.logger.experiment.some_tensorboard_function()

Module contents