Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance local module lookup #123

Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
package com.ibm.wala.cast.python.parser;

import static com.ibm.wala.cast.python.util.Util.MODULE_INITIALIZATION_ENTITY_NAME;
import static com.ibm.wala.cast.python.util.Util.PYTHON_FILE_EXTENSION;

import com.ibm.wala.cast.python.ir.PythonCAstToIRTranslator;
import com.ibm.wala.cast.python.util.Util;
Expand All @@ -32,7 +33,9 @@
import java.net.URL;
import java.nio.file.Path;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;
Expand Down Expand Up @@ -140,22 +143,6 @@ private CAstNode createImportNode(
.collect(Collectors.toList()));
}

/**
* Returns the {@link Path} corresponding to the given {@link SourceModule}. If a {@link
* SourceModule} is not supplied, an {@link IllegalStateException} is thrown.
*
* @param module The {@link SourceModule} for which to extract a {@link Path}.
* @return The {@link Path} corresponding to the given {@link SourceModule}.
* @throws IllegalStateException If the given {@link SourceModule} is not present.
*/
private static Path getPath(Optional<SourceModule> module) {
return module
.map(SourceModule::getURL)
.map(URL::getFile)
.map(Path::of)
.orElseThrow(IllegalStateException::new);
}

@Override
public CAstNode visitImportFrom(ImportFrom importFrom) throws Exception {
Optional<String> s =
Expand Down Expand Up @@ -336,12 +323,35 @@ protected Reader getReader() throws IOException {
return new InputStreamReader(fileName.getInputStream());
}

private boolean isLocalModule(String moduleName) {
boolean ret =
localModules.stream()
.map(lm -> scriptName((SourceModule) lm))
.anyMatch(sn -> sn.endsWith("/" + moduleName + ".py"));
/**
* Returns the {@link Path} corresponding to the given {@link SourceModule}. If a {@link
* SourceModule} is not supplied, an {@link IllegalStateException} is thrown.
*
* @param module The {@link SourceModule} for which to extract a {@link Path}.
* @return The {@link Path} corresponding to the given {@link SourceModule}.
* @throws IllegalStateException If the given {@link SourceModule} is not present.
*/
private static Path getPath(Optional<SourceModule> module) {
return module
.map(SourceModule::getURL)
.map(URL::getFile)
.map(Path::of)
.orElseThrow(IllegalStateException::new);
}

/**
* Get the {@link Path} of the parsed {@link SourceModule}.
*
* @see getPath(Optional<SourceModule>)
* @return The {@link Path} corresponding to the parsed {@link SourceModule}.
*/
@SuppressWarnings("unused")
private Path getPath() {
return getPath(Optional.of(this.fileName));
}

private boolean isLocalModule(String moduleName) {
boolean ret = this.getLocalModule(moduleName).isPresent();
LOGGER.finer("Module: " + moduleName + (ret ? " is" : " isn't") + " local.");
return ret;
}
Expand All @@ -353,12 +363,30 @@ private boolean isLocalModule(String moduleName) {
* @return The corresponding {@link SourceModule}.
*/
private Optional<SourceModule> getLocalModule(String moduleName) {
return localModules.stream()
.filter(
lm -> {
String scriptName = scriptName((SourceModule) lm);
return scriptName.endsWith("/" + moduleName + ".py");
})
// A map of paths to known local modules.
Map<String, SourceModule> pathToLocalModule = new HashMap<>();

for (SourceModule module : this.localModules) {
String scriptName = scriptName(module);
pathToLocalModule.put(scriptName, module);
}

// first, check the current directory, i.e., the directory where the import statement is
// executed. If the module is found here, the search stops.
String scriptName = scriptName();
String scriptDirectory = scriptName.substring(0, scriptName.lastIndexOf('/') + 1);
String moduleFileName = moduleName + "." + PYTHON_FILE_EXTENSION;
String modulePath = scriptDirectory + moduleFileName;
SourceModule module = pathToLocalModule.get(modulePath);

if (module != null) return Optional.of(module);

// otherwise, go through the local modules. NOTE: Should instead traverse PYTHONPATH here per
// https://g.co/gemini/share/310ca39fbd43. However, the problem is that the local modules may
// not be on disk. As such, this is our best approximation.
return pathToLocalModule.keySet().stream()
.filter(p -> p.endsWith(moduleFileName))
.map(p -> pathToLocalModule.get(p))
.findFirst();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2976,6 +2976,76 @@ public void testModule78()
test(new String[] {"module.py", "client2.py"}, "module.py", "f", "", 1, 1, new int[] {2});
}

/** Test https://github.com/wala/ML/issues/209. */
@Test
public void testModule79()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test(
new String[] {
"proj73/models/__init__.py",
"proj73/models/albert.py",
"proj73/bert.py",
"proj73/models/bert.py",
"proj73/client.py"
},
"models/albert.py",
"f",
"proj73",
1,
1,
new int[] {2});

test(
new String[] {
"proj73/models/__init__.py",
"proj73/models/albert.py",
"proj73/bert.py",
"proj73/models/bert.py",
"proj73/client.py"
},
"models/bert.py",
"g",
"proj73",
1,
1,
new int[] {2});
}

/** Test https://github.com/wala/ML/issues/209. */
@Test
public void testModule80()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test(
new String[] {
"proj74/models/__init__.py",
"proj74/models/albert.py",
"proj74/bert.py",
"proj74/models/bert.py",
"proj74/client.py"
},
"models/albert.py",
"f",
"proj74",
1,
1,
new int[] {2});

test(
new String[] {
"proj74/models/__init__.py",
"proj74/models/albert.py",
"proj74/bert.py",
"proj74/models/bert.py",
"proj74/client.py"
},
"models/bert.py",
"g",
"proj74",
1,
1,
new int[] {2});
}

@Test
public void testStaticMethod() throws ClassHierarchyException, CancelException, IOException {
test("tf2_test_static_method.py", "MyClass.the_static_method", 1, 1, 2);
Expand Down
2 changes: 2 additions & 0 deletions com.ibm.wala.cast.python.test/.pydevproject
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,7 @@
<path>/${PROJECT_DIR_NAME}/data/proj70</path>
<path>/${PROJECT_DIR_NAME}/data/proj71</path>
<path>/${PROJECT_DIR_NAME}/data/proj72</path>
<path>/${PROJECT_DIR_NAME}/data/proj73</path>
<path>/${PROJECT_DIR_NAME}/data/proj74</path>
</pydev_pathproperty>
</pydev_project>
1 change: 1 addition & 0 deletions com.ibm.wala.cast.python.test/data/proj73/bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# From https://github.com/kyzhouhzau/NLPGNN/blob/b9ecec2c6df1b3e40a54511366dcb6085cf90c34/tests/GNN/BERT-TextGCN/bert.py.
7 changes: 7 additions & 0 deletions com.ibm.wala.cast.python.test/data/proj73/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Test https://github.com/wala/ML/issues/209.
import tensorflow as tf
from models import f
from models import g

f(tf.constant(1))
g(tf.constant(1))
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# From https://github.com/kyzhouhzau/NLPGNN/blob/b9ecec2c6df1b3e40a54511366dcb6085cf90c34/nlpgnn/models/__init__.py.

from .albert import *
from .bert import *
7 changes: 7 additions & 0 deletions com.ibm.wala.cast.python.test/data/proj73/models/albert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# From https://github.com/kyzhouhzau/NLPGNN/blob/b9ecec2c6df1b3e40a54511366dcb6085cf90c34/nlpgnn/models/albert.py.kk

import tensorflow as tf


def f(a):
assert isinstance(a, tf.Tensor)
7 changes: 7 additions & 0 deletions com.ibm.wala.cast.python.test/data/proj73/models/bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# From https://github.com/kyzhouhzau/NLPGNN/blob/b9ecec2c6df1b3e40a54511366dcb6085cf90c34/nlpgnn/models/bert.py.

import tensorflow as tf


def g(a):
assert isinstance(a, tf.Tensor)
1 change: 1 addition & 0 deletions com.ibm.wala.cast.python.test/data/proj74/bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# From https://github.com/kyzhouhzau/NLPGNN/blob/b9ecec2c6df1b3e40a54511366dcb6085cf90c34/tests/GNN/BERT-TextGCN/bert.py.
7 changes: 7 additions & 0 deletions com.ibm.wala.cast.python.test/data/proj74/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Test https://github.com/wala/ML/issues/209.
import tensorflow as tf
from models import f
from models import g

f(tf.constant(1))
g(tf.constant(1))
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# From https://github.com/kyzhouhzau/NLPGNN/blob/b9ecec2c6df1b3e40a54511366dcb6085cf90c34/nlpgnn/models/__init__.py.

from models.albert import *
from models.bert import *
7 changes: 7 additions & 0 deletions com.ibm.wala.cast.python.test/data/proj74/models/albert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# From https://github.com/kyzhouhzau/NLPGNN/blob/b9ecec2c6df1b3e40a54511366dcb6085cf90c34/nlpgnn/models/albert.py.kk

import tensorflow as tf


def f(a):
assert isinstance(a, tf.Tensor)
7 changes: 7 additions & 0 deletions com.ibm.wala.cast.python.test/data/proj74/models/bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# From https://github.com/kyzhouhzau/NLPGNN/blob/b9ecec2c6df1b3e40a54511366dcb6085cf90c34/nlpgnn/models/bert.py.

import tensorflow as tf


def g(a):
assert isinstance(a, tf.Tensor)
Loading