[pypy-commit] pypy jvm-improvements: Enforce casting string arguments with rjvm.native_string

benol noreply at buildbot.pypy.org
Wed Jun 6 23:54:59 CEST 2012


Author: Michal Bendowski <michal at bendowski.pl>
Branch: jvm-improvements
Changeset: r55451:a5f570e4506b
Date: 2012-06-06 23:53 +0200
http://bitbucket.org/pypy/pypy/changeset/a5f570e4506b/

Log:	Enforce casting string arguments with rjvm.native_string

diff --git a/pypy/annotation/model.py b/pypy/annotation/model.py
--- a/pypy/annotation/model.py
+++ b/pypy/annotation/model.py
@@ -645,9 +645,6 @@
         if witness.contains(s_val):
             return T
 
-    if isinstance(s_val, SomeString):
-        return ootype.String
-
     if info is None:
         info = ''
     else:
diff --git a/pypy/module/jvm/interp_helpers.py b/pypy/module/jvm/interp_helpers.py
--- a/pypy/module/jvm/interp_helpers.py
+++ b/pypy/module/jvm/interp_helpers.py
@@ -2,7 +2,8 @@
 from pypy.interpreter.error import OperationError
 from pypy.interpreter.typedef import TypeDef
 from pypy.rlib import rjvm, rstring
-from pypy.rlib.rjvm import java
+from pypy.rlib.rjvm import java, native_string
+
 
 class W_JvmObject(Wrappable):
     """
@@ -124,8 +125,9 @@
         return str(b_type.getName())
 
 def class_for_name(space, class_name):
+    b_class_name = native_string(class_name)
     try:
-        return java.lang.Class.forName(class_name)
+        return java.lang.Class.forName(b_class_name)
     except rjvm.ReflectionException:
         raise OperationError(space.w_TypeError,
                              space.wrap("Class %s not found!" % class_name))
diff --git a/pypy/module/jvm/interp_level.py b/pypy/module/jvm/interp_level.py
--- a/pypy/module/jvm/interp_level.py
+++ b/pypy/module/jvm/interp_level.py
@@ -3,7 +3,7 @@
 from pypy.module.jvm import interp_helpers as helpers
 from pypy.module.jvm.interp_helpers import W_JvmObject
 from pypy.rlib import rjvm
-from pypy.rlib.rjvm import java
+from pypy.rlib.rjvm import java, native_string
 
 # ============== Interp-level module API ==============
 
@@ -106,7 +106,7 @@
     b_java_class = helpers.class_for_name(space, class_name)
 
     try:
-        b_meth = b_java_class.getMethod(method_name, types)
+        b_meth = b_java_class.getMethod(native_string(method_name), types)
     except rjvm.ReflectionException:
         raise helpers.raise_type_error(space,
                          "No method called %s found in class %s" % (method_name, str(b_java_class.getName())))
@@ -126,9 +126,10 @@
     """
     b_java_class = helpers.class_for_name(space, class_name)
     args, types = helpers.get_args_types(space, args_w)
+    b_method_name = native_string(method_name)
 
     try:
-        b_meth = b_java_class.getMethod(method_name, types)
+        b_meth = b_java_class.getMethod(b_method_name, types)
     except rjvm.ReflectionException:
         raise helpers.raise_type_error(space,
                          "No method called %s found in class %s" % (method_name, str(b_java_class.getName())))
@@ -308,8 +309,9 @@
     """
     The logic behind get_(static)_field_value.
     """
+    b_field_name = native_string(field_name)
     try:
-        b_field = b_class.getField(field_name)
+        b_field = b_class.getField(b_field_name)
     except rjvm.ReflectionException:
         raise helpers.raise_type_error(space, "No field called %s in class %s" % (
             field_name, str(b_class.getName())))
@@ -324,8 +326,9 @@
     """
     The logic behind set_(static)_field_value.
     """
+    b_field_name = native_string(field_name)
     try:
-        b_field = b_class.getField(field_name)
+        b_field = b_class.getField(b_field_name)
     except rjvm.ReflectionException:
         msg = "No field called %s in class %s" % (field_name, str(b_class.getName()))
         raise helpers.raise_type_error(space, msg)
diff --git a/pypy/rlib/rjvm/api.py b/pypy/rlib/rjvm/api.py
--- a/pypy/rlib/rjvm/api.py
+++ b/pypy/rlib/rjvm/api.py
@@ -146,6 +146,8 @@
             return self._unwrap_item(item._array)
         elif isinstance(item, jvm_str):
             return str(item)
+        elif isinstance(item, str):
+            raise TypeError("You have to wrap strings using rjvm.native_string!")
         return item
 
     def __call__(self, *args):
diff --git a/pypy/rlib/rjvm/test/test_rjvm.py b/pypy/rlib/rjvm/test/test_rjvm.py
--- a/pypy/rlib/rjvm/test/test_rjvm.py
+++ b/pypy/rlib/rjvm/test/test_rjvm.py
@@ -2,7 +2,7 @@
 import pypy.annotation.model as annmodel
 import pypy.rlib.rjvm as rjvm
 from pypy.rlib import rstring, rarithmetic
-from pypy.rlib.rjvm import java
+from pypy.rlib.rjvm import java, native_string
 from pypy.rpython.test.tool import BaseRtypingTest, OORtypeMixin
 from pypy.annotation.annrpython import RPythonAnnotator
 import pypy.translator.jvm.rjvm_support as rjvm_support
@@ -40,7 +40,7 @@
     al = java.util.ArrayList()
     assert isinstance(al, rjvm.JvmInstanceWrapper)
     assert isinstance(al.add, rjvm.JvmMethodWrapper)
-    al.add("test")
+    al.add(native_string("test"))
     assert str(al.get(0)) == "test"
 
 def test_class_repr():
@@ -53,12 +53,12 @@
 
 def test_invalid_method_name():
     al = java.util.ArrayList()
-    al.add("test")
+    al.add(native_string("test"))
     with py.test.raises(AttributeError):
         al.typo(0)
 
 def test_interpreted_reflection():
-    al_class = java.lang.Class.forName("java.util.ArrayList")
+    al_class = java.lang.Class.forName(native_string("java.util.ArrayList"))
     assert isinstance(al_class, rjvm.JvmInstanceWrapper)
     assert isinstance(java.util.Collection.class_, rjvm.JvmInstanceWrapper)
 
@@ -76,7 +76,7 @@
     assert isinstance(al, rjvm.JvmInstanceWrapper)
     assert isinstance(al.add, rjvm.JvmMethodWrapper)
 
-    al_clear = al_class.getMethod('clear', [])
+    al_clear = al_class.getMethod(native_string('clear'), [])
     assert isinstance(al_clear, rjvm.JvmInstanceWrapper)
     assert isinstance(al_clear.invoke, rjvm.JvmMethodWrapper)
 
@@ -87,10 +87,10 @@
     assert al.isEmpty()
     assert al.size() == 0
 
-    al_add = al_class.getMethod('add', [java.lang.Object.class_])
+    al_add = al_class.getMethod(native_string('add'), [java.lang.Object.class_])
     assert isinstance(al_add, rjvm.JvmInstanceWrapper)
     assert isinstance(al_add.invoke, rjvm.JvmMethodWrapper)
-    al_add.invoke(al, ["Hello"])
+    al_add.invoke(al, [native_string("Hello")])
     assert str(al.get(0)) == "Hello"
 
 
@@ -117,7 +117,7 @@
     def test_returning_string_as_object(self):
         def fn():
             al = java.util.ArrayList()
-            al.add('foobar')
+            al.add(native_string('foobar'))
             str_as_obj = al.get(0)
             str_as_jstr = rjvm.downcast(java.lang.String, str_as_obj)
             return str_as_jstr
@@ -154,7 +154,7 @@
             elif x == 2:
                 v = rjvm.upcast(java.lang.Object, java.lang.Boolean(True))
             else:
-                v = rjvm.upcast(java.lang.Object, rjvm.native_string('foobar'))
+                v = rjvm.upcast(java.lang.Object, native_string('foobar'))
             return v.toString()
 
         a = RPythonAnnotator()
@@ -168,7 +168,7 @@
             elif x == 2:
                 v = java.lang.Boolean(True)
             else:
-                v = rjvm.native_string('foobar')
+                v = native_string('foobar')
             return v.toString()
 
         a = RPythonAnnotator()
@@ -184,7 +184,7 @@
 
     def test_constructor_args(self):
         def fn():
-            sb = java.lang.StringBuilder('foobar')
+            sb = java.lang.StringBuilder(native_string('foobar'))
         res = self.interpret(fn, [])
         assert res is None
 
@@ -206,7 +206,7 @@
     def test_method_call_no_overload(self):
         def fn():
             t = java.lang.Thread()
-            t.setName('foo')
+            t.setName(native_string('foo'))
             return str(t.getName())
         res = self.ll_to_string(self.interpret(fn, []))
         assert res == 'foo'
@@ -214,7 +214,7 @@
     def test_method_call_overload(self):
         def fn():
             sb = java.lang.StringBuilder()
-            sb.append('foo ')
+            sb.append(native_string('foo '))
             sb.append(7)
             return str(sb.toString())
         res = self.ll_to_string(self.interpret(fn, []))
@@ -238,7 +238,8 @@
 
     def test_static_method_no_overload(self):
         def fn():
-            return java.lang.Integer.bitCount(5), str(java.util.regex.Pattern.compile('abc').toString())
+            pattern = java.util.regex.Pattern.compile(native_string('abc'))
+            return java.lang.Integer.bitCount(5), str(pattern.toString())
         (a,b) = self.ll_unpack_tuple(self.interpret(fn, []), 2)
         assert a == 2
         assert self.ll_to_string(b) == 'abc'
@@ -252,9 +253,9 @@
     def test_collections(self):
         def fn():
             array_list = java.util.ArrayList()
-            array_list.add("one")
-            array_list.add("two")
-            array_list.add("three")
+            array_list.add(native_string("one"))
+            array_list.add(native_string("two"))
+            array_list.add(native_string("three"))
             return array_list.size()
 
         res = self.interpret(fn, [])
@@ -288,7 +289,7 @@
 
     def test_array_result(self):
         def fn():
-            ms = java.lang.Class.forName('java.lang.Object').getMethods()
+            ms = java.lang.Class.forName(native_string('java.lang.Object')).getMethods()
             i = 0
             for m in xrange(len(ms)):
                 i += 1
@@ -310,7 +311,7 @@
 
     def test_reflection_for_name(self):
         def fn():
-            al_class = java.lang.Class.forName('java.util.ArrayList')
+            al_class = java.lang.Class.forName(native_string('java.util.ArrayList'))
             return str(al_class.getName())
 
         res = self.interpret(fn, [])
@@ -327,7 +328,7 @@
     def test_reflection_get_empty_constructor(self):
 
         def fn():
-            al_class = java.lang.Class.forName('java.util.ArrayList')
+            al_class = java.lang.Class.forName(native_string('java.util.ArrayList'))
             c = al_class.getConstructor(rjvm.new_array(java.lang.Class, 0))
             return c.getModifiers()
 
@@ -336,7 +337,7 @@
 
     def test_reflection_get_int_constructor(self):
         def fn():
-            al_class = java.lang.Class.forName('java.util.ArrayList')
+            al_class = java.lang.Class.forName(native_string('java.util.ArrayList'))
             c = al_class.getConstructor([java.lang.Integer.TYPE])
             return c.getModifiers()
 
@@ -345,7 +346,7 @@
 
     def test_reflection_get_collection_constructor_class_literal(self):
         def fn():
-            al_class = java.lang.Class.forName('java.util.ArrayList')
+            al_class = java.lang.Class.forName(native_string('java.util.ArrayList'))
             c = al_class.getConstructor([java.util.Collection.class_])
             return c.getModifiers()
 
@@ -354,9 +355,9 @@
 
     def test_reflection_get_collection_constructor_dynamic(self):
         def fn():
-            al_class = java.lang.Class.forName('java.util.ArrayList')
+            al_class = java.lang.Class.forName(native_string('java.util.ArrayList'))
             types = rjvm.new_array(java.lang.Class, 1)
-            types[0] = java.lang.Class.forName('java.util.Collection')
+            types[0] = java.lang.Class.forName(native_string('java.util.Collection'))
             c = al_class.getConstructor(types)
             return c.getModifiers()
 
@@ -365,7 +366,7 @@
 
     def test_reflection_instance_creation_no_args(self):
         def fn1():
-            al_class = java.lang.Class.forName('java.util.ArrayList')
+            al_class = java.lang.Class.forName(native_string('java.util.ArrayList'))
             c = al_class.getConstructor(rjvm.new_array(java.lang.Class, 0))
             object_al = c.newInstance(rjvm.new_array(java.lang.Object, 0))
             al = rjvm.downcast(java.util.ArrayList, object_al)
@@ -376,7 +377,7 @@
 
     def test_reflection_instance_creation_array_args(self):
         def fn():
-            al_class = java.lang.Class.forName('java.util.ArrayList')
+            al_class = java.lang.Class.forName(native_string('java.util.ArrayList'))
             c = al_class.getConstructor([java.lang.Integer.TYPE])
             args_array = rjvm.new_array(java.lang.Object, 1)
             args_array[0] = java.lang.Integer.valueOf(15)
@@ -389,7 +390,7 @@
 
     def test_reflection_instance_creation_arraylist_args(self):
         def fn():
-            al_class = java.lang.Class.forName('java.util.ArrayList')
+            al_class = java.lang.Class.forName(native_string('java.util.ArrayList'))
             c = al_class.getConstructor([java.lang.Integer.TYPE])
             args = java.util.ArrayList()
             args.add(java.lang.Integer.valueOf(15))
@@ -406,8 +407,8 @@
             al = java.util.ArrayList()
             o = java.lang.Object()
             al.add(o)
-            al_class = java.lang.Class.forName('java.util.ArrayList')
-            size_meth = al_class.getMethod('size', rjvm.new_array(java.lang.Class, 0))
+            al_class = java.lang.Class.forName(native_string('java.util.ArrayList'))
+            size_meth = al_class.getMethod(native_string('size'), rjvm.new_array(java.lang.Class, 0))
             size = rjvm.downcast(java.lang.Integer, size_meth.invoke(al, rjvm.new_array(java.lang.Object, 0)))
             return size.intValue()
 
@@ -417,7 +418,7 @@
     def test_returning_string_as_object(self):
         def fn():
             al = java.util.ArrayList()
-            al.add('foobar')
+            al.add(native_string('foobar'))
             str_as_obj = al.get(0)
             str_as_jstr = rjvm.downcast(java.lang.String, str_as_obj)
             return str(str_as_jstr)
@@ -427,11 +428,11 @@
 
     def test_reflection_static_field(self):
         def fn():
-            system_class = java.lang.Class.forName('java.lang.System')
-            out_field = system_class.getField('out')
+            system_class = java.lang.Class.forName(native_string('java.lang.System'))
+            out_field = system_class.getField(native_string('out'))
             dummy = java.lang.Object()
             out = out_field.get(dummy)
-            hashcode_meth = java.lang.Class.forName('java.lang.Object').getMethod('hashCode', rjvm.new_array(java.lang.Class, 0))
+            hashcode_meth = java.lang.Class.forName(native_string('java.lang.Object')).getMethod(native_string('hashCode'), rjvm.new_array(java.lang.Class, 0))
             res_as_obj = hashcode_meth.invoke(out, rjvm.new_array(java.lang.Object, 0))
             res_as_integer = rjvm.downcast(java.lang.Integer, res_as_obj)
             return res_as_integer.intValue()
@@ -440,13 +441,13 @@
         assert isinstance(res, int)
 
     def test_method_name(self):
-        def fn(s):
-            cls = java.lang.Class.forName(s)
+        def fn():
+            cls = java.lang.Class.forName(native_string('java.lang.Object'))
             m = cls.getMethods()[0]
             name = m.getName()
             return str(name)
 
-        res = self.interpret(fn, [self.string_to_ll('java.lang.Object')])
+        res = self.interpret(fn, [])
         assert isinstance(self.ll_to_string(res), str)
 
     def test_str_on_strings(self):
@@ -474,7 +475,7 @@
             elif x == 2:
                 v = java.lang.Boolean(True)
             else:
-                v = java.lang.String('foobar')
+                v = java.lang.String(native_string('foobar'))
 
             return 1 if v else 0
 
@@ -496,13 +497,13 @@
         def fn():
             args_w = [(java.awt.Point(), 'java.awt.Point')]
 
-            b_java_cls = java.lang.Class.forName('java.awt.Point')
+            b_java_cls = java.lang.Class.forName(native_string('java.awt.Point'))
             types = rjvm.new_array(java.lang.Class, 1)
             args = rjvm.new_array(java.lang.Object, 1)
 
             for i, w_arg_type in enumerate(args_w):
                 w_arg, w_type = w_arg_type
-                type_name = w_type
+                type_name = native_string(w_type)
                 b_arg = w_arg
                 types[i] = java.lang.Class.forName(type_name)
                 args[i] = b_arg
@@ -535,7 +536,7 @@
 
     def test_comparing_classes(self):
         def fn():
-            c1 = java.lang.Class.forName('java.lang.String')
+            c1 = java.lang.Class.forName(native_string('java.lang.String'))
             c2 = java.lang.String.class_
             return c1 == c2
 
@@ -545,7 +546,7 @@
     def test_native_strings(self):
         def fn():
             obj_array = rjvm.new_array(java.lang.Object, 3)
-            obj_array[1] = rjvm.native_string('foobar')
+            obj_array[1] = native_string('foobar')
             b_str = rjvm.downcast(java.lang.String, obj_array[1])
             return str(b_str)
 
@@ -555,7 +556,7 @@
     def test_exceptions_static_call(self):
         def fn():
             try:
-                java.lang.Class.forName('foobar')
+                java.lang.Class.forName(native_string('foobar'))
                 return False
             except rjvm.ReflectionException:
                return True
diff --git a/pypy/translator/jvm/rjvm_support/utils.py b/pypy/translator/jvm/rjvm_support/utils.py
--- a/pypy/translator/jvm/rjvm_support/utils.py
+++ b/pypy/translator/jvm/rjvm_support/utils.py
@@ -74,8 +74,7 @@
     def _can_convert_from_to(self, arg1, arg2):
         # Just the simplest logic for now:
         if isinstance(arg2, ootypemodel.NativeRJvmInstance) and arg2.class_name == 'java.lang.Object':
-            # TODO: autoboxing?
-            return isinstance(arg1, ootypemodel.NativeRJvmInstance) or arg1 == ootype.String
+            return isinstance(arg1, ootypemodel.NativeRJvmInstance)
         return super(JvmOverloadingResolver, self)._can_convert_from_to(arg1, arg2)
 
     def _get_refclass(self, meth):


More information about the pypy-commit mailing list