Source code for janggu.data.visualization

"""Genomic track visualization utilities."""
import warnings
from itertools import product

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from janggu.utils import NMAP
from janggu.utils import PMAP
from janggu.utils import _to_list


[docs]def plotGenomeTrack(tracks, chrom, start, end, figsize=(10, 5), plottypes=None): """plotGenomeTrack shows plots of a specific interval from cover objects data. It takes one or more cover objects as well as a genomic interval consisting of chromosome name, start and end and creates a genome browser-like plot. Parameters ---------- tracks : janggu.data.Cover, list(Cover), janggu.data.Track or list(Track) One or more track objects. chrom : str chromosome name. start : int The start of the required interval. end : int The end of the required interval. figsize : tuple(int, int) Figure size passed on to matplotlib. plottype : None or list(str) Plot type indicates whether to plot coverage tracks as line plots, heatmap, or seqplot using 'line' or 'heatmap', respectively. By default, all coverage objects are depicted as line plots if plottype=None. Otherwise, a list of types must be supplied containing the plot types for each coverage object explicitly. For example, ['line', 'heatmap', 'seqplot']. While, 'line' and 'heatmap' can be used for any type of coverage data, 'seqplot' is reserved to plot sequence influence on the output. It is intended to be used in conjunction with 'input_attribution' method which determines the importance of paricular sequence letters for the output prediction. Returns ------- matplotlib Figure A matplotlib figure illustrating the genome browser-view of the coverage objects for the given interval. To depict and save the figure the native matplotlib functions show() and savefig() can be used. """ tracks = _to_list(tracks) for track in tracks: if not isinstance(track, Track): warnings.warn('Convert the Dataset object to proper Track objects.' ' In the future, only Track objects will be supported.', FutureWarning) if plottypes is None: plottypes = ['line'] * len(tracks) assert len(plottypes) == len(tracks), \ "The number of cover objects must be the same as the number of plottyes." break def _convert_to_track(cover, plottype): if plottype == 'heatmap': track = HeatTrack(cover) elif plottype == 'seqplot': track = SeqTrack(cover) else: track = LineTrack(cover) return track tracks_ = [] for itrack, track in enumerate(tracks): if isinstance(track, Track): tracks_.append(track) else: warnings.warn('Convert the Dataset object to proper Track objects.' ' In the future, only Track objects will be supported.', FutureWarning) tracks_.append(_convert_to_track(track, plottypes[itrack])) tracks = tracks_ headertrack = 2 trackheights = 0 for track in tracks: trackheights += track.height spacer = len(tracks) - 1 grid = plt.GridSpec(headertrack + trackheights + spacer, 10, wspace=0.4, hspace=0.3) fig = plt.figure(figsize=figsize) # title and reference track title = fig.add_subplot(grid[0, 1:]) title.set_title(chrom) plt.xlim([0, end - start]) title.spines['right'].set_visible(False) title.spines['top'].set_visible(False) title.spines['left'].set_visible(False) plt.xticks([0, end-start], [start, end]) plt.yticks(()) y_offset = 1 for track in tracks: y_offset += 1 track.add_side_bar(fig, grid, y_offset) track.plot(fig, grid, y_offset, chrom, start, end) y_offset += track.height return (fig)
[docs]class Track(object): """General track Parameters ---------- data : Cover object Coverage object height : int Track height. """ def __init__(self, data, height): self.height = height self.data = data @property def name(self): """Track name""" return self.data.name def add_side_bar(self, fig, grid, offset): """Side-bar""" # side bar indicator for current cover ax = fig.add_subplot(grid[(offset): (offset + self.height), 0]) ax.set_xticks(()) ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.set_yticks([0.5]) ax.set_yticklabels([self.name]) def get_track_axis(self, fig, grid, offset, height): """Returns axis object""" return fig.add_subplot(grid[offset:(offset + height), 1:]) def get_data(self, chrom, start, end): """Returns data to plot.""" return self.data[chrom, start, end][0, :, :, :]
[docs]class LineTrack(Track): """Line track Visualizes genomic data as line plot. Parameters ---------- data : Cover object Coverage object height : int Track height. Default=3 linestyle : str Linestyle for plot marker : str Marker code for plot color : str Color code for plot linewidth : float Line width. """ def __init__(self, data, height=3, linestyle='-', marker='o', color='b', linewidth=2): super(LineTrack, self).__init__(data, height) self.height = height * len(data.conditions) self.linestyle = linestyle self.linewidth = linewidth self.marker = marker self.color = color def plot(self, fig, grid, offset, chrom, start, end): """Plot line track.""" coverage = self.get_data(chrom, start, end) offset_ = offset trackheight = self.height//len(self.data.conditions) def _get_xy(cov): xvalue = np.where(np.isfinite(cov))[0] yvalue = cov[xvalue] return xvalue, yvalue for i, condition in enumerate(self.data.conditions): ax = self.get_track_axis(fig, grid, offset_, trackheight) offset_ += trackheight if coverage.shape[1] == 2: #both strands are covered separately xvalue, yvalue = _get_xy(coverage[:, 0, i]) ax.plot(xvalue, yvalue, linewidth=self.linewidth, linestyle=self.linestyle, color=self.color, label="+", marker='+') xvalue, yvalue = _get_xy(coverage[:, 1, i]) ax.plot(xvalue, yvalue, linewidth=self.linewidth, linestyle=self.linestyle, color=self.color, label="-", marker=1) ax.legend() else: xvalue, yvalue = _get_xy(coverage[:, 0, i]) ax.plot(xvalue, yvalue, linewidth=self.linewidth, color=self.color, linestyle=self.linestyle, marker=self.marker) ax.set_yticks(()) ax.set_xticks(()) ax.set_xlim([0, end-start]) if len(self.data.conditions) > 1: ax.set_ylabel(condition, labelpad=12) ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False)
[docs]class SeqTrack(Track): """Sequence Track Visualizes sequence importance. Parameters ---------- data : Cover object Coverage object height : int Track height. Default=3 """ def __init__(self, data, height=3): super(SeqTrack, self).__init__(data, height) def plot(self, fig, grid, offset, chrom, start, end): """Plot sequence track""" if len(self.data.conditions) % len(NMAP) == 0: alphabetsize = len(NMAP) MAP = NMAP elif len(self.data.conditions) % len(PMAP) == 0: # pragma: no cover alphabetsize = len(PMAP) MAP = PMAP else: # pragma: no cover raise ValueError( "Coverage tracks seems not represent biological sequences. " "The last dimension must be divisible by the alphabetsize.") for cond in self.data.conditions: if cond[0] not in MAP: raise ValueError( "Coverage tracks seems not represent biological sequences. " "Condition names must represent the alphabet letters.") coverage = self.get_data(chrom, start, end) # project higher-order sequence structure onto original sequence. coverage = coverage.reshape(coverage.shape[0], -1) coverage = coverage.reshape(coverage.shape[:-1] + (alphabetsize, int(coverage.shape[-1]/alphabetsize))) coverage = coverage.sum(-1) ax = self.get_track_axis(fig, grid, offset, self.height) x = np.arange(coverage.shape[0]) y_figure_offset = np.zeros(coverage.shape[0]) handles = [] for letter in MAP: handles.append(ax.bar(x, coverage[:, MAP[letter]], bottom=y_figure_offset, color=sns.color_palette("hls", len(MAP))[MAP[letter]], label=letter)) y_figure_offset += coverage[:, MAP[letter]] ax.legend(handles=handles) ax.set_yticklabels(()) ax.set_yticks(()) ax.set_xticks(()) ax.set_xlim([0, end-start])
[docs]class HeatTrack(Track): """Heatmap Track Visualizes genomic data as heatmap. Parameters ---------- data : Cover object Coverage object height : int Track height. Default=3 """ def __init__(self, data, height=3): super(HeatTrack, self).__init__(data, height) def plot(self, fig, grid, offset, chrom, start, end): """Plot heatmap track.""" ax = self.get_track_axis(fig, grid, offset, self.height) coverage = self.get_data(chrom, start, end) im = ax.pcolor(coverage.reshape(coverage.shape[0], -1).T) if coverage.shape[-2] == 2: ticks = [':'.join([x, y]) for y, x \ in product(['+', '-'], self.data.conditions)] else: ticks = self.data.conditions ax.set_yticklabels(ticks) ax.set_xticks(()) ax.set_yticks(np.arange(0, len(ticks) + 1, 1.0)) ax.set_xlim([0, end-start])