summaryrefslogtreecommitdiff
path: root/ts_lib.py
blob: 4415d10f9479153b9c0e0d0a93effab0d8c05e93 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
"""Core library for describing triplet-structures in Python."""
import itertools
from collections import defaultdict

class TripletStructure:
    """Represents a triplet structure. Instances are usually named 'ts'.

    A TripletStructure starts out empty, with no nodes and no facts. It can be
    modified using the syntax: `ts["/:A"].map({ts["/:B"]: ts["/:C"]})` which
    adds the fact `(/:A,/:B,/:C)` to the structure. By default, nodes are added
    automatically upon first reference.

    We want to be able to easily roll-back changes to the TripletStructure. Every
    direct modification of the TripletStructure is automatically registered in the
    TSDelta instance @ts.buffer. This acts as a buffer of changes.  The method
    @ts.commit(...) will commit this buffer, i.e., save it to the end of the
    list @ts.path and replace @ts.buffer with a fresh TSDelta instance. You can
    always re-construct the structure by applying the TSDeltas in @ts.path
    successively to an empty structure, then applying @ts.buffer.  When
    @ts.buffer is empty, we say the structure is 'clean.' @ts.rollback(...) can
    be used to restore the state of the structure to a particular commit in
    @ts.path.

    Generally, every change to a triplet structure is owned by some TSDelta
    instance. @ts.path gives a list [None, delta_1, delta_2, ..., delta_n] of
    TSDeltas.
    """
    def __init__(self):
        """Initializes a new triplet structure."""
        # A list of the names of all nodes in the structure.
        self.nodes = []
        # Maps full_name -> short_name. The short name will be used in
        # user-facing printouts. We should maintain the invariant
        # self.display_names.keys() == self.nodes.
        self.display_names = dict()
        # self.facts a pre-computed index of the facts in the structure. Keys
        # are of two types:
        # 1. Triplet keys with 'holes' represented by None. E.g.,
        #    self.facts[(None, x, None)] is a list of all facts containing `x`
        #    in the middle slot. To get a list of all facts, use
        #    self.facts[(None, None, None)].
        # 2. Single-node keys. For a node string @x, self.facts[x] is all facts
        #    with x in at least one slot (see facts_about_node(...)).
        # Notably, if a fact (A, B, C) is in the structure at all, then it
        # *MUST* be belong to exactly the 11 keys returned by
        # self._iter_subfacts((A,B,C)).
        self.facts = defaultdict(list)
        # A prefix applied to node lookups. See ts.scope(...) and
        # ts.__getitem__.
        self.current_scope = "/"
        # The historical and running deltas.
        self.path = [None]
        self.buffer = TSDelta(self)
        # (Optional) an object with [add,remove]_[node,fact] methods which will
        # shadow changes to the structure. Used to implement efficient solving
        # with the C++ extensions.
        self.shadow = None

    def __getitem__(self, node):
        """Returns a (list of) NodeWrapper(s) corresponding to @node.

        This is the main entrypoint to manipulation of the structure.

        @node should be a string containing either (i) the name of a node, or
        (ii) a comma-separated list of node names. Node names should not
        contain spaces or commas.

        If a node name ends in ":??", then the "??" will be replaced with the
        smallest number such that the resulting node name does not yet exist in
        the structure.  Its use is somewhat analagous to LISP's gensym. See
        ts_utils.py for example usage.

        NOTE: ts[...] *CAN HAVE SIDE-EFFECTS*, namely _it constructs nodes
        which don't already exist_. You may think of it as a Python
        defaultdict. This makes for simpler code, but has a slight drawback of
        making typos harder to catch.  We may decide to change this syntax in
        the future, to something like ts.node(name) or tc(name), but: (i) the
        former would make quickly understanding 'tc-dense' code (like
        mapper.py) difficult while (ii) the latter loses intuition.
        """
        if "," in node:
            return [self[subname.strip()] for subname in node.split(",")]
        full_name = self._full_name(node)
        if full_name.endswith(":??"):
            for i in itertools.count():
                filled_name = "{}:{}".format(full_name[:-3], i)
                if filled_name not in self.nodes:
                    full_name = filled_name
                    break
        self.add_node(full_name)
        return NodeWrapper(self, full_name)

    def lookup(self, *template, read_direct=False):
        """Returns all facts according to a given template.

        This method should be called like ts.lookup(A,B,C) where A, B, C can be
        either node names or Nones. Nones match against any node name.

        Setting @read_direct=True returns a reference to the corresponding list
        of facts stored on the Structure instance. _May_ sometimes improve
        performance, but in general should be avoided due to unexpected
        behavior when either this class or the returned list is modified.
        """
        if not read_direct:
            return self.lookup(*template, read_direct=True).copy()
        return self.facts[template]

    def facts_about_node(self, full_name, read_direct=False):
        """Returns all facts involving the node with name @full_name.

        See self.lookup for nodes about @read_direct.
        """
        if not read_direct:
            return self.facts_about_node(full_name, read_direct=True).copy()
        return self.facts[full_name]

    def scope(self, scope="", protect=False):
        """Returns a TSScope representing the given scope.

        Often used like with ts.scope(...): ... to automatically prefix node
        names, e.g., to prevent name collisions.
        """
        return TSScope(self, self._full_name(scope), protect)

    def is_clean(self):
        """True iff the current buffer is empty."""
        return not self.buffer

    def commit(self, commit_if_clean=True):
        """Commits self.buffer to self.path."""
        if self.is_clean() and not commit_if_clean:
            return False
        self.path.append(self.buffer)
        self.buffer = TSDelta(self)
        return self.path[-1]

    def rollback(self, to_time=0):
        """Restores the structure to a previously-committed state.

        to_time = 0 means roll back the current buffer.
        to_time > 0 means roll back so that len(path) == to_time.
        to_time < 0 means roll back so that len(path) == len(path) - to_time.
        NOTE: In the final case, len(path) does *not* include the buffer.
        NOTE: len(path) == 0 is invalid, as path[0] = None (the 'root delta').
        """
        old_running = self.buffer
        self.buffer = TSDelta(self)
        old_running.rollback()
        self.buffer = TSDelta(self)
        if to_time == 0:
            return

        if to_time >= 0:
            target_length = to_time
        else:
            target_length = len(self.path) + to_time
        assert len(self.path) >= target_length > 0

        while len(self.path) > target_length:
            self.path.pop().rollback()
        # buffer will have a bunch of changes which aren't needed. In
        # theory we can 'disable' the TSDelta instead of just overwriting it
        # here, which might improve performance for some such operations.
        self._force_clean()

    def start_recording(self):
        """Returns a new TSRecording to track changes to @self."""
        return TSRecording(self)

    def freeze_frame(self):
        """Returns a new TSFreezeFrame saving the state of the structure."""
        return TSFreezeFrame(self)

    def has_node(self, full_name):
        """True iff @full_name is a registered node in the structure."""
        assert isinstance(full_name, str)
        return full_name in self.nodes

    def add_node(self, full_name, display_name=None):
        """Low-level method to add a node to the structure."""
        if not self.has_node(full_name):
            self.nodes.append(full_name)
            self.display_names[full_name] = display_name or full_name
            self.buffer.add_node(full_name)
            if self.shadow:
                self.shadow.add_node(full_name)

    def remove_node(self, full_name):
        """Low-level method to remove a node from the structure."""
        assert not self.facts_about_node(full_name, True), \
               f"Remove facts using {full_name} before removing it."
        if full_name in self.nodes:
            self.nodes.remove(full_name)
            self.display_names.pop(full_name)
            self.buffer.remove_node(full_name)
            if self.shadow:
                self.shadow.remove_node(full_name)

    def add_fact(self, fact):
        """Low-level method to add a fact to the structure."""
        if self.lookup(*fact, read_direct=True):
            # The fact already exists in the structure.
            return
        assert all(map(self.has_node, fact)), \
               f"Add all nodes in {fact} before adding the fact."
        for key in self._iter_subfacts(fact):
            self.facts[key].append(fact)
        self.buffer.add_fact(fact)
        if self.shadow:
            self.shadow.add_fact(fact)

    def remove_fact(self, fact):
        """Remove a fact from the structure."""
        if not self.lookup(*fact, read_direct=True):
            # Fact was already removed, or never added.
            return
        for key in self._iter_subfacts(fact):
            self.facts[key].remove(fact)
        self.buffer.remove_fact(fact)
        if self.shadow:
            self.shadow.remove_fact(fact)

    def add_nodes(self, nodes):
        """Helper to add multiple nodes to the structure."""
        for node in nodes:
            self.add_node(node)

    def remove_nodes(self, nodes):
        """Helper to remove multiple nodes from the structure."""
        for node in nodes:
            self.remove_node(node)

    def add_facts(self, facts):
        """Helper to add multiple facts to the structure."""
        for fact in facts:
            self.add_fact(fact)

    def remove_facts(self, facts):
        """Helper to remove multiple facts from the structure."""
        for fact in facts:
            self.remove_fact(fact)

    def print_delta(self):
        """Helper context that prints changes to the structure on exit."""
        class DeltaPrinter:
            """Helper context manager for printing changes to a structure."""
            def __init__(self, ts):
                self.ts = ts
                self.frame = None

            def __enter__(self):
                self.frame = self.ts.freeze_frame()

            def __exit__(self, t, v, tb):
                print(self.ts.freeze_frame() - self.frame)
        return DeltaPrinter(self)

    @staticmethod
    def _iter_subfacts(fact):
        """Yields all keys of self.facts which should hold @fact.

        This method *MUST* be used any time ts.facts is modified. For examples,
        see ts.add_fact, ts.remove_fact.
        """
        for subset in range(2**3):
            yield tuple(arg if (subset & (0b1 << i)) else None
                        for i, arg in enumerate(fact))
        for argument in sorted(set(fact)):
            yield argument

    def _full_name(self, name):
        """Returns the full name of a node relative to the current scope."""
        if name.startswith("/"):
            return name
        return "{}{}".format(self.current_scope, name)

    def _force_clean(self):
        """Manually clears the buffer.

        NOTE: Code outside of this file should **NEVER** call _force_clean.
        """
        self.buffer = TSDelta(self)

    def __str__(self):
        """Returns a string representation of the Structure.

        WARNING: This representation basically prints all the facts; it can get
        quite long, especially with a lot of rules.
        """
        def _format_fact(fact):
            return str(tuple(map(self.display_names.get, fact)))
        return "TripletStructure ({id}):\n\t{facts}".format(
            id=id(self), facts="\n\t".join(
                map(_format_fact, self.lookup(None, None, None))))

class TSScope:
    """Represents a scope (node name prefix) in a particular structure.

    Often used indirectly as in with ts.scope("..."): ... to automatically
    prefix node names, but also has some useful methods for using directly (eg.
    listing all nodes with a certain prefix).
    """
    def __init__(self, structure, prefix, protect=False):
        """Initializes a new TSScope.

        This should usually only be called via ts.scope(...) or
        scope.scope(...).
        """
        self.structure = structure
        self.prefix = prefix
        # Keeps track of the prefix on the structure before __enter__ so we can
        # reset it upon __exit__.
        self.old_scope_stack = []
        self.protect = protect

    def __enter__(self):
        """Instructs the Structure to prefix nodes with self.prefix by default.

        Returns the TSScope instance for convenience.
        """
        assert not self.protect
        self.old_scope_stack.append(self.structure.current_scope)
        self.structure.current_scope = self.prefix
        return self

    def __exit__(self, type_, value, traceback):
        """Resets the Structure's default prefix.
        """
        assert not self.protect
        self.structure.current_scope = self.old_scope_stack.pop()

    def __getitem__(self, index):
        """Get node relative to self regardless of the structure's prefix.
        """
        if self.protect:
            return "{}{}".format(self.prefix, index)
        with self:
            return self.structure[index]

    def scope(self, scope):
        """Get a sub-scope relative to self regardless of structure's prefix.
        """
        with self:
            return self.structure.scope(scope, self.protect)

    def protected(self):
        """Returns a protected version of the scope.

        In a protected scope, doing scope[name] will return the full node name
        as a string instead of a NodeWrapper, and will *NOT* add the node if it
        does not exist.
        """
        return self.structure.scope(self.prefix, True)

    def __iter__(self):
        """Iterator for all nodes in the structure within the scope.
        """
        for member_name in self.structure.nodes:
            if member_name.startswith(self.prefix + ":"):
                yield self.structure[member_name]

    def __contains__(self, node):
        """True iff @node is a member of the scope.
        """
        if isinstance(node, NodeWrapper):
            assert node.structure == self.structure
            node = node.full_name
        return node.startswith(self.prefix + ":")

    def __len__(self):
        """Returns the number of nodes in the structure within the scope.
        """
        return sum(node.startswith(self.prefix + ":")
                   for node in self.structure.nodes)

class NodeWrapper:
    """Represents a single node in a given structure."""
    def __init__(self, structure, full_name):
        """Initialize the NodeWrapper."""
        self.structure = structure
        self.full_name = full_name

    def map(self, mappings):
        """Helper for adding facts to the structure.

        node.map({A: B, C: D}) adds (node, A, B) and (node, C, D).

        NOTE: Be wary of repeated keys!
        """
        def to_fact(value_node, key_node):
            return (self.full_name, value_node.full_name, key_node.full_name)
        facts = []
        for value, key in mappings.items():
            if isinstance(key, NodeWrapper):
                facts.append(to_fact(value, key))
            else:
                # Allow sets of keys
                facts.extend(to_fact(value, sub_key) for sub_key in key)

        # We sort here to ensure it's deterministic.
        self.structure.add_facts(sorted(facts))

    def scoped_name(self, scope):
        """Returns string @x such that @scope[@x] = @self.

        This is the "first name" where @scope is the "last name." Used, for
        example, by ts_utils to find rules that should be marked /= or /MAYBE=
        in rules based on their name.
        """
        if not self.full_name.startswith(scope.prefix):
            return self.full_name
        return self.full_name[len(scope.prefix):]

    def __sub__(self, scope):
        """Syntactic sugar for scoped_name(...)."""
        return self.scoped_name(scope)

    def remove_with_facts(self):
        """Removes the node and all associated facts from the structure."""
        self.structure.remove_facts(
            self.structure.facts_about_node(self.full_name))
        self.structure.remove_node(self.full_name)

    def remove(self):
        """Removes the node (without associated facts) from the structure.

        This is equivalent to assert not facts_about_node; remove_with_facts().
        It should be used when there is an invariant that no related facts
        should exist in the structure. See runtime/assignment.py for an
        example.
        """
        self.structure.remove_node(self.full_name)

    def display_name(self, set_to=None):
        """Gets or sets the display name of the node."""
        if set_to is not None:
            self.structure.display_names[self.full_name] = set_to
        return self.structure.display_names[self.full_name]

    def __eq__(self, other):
        """True iff @self and @other refer to the same node."""
        return ((self.structure, self.full_name) ==
                (other.structure, other.full_name))

    def __hash__(self):
        """Hash based on the structure and name of the node."""
        return hash((self.structure, self.full_name))

    def __lt__(self, other):
        """Lexicographical comparison for sorting."""
        return ((id(self.structure), self.full_name)
                < (id(other.structure), other.full_name))

    def __str__(self):
        """Returns the name of the node."""
        return self.full_name

class TSDelta:
    """Represents the change between two TripletStructures."""
    def __init__(self, ts):
        """Initialize a TSDelta."""
        self.ts = ts
        self.add_nodes, self.add_facts = set(), set()
        self.remove_nodes, self.remove_facts = set(), set()

    def apply(self):
        """Apply the TSDelta to self.ts."""
        assert self is not self.ts.buffer
        assert self.ts.is_clean()
        # NOTE: Sorted here is just for determinism.
        self.ts.add_nodes(sorted(self.add_nodes))
        self.ts.add_facts(sorted(self.add_facts))
        self.ts.remove_facts(sorted(self.remove_facts))
        self.ts.remove_nodes(sorted(self.remove_nodes))
        self.ts._force_clean()
        # TODO: maybe this should just wrap it?
        self.ts.path.append(self)

    def rollback(self):
        """Undo the TSDelta."""
        assert self is not self.ts.buffer
        # NOTE: Sorted here is just for determinism.
        self.ts.remove_facts(sorted(self.add_facts))
        self.ts.remove_nodes(sorted(self.add_nodes))
        self.ts.add_nodes(sorted(self.remove_nodes))
        self.ts.add_facts(sorted(self.remove_facts))
        # Maybe we should assert that this is at the end of the path and remove
        # it?

    def add_node(self, full_name):
        """Record the addition of a new node."""
        self.add_nodes.add(full_name)

    def add_fact(self, fact):
        """Record the addition of a new fact."""
        self.add_facts.add(fact)

    def remove_node(self, full_name):
        """Record the removal of an existing node."""
        try:
            self.add_nodes.remove(full_name)
        except KeyError:
            self.remove_nodes.add(full_name)

    def remove_fact(self, fact):
        """Record the removal of an existing fact."""
        try:
            self.add_facts.remove(fact)
        except KeyError:
            self.remove_facts.add(fact)

    def __bool__(self):
        """True iff the TSDelta is not a no-op."""
        return (bool(self.add_nodes) or bool(self.add_facts) or
                bool(self.remove_nodes) or bool(self.remove_facts))

    def __str__(self):
        """Human-readable format of the TSDelta. """
        def _format(list_):
            if list_ and isinstance(sorted(list_)[0], tuple):
                list_ = [tuple(map(lambda x: self.ts.display_names.get(x, x),
                                   map(str, el))) for el in sorted(list_)]
            return "\n\t\t" + "\n\t\t".join(map(str, list_))
        return ("TSDelta ({id}):"
                "\n\t- Nodes: {remove_nodes}" +
                "\n\t+ Nodes: {add_nodes}" +
                "\n\t- Facts: {remove_facts}" +
                "\n\t+ Facts: {add_facts}").format(
                    id=id(self),
                    remove_nodes=_format(self.remove_nodes),
                    add_nodes=_format(self.add_nodes),
                    remove_facts=_format(self.remove_facts),
                    add_facts=_format(self.add_facts))

class TSRecording:
    """Helper class representing all TSDeltas applied after some checkpoint."""
    def __init__(self, ts):
        """Initialize the TSRecording."""
        self.ts = ts
        self.start_path = self.ts.path.copy()
        assert self.ts.is_clean()

    def commits(self, rollback=False):
        """TSDeltas applied to the structure since @self was initialized.

        If rollback=True, it will also roll back the state of the structure to
        when @self was initialized.
        """
        assert (self.ts.path[:len(self.start_path)] == self.start_path
                and self.ts.is_clean())
        deltas = self.ts.path[len(self.start_path):]
        if rollback:
            self.rollback()
        return deltas

    def rollback(self):
        """Roll back the state of the structure to when @self was initialized.
        """
        assert (self.ts.path[:len(self.start_path)] == self.start_path
                and self.ts.is_clean())
        self.ts.rollback(len(self.start_path))

class TSFreezeFrame:
    """Represents a snapshot of a TripletStructure."""
    def __init__(self, ts):
        """Initialize the TSFreezeFrame."""
        self.ts = ts
        self.nodes = set(ts.nodes)
        self.facts = set(ts.lookup(None, None, None))

    def delta_to_reach(self, desired, nodes=True, facts=True):
        """Return a TSDelta, applying which transforms @self to @desired."""
        delta = TSDelta(self.ts)
        if nodes:
            delta.add_nodes = desired.nodes - self.nodes
            delta.remove_nodes = self.nodes - desired.nodes
        if facts:
            delta.add_facts = desired.facts - self.facts
            delta.remove_facts = self.facts - desired.facts
        return delta

    def __sub__(self, other):
        """Syntactic sugar for delta_to_reach."""
        return other.delta_to_reach(self)

    def __eq__(self, other):
        """True iff @self and @other represent the same structure state."""
        return (self.ts == other.ts
                and self.nodes == other.nodes
                and self.facts == other.facts)
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback