Source code for haddock.modules.analysis.contactmap.contmap
"""Module computing contact maps of complexes, alone or grouped by cluster.
Chord diagram functions were adapted from:
https://plotly.com/python/v3/filled-chord-diagram/
"""
import os
import glob
from pathlib import Path
import numpy as np
import plotly.graph_objs as go
from scipy.spatial.distance import pdist, squareform
from haddock import log
from haddock.libs.libontology import PDBFile
from haddock.libs.libpdb import (
slc_name,
slc_resname,
slc_chainid,
slc_resseq,
slc_x,
slc_y,
slc_z,
)
from haddock.core.typing import (
Any,
NDFloat,
NDArray,
Optional,
Union,
SupportsRun,
)
from haddock.libs.libplots import heatmap_plotly, fig_to_html
###############################
# Global variable definitions #
###############################
RESIDUE_POLARITY = {
"CYS": "polar",
"HIS": "polar",
"ASN": "polar",
"GLN": "polar",
"SER": "polar",
"THR": "polar",
"TYR": "polar",
"TRP": "polar",
"ALA": "apolar",
"PHE": "apolar",
"GLY": "apolar",
"ILE": "apolar",
"VAL": "apolar",
"MET": "apolar",
"PRO": "apolar",
"LEU": "apolar",
"GLU": "negative",
"ASP": "negative",
"LYS": "positive",
"ARG": "positive",
}
DNA_RNA_POLARITY = {
"A": "negative",
"DA": "negative",
"T": "negative",
"DT": "negative",
"U": "negative",
"DU": "negative",
"C": "negative",
"DC": "negative",
"G": "negative",
"DG": "negative",
}
RESIDUE_POLARITY.update(DNA_RNA_POLARITY)
PI = np.pi
# Define interaction types colors
CONNECT_COLORS = {
"polar-polar": (153, 255, 153),
"polar-apolar": (255, 204, 204),
"polar-negative": (255, 204, 153),
"polar-positive": (153, 204, 255),
"apolar-apolar": (255, 255, 0),
"apolar-negative": (255, 229, 204),
"apolar-positive": (204, 229, 255),
"negative-negative": (255, 127, 0),
"negative-positive": (0, 204, 0),
"positive-positive": (255, 127, 0),
}
# Also add reversed keys order
REVERSED_CONNECT_COLORS_KEYS = {
'-'.join(k.split('-')[::-1]): v
for k, v in CONNECT_COLORS.items()
}
CONNECT_COLORS.update(REVERSED_CONNECT_COLORS_KEYS)
# Colors for each amino-acids
RESIDUES_COLORS = {
"CYS": "rgba(229, 255, 204, 0.80)",
"MET": "rgba(229, 255, 204, 0.80)",
"ASN": "rgba(128, 255, 0, 0.80)",
"GLN": "rgba(128, 255, 0, 0.80)",
"SER": "rgba(153, 255, 51, 0.80)",
"THR": "rgba(153, 255, 51, 0.80)",
"TYR": "rgba(204, 155, 53, 0.80)",
"TRP": "rgba(204, 155, 53, 0.80)",
"HIS": "rgba(204, 155, 53, 0.80)",
"PHE": "rgba(255, 255, 51, 0.80)",
"ALA": "rgba(255, 255, 0, 0.80)",
"ILE": "rgba(255, 255, 0, 0.80)",
"VAL": "rgba(255, 255, 0, 0.80)",
"PRO": "rgba(255, 255, 0 0.80)",
"LEU": "rgba(255, 255, 0, 0.80)",
"GLY": "rgba(255, 255, 255, 0.80)",
"GLU": "rgba(255, 0, 0, 0.80)",
"ASP": "rgba(255, 0, 0, 0.80)",
"LYS": "rgba(0, 0, 255, 0.80)",
"ARG": "rgba(0, 0, 255, 0.80)",
}
# Colors for DNA / RNA
# Note: Based on WebLogo color scheme
DNARNA_COLORS = {
"A": "rgba(51, 204, 51, 0.80)",
"T": "rgba(204, 0, 0, 0.80)",
"U": "rgba(204, 0, 0, 0.80)",
"C": "rgba(51, 102, 255, 0.80)",
"G": "rgba(255, 163, 26, 0.80)",
}
FULL_DNARNA_COLORS = {f"D{k}": rgba for k, rgba in DNARNA_COLORS.items()}
# Combine all of them
AA_DNA_RNA_COLORS = {}
AA_DNA_RNA_COLORS.update(RESIDUES_COLORS)
AA_DNA_RNA_COLORS.update(DNARNA_COLORS)
AA_DNA_RNA_COLORS.update(FULL_DNARNA_COLORS)
# Chain colors
CHAIN_COLORS = [
'rgba(51, 255, 51, 0.85)',
'rgba(51, 153, 255, 0.85)',
'rgba(255, 153, 51, 0.85)',
'rgba(255, 255, 51, 0.85)',
'rgba(255, 0, 0, 0.85)',
'rgba(255, 0, 127, 0.85)',
'rgba(0, 255, 0, 0.85)',
'rgba(0, 0, 255, 0.85)',
'rgba(0, 153, 0, 0.85)',
]
CHAIN_COLORS = CHAIN_COLORS[::-1]
##################
# Define classes #
##################
[docs]
class ContactsMapJob(SupportsRun):
"""A Job dedicated to the running of contact maps objects."""
def __init__(
self,
output,
params,
name,
contact_obj,
):
super(ContactsMapJob, self).__init__()
self.params = params
self.output = output
self.name = name
self.contact_obj = contact_obj
[docs]
class ContactsMap():
"""ContactMap analysis for single structure."""
def __init__(
self,
model: Path,
output: Path,
params: dict,
) -> None:
self.model = model
self.output = output
self.params = params
self.files: dict[str, Union[str, Path]] = {}
[docs]
def run(self):
"""Process analysis of contacts of a PDB structure."""
# Load pdb
pdb_dt = extract_pdb_dt(self.model)
# Extract all cordinates
all_coords, resid_keys, resid_dt = get_ordered_coords(pdb_dt)
# Compute distance matrix
full_dist_matrix = compute_distance_matrix(all_coords)
res_res_contacts = []
all_heavy_interchain_contacts = []
# First loop over residues
for ri, reskey_1 in enumerate(resid_keys):
# Second loop over residues (half matrix only)
for _rj, reskey_2 in enumerate(resid_keys[ri + 1:], start=ri + 1):
# Extract data
contact_dt = gen_contact_dt(
full_dist_matrix,
resid_dt,
reskey_1,
reskey_2,
)
res_res_contacts.append(contact_dt)
# Extract interchain heavy atoms data
if reskey_2.split('-')[0] == reskey_1.split('-')[0]:
continue
heavy_atoms_contacts = extract_heavyatom_contacts(
full_dist_matrix,
resid_dt,
reskey_1,
reskey_2,
contact_distance=self.params['shortest_dist_threshold'],
)
all_heavy_interchain_contacts += heavy_atoms_contacts
# generate outputs for single models
if self.params['single_model_analysis']:
self.generate_output(
res_res_contacts, all_heavy_interchain_contacts,
)
return res_res_contacts, all_heavy_interchain_contacts
[docs]
def generate_output(
self,
res_res_contacts: list[dict],
all_heavy_interchain_contacts: list[dict],
) -> None:
"""Generate several outputs based on contacts.
Parameters
----------
res_res_contacts : list[dict]
List of residue-residue contacts
all_heavy_interchain_contacts : list[dict]
List of heavy atoms interchain contacts
"""
# write contacts tsv files
header = ['res1', 'res2']
header += [
v for v in sorted(res_res_contacts[0])
if v not in header
]
fpath = write_res_contacts(
res_res_contacts,
header,
f'{self.output}_contacts.tsv',
interchain_data={
'path': f'{self.output}_interchain_contacts.tsv',
'data_key': 'ca-ca-dist',
'contact_threshold': self.params['ca_ca_dist_threshold'],
}
)
log.info(f'Generated contacts file: {fpath}')
self.files['res-res-contacts'] = fpath
# Genreate corresponding heatmap
if self.params['generate_heatmap']:
heatmap = tsv_to_heatmap(
fpath,
data_key='ca-ca-dist',
contact_threshold=self.params['ca_ca_dist_threshold'],
colorscale=self.params['color_ramp'],
output_fname=f'{self.output}_heatmap.html',
offline=self.params["offline"],
)
log.info(f'Generated single model heatmap file: {heatmap}')
self.files['res-res-contactmap'] = heatmap
# Generate corresponding chord chart
if self.params['generate_chordchart']:
# find theshold type
if self.params['chordchart_datatype'] == 'ca-ca-dist':
threshold = self.params['ca_ca_dist_threshold']
else:
threshold = self.params['shortest_dist_threshold']
chordp = tsv_to_chordchart(
fpath,
data_key=self.params['chordchart_datatype'],
contact_threshold=threshold,
output_fname=f'{self.output}_chordchart.html',
filter_intermolecular_contacts=True,
title=Path(self.output).stem.replace('_', ' '),
offline=self.params["offline"],
)
log.info(f'Generated single model chordchart file: {chordp}')
self.files['res-res-chordchart'] = chordp
# Write interchain heavy atoms contacts tsv file
header2 = ['atom1', 'atom2', 'dist']
fpath2 = write_res_contacts(
all_heavy_interchain_contacts,
header2,
f'{self.output}_heavyatoms_interchain_contacts.tsv',
)
log.info(f'Generated contacts file: {fpath2}')
self.files['atom-atom-interchain-contacts'] = fpath2
[docs]
class ClusteredContactMap():
"""ContactMap analysis for set of clustered structures."""
def __init__(
self,
models: list[Path],
output: Path,
params: dict,
) -> None:
self.models = models
self.output = output
self.params = params
self.files: dict[str, Union[str, Path]] = {}
self.terminated = False
[docs]
@staticmethod
def aggregate_contacts(
contacts_holder: dict,
contact_keys: list[str],
contacts: list[dict],
key1: str,
key2: str,
) -> None:
"""Aggregate single models data belonging to a cluster.
Parameters
----------
contacts_holder : dict
Dictionnary holding list of contact data
contact_keys : list[str]
Order of the keys to access the dictionnary
contacts : list[dict]
Singel model contact data.
key1 : str
Name of the key to access first entry in data.
key2 : str
Name of the key to access second entry in data.
"""
# Parse outputs to aggregate contacts in `clusters_contacts`
for cont in contacts:
# Check key
combined_key = f'{cont[key2]}/{cont[key1]}' # resversed
if combined_key not in contacts_holder.keys():
combined_key = f'{cont[key1]}/{cont[key2]}' # normal
if combined_key not in contacts_holder.keys():
# Add key order
contact_keys.append(combined_key)
# Initiate key
contacts_holder[combined_key] = {
k: []
for k in cont.keys()
if k not in [key1, key2]
}
# Add data
for dtk in contacts_holder[combined_key].keys():
contacts_holder[combined_key][dtk].append(cont[dtk])
[docs]
def run(self):
"""Process analysis of contacts of a set of PDB structures."""
# initiate holding variables
clusters_contacts = {} # Residue-residue contacts
resres_keys_list = [] # Ordered residue-residue contacts keys
clusters_heavyatm_contacts = {} # Interchain atom-atom contacts
atat_keys_list = [] # Ordered interchain atom-atom contacts keys
# loop over models/structures
for pdb_path in self.models:
# initiate object
contact_map_obj = ContactsMap(
pdb_path,
f'{self.output}_{pdb_path.stem}',
self.params,
)
# Run it
pdb_contacts, interchain_heavy_contacts = contact_map_obj.run()
# Parse outputs to aggregate contacts in `clusters_contacts`
self.aggregate_contacts(
clusters_contacts, resres_keys_list,
pdb_contacts,
"res1", "res2",
)
# Parse outputs for heavy atoms contacts
self.aggregate_contacts(
clusters_heavyatm_contacts, atat_keys_list,
interchain_heavy_contacts,
"atom1", "atom2",
)
# Initiate heavy atoms contact cluster aggrated data
heavy_atm_clust_list = []
for atatk in atat_keys_list:
at1, at2 = atatk.split('/')
# point corresponding list of distances
h_dists = clusters_heavyatm_contacts[atatk]['dist']
# Summerize it
heavy_atm_clust_list.append({
"atom1": at1,
"atom2": at2,
"nb_dists": len(h_dists),
"avg_dist": round(np.mean(h_dists), 2),
"std_dist": round(np.std(h_dists), 2),
})
# write contacts
header = ['atom1', 'atom2']
header += [
v for v in sorted(heavy_atm_clust_list[0])
if v not in header
]
hfpath = write_res_contacts(
heavy_atm_clust_list,
header,
f'{self.output}_heavyatoms_interchain_contacts.tsv',
)
log.info(f'Generated heavy atoms interchain contacts file: {hfpath}')
# Initiate cluster aggregated data holder
combined_clusters_list = []
# Loop over ordered keys
for combined_key in resres_keys_list:
# point data
dt = clusters_contacts[combined_key]
# Compute averages Ca-Ca distances
avg_ca_ca_dist = np.mean(dt['ca-ca-dist'])
# Compute nb. times cluster members holds a value under threshold
ca_ca_under_thresh = [
v for v in dt['ca-ca-dist']
if v <= self.params['ca_ca_dist_threshold']
]
nb_under = len(ca_ca_under_thresh)
# Compute probability
ca_ca_cont_probability = nb_under / len(dt['ca-ca-dist'])
# Compute averages for shortest distances
avg_shortest = np.mean(dt['shortest-dist'])
# Generate list of shortest distances observed between two residues
short_under_threshold = [
v for v in dt['shortest-dist']
if v <= self.params['shortest_dist_threshold']
]
short_nb_und = len(short_under_threshold)
# Compute nb. time the cluster members holds a value under threshold
short_cont_proba = short_nb_und / len(dt['shortest-dist'])
# Find most representative contact type
cont_ts = list(set(dt['contact-type']))
# Decreasing sorting of cluster contact types and pick highest one
cont_t = sorted(
cont_ts,
key=lambda k: cont_ts.count(k),
reverse=True,
)[0]
# Split key to recover resiudes names
res1, res2 = combined_key.split('/')
# Hold summary data for cluster
combined_clusters_list.append({
'res1': res1,
'res2': res2,
'ca-ca-dist': round(avg_ca_ca_dist, 1),
'ca-ca-cont-probability': round(ca_ca_cont_probability, 2),
'shortest-dist': round(avg_shortest, 1),
'shortest-cont-probability': round(short_cont_proba, 2),
'contact-type': cont_t,
})
# write contacts
header = ['res1', 'res2']
header += [
v for v in sorted(combined_clusters_list[0])
if v not in header
]
fpath = write_res_contacts(
combined_clusters_list,
header,
f'{self.output}_contacts.tsv',
interchain_data={
'path': f'{self.output}_interchain_contacts.tsv',
'data_key': 'ca-ca-dist',
'contact_threshold': self.params['ca_ca_dist_threshold'],
}
)
log.info(f'Generated contacts file: {fpath}')
self.files['res-res-contacts'] = fpath
self.files['atom-atom-interchain-contacts'] = f'{self.output}_interchain_contacts.tsv' # noqa : E501
# Generate corresponding heatmap
if self.params['generate_heatmap']:
heatmap_path = tsv_to_heatmap(
fpath,
data_key=self.params['cluster_heatmap_datatype'],
contact_threshold=1,
colorscale=self.params['color_ramp'],
output_fname=f'{self.output}_heatmap.html',
offline=self.params["offline"],
)
log.info(f'Generated cluster contacts heatmap: {heatmap_path}')
self.files['res-res-contactmap'] = heatmap_path
# Generate corresponding chord chart
if self.params['generate_chordchart']:
# find theshold type
if self.params['chordchart_datatype'] == 'ca-ca-dist':
threshold = self.params['ca_ca_dist_threshold']
else:
threshold = self.params['shortest_dist_threshold']
chordp = tsv_to_chordchart(
fpath,
data_key=self.params['chordchart_datatype'],
contact_threshold=threshold,
output_fname=f'{self.output}_chordchart.html',
filter_intermolecular_contacts=True,
title=Path(self.output).stem.replace('_', ' '),
offline=self.params["offline"],
)
log.info(f'Generated cluster contacts chordchart file: {chordp}')
self.files['res-res-chordchart'] = chordp
self.terminated = True
[docs]
def make_contactmap_report(
contactmap_jobs: list[ContactsMapJob],
outputpath: Union[str, Path],
) -> Union[str, Path]:
"""Generate a HTML navigation page holding all generated files.
Parameters
----------
contact_jobs : list[Union[ClusteredContactMap, ContactsMap]]
All the terminated jobs
outputpath : Union[str, Path]
Output filepath where to write the report.
Returns
-------
outputpath: Union[str, Path]
Path to the generated report.
"""
ordered_files = []
# Loop over terminated jobs
for job in contactmap_jobs:
basepath = f"{job.output}_"
# Gather all files generated by this job
job_files = glob.glob(f"{basepath}*")
# Sort them by file extension
ext_names: dict[str, Union[str, Path]] = {}
for fpath in job_files:
_fname, ext = os.path.splitext(fpath)
if ext not in ext_names.keys():
ext_names[ext] = []
ext_names[ext].append(fpath)
# Sort each keys filepaths
for ext in ext_names.keys():
ext_names[ext] = sorted(
ext_names[ext],
key=lambda k: k.replace(basepath, ""),
)
# Get final list order
sorted_jobfiles = [
fpath
for ext in sorted(ext_names)
for fpath in ext_names[ext]
]
# Initiate html links holding list
job_list: list[str] = []
# Loop over generated files
for fpath in sorted_jobfiles:
# Generate html link
shortname = fpath.replace(basepath, "")
html_string = f'<a href="{fpath}" target="_blank">{shortname}</a>'
job_list.append(html_string)
# Combine all links in one string
job_list_combined = ', '.join(job_list)
# Create final string
job_access = f"<b>{job.name}:</b> {job_list_combined}"
# Hold that guy
ordered_files.append(job_access)
# Combine all jobs outputs as a list
all_access = '</li>\n <li>'.join(ordered_files)
# Generate small html file
htmldt = f"""
<div id="contactmap_report">
<ul>
<li>
{all_access}
</li>
</ul>
</div>
"""
# Write it
with open(outputpath, 'w') as reportout:
reportout.write(htmldt)
log.info(f'Generated report file: {outputpath}')
# Return generate outputfilepath
return outputpath
[docs]
def get_clusters_sets(models: list[PDBFile]) -> dict:
"""Split models by clusters ids.
Parameters
----------
models : list
List of pdb models/complexes.
Return
------
clusters_sets : dict
Dictionary of models acccessible by their cluster ids as keys.
"""
clusters_sets: dict = {}
for model in models:
if model.clt_id not in clusters_sets.keys():
clusters_sets[model.clt_id] = []
clusters_sets[model.clt_id].append(model)
return clusters_sets
[docs]
def topX_models(models: list[PDBFile], topX: int = 10) -> list[Any]:
"""Sort and return subset of top X best models.
Parameters
----------
models : list
List of pdb models/complexes.
topX : int
Number of models to return after sorting.
Return
------
subset_bests : list
List of top `X` best models.
"""
try:
sorted_models = sorted(models, key=lambda m: m.score)
except AttributeError:
sorted_models = models
finally:
subset_bests = sorted_models[:topX]
return subset_bests
####################
# Define functions #
####################
[docs]
def extract_pdb_dt(path: Path) -> dict:
"""Read and extract ATOM/HETATM records from a pdb file.
Parameters
----------
path : Path
Path to a pdb file.
Return
------
pdb_chains : dict
A dictionary of the pdb file accesible using chains as keys.
"""
pdb_chains: dict = {'chain_order': []}
# Read file
with open(path, 'r') as f:
# Loop over lines
for _ in f:
# Skip non ATOM / HETATM lines
if not any([
_.startswith('ATOM'),
_.startswith('HETATM'),
]):
continue
# Extract residue name
resname = _[slc_resname]
# Extract chain id
chainid = _[slc_chainid]
# Extract resid
resid = _[slc_resseq].strip()
# Check if chain already parsed
if chainid not in pdb_chains.keys():
# Add to ordered chains
pdb_chains['chain_order'].append(chainid)
# Initiate new chain holder
pdb_chains[chainid] = {'order': []}
# Check if new resid id
if resid not in pdb_chains[chainid].keys():
# Add to oredered resids
pdb_chains[chainid]['order'].append(resid)
# Initiate new residue holder
pdb_chains[chainid][resid] = {
'index': len(pdb_chains[chainid]['order']) - 1,
'resname': resname,
'chainid': chainid,
'resid': resid,
'position': len(pdb_chains[chainid]['order']),
'atoms_order': [],
'atoms': {},
}
# extract atome name
atname = _[slc_name].strip()
# check if not an hydrogen
if atname[0] == 'H':
continue
# extact atome coordinates
coords = extract_pdb_coords(_)
pdb_chains[chainid][resid]['atoms_order'].append(atname)
pdb_chains[chainid][resid]['atoms'][atname] = coords
return pdb_chains
[docs]
def extract_pdb_coords(line: str) -> list[float]:
"""Extract coordinated from a PDB line.
Parameters
----------
line : str
A strandard ATOM/HETATM pdb record.
Return
------
coords : list[float]
List of the X, Y and Z coordinate of this atom.
"""
x = float(line[slc_x].strip())
y = float(line[slc_y].strip())
z = float(line[slc_z].strip())
coords = [x, y, z]
return coords
[docs]
def get_ordered_coords(
pdb_chains: dict,
) -> tuple[list[list[float]], list[str], dict]:
"""Generate list of all atom coordinates.
Parameters
----------
pdb_chains : dict
A dictionary of the pdb file accesible using chains as keys,
as provided by the `extract_pdb_dt()` function.
Return
------
all_coords : list[list[float]]
All atomic coordinates in a single list.
resid_keys : list[str]
Ordered list of residues keys.
resid_dt : dict
Dictionary of coordinates indices for each residue.
"""
# Define holders
all_coords = []
resid_keys = []
resid_dt = {}
i = 0
# Loop over chains
for chainid in pdb_chains['chain_order']:
# Loop over residues of this chain
for resid in pdb_chains[chainid]['order']:
# create a resdiue key
resname = pdb_chains[chainid][resid]['resname']
reskey = f'{chainid}-{resid}-{resname}'
resdt = {
'atoms_indices': [],
'resname': resname,
'atoms_order': pdb_chains[chainid][resid]['atoms_order'],
}
# Loop over atoms of this residue
for atname in pdb_chains[chainid][resid]['atoms_order']:
# list of internal indices
resdt['atoms_indices'].append(i)
# index of a CA
if atname == 'CA':
resdt['CA'] = i
# Point atome submatrix coordinates index
all_coords.append(pdb_chains[chainid][resid]['atoms'][atname])
# increment atom index
i += 1
resid_dt[reskey] = resdt
resid_keys.append(reskey)
return (all_coords, resid_keys, resid_dt)
[docs]
def compute_distance_matrix(all_atm_coords: list[list[float]]) -> NDFloat:
"""Compute all vs all distance matrix.
Paramaters
----------
all_atm_coords : list[list[float]]
List of atomic coordinates.
Return
------
dist_matrix : NDFloat
N*N distance matrix between all coordinates.
"""
dist_matrix = squareform(pdist(all_atm_coords))
return dist_matrix
[docs]
def extract_submatrix(
matrix: NDFloat,
indices: list[int],
indices2: Optional[list[int]] = None,
) -> NDFloat:
"""Extract submatrix based on desired indices.
Paramaters
----------
matrix : NDFloat
A N*N matrix.
indices : list[int]
List of `row` indices to extract from this matrix
indices2 : list[int]
List of `columns` indices to extract from this matrix.
if unspecified, indices2 == indices and symetric matrix
is extracted.
Return
------
submat : NDFloat
The extracted submatrix.
"""
# Set second set of indices (columns) to first if not defined
if not indices2:
indices2 = indices
# extract submatrix
submat = matrix[np.ix_(indices, indices2)]
return submat
[docs]
def gen_contact_dt(
matrix: NDFloat,
resdt: dict,
res1_key: str,
res2_key: str,
) -> dict:
"""Generate contacts data.
Parameters
----------
matrix : NDFloat
The distance matrix.
resdt : dict
Residues data with atom indices as returned by `get_ordered_coords()`.
res1_key : str
First residue of interest.
res2_key : str
Second residue of interest
Return
------
cont_dt : dict
Dictionary holding contact data
"""
# point residues data
res1_dt = resdt[res1_key]
res2_dt = resdt[res2_key]
# point ca-ca dist
try:
ca_ca_dist = matrix[res1_dt['CA'], res2_dt['CA']]
except KeyError:
ca_ca_dist = 9999
# obtain submatrix
res1_res2_atm_submat = extract_submatrix(
matrix,
res1_dt['atoms_indices'],
res2_dt['atoms_indices'],
)
# obtain clostest contact
clostest_contact = min_dist(res1_res2_atm_submat)
# contact type
cont_type = get_cont_type(res1_dt['resname'], res2_dt['resname'])
# set return variable
cont_dt = {
'res1': res1_key,
'res2': res2_key,
'ca-ca-dist': round(ca_ca_dist, 1),
'shortest-dist': round(clostest_contact, 1),
'contact-type': cont_type,
}
return cont_dt
[docs]
def extract_heavyatom_contacts(
matrix: NDFloat,
resdt: dict,
res1_key: str,
res2_key: str,
contact_distance: float = 4.5,
) -> list[dict[str, Union[float, str]]]:
"""Generate contacts data.
Parameters
----------
matrix : NDFloat
The distance matrix.
resdt : dict
Residues data with atom indices as returned by `get_ordered_coords()`.
res1_key : str
First residue of interest.
res2_key : str
Second residue of interest.
contact_distance : float
Distance defining a contact.
Return
------
all_contacts : list[dict[str, Union[float, str]]]
List holding contact data
"""
all_contacts: list[dict[str, Union[float, str]]] = []
# point data for first residue
res1_indices = resdt[res1_key]['atoms_indices']
res1_atnames = resdt[res1_key]['atoms_order']
# point data for second residue
res2_indices = resdt[res2_key]['atoms_indices']
res2_atnames = resdt[res2_key]['atoms_order']
# Loop over res1 atoms / indices
for r1_atname, r1_atindex in zip(res1_atnames, res1_indices):
# Loop over res2 atoms / indices
for r2_atname, r2_atindex in zip(res2_atnames, res2_indices):
# Point corresponding distance in matrix
r1_r2_dist = matrix[r1_atindex, r2_atindex]
# Check if distance <= threshold
if r1_r2_dist <= contact_distance:
# Hold data
contactdt = {
'atom1': f'{res1_key}-{r1_atname}',
'atom2': f'{res2_key}-{r2_atname}',
'dist': r1_r2_dist,
}
all_contacts.append(contactdt)
return all_contacts
[docs]
def get_cont_type(resn1: str, resn2: str) -> str:
"""Generate polarity key between two residues.
Parameters
----------
resn1 : str
3 letters code of fist residue.
resn2 : str
3 letters code of second residue.
Return
------
pol_key : str
Combined residues polarities
"""
pol_keys: list[str] = []
for resn in [resn1, resn2]:
if resn.strip() in RESIDUE_POLARITY.keys():
pol_keys.append(RESIDUE_POLARITY[resn.strip()])
else:
pol_keys.append('unknow')
pol_key = '-'.join(pol_keys)
return pol_key
[docs]
def min_dist(matrix: NDFloat) -> float:
"""Find minimum value in a matrix."""
return np.min(matrix)
[docs]
def write_res_contacts(
res_res_contacts: list[dict],
header: list[str],
path: Union[Path, str],
sep: str = '\t',
interchain_data: Union[bool, dict] = None,
) -> Path:
"""Write a tsv file based on residues-residues contacts data.
Parameters
----------
res_res_contacts : list[dict]
List of dict holding data for each residue-residue contacts.
header : list[str]
Ordered list of keys to access in the dicts.
path : Path
Path to the output file to generate.
sep : str
Character used to separate data within a line.
Return
------
path : Path
Path to the generated file.
"""
# define README data type content
dttype_info = {
'res1': 'Chain-Resname-ResID key identifying first residue',
'res2': 'Chain-Resname-ResID key identifying second residue',
'ca-ca-dist': 'Observed distances between the two carbon alpha (Ca) atoms',
'ca-ca-cont-probability': 'Fraction of times a contact is observed under the ca-ca-dist threshold over all analysed models of the same cluster', # noqa : E501
'shortest-dist': 'Observed shortest distance between the two residues',
'shortest-cont-probability': 'Fraction of times a contact is observed under the shortest-dist threshold over all analysed models of the same cluster', # noqa : E501
'contact-type': 'ResidueType - ResidueType contact name',
'atom1': 'Chain-Resname-ResID-Atome key identifying first atom',
'atom2': 'Chain-Resname-ResID-Atome key identifying second atom',
'dist': 'Observed distance between two atoms',
'nb_dists': 'Total number of observed distances',
'avg_dist': 'Cluster average distance',
'std_dist': 'Cluster distance standard deviation',
}
# Check for inter chain contacts
gen_interchain_tsv: bool = False
if interchain_data and type(interchain_data) == dict:
expected_keys = ('path', 'contact_threshold', 'data_key', )
if all([k in interchain_data.keys() for k in expected_keys]):
gen_interchain_tsv = True
interchain_tsvdt: list[list[str]] = [header]
else:
raise KeyError
# initiate file content
tsvdt: list[list[str]] = [header]
for res_res_cont in res_res_contacts:
tsvdt.append([str(res_res_cont[h]) for h in header])
if gen_interchain_tsv:
chain1 = res_res_cont['res1'].split('-')[0]
chain2 = res_res_cont['res2'].split('-')[0]
if chain1 != chain2:
dist = res_res_cont[interchain_data['data_key']]
if dist < interchain_data['contact_threshold']:
interchain_tsvdt.append(tsvdt[-1])
tsv_str = '\n'.join([sep.join(_) for _ in tsvdt])
# generate commented lines to be placed on top of file
readme = [
'#' * 80,
'# This file contains extracted contacts half-matrix information',
'#' * 80,
'',
]
for head in header[::-1]:
readme.insert(2, f'# {head}: {dttype_info[head]}')
# Write file
with open(path, 'w') as tsvout:
tsvout.write('\n'.join(readme))
tsvout.write(tsv_str)
# Write inter chain file
if gen_interchain_tsv:
# Modify readme
readme[1] = readme[1].replace(
'contacts half-matrix',
'interchain contacts',
)
# Write file
with open(interchain_data['path'], 'w') as f:
f.write('\n'.join(readme))
# Write data string
f.write('\n'.join([sep.join(_) for _ in interchain_tsvdt]))
return path
[docs]
def tsv_to_heatmap(
tsv_path: Path,
sep: str = '\t',
data_key: str = 'ca-ca-dist',
contact_threshold: float = 7.5,
colorscale: str = 'Greys',
output_fname: Union[Path, str] = 'contacts.html',
offline: bool = False,
) -> Union[Path, str]:
"""Read a tsv file and generate a heatmap from it.
Paramters
---------
tsv_path : Path
Path a the .tsv file containing contact data.
sep : str
Separator character used to split data in each line.
data_key : str
Data key used to draw the plot.
contact_threshold : float
Upper boundary of maximum value to be plotted.
any value above it will be set to this value.
output_fname : Path
Path to the generated graph.
Return
------
output_filepath : Union[Path, str]
Path to the generated file.
"""
half_matrix: list[float] = []
labels: list[str] = []
header: Union[bool, list[str]] = None
with open(tsv_path, 'r') as f:
for line in f:
# skip commented lines
if line.startswith('#'):
continue
# split line
s_ = line.strip().split(sep)
# gather header
if not header:
header = s_
continue
# point labels
label1 = s_[header.index('res1')]
label2 = s_[header.index('res2')]
# Add them to set of labels
if label1 not in labels:
labels.append(label1)
if label2 not in labels:
labels.append(label2)
# point data
value = float(s_[header.index(data_key)])
# bound data to contact_threshold
bounded_value = min(value, contact_threshold)
# add it to matrix
half_matrix.append(bounded_value)
# Genereate full matrix
matrix = squareform(half_matrix)
# set data label
color_scale = datakey_to_colorscale(data_key, color_scale=colorscale)
if 'probability' in data_key:
data_label = 'probability'
np.fill_diagonal(matrix, 1)
else:
data_label = 'distance'
# Compute chains length
chains_length: dict[str, int] = {}
ordered_chains: list[str] = []
for label in labels:
chainid = label.split('-')[0]
if chainid not in chains_length.keys():
chains_length[chainid] = 0
ordered_chains.append(chainid)
chains_length[chainid] += 1
# Compute chains delineations positions
del_posi = [0]
for chainid in ordered_chains:
del_posi.append(del_posi[-1] + chains_length[chainid])
# Compute chains delineations lines
chains_limits: list[dict[str, float]] = []
for delpos in del_posi:
# Vertical lines
chains_limits.append({
"x0": delpos - 0.5,
"x1": delpos - 0.5,
"y0": -0.5,
"y1": len(labels) - 0.5,
})
# Horizontal lines
chains_limits.append({
"y0": delpos - 0.5,
"y1": delpos - 0.5,
"x0": -0.5,
"x1": len(labels) - 0.5,
})
# Generate hover template
hovertemplate = (
' %{y} ↭ %{x} <br>'
f' Contact {data_label}: %{{z}}'
'<extra></extra>'
)
# Generate heatmap
output_filepath = heatmap_plotly(
matrix,
labels={'color': data_label},
xlabels=labels,
ylabels=labels,
color_scale=color_scale,
output_fname=output_fname,
offline=offline,
delineation_traces=chains_limits,
hovertemplate=hovertemplate,
)
return output_filepath
[docs]
def datakey_to_colorscale(data_key: str, color_scale: str = 'Greys') -> str:
"""Convert color scale into reverse if data implies to do it.
data_key : str
A dictionary key pointing to data type.
color_scale : str
Name of a base plotpy color_scale.
Return
------
color_scale : str
Possibly the reverse name of the color_scale.
"""
return f'{color_scale}_r' if 'probability' not in data_key else color_scale
######################################
# Start of the chord chart functions #
######################################
[docs]
def moduloAB(val: float, lb: float, ub: float) -> float:
"""Map a real number onto the unit circle.
The unit circle is identified with the interval [lb, ub), ub-lb=2*PI.
Parameters
----------
val : float
The value to be mapped into the unit circle.
lb : float
The lower boundary.
ub : float
The upper boundary
Return
------
moduloab : float
The modulo of val between lb and ub
"""
if lb >= ub:
raise ValueError('Incorrect interval ends')
y = (val - lb) % (ub - lb)
moduloab = y + ub if y < 0 else y + lb
return moduloab
[docs]
def within_2PI(val: float) -> bool:
"""Check if float value is within unit circle value range.
Parameters
----------
val : float
The value to be tested.
"""
return 0 <= val < 2 * PI
[docs]
def check_square_matrix(data_matrix: NDArray) -> int:
"""Check if the matrix is a square one.
Parameters
----------
data_matrix : NDArray (2DArray)
The matrix to be checked.
Return
------
nb_rows : int
Number of rows in this matrix.
"""
matrixshape = data_matrix.shape
nb_rows = matrixshape[0]
if len(matrixshape) > 2:
raise ValueError('Data array must have only two dimensions')
if nb_rows != matrixshape[1]:
raise ValueError('Data array must have (n,n) shape')
return nb_rows
[docs]
def get_ideogram_ends(
ideogram_len: NDFloat,
gap: float,
) -> list[tuple[float, float]]:
"""Generate ideogram ends.
Paramaters
----------
ideogram_len : NDArray
Length of each ideograms.
gap : float
Gap to add in between each ideogram.
Return
------
ideo_ends : list[tuple[float]]
List of start and end position for each ideograms.
"""
ideo_ends: list[tuple[float, float]] = []
start = 0.0
for k in range(len(ideogram_len)):
end = float(start + ideogram_len[k])
ideo_ends.append((start, end))
# Increment new start by gap for next origin
start = end + gap
return ideo_ends
[docs]
def make_ideogram_arc(
radius: float,
_phi: tuple[float, float],
nb_points: float = 50,
) -> NDFloat:
"""Generate ideogran arc.
Parameters
----------
radius : float
The circle radius.
phi : tuple[float, float]
Tuple of ends angle coordinates of an arc.
nb_points : float
Parameter that controls the number of points to be evaluated on an arc
Return
------
arc_positions : NDArray
Array of 2D coorinates defining an arc.
"""
if not within_2PI(_phi[0]) or not within_2PI(_phi[1]):
phi = [moduloAB(t, 0, 2 * PI) for t in _phi]
else:
phi = [t for t in _phi]
length = (phi[1] - phi[0]) % (2 * PI)
nr = 5 if length <= (PI / 4) else int((nb_points * length) / PI)
if phi[0] < phi[1]:
theta = np.linspace(phi[0], phi[1], nr)
else:
theta = np.linspace(
moduloAB(phi[0], -PI, PI),
moduloAB(phi[1], -PI, PI),
nr,
)
arc_positions = radius * np.exp(1j * theta)
return arc_positions
[docs]
def make_ribbon_ends(
matrix: NDArray,
row_sum: list[int],
ideo_ends: list[tuple[float, float]],
L: int,
) -> list[list[tuple[float, float]]]:
"""Generate all connecting ribbons coordinates.
Parameters
----------
matrix : NDArray
The data matrix.
row_sum : list[int]
Number of connexions in each row.
ideo_ends : list[tuple[float, float]]
List of start and end position for each ideograms.
Returns
-------
ribbon_boundary : list[list[tuple[float, float]]]
Matrix of per residue ribbons start and end positions.
"""
ribbon_boundary: list[list[tuple[float, float]]] = []
for k, ideo_end in enumerate(ideo_ends):
# Point stating coordinates of this residue ideo
start = float(ideo_end[0])
# No ribbon to be formed
if row_sum[k] == 0:
# Add empty set of ribbons
ribbon_boundary.append([(0., 0.) for i in range(len(ideo_ends))])
continue
# Initiate row ribbons
row_ribbon_ends: list[tuple[float, float]] = []
# Compute increment
increment = (ideo_end[1] - start) / row_sum[k]
# Loop over positions
for j in range(1, L + 1):
# Skip if no ribbon to add for this k, j pair
if matrix[k][j - 1] == 0:
row_ribbon_ends.append((0., 0.))
continue
# Define end
end = float(start + increment)
# Hold data
row_ribbon_ends.append((start, end))
# Set next start to current end
start = end
# Add full row
ribbon_boundary.append(row_ribbon_ends)
return ribbon_boundary
[docs]
def control_pts(
angle: list[float],
radius: float,
) -> list[tuple[float, float]]:
"""Generate control points to draw a SVGpath.
Parameters
----------
angle : list[float]
A list containing angular coordinates of the control points b0, b1, b2.
radius : float
The distance from b1 to the origin O(0,0)
Returns
-------
control_points : list[tuple[float, float]]
The set of control points.
Raises
------
ValueError
Raised if the number of angular coordinates is not equal to 3.
"""
# Check number of angular coordinates
if len(angle) != 3:
raise ValueError('angle must have len = 3')
b_cplx = np.array([np.exp(1j * angle[k]) for k in range(3)])
# Give it its size
b_cplx[1] = radius * b_cplx[1]
# Generate control points as a list for two values
control_points = list(zip(b_cplx.real, b_cplx.imag))
return control_points
[docs]
def ctrl_rib_chords(
side1: tuple[float, float],
side2: tuple[float, float],
radius: float,
) -> list[list[tuple[float, float]]]:
"""Generate poligons points aiming at drawing ribbons.
Parameters
----------
side1 : tuple[float, float]
List of angular variables of the ribbon arc ends defining
the ribbon starting (ending) arc
side2 : tuple[float, float]
List of angular variables of the ribbon arc ends defining
the ribbon starting (ending) arc
radius : float, optional
Circle radius size
Returns
-------
list[list[tuple[float, float]]]
_description_
"""
if len(side1) != 2 or len(side2) != 2:
raise ValueError('the arc ends must be elements in a list of len 2')
poligons = [
control_pts(
[side1[j], (side1[j] + side2[j]) / 2, side2[j]],
radius,
)
for j in range(2)
]
return poligons
[docs]
def make_q_bezier(control_points: list[tuple[float, float]]) -> str:
"""Define the Plotly SVG path for a quadratic Bezier curve.
defined by the list of its control points.
Parameters
----------
control_points : list[tuple[float, float]]
List of control points
Return
------
svgpath : str
An SVG path
"""
if len(control_points) != 3:
raise ValueError('control poligon must have 3 points')
_a, _b, _c = control_points
svgpath = 'M ' + str(_a[0]) + ',' + str(_a[1]) + ' Q ' +\
str(_b[0]) + ', ' + str(_b[1]) + ' ' +\
str(_c[0]) + ', ' + str(_c[1])
return svgpath
[docs]
def make_ribbon_arc(theta0: float, theta1: float) -> str:
"""Generate a SVGpath to draw a ribbon arc.
Parameters
----------
theta0 : float
Starting angle value
theta1 : float
Ending angle value
Returns
-------
string_arc : str
A string representing the SVGpath of the ribbon arc.
Raises
------
ValueError
If provided theta0 and theta1 angles are incorrect for a ribbon.
ValueError
If the angle coordinates for an arc side of a ribbon are not
in the appropriate range [0, 2*pi]
"""
if within_2PI(theta0) and within_2PI(theta1):
if theta0 < theta1:
theta0 = moduloAB(theta0, -PI, PI)
theta1 = moduloAB(theta1, -PI, PI)
if theta0 * theta1 > 0:
raise ValueError('incorrect angle coordinates for ribbon')
nr = int(40 * (theta0 - theta1) / PI)
if nr <= 2:
nr = 3
theta = np.linspace(theta0, theta1, nr)
pts = np.exp(1j * theta) # points on arc in polar complex form
string_arc: str = ''
for k in range(len(theta)):
string_arc += f'L {str(pts.real[k])}, {str(pts.imag[k])} '
return string_arc
else:
raise ValueError('the angle coordinates for an arc side of a ribbon '
'must be in [0, 2*pi]')
[docs]
def make_layout(
title: str,
plot_size: float,
layout_shapes: list[dict],
) -> go.Layout:
"""Generate the chart layout.
Parameters
----------
title : str
Title to be given to the chart.
plot_size : float
Size of the chart.
layout_shapes : list[dict]
Shapes to be drawn.
Returns
-------
layout : go.Layout
The plotly layout.
"""
# Set axis parameters to hide axis line, grid, ticklabels and title
axis = {
'showline': False,
'zeroline': False,
'showgrid': False,
'showticklabels': False,
'title': '',
}
# Getenate the layout
layout = go.Layout(
title=title,
xaxis=axis,
yaxis=axis,
showlegend=True, # Important to show legend
# legend={'font': {'size': 10}}, # Lower font size
width=plot_size + 150, # +150 to accomodate legend / keep circle round
height=plot_size,
margin={"t": 25, "b": 25, "l": 25, "r": 25},
hovermode='closest',
shapes=layout_shapes,
)
return layout
[docs]
def make_ideo_shape(
path: str,
line_color: str,
fill_color: str,
) -> dict:
"""Generate data to draw a ideogram shape.
Parameters
----------
path : str
A SVGPath to be drawn.
line_color : str
Color of the shape boundary.
fill_color : str
Shape filling color fr the ribbon shape.
Returns
-------
dict
Data enabling to draw a ideogram shape in layout.
"""
return {
"line": {"color": line_color, "width": 0.45},
"path": path,
"type": 'path',
"fillcolor": fill_color,
"layer": 'below',
}
[docs]
def make_ribbon(
side1: tuple[float, float],
side2: tuple[float, float],
line_color: str,
fill_color: str,
radius: float = 0.2,
) -> dict:
"""Generate data to draw a ribbon.
Parameters
----------
side1 : list[float]
List of angular variables of first ribbon arc ends defining
the ribbon starting (ending) arc.
side2 : list[float]
List of angular variables of the other ribbon arc ends defining
the ribbon starting (ending) arc.
line_color : str
Color of the shape boundary.
fill_color : str
Shape filling color fr the ribbon shape.
radius : float, optional
Circle radius size, by default 0.2.
Returns
-------
dict
Data enabling to draw a ribbon in layout.
"""
poligon = ctrl_rib_chords(side1, side2, radius)
_b, _c = poligon
# Generate the SVGpath
path = make_q_bezier(_b)
path += make_ribbon_arc(side2[0], side2[1])
path += make_q_bezier(_c[::-1])
path += make_ribbon_arc(side1[1], side1[0])
return {
'line': {'color': line_color, 'width': 0.5},
'path': path,
'type': 'path',
'fillcolor': fill_color,
'layer': 'below',
}
[docs]
def invPerm(perm: list[int]) -> list[int]:
"""Generate the inverse of a permutation.
Parameters
----------
perm : _type_
A permutation.
Returns
-------
inv : list[int]
Inverse of a permutation.
"""
# Fill with zeros
inv = [0] * len(perm)
for i, s in enumerate(perm):
inv[s] = i
return inv
[docs]
def get_chains_ideograms_ends(
chains: dict[str, list[str]],
gap: float = 2 * PI * 0.005,
) -> tuple[list[tuple[float, float]], NDFloat]:
"""Build ideogram ends to represent protein chains.
Parameters
----------
chains : dict[str, list[str]]
Dictionary mapping chains with their respective set of residues labels.
gap : float, optional
Gap between two ideograms, by default 2*PI*0.005
Returns
-------
chain_ideo_ends : list[tuple[float, float]]
Ideogram ends to represent protein chains.
chain_ideogram_length : NDFloat
"""
chain_row_sum = [len(chains[chain])
for chain in sorted(chains, reverse=True)]
chain_ideogram_length = 2 * PI * np.asarray(chain_row_sum)
chain_ideogram_length /= sum(chain_row_sum)
chain_ideogram_length -= gap * np.ones(len(chain_row_sum))
chain_ideo_ends = get_ideogram_ends(chain_ideogram_length, gap)
return chain_ideo_ends, chain_ideogram_length
[docs]
def get_all_ideograms_ends(
chains: dict,
gap: float = 2 * PI * 0.005,
) -> tuple[list[tuple[float, float]], list[tuple[float, float]]]:
"""Generate both chain and residues ideograms ends.
Parameters
----------
chains : dict
Dictionary mapping to list of residues labels.
gap : float, optional
Gap distance used to separate two ideograms, by default 2*PI*0.005
Returns
-------
tuple[ideo_ends, chain_ideo_ends]
A tuple containing residues ideo ends and chains ideo ends.
ideo_ends : list[tuple[float, float]]
List of residues ideograms start and ending positions.
chain_ideo_ends : list[tuple[float, float]]
List of chain ideograms start and ending positions.
"""
chain_ideo_ends, chain_ideogram_length = get_chains_ideograms_ends(
chains,
gap=gap,
)
ideo_ends: list[tuple[float, float]] = []
left = 0
for ind, chain in enumerate(sorted(chains, reverse=True)):
chain_labels = chains[chain]
for _label in chain_labels:
right = left + (chain_ideogram_length[ind] / len(chain_labels))
ideo_ends.append((left, right))
left = right
left = right + gap
return ideo_ends, chain_ideo_ends
[docs]
def split_labels_by_chains(labels: list[str]) -> dict[str, list[str]]:
"""Map each label to its chain.
Parameters
----------
labels : list[str]
List of residues keys. e.g.: A-SER-123 (chain A, serine 123)
Returns
-------
chains : dict[str, list[str]]
Dictionary mapping chains with their respective set of residues labels.
"""
chains: dict[str, list] = {}
for lab in labels:
chain, resname, resid = lab.split('-')
if chain not in chains.keys():
chains[chain] = []
chains[chain].append(lab)
return chains
[docs]
def contacts_to_connect_matrix(
matrix: NDFloat,
labels: list[str],
) -> list[list[int]]:
""".
Parameters
----------
matrix : NDFloat
A square contact matrix.
labels : list[str]
List of labels corresponding row & columns entries.
Returns
-------
connect_matrix : list[list[Union[str, int]]]
The connectivity matrix without self contacts.
"""
connect_matrix: list[list[int]] = []
for ri, label_i in enumerate(labels):
# Point label chain name
chain1 = label_i.split('-')[0]
new_contact_mat_row: list[int] = []
# Loop over columns
for ci, label_j in enumerate(labels):
# Point label chain name
chain2 = label_j.split('-')[0]
if chain1 != chain2 and matrix[ri, ci] == 1:
interaction = 1
else:
interaction = 0
new_contact_mat_row.append(interaction)
# Hold row
connect_matrix.append(new_contact_mat_row)
return connect_matrix
[docs]
def to_nice_label(label: str) -> str:
"""Convert a label into a user friendly label.
Parameters
----------
label : str
Label name as found in csv
Returns
-------
nicelabel : str
User friendly description of the label.
"""
slabel = label.split('-')
nicelabel = f"Chain {slabel[0]}, residue {slabel[2]} {slabel[1]}"
return nicelabel
[docs]
def to_color_weight(
distance: float,
max_dist: float,
min_dist: float = 2.,
min_weight: float = 0.2,
max_weight: float = 0.90,
) -> float:
"""Compute color weight based on distance.
Parameters
----------
distance : float
The distance to weight.
max_dist : float
The max distance observed in the dataset.
min_dist : float, optional
The minumu, distance observed in the dataset, by default 2.
min_weight : float, optional
Color wight for the maximum distance, by default 0.2
max_weight : float, optional
Color wight for the minimum distance, by default 0.90
Returns
-------
weight : float
The color weight. in range [min_weight, max_weight]
"""
# Scale dist into minimum
dist = max(distance, min_dist)
# Compute probability
probability_dist = (dist - min_dist) / (max_dist - min_dist)
# Obtain corresponding weight
weight = ((min_weight - max_weight) * probability_dist) + max_weight
# Return rounded value of weight
return round(weight, 2)
[docs]
def to_rgba_color_string(
connect_color: tuple[int, int, int],
alpha: float,
) -> str:
"""Generate a rgba string from list of colors and alpha.
Parameters
----------
connect_color : list[int]
A 3-values list of integers defining the red, green and blue colors.
alpha : float
color_weight
Returns
-------
rgba_color : str
The html like rgba colors. e.g.: 'rgba(123, 123, 123, 0.5)'
"""
colors_str = ",".join([str(v) for v in connect_color])
rgba_color = f'rgba({colors_str},{alpha})'
return rgba_color
[docs]
def to_full_matrix(
half_matrix: list[Union[int, float, str]],
diag_val: Union[int, float, str],
) -> NDArray:
"""Generate a full matrix from a half matrix.
Parameters
----------
half_matrix : list[Any]
Values of the N*(N-1)/2 half matrix.
diag_val : Any
Value to be placed in diagonal of the full matrix.
Returns
-------
matrix : NDArry
The reconstituted full matrix.
"""
# Genereate full matrix from N*(N-1)/2 vector
matrix = squareform(half_matrix)
# Update diagonal with data
np.fill_diagonal(matrix, diag_val)
return matrix
[docs]
def make_chordchart(
_contact_matrix: list[list[int]],
_dist_matrix: list[list[float]],
_interttype_matrix: list[list[str]],
_labels: list[str],
gap: float = 2 * PI * 0.005,
output_fpath: Union[str, Path] = 'chordchart.html',
title: str = 'Chord diagram',
offline: bool = False,
) -> Union[str, Path]:
"""Generate a plotly chordchart graph.
Parameters
----------
_contact_matrix : list[list[int]]
The contact matrix
_dist_matrix : list[list[float]]
The distance matrix
_interttype_matrix : list[list[str]]
The interaction type matrix
_labels : list[str]
Labels of each matrix rows (and columns as supposed to be symetric)
gap : float, optional
Gap between two ideograms, by default 2*PI*0.005
output_fpath : Union[str, Path], optional
Path to the output file, by default 'chordchart.html'
title : str, optional
Title to give to the diagram, by default 'Chord diagram'
Returns
-------
output_fpath : Union[str, Path]
Path to the genereated output file.
"""
# Unpack matrices
contact_matrix = contacts_to_connect_matrix(_contact_matrix, _labels)
# Reverse data order so later graph displayed clockwise
matrix = np.array([ri[::-1] for ri in contact_matrix[::-1]])
dist_matrix = [ri[::-1] for ri in _dist_matrix[::-1]]
interttype_matrix = [ri[::-1] for ri in _interttype_matrix[::-1]]
labels = _labels[::-1]
# Check matrix shape
L = check_square_matrix(matrix)
# Map labels into respective chains
chains = split_labels_by_chains(labels)
# Set matrix of indices
idx_sort = [list(range(L)) for i in range(L)]
# Compute residues and chain ideograms positions
ideo_ends, chain_ideo_ends = get_all_ideograms_ends(chains, gap=gap)
# Compute number of connexion per residues
row_sum = [np.sum(matrix[k, :]) for k in range(L)]
# Compute connexion ribbons positions
ribbon_ends = make_ribbon_ends(matrix, row_sum, ideo_ends, L)
# Initiate shape holder
layout_shapes: list[dict] = []
# Initiate ribbon info holder
ribbon_info: list[go.Scatter] = []
# Loop over entries
for k, label1 in enumerate(labels):
sigma = idx_sort[k]
sigma_inv = invPerm(sigma)
# Half matrix loop to avoid duplicates
for j in range(k, L):
# No data to draw
if matrix[k][j] == 0 and matrix[j][k] == 0:
continue
# Obtain ribbon color for this interaction type
try:
connect_color = CONNECT_COLORS[interttype_matrix[k][j]]
except KeyError:
connect_color = (123, 123, 123)
color_weight = to_color_weight(dist_matrix[k][j], 9.5)
rgba_color = to_rgba_color_string(connect_color, color_weight)
# Point ribbons data
side1 = ribbon_ends[k][sigma_inv[j]]
eta = idx_sort[j]
eta_inv = invPerm(eta)
side2 = ribbon_ends[j][eta_inv[k]]
zi = 0.9 * np.exp(1j * (side1[0] + side1[1]) / 2)
zf = 0.9 * np.exp(1j * (side2[0] + side2[1]) / 2)
# reverse interaction type for second label
s_intertype = interttype_matrix[k][j].split('-')
rev_s_intertype = s_intertype[::-1]
rev_interttype = '-'.join(rev_s_intertype)
# Obtain nice labels
nicelabel1 = to_nice_label(label1)
nicelabel2 = to_nice_label(labels[j])
# texti and textf are the strings that will be displayed when
# hovering the mouse over the two ribbon ends
texti = f'{nicelabel1} {rev_interttype} with {nicelabel2}'
textf = f'{nicelabel2} {interttype_matrix[j][k]} with {nicelabel1}' # noqa : E501
# Generate interactive labels
for zv, text in zip([zi, zf], [texti, textf]):
# Generate ribbon info
ribbon_info.append(
go.Scatter(
x=[zv.real],
y=[zv.imag],
mode='markers',
marker={"size": 0.5, "color": rgba_color},
text=text,
hoverinfo='text',
showlegend=False,
)
)
# Note: must reverse these arc ends to avoid twisted ribbon
side2_rev = (side2[1], side2[0])
# Append the ribbon shape
layout_shapes.append(
make_ribbon(
side1,
side2_rev,
'rgba(175,175,175)',
rgba_color,
)
)
ideograms: list[go.Scatter] = []
# Draw ideograms for residues
for k, label in enumerate(labels):
z = make_ideogram_arc(1.1, ideo_ends[k])
zi = make_ideogram_arc(1.0, ideo_ends[k])
# Point residue name
resname = label.split('-')[2]
# Point corresponding color
try:
rescolor = AA_DNA_RNA_COLORS[resname.strip()]
except KeyError:
rescolor = 'rgba(123, 123, 123, 0.7)'
# Build textual info
text_info = f'{to_nice_label(label)}<br>'
if row_sum[k] == 0:
text_info += 'No interaction'
else:
text_info += f'Total of {row_sum[k]:d} interaction'
if row_sum[k] >= 2:
text_info += 's'
# Add info
ideograms.append(
go.Scatter(
x=z.real,
y=z.imag,
mode='lines',
line={
"color": rescolor,
"shape": 'spline',
"width": 0.25,
},
text=text_info,
hoverinfo='text',
showlegend=False,
)
)
# Build corresponding SVG path
m = len(z)
svgpath = 'M '
for s in range(m):
svgpath += f'{str(z.real[s])}, {str(z.imag[s])} L '
Zi = np.array(zi.tolist()[::-1])
for s in range(m):
svgpath += f'{str(Zi.real[s])}, {str(Zi.imag[s])} L '
svgpath += f'{str(z.real[0])}, {str(z.imag[0])}'
# Hold it
layout_shapes.append(
make_ideo_shape(
svgpath,
'rgba(150,150,150)',
rescolor,
)
)
# Draw ideograms for chains
for k, chainid in enumerate(sorted(chains, reverse=True)):
z = make_ideogram_arc(1.2, chain_ideo_ends[k])
zi = make_ideogram_arc(1.11, chain_ideo_ends[k])
m = len(z)
ideograms.append(
go.Scatter(
x=z.real,
y=z.imag,
mode='lines',
line={
"color": CHAIN_COLORS[k],
"shape": 'spline',
"width": 0.25,
},
text=f'Chain {chainid}',
hoverinfo='text',
showlegend=False,
)
)
# Build corresponding SVG path
svgpath = 'M '
for s in range(m):
svgpath += f'{str(z.real[s])}, {str(z.imag[s])} L '
Zi = np.array(zi.tolist()[::-1])
for s in range(m):
svgpath += f'{str(Zi.real[s])}, {str(Zi.imag[s])} L '
svgpath += f'{str(z.real[0])}, {str(z.imag[0])}'
layout_shapes.append(
make_ideo_shape(
svgpath,
'rgba(150,150,150)',
CHAIN_COLORS[k],
)
)
# Compute figure size
fig_size = 100 * np.log(L * L)
# Create plotly layout
layout = make_layout(title, fig_size, layout_shapes)
# combine all data info
data = ideograms + ribbon_info
# Generate the figure
fig = go.Figure(data=data, layout=layout)
# Fine tune figure
fig.update_layout(
plot_bgcolor='white',
)
# Add legend(s)
add_chordchart_legends(fig)
# Write it as html file
fig_to_html(
fig,
output_fpath,
figure_height=fig_size,
figure_width=fig_size,
offline=offline,
)
return output_fpath
[docs]
def add_chordchart_legends(fig: go.Figure) -> None:
"""Add custom legend to chordchart.
Parameters
----------
fig : go.Figure
A plotly figure.
"""
# Add connection types legends
for key_key, color in REVERSED_CONNECT_COLORS_KEYS.items():
# Create dummy traces
fig.add_trace(
go.Scatter(
x=[None],
y=[None],
legendgroup="connect_color",
legendgrouptitle_text="Interaction types",
showlegend=True,
name=key_key.replace('-', '↭'),
mode="lines",
marker={
"color": to_rgba_color_string(color, 0.75),
"size": 10,
"symbol": "line-ew-open",
},
)
)
# Add unknown interaction type
fig.add_trace(
go.Scatter(
x=[None],
y=[None],
legendgroup="connect_color",
legendgrouptitle_text="Interaction types",
showlegend=True,
name="Unknown",
mode="lines",
marker={
"color": to_rgba_color_string((111, 111, 111), 0.75),
"size": 10,
"symbol": "line-ew-open",
},
)
)
# Add aa types legend
for aa, rgba_color in RESIDUES_COLORS.items():
# Create dummy traces
fig.add_trace(
go.Scatter(
x=[None],
y=[None],
legendgroup="aa_color",
legendgrouptitle_text="Residues/Bases",
showlegend=True,
name=aa,
mode="lines",
marker={
"color": rgba_color,
"size": 10,
"symbol": "line-ew-open",
},
)
)
# Add nucleobases legend
for na, rgba_color in DNARNA_COLORS.items():
# Create dummy traces
fig.add_trace(
go.Scatter(
x=[None],
y=[None],
legendgroup="aa_color",
legendgrouptitle_text="Residues/Bases",
showlegend=True,
name=na,
mode="lines",
marker={
"color": rgba_color,
"size": 10,
"symbol": "line-ew-open",
},
)
)
# Add unknown type
fig.add_trace(
go.Scatter(
x=[None],
y=[None],
legendgroup="aa_color",
legendgrouptitle_text="Residues/Bases",
showlegend=True,
name="Unknown",
mode="lines",
marker={
"color": to_rgba_color_string((111, 111, 111), 0.75),
"size": 10,
"symbol": "line-ew-open",
},
)
)
[docs]
def tsv_to_chordchart(
tsv_path: Path,
sep: str = '\t',
data_key: str = 'ca-ca-dist',
contact_threshold: float = 7.5,
filter_intermolecular_contacts: bool = True,
output_fname: Union[Path, str] = 'contacts_chordchart.html',
title: str = 'Chord diagram',
offline: bool = False,
) -> Union[Path, str]:
"""Read a tsv file and generate a chord diagram from it.
Paramters
---------
tsv_path : Path
Path a the .tsv file containing contact data.
sep : str
Separator character used to split data in each line.
data_key : str
Data key used to draw the plot.
contact_threshold : float
Upper boundary of maximum value to be plotted.
any value above it will be set to this value.
output_fname : Union[Path, str]
Path where to generate the graph.
title : str
Title to give to the Chord diagram
Return
------
chord_chart_fpath : Union[Path, str]
Path to the generated graph
"""
# Initiate holders
half_contact_matrix: list[int] = []
half_value_matrix: list[float] = []
half_intertype_matrix: list[str] = []
labels: list[str] = []
header: Union[bool, list[str]] = None
# Read tsv file
with open(tsv_path, 'r') as f:
for line in f:
# skip commented lines
if line.startswith('#'):
continue
# split line
s_ = line.strip().split(sep)
# gather header
if not header:
header = s_
continue
# point labels
label1 = s_[header.index('res1')]
label2 = s_[header.index('res2')]
# Add them to set of labels
if label1 not in labels:
labels.append(label1)
if label2 not in labels:
labels.append(label2)
# point data
value = float(s_[header.index(data_key)])
# check if in contact
contact = 1 if value <= contact_threshold else 0
# Point interaction type
inter_type = s_[header.index('contact-type')]
# add it to matrix
half_contact_matrix.append(contact)
half_value_matrix.append(value)
half_intertype_matrix.append(inter_type)
# Genereate full matrices
contact_matrix = to_full_matrix(half_contact_matrix, 1)
dist_matrix = to_full_matrix(half_value_matrix, 0.)
intertype_matrix = to_full_matrix(half_intertype_matrix, 'self-self')
# Check if must get only the intermolecular contacts submatrix
if filter_intermolecular_contacts:
# Filter positions where: intermolecular contacts + dist <= threshold
intmol_cont_labels: list[str] = []
for ri, label1 in enumerate(labels):
label1_chain = label1.split('-')[0]
for ci, label2 in enumerate(labels):
label2_chain = label2.split('-')[0]
# Skip same chains
if label1_chain == label2_chain:
continue
# Check contacts
if contact_matrix[ri, ci] == 1:
intmol_cont_labels += [label1, label2]
# Obtain sorted subset
nodoubles_intmol_cont_labels = list(set(intmol_cont_labels))
sorted_intmol_cont_labels = sorted(
nodoubles_intmol_cont_labels,
key=lambda k: labels.index(k),
)
# Get indices to be extracted
sorted_indices = [labels.index(k) for k in sorted_intmol_cont_labels]
# Get submatrix
contact_submatrix = extract_submatrix(contact_matrix, sorted_indices)
dist_submatrix = extract_submatrix(dist_matrix, sorted_indices)
intert_submatrix = extract_submatrix(intertype_matrix, sorted_indices)
sublabels = sorted_intmol_cont_labels
else:
contact_submatrix = contact_matrix
dist_submatrix = dist_matrix
intert_submatrix = intertype_matrix
sublabels = labels
# Generate chord chart
chord_chart_fpath = make_chordchart(
contact_submatrix,
dist_submatrix,
intert_submatrix,
sublabels,
output_fpath=output_fname,
title=title,
offline=offline,
)
return chord_chart_fpath