Skip to content

Commit

Permalink
Enhance local module lookup (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
khatchad authored Jul 23, 2024
1 parent 205beee commit 099932d
Show file tree
Hide file tree
Showing 13 changed files with 179 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
package com.ibm.wala.cast.python.parser;

import static com.ibm.wala.cast.python.util.Util.MODULE_INITIALIZATION_ENTITY_NAME;
import static com.ibm.wala.cast.python.util.Util.PYTHON_FILE_EXTENSION;

import com.ibm.wala.cast.python.ir.PythonCAstToIRTranslator;
import com.ibm.wala.cast.python.util.Util;
Expand All @@ -32,7 +33,9 @@
import java.net.URL;
import java.nio.file.Path;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;
Expand Down Expand Up @@ -140,22 +143,6 @@ private CAstNode createImportNode(
.collect(Collectors.toList()));
}

/**
* Returns the {@link Path} corresponding to the given {@link SourceModule}. If a {@link
* SourceModule} is not supplied, an {@link IllegalStateException} is thrown.
*
* @param module The {@link SourceModule} for which to extract a {@link Path}.
* @return The {@link Path} corresponding to the given {@link SourceModule}.
* @throws IllegalStateException If the given {@link SourceModule} is not present.
*/
private static Path getPath(Optional<SourceModule> module) {
return module
.map(SourceModule::getURL)
.map(URL::getFile)
.map(Path::of)
.orElseThrow(IllegalStateException::new);
}

@Override
public CAstNode visitImportFrom(ImportFrom importFrom) throws Exception {
Optional<String> s =
Expand Down Expand Up @@ -336,12 +323,35 @@ protected Reader getReader() throws IOException {
return new InputStreamReader(fileName.getInputStream());
}

private boolean isLocalModule(String moduleName) {
boolean ret =
localModules.stream()
.map(lm -> scriptName((SourceModule) lm))
.anyMatch(sn -> sn.endsWith("/" + moduleName + ".py"));
/**
* Returns the {@link Path} corresponding to the given {@link SourceModule}. If a {@link
* SourceModule} is not supplied, an {@link IllegalStateException} is thrown.
*
* @param module The {@link SourceModule} for which to extract a {@link Path}.
* @return The {@link Path} corresponding to the given {@link SourceModule}.
* @throws IllegalStateException If the given {@link SourceModule} is not present.
*/
private static Path getPath(Optional<SourceModule> module) {
return module
.map(SourceModule::getURL)
.map(URL::getFile)
.map(Path::of)
.orElseThrow(IllegalStateException::new);
}

/**
* Get the {@link Path} of the parsed {@link SourceModule}.
*
* @see getPath(Optional<SourceModule>)
* @return The {@link Path} corresponding to the parsed {@link SourceModule}.
*/
@SuppressWarnings("unused")
private Path getPath() {
return getPath(Optional.of(this.fileName));
}

private boolean isLocalModule(String moduleName) {
boolean ret = this.getLocalModule(moduleName).isPresent();
LOGGER.finer("Module: " + moduleName + (ret ? " is" : " isn't") + " local.");
return ret;
}
Expand All @@ -353,12 +363,30 @@ private boolean isLocalModule(String moduleName) {
* @return The corresponding {@link SourceModule}.
*/
private Optional<SourceModule> getLocalModule(String moduleName) {
return localModules.stream()
.filter(
lm -> {
String scriptName = scriptName((SourceModule) lm);
return scriptName.endsWith("/" + moduleName + ".py");
})
// A map of paths to known local modules.
Map<String, SourceModule> pathToLocalModule = new HashMap<>();

for (SourceModule module : this.localModules) {
String scriptName = scriptName(module);
pathToLocalModule.put(scriptName, module);
}

// first, check the current directory, i.e., the directory where the import statement is
// executed. If the module is found here, the search stops.
String scriptName = scriptName();
String scriptDirectory = scriptName.substring(0, scriptName.lastIndexOf('/') + 1);
String moduleFileName = moduleName + "." + PYTHON_FILE_EXTENSION;
String modulePath = scriptDirectory + moduleFileName;
SourceModule module = pathToLocalModule.get(modulePath);

if (module != null) return Optional.of(module);

// otherwise, go through the local modules. NOTE: Should instead traverse PYTHONPATH here per
// https://g.co/gemini/share/310ca39fbd43. However, the problem is that the local modules may
// not be on disk. As such, this is our best approximation.
return pathToLocalModule.keySet().stream()
.filter(p -> p.endsWith(moduleFileName))
.map(p -> pathToLocalModule.get(p))
.findFirst();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2976,6 +2976,76 @@ public void testModule78()
test(new String[] {"module.py", "client2.py"}, "module.py", "f", "", 1, 1, new int[] {2});
}

/** Test https://github.com/wala/ML/issues/209. */
@Test
public void testModule79()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test(
new String[] {
"proj73/models/__init__.py",
"proj73/models/albert.py",
"proj73/bert.py",
"proj73/models/bert.py",
"proj73/client.py"
},
"models/albert.py",
"f",
"proj73",
1,
1,
new int[] {2});

test(
new String[] {
"proj73/models/__init__.py",
"proj73/models/albert.py",
"proj73/bert.py",
"proj73/models/bert.py",
"proj73/client.py"
},
"models/bert.py",
"g",
"proj73",
1,
1,
new int[] {2});
}

/** Test https://github.com/wala/ML/issues/209. */
@Test
public void testModule80()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test(
new String[] {
"proj74/models/__init__.py",
"proj74/models/albert.py",
"proj74/bert.py",
"proj74/models/bert.py",
"proj74/client.py"
},
"models/albert.py",
"f",
"proj74",
1,
1,
new int[] {2});

test(
new String[] {
"proj74/models/__init__.py",
"proj74/models/albert.py",
"proj74/bert.py",
"proj74/models/bert.py",
"proj74/client.py"
},
"models/bert.py",
"g",
"proj74",
1,
1,
new int[] {2});
}

@Test
public void testStaticMethod() throws ClassHierarchyException, CancelException, IOException {
test("tf2_test_static_method.py", "MyClass.the_static_method", 1, 1, 2);
Expand Down
2 changes: 2 additions & 0 deletions com.ibm.wala.cast.python.test/.pydevproject
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,7 @@
<path>/${PROJECT_DIR_NAME}/data/proj70</path>
<path>/${PROJECT_DIR_NAME}/data/proj71</path>
<path>/${PROJECT_DIR_NAME}/data/proj72</path>
<path>/${PROJECT_DIR_NAME}/data/proj73</path>
<path>/${PROJECT_DIR_NAME}/data/proj74</path>
</pydev_pathproperty>
</pydev_project>
1 change: 1 addition & 0 deletions com.ibm.wala.cast.python.test/data/proj73/bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# From https://github.com/kyzhouhzau/NLPGNN/blob/b9ecec2c6df1b3e40a54511366dcb6085cf90c34/tests/GNN/BERT-TextGCN/bert.py.
7 changes: 7 additions & 0 deletions com.ibm.wala.cast.python.test/data/proj73/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Test https://github.com/wala/ML/issues/209.
import tensorflow as tf
from models import f
from models import g

f(tf.constant(1))
g(tf.constant(1))
4 changes: 4 additions & 0 deletions com.ibm.wala.cast.python.test/data/proj73/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# From https://github.com/kyzhouhzau/NLPGNN/blob/b9ecec2c6df1b3e40a54511366dcb6085cf90c34/nlpgnn/models/__init__.py.

from .albert import *
from .bert import *
7 changes: 7 additions & 0 deletions com.ibm.wala.cast.python.test/data/proj73/models/albert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# From https://github.com/kyzhouhzau/NLPGNN/blob/b9ecec2c6df1b3e40a54511366dcb6085cf90c34/nlpgnn/models/albert.py.kk

import tensorflow as tf


def f(a):
assert isinstance(a, tf.Tensor)
7 changes: 7 additions & 0 deletions com.ibm.wala.cast.python.test/data/proj73/models/bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# From https://github.com/kyzhouhzau/NLPGNN/blob/b9ecec2c6df1b3e40a54511366dcb6085cf90c34/nlpgnn/models/bert.py.

import tensorflow as tf


def g(a):
assert isinstance(a, tf.Tensor)
1 change: 1 addition & 0 deletions com.ibm.wala.cast.python.test/data/proj74/bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# From https://github.com/kyzhouhzau/NLPGNN/blob/b9ecec2c6df1b3e40a54511366dcb6085cf90c34/tests/GNN/BERT-TextGCN/bert.py.
7 changes: 7 additions & 0 deletions com.ibm.wala.cast.python.test/data/proj74/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Test https://github.com/wala/ML/issues/209.
import tensorflow as tf
from models import f
from models import g

f(tf.constant(1))
g(tf.constant(1))
4 changes: 4 additions & 0 deletions com.ibm.wala.cast.python.test/data/proj74/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# From https://github.com/kyzhouhzau/NLPGNN/blob/b9ecec2c6df1b3e40a54511366dcb6085cf90c34/nlpgnn/models/__init__.py.

from models.albert import *
from models.bert import *
7 changes: 7 additions & 0 deletions com.ibm.wala.cast.python.test/data/proj74/models/albert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# From https://github.com/kyzhouhzau/NLPGNN/blob/b9ecec2c6df1b3e40a54511366dcb6085cf90c34/nlpgnn/models/albert.py.kk

import tensorflow as tf


def f(a):
assert isinstance(a, tf.Tensor)
7 changes: 7 additions & 0 deletions com.ibm.wala.cast.python.test/data/proj74/models/bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# From https://github.com/kyzhouhzau/NLPGNN/blob/b9ecec2c6df1b3e40a54511366dcb6085cf90c34/nlpgnn/models/bert.py.

import tensorflow as tf


def g(a):
assert isinstance(a, tf.Tensor)

0 comments on commit 099932d

Please sign in to comment.