Skip to content

Commit 1100299

Browse files
committed
Doing all neighbor indices upfront
1 parent a84609b commit 1100299

2 files changed

Lines changed: 71 additions & 58 deletions

File tree

biojava-structure/src/main/java/org/biojava/nbio/structure/asa/AsaCalculator.java

Lines changed: 59 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ public void run() {
104104
private int nThreads;
105105
private Point3d[] spherePoints;
106106
private double cons;
107-
private List<Contact> contacts;
107+
private int[][] neighborIndices;
108108

109109
private boolean useSpatialHashingForNeighbors;
110110

@@ -250,6 +250,12 @@ public double[] calculateAsas() {
250250

251251
double[] asas = new double[atomCoords.length];
252252

253+
if (useSpatialHashingForNeighbors) {
254+
neighborIndices = findNeighborIndicesSpatialHashing();
255+
} else {
256+
neighborIndices = findNeighborIndices();
257+
}
258+
253259
if (nThreads<=1) { // (i.e. it will also be 1 thread if 0 or negative number specified)
254260
for (int i=0;i<atomCoords.length;i++) {
255261
asas[i] = calcSingleAsa(i);
@@ -333,69 +339,73 @@ private Point3d[] generateSpherePoints(int nSpherePoints) {
333339
}
334340

335341
/**
336-
* Returns list of indices of atoms within probe distance to atom k.
337-
* @param k index of atom for which we want neighbor indices
338-
* @return the indices of neighboring atoms
342+
* Returns the 2-dimensional array with neighbor indices for every atom.
343+
* @return 2-dimensional array of size: n_atoms x n_neighbors_per_atom
339344
*/
340-
Integer[] findNeighborIndices(int k) {
341-
// looking at a typical protein case, number of neighbours are from ~10 to ~50, with an average of ~30
342-
// Thus 40 seems to be a good compromise for the starting capacity
343-
ArrayList<Integer> neighbor_indices = new ArrayList<>(40);
345+
int[][] findNeighborIndices() {
344346

345-
double radius = radii[k] + probe + probe;
347+
int[][] nbsIndices = new int[atomCoords.length][];
346348

347-
for (int i=0;i<atomCoords.length;i++) {
348-
if (i==k) continue;
349+
for (int k=0; k<atomCoords.length; k++) {
350+
double radius = radii[k] + probe + probe;
349351

350-
double dist = atomCoords[i].distance(atomCoords[k]);
352+
List<Integer> thisNbIndices = new ArrayList<>();
351353

352-
if (dist < radius + radii[i]) {
353-
neighbor_indices.add(i);
354+
for (int i = 0; i < atomCoords.length; i++) {
355+
if (i == k) continue;
356+
357+
double dist = atomCoords[i].distance(atomCoords[k]);
358+
359+
if (dist < radius + radii[i]) {
360+
thisNbIndices.add(i);
361+
}
354362
}
355363

364+
int[] indicesArray = new int[thisNbIndices.size()];
365+
for (int i=0;i<thisNbIndices.size();i++) indicesArray[i] = thisNbIndices.get(i);
366+
nbsIndices[k] = indicesArray;
356367
}
357-
358-
Integer[] indicesArray = new Integer[neighbor_indices.size()];
359-
indicesArray = neighbor_indices.toArray(indicesArray);
360-
return indicesArray;
368+
return nbsIndices;
361369
}
362370

363371
/**
364-
* Returns list of indices of atoms within probe distance to atom k,
372+
* Returns the 2-dimensional array with neighbor indices for every atom,
365373
* using spatial hashing to avoid all to all distance calculation.
366-
* @param k index of atom for which we want neighbor indices
367-
* @return the indices of neighboring atoms
374+
* @return 2-dimensional array of size: n_atoms x n_neighbors_per_atom
368375
*/
369-
Integer[] findNeighborIndicesSpatialHashing(int k) {
376+
int[][] findNeighborIndicesSpatialHashing() {
370377

371-
if (contacts == null) {
372-
contacts = calcContacts();
373-
}
378+
int[][] nbsIndices = new int[atomCoords.length][];
374379

375-
// looking at a typical protein case, number of neighbours are from ~10 to ~50, with an average of ~30
376-
// Thus 40 seems to be a good compromise for the starting capacity
377-
ArrayList<Integer> neighbor_indices = new ArrayList<>(40);
380+
List<Contact> contactList = calcContacts();
378381

379-
double radius = radii[k] + probe + probe;
382+
for (int k=0; k<atomCoords.length; k++) {
383+
double radius = radii[k] + probe + probe;
380384

381-
for (Contact contact : contacts) {
382-
double dist = contact.getDistance();
383-
int i;
384-
if (contact.getJ() == k) {
385-
i = contact.getI();
386-
} else if (contact.getI() == k) {
387-
i = contact.getJ();
388-
} else {
389-
continue;
390-
}
391-
if (dist < radius + radii[i]) {
392-
neighbor_indices.add(i);
385+
List<Integer> thisNbIndices = new ArrayList<>();
386+
387+
// TODO make this the outer loop
388+
for (Contact contact : contactList) {
389+
double dist = contact.getDistance();
390+
int i;
391+
if (contact.getJ() == k) {
392+
i = contact.getI();
393+
} else if (contact.getI() == k) {
394+
i = contact.getJ();
395+
} else {
396+
continue;
397+
}
398+
if (dist < radius + radii[i]) {
399+
thisNbIndices.add(i);
400+
}
393401
}
402+
403+
int[] indicesArray = new int[thisNbIndices.size()];
404+
for (int i=0;i<thisNbIndices.size();i++) indicesArray[i] = thisNbIndices.get(i);
405+
nbsIndices[k] = indicesArray;
394406
}
395407

396-
Integer[] indicesArray = new Integer[neighbor_indices.size()];
397-
indicesArray = neighbor_indices.toArray(indicesArray);
398-
return indicesArray;
408+
return nbsIndices;
399409
}
400410

401411
Point3d[] getAtomCoords() {
@@ -405,6 +415,7 @@ Point3d[] getAtomCoords() {
405415
private List<Contact> calcContacts() {
406416
double maxRadius = maxValue(radii);
407417
double cutoff = maxRadius + maxRadius + probe + probe;
418+
logger.debug("Max radius is {}, cutoff is {}", maxRadius, cutoff);
408419
Grid grid = new Grid(cutoff);
409420
grid.addCoords(atomCoords);
410421
return grid.getIndicesContacts();
@@ -422,13 +433,9 @@ private static double maxValue(double[] array) {
422433

423434
private double calcSingleAsa(int i) {
424435
Point3d atom_i = atomCoords[i];
425-
Integer[] neighbor_indices;
426-
if (useSpatialHashingForNeighbors) {
427-
neighbor_indices = findNeighborIndicesSpatialHashing(i);
428-
} else {
429-
neighbor_indices = findNeighborIndices(i);
430-
}
431-
int n_neighbor = neighbor_indices.length;
436+
437+
int n_neighbor = neighborIndices[i].length;
438+
int[] neighbor_indices = neighborIndices[i];
432439
int j_closest_neighbor = 0;
433440
double radius = probe + radii[i];
434441

@@ -579,7 +586,7 @@ private static double getRadiusForNucl(NucleotideImpl nuc, Atom atom) {
579586
*
580587
* If atom is neither part of a nucleotide nor of a standard aminoacid,
581588
* the default vdw radius for the element is returned. If atom is of
582-
* unknown type (element) the vdw radius of {@link #Element().N} is returned
589+
* unknown type (element) the vdw radius of {@link Element().N} is returned
583590
*
584591
* @param atom
585592
* @return

biojava-structure/src/test/java/org/biojava/nbio/structure/asa/TestAsaCalc.java

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,19 +91,21 @@ public void testNeighborIndicesFinding() throws StructureException, IOException
9191
AsaCalculator.DEFAULT_PROBE_SIZE,
9292
1000, 1, false);
9393

94-
for (int indexToTest =0; indexToTest < asaCalc.getAtomCoords().length; indexToTest++) {
95-
//int indexToTest = 198;
94+
int[][] allNbsSh = asaCalc.findNeighborIndicesSpatialHashing();
9695

97-
Integer[] nbsSh = asaCalc.findNeighborIndicesSpatialHashing(indexToTest);
96+
int[][] allNbs = asaCalc.findNeighborIndices();
9897

99-
Integer[] nbs = asaCalc.findNeighborIndices(indexToTest);
98+
for (int indexToTest =0; indexToTest < asaCalc.getAtomCoords().length; indexToTest++) {
99+
//int indexToTest = 198;
100+
int[] nbsSh = allNbsSh[indexToTest];
101+
int[] nbs = allNbs[indexToTest];
100102

101103
int countNotInNbs = 0;
102104
List<Integer> listOfMatchingIndices = new ArrayList<>();
103105
for (int i = 0; i < nbsSh.length; i++) {
104106
boolean contained = false;
105107
for (int j = 0; j < nbs.length; j++) {
106-
if (nbs[j].equals(nbsSh[i])) {
108+
if (nbs[j] == nbsSh[i]) {
107109
listOfMatchingIndices.add(j);
108110
contained = true;
109111
break;
@@ -142,7 +144,7 @@ public void testPerformance() throws StructureException, IOException {
142144
Structure structure = StructureIO.getStructure("4F5X");
143145
Chain c = structure.getPolyChainByPDB("W");
144146
Atom[] atoms = StructureTools.getAllAtomArray(c);
145-
System.out.printf("Total of %d atoms\n", atoms.length);
147+
System.out.printf("Total of %d atoms. n(n-1)/2= %d \n", atoms.length, atoms.length*(atoms.length-1)/2);
146148

147149
int nThreads = 1;
148150
// 1. WITH SPATIAL HASHING
@@ -164,6 +166,8 @@ public void testPerformance() throws StructureException, IOException {
164166
double withSH = totAtoms;
165167
System.out.printf("Total ASA is %6.2f \n", totAtoms);
166168

169+
//System.out.println("Distances calculated: " + asaCalc.distancesCalculated);
170+
167171

168172
// 2. WITHOUT SPATIAL HASHING
169173
start = System.currentTimeMillis();
@@ -183,6 +187,8 @@ public void testPerformance() throws StructureException, IOException {
183187
double withoutSH = totAtoms;
184188
System.out.printf("Total ASA is %6.2f \n", totAtoms);
185189

190+
//System.out.println("Distances calculated: " + asaCalc.distancesCalculated);
191+
186192
assertEquals(withoutSH, withSH, 0.000001);
187193

188194
}

0 commit comments

Comments
 (0)