Skip to content

Commit b91071c

Browse files
committed
Fix set comparisons
1 parent 54a9559 commit b91071c

File tree

3 files changed

+66
-25
lines changed

3 files changed

+66
-25
lines changed

src/runtime/dict.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ Box* dictCopy(BoxedDict* self) {
8585
raiseExcHelper(TypeError, "descriptor 'copy' requires a 'dict' object but received a '%s'", getTypeName(self));
8686

8787
BoxedDict* r = new BoxedDict();
88-
r->d.insert(self->d.begin(), self->d.end());
88+
r->d = self->d;
8989
return r;
9090
}
9191

@@ -576,11 +576,11 @@ Box* dictEq(BoxedDict* self, Box* _rhs) {
576576
if (self->d.size() != rhs->d.size())
577577
return False;
578578

579-
for (const auto& p : *self) {
579+
for (const auto& p : self->d) {
580580
auto it = rhs->d.find(p.first);
581581
if (it == rhs->d.end())
582582
return False;
583-
if (!nonzero(compare(p.second, it->second, AST_TYPE::Eq)))
583+
if (!PyEq()(p.second, it->second))
584584
return False;
585585
}
586586

src/runtime/set.cpp

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,9 @@ static Box* setIssubset(BoxedSet* self, Box* container) {
407407
assert(PyAnySet_Check(container));
408408

409409
BoxedSet* rhs = static_cast<BoxedSet*>(container);
410+
if (self->s.size() > rhs->s.size())
411+
return False;
412+
410413
for (auto e : self->s) {
411414
if (rhs->s.find(e) == rhs->s.end())
412415
return False;
@@ -421,13 +424,7 @@ static Box* setIssuperset(BoxedSet* self, Box* container) {
421424
container = makeNewSet(set_cls, container);
422425
}
423426
assert(PyAnySet_Check(container));
424-
425-
BoxedSet* rhs = static_cast<BoxedSet*>(container);
426-
for (auto e : rhs->s) {
427-
if (self->s.find(e) == self->s.end())
428-
return False;
429-
}
430-
return True;
427+
return setIssubset((BoxedSet*)container, self);
431428
}
432429

433430
static Box* setIsdisjoint(BoxedSet* self, Box* container) {
@@ -473,7 +470,7 @@ Box* setCopy(BoxedSet* self) {
473470
RELEASE_ASSERT(PyAnySet_Check(self), "");
474471

475472
BoxedSet* rtn = new BoxedSet();
476-
rtn->s.insert(self->s.begin(), self->s.end());
473+
rtn->s = self->s;
477474
return rtn;
478475
}
479476

@@ -497,24 +494,56 @@ Box* setContains(BoxedSet* self, Box* v) {
497494
Box* setEq(BoxedSet* self, BoxedSet* rhs) {
498495
RELEASE_ASSERT(PyAnySet_Check(self), "");
499496
if (!PyAnySet_Check(rhs))
500-
return NotImplemented;
497+
return False;
501498

502499
if (self->s.size() != rhs->s.size())
503500
return False;
504501

505-
for (auto e : self->s) {
506-
if (!rhs->s.count(e))
507-
return False;
508-
}
509-
return True;
502+
return setIssubset(self, rhs);
510503
}
511504

512505
Box* setNe(BoxedSet* self, BoxedSet* rhs) {
513506
Box* r = setEq(self, rhs);
514-
if (r->cls == bool_cls)
515-
return boxBool(r == False);
516-
assert(r == NotImplemented);
517-
return r;
507+
assert(r->cls == bool_cls);
508+
return boxBool(r == False);
509+
}
510+
511+
Box* setLe(BoxedSet* self, BoxedSet* rhs) {
512+
RELEASE_ASSERT(PyAnySet_Check(self), "");
513+
if (!PyAnySet_Check(rhs))
514+
raiseExcHelper(TypeError, "can only compare to a set");
515+
516+
return setIssubset(self, rhs);
517+
}
518+
519+
Box* setLt(BoxedSet* self, BoxedSet* rhs) {
520+
RELEASE_ASSERT(PyAnySet_Check(self), "");
521+
if (!PyAnySet_Check(rhs))
522+
raiseExcHelper(TypeError, "can only compare to a set");
523+
524+
if (self->s.size() >= rhs->s.size())
525+
return False;
526+
527+
return setIssubset(self, rhs);
528+
}
529+
530+
Box* setGe(BoxedSet* self, BoxedSet* rhs) {
531+
RELEASE_ASSERT(PyAnySet_Check(self), "");
532+
if (!PyAnySet_Check(rhs))
533+
raiseExcHelper(TypeError, "can only compare to a set");
534+
535+
return setIssuperset(self, rhs);
536+
}
537+
538+
Box* setGt(BoxedSet* self, BoxedSet* rhs) {
539+
RELEASE_ASSERT(PyAnySet_Check(self), "");
540+
if (!PyAnySet_Check(rhs))
541+
raiseExcHelper(TypeError, "can only compare to a set");
542+
543+
if (self->s.size() <= rhs->s.size())
544+
return False;
545+
546+
return setIssuperset(self, rhs);
518547
}
519548

520549
Box* setNonzero(BoxedSet* self) {
@@ -627,10 +656,18 @@ void setupSet() {
627656
set_cls->giveAttr("__contains__", new BoxedFunction(boxRTFunction((void*)setContains, BOXED_BOOL, 2)));
628657
frozenset_cls->giveAttr("__contains__", set_cls->getattr(internStringMortal("__contains__")));
629658

630-
set_cls->giveAttr("__eq__", new BoxedFunction(boxRTFunction((void*)setEq, UNKNOWN, 2)));
659+
set_cls->giveAttr("__eq__", new BoxedFunction(boxRTFunction((void*)setEq, BOXED_BOOL, 2)));
631660
frozenset_cls->giveAttr("__eq__", set_cls->getattr(internStringMortal("__eq__")));
632-
set_cls->giveAttr("__ne__", new BoxedFunction(boxRTFunction((void*)setNe, UNKNOWN, 2)));
661+
set_cls->giveAttr("__ne__", new BoxedFunction(boxRTFunction((void*)setNe, BOXED_BOOL, 2)));
633662
frozenset_cls->giveAttr("__ne__", set_cls->getattr(internStringMortal("__ne__")));
663+
set_cls->giveAttr("__le__", new BoxedFunction(boxRTFunction((void*)setLe, BOXED_BOOL, 2)));
664+
frozenset_cls->giveAttr("__le__", set_cls->getattr(internStringMortal("__le__")));
665+
set_cls->giveAttr("__lt__", new BoxedFunction(boxRTFunction((void*)setLt, BOXED_BOOL, 2)));
666+
frozenset_cls->giveAttr("__lt__", set_cls->getattr(internStringMortal("__lt__")));
667+
set_cls->giveAttr("__ge__", new BoxedFunction(boxRTFunction((void*)setGe, BOXED_BOOL, 2)));
668+
frozenset_cls->giveAttr("__ge__", set_cls->getattr(internStringMortal("__ge__")));
669+
set_cls->giveAttr("__gt__", new BoxedFunction(boxRTFunction((void*)setGt, BOXED_BOOL, 2)));
670+
frozenset_cls->giveAttr("__gt__", set_cls->getattr(internStringMortal("__gt__")));
634671

635672
set_cls->giveAttr("__nonzero__", new BoxedFunction(boxRTFunction((void*)setNonzero, BOXED_BOOL, 1)));
636673
frozenset_cls->giveAttr("__nonzero__", set_cls->getattr(internStringMortal("__nonzero__")));

test/tests/set.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,12 @@ class MyFrozenset(frozenset):
128128

129129
for s1 in set(range(5)), frozenset(range(5)):
130130
for s2 in compare_to:
131-
print type(s2), sorted(s2), s1.issubset(s2), s1.issuperset(s2), s1 == s2, s1 != s2, sorted(s1.difference(s2)), s1.isdisjoint(s2), sorted(s1.union(s2)), sorted(s1.intersection(s2))
132-
131+
print type(s2), sorted(s2), s1.issubset(s2), s1.issuperset(s2), sorted(s1.difference(s2)), s1.isdisjoint(s2), sorted(s1.union(s2)), sorted(s1.intersection(s2))
132+
print s1 == s2, s1 != s2
133+
try:
134+
print s1 < s2, s1 <= s2, s1 > s2, s1 >= s2
135+
except Exception as e:
136+
print e
133137
f = float('nan')
134138
s = set([f])
135139
print f in s, f == list(s)[0]

0 commit comments

Comments
 (0)