From 8164f10590b192fe2b2d4f5f5803d24470444a28 Mon Sep 17 00:00:00 2001 From: Love Waern Date: Mon, 27 May 2024 08:45:49 +0200 Subject: [PATCH] Type Equality/Compatibility Revamp Taken from #213, but with some refinements --- py/dml/codegen.py | 18 +- py/dml/ctree.py | 66 +++++--- py/dml/ctree_test.py | 2 + py/dml/structure.py | 6 +- py/dml/traits.py | 4 +- py/dml/types.py | 319 +++++++++++++++++++++++++++--------- test/1.4/errors/T_ECAST.dml | 6 +- 7 files changed, 300 insertions(+), 121 deletions(-) diff --git a/py/dml/codegen.py b/py/dml/codegen.py index 1fe97b5a2..b5978c065 100644 --- a/py/dml/codegen.py +++ b/py/dml/codegen.py @@ -1306,18 +1306,7 @@ def expr_cast(tree, location, scope): for (site, _) in struct_defs: report(EANONSTRUCT(site, "'cast' expression")) - if (compat.dml12_misc in dml.globals.enabled_compat - and isinstance(expr, InterfaceMethodRef)): - # Workaround for SIMICS-9868 - return mkLit(tree.site, "%s->%s" % ( - expr.node_expr.read(), expr.method_name), type) - - if isinstance(expr, NonValue) and ( - not isinstance(expr, NodeRef) - or not isinstance(safe_realtype(type), TTrait)): - raise expr.exc() - else: - return mkCast(tree.site, expr, type) + return mkCast(tree.site, expr, type) @expression_dispatcher def expr_undefined(tree, location, scope): @@ -2008,9 +1997,8 @@ def mk_sym(name, typ, mkunique=not dml.globals.debuggable): for (name, typ) in decls: sym = mk_sym(name, typ) tgt_typ = safe_realtype_shallow(typ) - if tgt_typ.const: - nonconst_typ = tgt_typ.clone() - nonconst_typ.const = False + if shallow_const(tgt_typ): + nonconst_typ = safe_realtype_unconst(tgt_typ) tgt_sym = mk_sym('_tmp_' + name, nonconst_typ, True) sym.init = ExpressionInitializer(mkLocalVariable(stmt.site, tgt_sym)) diff --git a/py/dml/ctree.py b/py/dml/ctree.py index 12a5acf9d..52756f564 100644 --- a/py/dml/ctree.py +++ b/py/dml/ctree.py @@ -1222,7 +1222,7 @@ def mkIfExpr(site, cond, texpr, fexpr): (texpr, fexpr, utype) = usual_int_conv( texpr, ttype, fexpr, ftype) else: - if not compatible_types(ttype, ftype): + if not compatible_types_fuzzy(ttype, ftype): raise EBINOP(site, ':', texpr, fexpr) # TODO: in C, the rules are more complex, # but our type system is too primitive to cover that @@ -1396,7 +1396,7 @@ def make(cls, site, lh, rh): if ((lhtype.is_arith and rhtype.is_arith) or (isinstance(lhtype, (TPtr, TArray)) and isinstance(rhtype, (TPtr, TArray)) - and compatible_types(lhtype.base, rhtype.base))): + and compatible_types_fuzzy(lhtype.base, rhtype.base))): return cls.make_simple(site, lh, rh) raise EILLCOMP(site, lh, lhtype, rh, rhtype) @@ -1601,7 +1601,7 @@ def make(cls, site, lh, rh): if ((lhtype.is_arith and rhtype.is_arith) or (isinstance(lhtype, (TPtr, TArray)) and isinstance(rhtype, (TPtr, TArray)) - and compatible_types(lhtype, rhtype)) + and compatible_types_fuzzy(lhtype, rhtype)) or (isinstance(lhtype, TBool) and isinstance(rhtype, TBool))): return Equals(site, lh, rh) @@ -2932,8 +2932,8 @@ def mkInterfaceMethodRef(site, iface_node, indices, method_name): if (not isinstance(ftype, TPtr) or not isinstance(ftype.base, TFunction) or not ftype.base.input_types - or TPtr(safe_realtype(TNamed('conf_object_t'))).cmp( - safe_realtype(ftype.base.input_types[0])) != 0): + or TPtr(safe_realtype_unconst(TNamed('conf_object_t'))).cmp( + safe_realtype_unconst(ftype.base.input_types[0])) != 0): # non-method members are not accessible raise EMEMBER(site, struct_name, method_name) @@ -4674,7 +4674,10 @@ class ArrayRef(LValue): explicit_type = True @auto_init def __init__(self, site, expr, idx): - self.type = realtype_shallow(expr.ctype()).base + expr_type = realtype_shallow(expr.ctype()) + self.type = conv_const(expr_type.const + and isinstance(expr_type, TArray), + expr_type.base) def __str__(self): return '%s[%s]' % (self.expr, self.idx) def read(self): @@ -4787,6 +4790,15 @@ def mkCast(site, expr, new_type): raise ETEMPLATEUPCAST(site, "object", new_type) else: return mkTraitUpcast(site, expr, real.trait) + + if (compat.dml12_misc in dml.globals.enabled_compat + and isinstance(expr, InterfaceMethodRef)): + # Workaround for SIMICS-9868 + return mkLit(site, "%s->%s" % ( + expr.node_expr.read(), expr.method_name), new_type) + + if isinstance(expr, NonValue): + raise expr.exc() old_type = safe_realtype(expr.ctype()) if (dml.globals.compat_dml12_int(site) and (isinstance(old_type, (TStruct, TVector)) @@ -4794,15 +4806,17 @@ def mkCast(site, expr, new_type): # these casts are permitted by C only if old and new are # the same type, which is useless return Cast(site, expr, new_type) - if isinstance(real, TStruct): - if isinstance(old_type, TStruct) and old_type.label == real.label: - return expr - raise ECAST(site, expr, new_type) - if isinstance(real, TExternStruct): - if isinstance(old_type, TExternStruct) and old_type.id == real.id: - return expr + if isinstance(real, (TVoid, TArray, TFunction)): raise ECAST(site, expr, new_type) - if isinstance(real, (TVoid, TArray, TVector, TTraitList, TFunction)): + if old_type.cmp(real) == 0: + if (old_type.is_int + and not old_type.is_endian + and dml.globals.compat_dml12_int(expr.site)): + # 1.2 integer expressions often lie about their actual type, + # and require a "redundant" cast! Why yes, this IS horrid! + return Cast(site, expr, new_type) + return mkRValue(expr) + if isinstance(real, (TStruct, TExternStruct, TVector, TTraitList)): raise ECAST(site, expr, new_type) if isinstance(old_type, (TVoid, TStruct, TVector, TTraitList, TTrait)): raise ECAST(site, expr, new_type) @@ -4810,7 +4824,7 @@ def mkCast(site, expr, new_type): expr = as_int(expr) old_type = safe_realtype(expr.ctype()) if real.is_int and not real.is_endian: - if isinstance(expr, IntegerConstant): + if old_type.is_int and expr.constant: value = truncate_int_bits(expr.value, real.signed, real.bits) if dml.globals.compat_dml12_int(site): return IntegerConstant_dml12(site, value, real) @@ -4821,8 +4835,8 @@ def mkCast(site, expr, new_type): # Shorten redundant chains of integer casts. Avoids insane C # output for expressions like a+b+c+d. if (isinstance(expr, Cast) - and isinstance(expr.type, TInt) - and expr.type.bits >= real.bits): + and isinstance(old_type, TInt) + and old_type.bits >= real.bits): # (uint64)(int64)x -> (uint64)x expr = expr.expr old_type = safe_realtype(expr.ctype()) @@ -4858,9 +4872,7 @@ def mkCast(site, expr, new_type): return expr elif real.is_int and real.is_endian: old_type = safe_realtype(expr.ctype()) - if real.cmp(old_type) == 0: - return expr - elif old_type.is_arith or isinstance(old_type, TPtr): + if old_type.is_arith or isinstance(old_type, TPtr): return mkApply( expr.site, mkLit(expr.site, *real.get_store_fun()), @@ -4917,7 +4929,6 @@ def mkCast(site, expr, new_type): class RValue(Expression): '''Wraps an lvalue to prohibit write. Useful when a composite expression is reduced down to a single variable.''' - writable = False @auto_init def __init__(self, site, expr): pass def __str__(self): @@ -4926,11 +4937,22 @@ def ctype(self): return self.expr.ctype() def read(self): return self.expr.read() - def discard(self): pass + def discard(self): + return self.expr.discard() def incref(self): self.expr.incref() def decref(self): self.expr.decref() + @property + def explicit_type(self): + return self.expr.explicit_type + @property + def type(self): + assert self.explicit_type + return self.expr.type + @property + def is_pointer_to_stack_allocation(self): + return self.expr.is_pointer_to_stack_allocation def mkRValue(expr): if isinstance(expr, LValue) or expr.writable: diff --git a/py/dml/ctree_test.py b/py/dml/ctree_test.py index bc494f89e..ee4a67f53 100644 --- a/py/dml/ctree_test.py +++ b/py/dml/ctree_test.py @@ -1416,6 +1416,8 @@ def const_types(self): # abstract type types.IntegerType, # abstract type + types.ArchDependentIntegerType, + # abstract type types.StructType, # 1.2, weird types.TUnknown, diff --git a/py/dml/structure.py b/py/dml/structure.py index efba64bd3..13b3ea1b9 100644 --- a/py/dml/structure.py +++ b/py/dml/structure.py @@ -674,7 +674,8 @@ def typecheck_method_override(m1, m2, location): # TODO move to caller (_, type1) = eval_type(t1, a1.site, location, global_scope) (_, type2) = eval_type(t2, a2.site, location, global_scope) - if safe_realtype(type1).cmp(safe_realtype(type2)) != 0: + if safe_realtype_unconst(type1).cmp( + safe_realtype_unconst(type2)) != 0: raise EMETH(a1.site, a2.site, f"mismatching types in input argument {n1}") @@ -683,7 +684,8 @@ def typecheck_method_override(m1, m2, location): ((n1, t1), (n2, t2)) = (a1.args, a2.args) (_, type1) = eval_type(t1, a1.site, location, global_scope) (_, type2) = eval_type(t2, a2.site, location, global_scope) - if safe_realtype(type1).cmp(safe_realtype(type2)) != 0: + if safe_realtype_unconst(type1).cmp( + safe_realtype_unconst(type2)) != 0: msg = "mismatching types in return value" if len(outp1) > 1: msg += f" {i + 1}" diff --git a/py/dml/traits.py b/py/dml/traits.py index 115125ae8..091e8e3b9 100644 --- a/py/dml/traits.py +++ b/py/dml/traits.py @@ -398,11 +398,11 @@ def typecheck_method_override(left, right): if throws0 != throws1: raise EMETH(site0, site1, "different nothrow annotations") for ((n, t0), (_, t1)) in zip(inp0, inp1): - if realtype(t0).cmp(realtype(t1)) != 0: + if safe_realtype_unconst(t0).cmp(safe_realtype_unconst(t1)) != 0: raise EMETH(site0, site1, "mismatching types in input argument %s" % (n,)) for (i, ((_, t0), (_, t1))) in enumerate(zip(outp0, outp1)): - if realtype(t0).cmp(realtype(t1)) != 0: + if safe_realtype_unconst(t0).cmp(safe_realtype_unconst(t1)) != 0: raise EMETH(site0, site1, "mismatching types in output argument %d" % (i + 1,)) diff --git a/py/dml/types.py b/py/dml/types.py index 049331671..411ec2d7d 100644 --- a/py/dml/types.py +++ b/py/dml/types.py @@ -12,10 +12,13 @@ 'realtype', 'safe_realtype_shallow', 'safe_realtype', + 'safe_realtype_unconst', 'conv_const', + 'shallow_const', 'deep_const', 'type_union', 'compatible_types', + 'compatible_types_fuzzy', 'typedefs', 'global_type_declaration_order', 'global_anonymous_structs', @@ -133,7 +136,7 @@ def realtype(t): elif isinstance(t, TVector): t2 = realtype(t.base) if t2 != t: - return TVector(t2, t.const) + return TVector(t2, t.const, t.uniq) elif isinstance(t, TFunction): input_types = tuple(realtype(sub) for sub in t.input_types) output_type = realtype(t.output_type) @@ -168,6 +171,27 @@ def conv_const(const, t): t.const = True return t +def safe_realtype_unconst(t0): + def sub(t): + if isinstance(t, (TArray, TVector)): + base = sub(t.base) + if t.const or base is not t.base: + t = t.clone() + t.const = False + t.base = base + elif t.const: + t = t.clone() + t.const = False + return t + return sub(safe_realtype(t0)) + +def shallow_const(t): + t = safe_realtype_shallow(t) + while not t.const and isinstance(t, (TArray, TVector)): + t = safe_realtype_shallow(t.base) + + return t.const + def deep_const(origt): subtypes = [origt] while subtypes: @@ -200,7 +224,7 @@ def __eq__(self, other): in zip(self.types, other.types))) def __hash__(self): - return hash(tuple(type(elem) for elem in self.types)) + return hash(tuple(elem.hashed() for elem in self.types)) class DMLType(metaclass=abc.ABCMeta): '''The type of a value (expression or declaration) in DML. One DML @@ -231,25 +255,47 @@ def sizeof(self): return None def cmp(self, other): - """Compare this type to another. - - Return 0 if the types are equivalent, - Return NotImplemented otherwise. - - The exact meaning of this is somewhat fuzzy. The - method is used for three purposes: - + """Strict type compatibility/equality. + + Return 0 if the types are run-time compatible, + Return NotImplemented otherwise + + "Run-time compatibility" has two minimal criteria: + 1. The C representations of the types MUST be compatible, in a C sense + 2. A value of one type can be treated by DMLC as though it were of the + other type without any additional risk of undefined behavior or invalid + generated C. + For example, all trait reference types share the same C representation, + and so satisfy (1), but trait reference types for different traits do + not share vtables; trying to use a vtable for one trait with an + incompatible reference would result in undefined behavior, and so do + not satisfy (2). + + The method is used for three purposes: 1. in TPtr.canstore(), to judge whether pointer target types are compatible. 2. in ctree, to compare how large values different numerical types can hold - 3. when judging whether a method override is allowed, as an inaccurate - replacement of TPtr(self).canstore(TPtr(other))[0] + 3. when judging whether a method override is allowed See SIMICS-9504 for further discussions. + """ + return (0 if type(self) is type(other) and self.const == other.const + else NotImplemented) + def cmp_fuzzy(self, other): + """Compare this type to another. + Return 0 if the types are pretty much equivalent, + Return NotImplemented otherwise. + As implied, the exact meaning of this is fuzzy. It mostly relaxes + criteria (1) of 'cmp'); for example, TPtr(void).cmp(TPtr(TBool())) is + allowed to return 0, as is TPtr(TBool()).cmp(TArray(TBool())). + + Most notably, cmp_fuzzy does not take const-qualification into account. + + Any usage of cmp_fuzzy should be considered a HACK """ - return NotImplemented + return safe_realtype_unconst(self).cmp(safe_realtype_unconst(other)) def canstore(self, other): """Can a variable of this type store a value of another type. @@ -262,7 +308,8 @@ def canstore(self, other): The correctness of the return value can not be trusted; see SIMICS-9504 for further discussions. """ - return (self.cmp(other) == 0, False, False) + return (safe_realtype_unconst(self).cmp(safe_realtype_unconst(other)) + == 0, False, False) @abc.abstractmethod def clone(self): @@ -279,6 +326,13 @@ def print_declaration(self, var, init = None, unused = False): def describe(self): raise Exception("%s.describe not implemented" % self.__class__.__name__) + def hashed(self): + '''Hash the DML type in a way compatible with cmp. I.e. + a.cmp(b) == 0 implies a.hashed() == b.hashed()''' + assert type(self).cmp is DMLType.cmp, \ + '.cmp() overridden without overriding .hashed()' + return hash((type(self), self.const)) + def key(self): return self.const_str + self.describe() @@ -309,8 +363,6 @@ def declaration(self, var): return 'void ' + self.const_str + ' ' + var def clone(self): return TVoid() - def cmp(self, other): - return 0 if isinstance(realtype(other), TVoid) else NotImplemented class TUnknown(DMLType): '''A type unknown to DML. Typically used for a generic C macro @@ -344,8 +396,6 @@ def describe(self): return 'pointer to %s' % self.name def key(self): return 'device' - def cmp(self, other): - return 0 if isinstance(realtype(other), TDevice) else NotImplemented def canstore(self, other): constviol = False if not self.const and other.const: @@ -400,6 +450,10 @@ def key(self): raise ICE(self.declaration_site, 'need realtype before key') def cmp(self, other): assert False, 'need realtype before cmp' + def cmp_fuzzy(self, other): + assert False, 'need realtype before cmp_fuzzy' + def hashed(self): + assert False, 'need realtype before hashed' def clone(self): return TNamed(self.c, self.const) @@ -418,14 +472,10 @@ def describe(self): return 'bool' def declaration(self, var): return 'bool ' + self.const_str + var - def cmp(self, other): - if isinstance(other, TBool): - return 0 - return NotImplemented def canstore(self, other): constviol = False - if isinstance(other, TBool): + if type(other) is TBool: return (True, False, constviol) if (other.is_int and other.bits == 1 and not other.signed): @@ -454,6 +504,8 @@ def __init__(self, bits, signed, members = None, const = False): is_int = True is_arith = True is_endian = False + is_arch_dependent = False + @property def is_bitfields(self): return self.members is not None @@ -477,8 +529,13 @@ def get_member_qualified(self, member): return t def cmp(self, other): + if self.const != other.const: + return NotImplemented if not other.is_int: return NotImplemented + if ((self.is_arch_dependent or other.is_arch_dependent) + and type(self) is not type(other)): + return NotImplemented if self.is_endian: if not other.is_endian: return NotImplemented @@ -486,11 +543,21 @@ def cmp(self, other): return NotImplemented elif other.is_endian: return NotImplemented - if isinstance(self, TLong) != isinstance(other, TLong): + return (0 if (self.bits, self.signed) == (other.bits, other.signed) + else NotImplemented) + + def cmp_fuzzy(self, other): + if not other.is_int: return NotImplemented - if isinstance(self, TSize) != isinstance(other, TSize): + if ((self.is_arch_dependent or other.is_arch_dependent) + and type(self) is not type(other)): return NotImplemented - if isinstance(self, TInt64_t) != isinstance(other, TInt64_t): + if self.is_endian: + if not other.is_endian: + return NotImplemented + if self.byte_order != other.byte_order: + return NotImplemented + elif other.is_endian: return NotImplemented if (dml.globals.dml_version == (1, 2) and compat.dml12_int in dml.globals.enabled_compat): @@ -499,6 +566,12 @@ def cmp(self, other): else: return (0 if (self.bits, self.signed) == (other.bits, other.signed) else NotImplemented) + + def hashed(self): + cls = type(self) if self.is_arch_dependent else IntegerType + byte_order = self.byte_order if self.is_endian else None + return hash((cls, self.const, self.bits, self.signed, byte_order)) + # This is the most restrictive canstore definition for # IntegerTypes, if this is overridden then it should be # because we want to be less restrictive @@ -582,7 +655,13 @@ def declaration(self, var): else: return 'uint8 ' + self.const_str + var + '[' + str(self.bytes) + ']' -class TLong(IntegerType): +class ArchDependentIntegerType(IntegerType): + '''Integer types whose definition and/or properties are architecture + dependent.''' + __slots__ = () + is_arch_dependent = True + +class TLong(ArchDependentIntegerType): '''The 'long' type from C''' __slots__ = () def __init__(self, signed, const=False): @@ -604,7 +683,7 @@ def clone(self): def declaration(self, var): return f'{self.c_name()} {var}' -class TSize(IntegerType): +class TSize(ArchDependentIntegerType): '''The 'size_t' type from C''' __slots__ = () def __init__(self, signed, const=False): @@ -625,7 +704,7 @@ def clone(self): def declaration(self, var): return f'{self.c_name()} {var}' -class TInt64_t(IntegerType): +class TInt64_t(ArchDependentIntegerType): '''The '[u]int64_t' type from ISO C. For compatibility with C APIs, e.g., calling an externally defined C function that takes a `uint64_t *` arg. We find `uint64` a generally more useful type @@ -731,10 +810,14 @@ def __repr__(self): def describe(self): return self.name def cmp(self, other): - if other.is_float and self.name == other.name: + if (self.const == other.const + and other.is_float and self.name == other.name): return 0 return NotImplemented + def hashed(self): + return hash((TFloat, self.const, self.name)) + def canstore(self, other): constviol = False if other.is_float: @@ -765,7 +848,7 @@ def key(self): % (conv_const(self.const, self.base).key(), self.size.value)) def describe(self): - return 'array of size %s of %s' % (self.size.read(), + return 'array of size %s of %s' % (str(self.size), self.base.describe()) def declaration(self, var): return self.base.declaration(self.const_str + var @@ -784,16 +867,30 @@ def sizeof(self): if elt_size == None: return None return self.size.value * elt_size + def cmp(self, other): - if compat.dml12_misc in dml.globals.enabled_compat: - if isinstance(other, (TArray, TPtr)): - return self.base.cmp(other.base) - elif isinstance(other, (TPtr, TArray)): - if self.base.void or other.base.void: - return 0 - if self.base.cmp(other.base) == 0: + if not isinstance(other, TArray): + return NotImplemented + if not (self.size is other.size + or (self.size.constant and other.size.constant + and self.size.value == other.size.value)): + return NotImplemented + return conv_const(self.const, self.base).cmp( + conv_const(other.const, other.base)) + + def cmp_fuzzy(self, other): + if isinstance(other, (TArray, TPtr)): + if other.base.void: return 0 + return self.base.cmp_fuzzy(other.base) return NotImplemented + + def hashed(self): + size = self.size.value if self.size.constant else self.size + return hash((TArray, + size, + conv_const(self.const, self.base).hashed())) + def canstore(self, other): return (False, False, False) def clone(self): @@ -816,33 +913,41 @@ def key(self): def describe(self): return 'pointer to %s' % (self.base.describe()) def cmp(self, other): - if compat.dml12_misc in dml.globals.enabled_compat: - if isinstance(other, TPtr): - # Can only compare for voidness or equality - if self.base.void or other.base.void: - return 0 - if self.base.cmp(other.base) == 0: - return 0 - elif isinstance(other, (TPtr, TArray)): + if DMLType.cmp(self, other) != 0: + return NotImplemented + return self.base.cmp(other.base) + + def cmp_fuzzy(self, other): + if isinstance(other, (TPtr, TArray)): if self.base.void or other.base.void: return 0 - if self.base.cmp(other.base) == 0: - return 0 + return self.base.cmp_fuzzy(other.base) return NotImplemented + def hashed(self): + return hash((TPtr, self.const, self.base.hashed())) + def canstore(self, other): ok = False trunc = False constviol = False if isinstance(other, (TPtr, TArray)): + constviol = (not shallow_const(self.base) + and shallow_const(other.base)) if self.base.void or other.base.void: ok = True else: - if not self.base.const and other.base.const: - constviol = True - ok = (self.base.cmp(other.base) == 0) + unconst_self_base = safe_realtype_unconst(self.base) + unconst_other_base = safe_realtype_unconst(other.base) + + ok = ((unconst_self_base.cmp_fuzzy + if compat.dml12_int in dml.globals.enabled_compat + else unconst_self_base.cmp)(unconst_other_base) + == 0) elif isinstance(other, TFunction): - ok = True + ok = safe_realtype_unconst(self.base).cmp(other) == 0 + # TODO gate this behind dml.globals.dml_version == (1, 2) or + # dml12_misc? if self.base.void and isinstance(other, TDevice): ok = True #dbg('TPtr.canstore %r %r => %r' % (self, other, ok)) @@ -862,9 +967,14 @@ def resolve(self): return self class TVector(DMLType): - __slots__ = ('base',) - def __init__(self, base, const = False): + count = 0 + __slots__ = ('base', 'uniq',) + def __init__(self, base, const=False, uniq=None): DMLType.__init__(self, const) + if uniq is None: + uniq = TVector.count + TVector.count += 1 + self.uniq = uniq if not base: raise DMLTypeError("Null base") self.base = base @@ -875,15 +985,20 @@ def key(self): def describe(self): return 'vector of %s' % self.base.describe() def cmp(self, other): + return (0 if (DMLType.cmp(self, other) == 0 + and self.uniq == other.uniq) + else NotImplemented) + def cmp_fuzzy(self, other): if isinstance(other, TVector): # Can only compare for voidness or equality if self.base.void or other.base.void: return 0 - if self.base.cmp(other.base) == 0: - return 0 + return self.base.cmp_fuzzy(other.base) return NotImplemented + def hashed(self): + return hash((TVector, self.const, self.uniq)) def clone(self): - return TVector(self.base, self.const) + return TVector(self.base, self.const, self.uniq) def declaration(self, var): s = self.base.declaration('') return 'VECT(%s) %s%s' % (s, self.const_str, var) @@ -904,14 +1019,16 @@ def clone(self): return TTrait(self.trait) def cmp(self, other): - if isinstance(other, TTrait) and self.trait is other.trait: - return 0 - else: - return NotImplemented + return (0 if (DMLType.cmp(self, other) == 0 + and self.trait is other.trait) + else NotImplemented) def key(self): return f'{self.const_str}trait({self.trait.name})' + def hashed(self): + return hash((TTrait, self.const, self.trait)) + def c_name(self): return f'{self.const_str}{cident(self.trait.name)}' @@ -935,14 +1052,16 @@ def clone(self): return TTraitList(self.traitname, self.const) def cmp(self, other): - if isinstance(other, TTraitList) and self.traitname == other.traitname: - return 0 - else: - return NotImplemented + return (0 if (DMLType.cmp(self, other) == 0 + and self.traitname == other.traitname) + else NotImplemented) def key(self): return f'{self.const_str}sequence({self.traitname})' + def hashed(self): + return hash((TTraitList, self.const, self.traitname)) + def c_type(self): return f'{self.const_str}_each_in_t' @@ -1011,9 +1130,12 @@ def declaration(self, var): return "%s %s%s" % (self.typename, self.const_str, var) def cmp(self, other): - if isinstance(other, TExternStruct) and self.id == other.id: - return 0 - return NotImplemented + return (0 if (DMLType.cmp(self, other) == 0 + and self.id == other.id) + else NotImplemented) + + def hashed(self): + return hash((TExternStruct, self.const, self.id)) def clone(self): return TExternStruct(self.members, self.id, self.typename, self.const) @@ -1066,9 +1188,12 @@ def print_struct_definition(self): out("};\n", preindent = -1) def cmp(self, other): - if isinstance(other, TStruct) and self.label == other.label: - return 0 - return NotImplemented + return (0 if (DMLType.cmp(self, other) == 0 + and self.label == other.label) + else NotImplemented) + + def hashed(self): + return hash((TStruct, self.const, self.label)) def clone(self): return TStruct(self.members, self.label, self.const) @@ -1205,14 +1330,34 @@ def describe(self): def cmp(self, other): if (isinstance(other, TFunction) and len(self.input_types) == len(other.input_types) - and all(arg1.cmp(arg2) == 0 - for (arg1, arg2) in zip(self.input_types, - other.input_types)) - and self.output_type.cmp(other.output_type) == 0 + and all(safe_realtype_unconst(arg1).cmp( + safe_realtype_unconst(arg2)) == 0 + for (arg1, arg2) + in zip(self.input_types, other.input_types)) + and safe_realtype_unconst(self.output_type).cmp( + safe_realtype_unconst(other.output_type)) == 0 + and self.varargs == other.varargs): + return 0 + return NotImplemented + + def cmp_fuzzy(self, other): + if (isinstance(other, TFunction) + and len(self.input_types) == len(other.input_types) + and all(arg1.cmp_fuzzy(arg2) == 0 + for (arg1, arg2) + in zip(self.input_types, other.input_types)) + and self.output_type.cmp_fuzzy(other.output_type) == 0 and self.varargs == other.varargs): return 0 return NotImplemented + def hashed(self): + return hash((TFunction, + tuple(safe_realtype_unconst(typ).hashed() + for typ in self.input_types), + safe_realtype_unconst(self.output_type).hashed(), + self.varargs)) + def canstore(self, other): return (False, False, False) @@ -1242,7 +1387,7 @@ def clone(self): return THook(self.msg_types, self.validated, self.const) def cmp(self, other): - if (isinstance(other, THook) + if (DMLType.cmp(self, other) == 0 and len(self.msg_types) == len(other.msg_types) and all(own_comp.cmp(other_comp) == 0 for (own_comp, other_comp) in zip(self.msg_types, @@ -1251,6 +1396,11 @@ def cmp(self, other): else: return NotImplemented + def hashed(self): + return hash((THook, + self.const, + tuple(comp.hashed() for comp in self.msg_types))) + def key(self): return ('%shook(%s)' % (self.const_str, ','.join(t.key() for t in self.msg_types))) @@ -1311,11 +1461,22 @@ def type_union(type1, type2): def compatible_types(type1, type2): # This function intends to verify that two DML types are # compatible in the sense defined by the C spec, possibly with - # some DML-specific restrictions added. TODO: DMLType.cmp is only - # a rough approximation of this; we should write tests and - # either repair cmp or rewrite the logic from scratch. + # some DML-specific restrictions added. return type1.cmp(type2) == 0 +# TODO We should look into getting rid of this and cmp_fuzzy, and replace +# their usages with usage-specific checks. +def compatible_types_fuzzy(type1, type2): + # This function intends to verify that two DML types are + # compatible in the sense defined by the C spec, possibly with + # some DML-specific restrictions added. + # DMLType.cmp_fuzzy is only a very rough approximation of this, + # meant to suite usages such as type-checking the ternary + # operator. + # Any use of .cmp_fuzzy or compatible_type_fuzzy should be considered + # a HACK. + return type1.cmp_fuzzy(type2) == 0 + void = TVoid() # These are the named types used. This includes both "imported" # typedefs for types declared in C header files, and types defined in diff --git a/test/1.4/errors/T_ECAST.dml b/test/1.4/errors/T_ECAST.dml index 9736eda8e..15eb6f762 100644 --- a/test/1.4/errors/T_ECAST.dml +++ b/test/1.4/errors/T_ECAST.dml @@ -10,6 +10,8 @@ typedef struct { uint32 x; } s_t; typedef layout "little-endian" { uint32 x; } l_t; /// WARNING WEXPERIMENTAL typedef int vect v_t; +/// WARNING WEXPERIMENTAL +typedef int vect alt_v_t; typedef int a_t[1]; typedef void f_t(void); extern f_t f; @@ -43,9 +45,11 @@ method init() { cast(v, uint32); /// ERROR ECAST cast(i, v_t); - /// ERROR ECAST + // no error cast(v, v_t); /// ERROR ECAST + cast(v, alt_v_t); + /// ERROR ECAST cast(l, uint32); // no error! cast(a, uint32);