summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Sotoudeh <matthew@masot.net>2023-07-27 14:26:33 -0700
committerMatthew Sotoudeh <matthew@masot.net>2023-07-27 14:26:33 -0700
commit2b8985608d33abaae7b201a008e292cbbe2167ef (patch)
tree58a76aaf7c7447988b1c78095480dc8bce5d2ab4
parent578531395ecbabd8179e31520c2832ac7d6d3765 (diff)
add automated refcounting pass
-rw-r--r--python/README.txt56
-rw-r--r--python/dietc.py7
-rw-r--r--python/examples/refcounting/.gitignore1
-rw-r--r--python/examples/refcounting/Makefile8
l---------python/examples/refcounting/dietc.py1
-rwxr-xr-xpython/examples/refcounting/dietpass76
-rw-r--r--python/examples/refcounting/runtime/refcounting.c114
-rw-r--r--python/examples/refcounting/runtime/refcounting.h8
-rw-r--r--python/examples/refcounting/test.c34
9 files changed, 303 insertions, 2 deletions
diff --git a/python/README.txt b/python/README.txt
index 3dd6a96..cd682f0 100644
--- a/python/README.txt
+++ b/python/README.txt
@@ -18,7 +18,8 @@ Example of using a pass:
Foo return value: 0
Foo return value: 0
- The examples/dynamic_typing pass adds a "dynamic_typeof" feature to C:
+ The examples/dynamic_typing pass adds a "dynamic_typeof" feature to C, and
+ uses it to implement pretty-printing in C:
$ which dietcc
[make sure it's on your path!]
@@ -41,3 +42,56 @@ Example of using a pass:
$ ./test
{ .x = (char)1, .y = (int)2, .z = { (int)3, (int)4, (int)5 } }
{ .a = (char)100, .b = (int)200 }
+
+ The examples/refcounting pass adds automated reference counting to C:
+
+ $ which dietcc
+ [make sure it's on your path!]
+ $ cd examples/refcounting
+ $ cat test.c
+ ...
+ int *foo(void) {
+ int *ptr = 0;
+ for (int i = 0; i < 5; i++)
+ ptr = calloc(1, sizeof(*ptr));
+ ptr = calloc(1, sizeof(*ptr));
+ ptr = calloc(1, sizeof(*ptr));
+ *ptr = 5;
+ return ptr;
+ }
+ ...
+ int *bar(void) {
+ struct bar bar;
+ bar.x = 5;
+ bar.y = calloc(1, sizeof(int));
+ *bar.y = bar.x;
+ return bar.y;
+ }
+ ...
+ int main(void) {
+ int *x = foo();
+ printf("Result of foo(): %d\n", *x);
+ x = bar();
+ printf("Result of bar(): %d\n", *x);
+ return 0;
+ }
+ $ make
+ $ ./test
+ Allocated: 0x55bc2ca6e2a0 ; now tracking 1 regions
+ Allocated: 0x55bc2ca6e700 ; now tracking 2 regions
+ Freeing! Left: 1
+ Allocated: 0x55bc2ca6e750 ; now tracking 2 regions
+ Freeing! Left: 1
+ Allocated: 0x55bc2ca6e7a0 ; now tracking 2 regions
+ Freeing! Left: 1
+ Allocated: 0x55bc2ca6e7f0 ; now tracking 2 regions
+ Freeing! Left: 1
+ Allocated: 0x55bc2ca6e840 ; now tracking 2 regions
+ Allocated: 0x55bc2ca6e890 ; now tracking 3 regions
+ Freeing! Left: 2
+ Freeing! Left: 1
+ Result of foo(): 5
+ Allocated: 0x55bc2ca6e8e0 ; now tracking 2 regions
+ Result of bar(): 5
+ Freeing! Left: 1
+ Freeing! Left: 0
diff --git a/python/dietc.py b/python/dietc.py
index 1101062..58d7e42 100644
--- a/python/dietc.py
+++ b/python/dietc.py
@@ -43,7 +43,7 @@ class Function:
def code_start(self):
return next(i for i, l in enumerate(self.body)
- if not l.line.startswith("\tType_"))
+ if not l[0].startswith("Type_"))
def insert(self, i, string):
self.body.insert(i, Instruction(string))
@@ -71,6 +71,11 @@ class Instruction:
if self[0] == "*":
return program.object_types[self[1]].base
return program.object_types[self[0]]
+ def replace_token(self, old, new):
+ self = Instruction(self.line)
+ self.tokens = [t if t != old else new for t in self.tokens]
+ self.line = ' '.join(self.tokens)
+ return self
class Program:
def __init__(self, string):
diff --git a/python/examples/refcounting/.gitignore b/python/examples/refcounting/.gitignore
new file mode 100644
index 0000000..9daeafb
--- /dev/null
+++ b/python/examples/refcounting/.gitignore
@@ -0,0 +1 @@
+test
diff --git a/python/examples/refcounting/Makefile b/python/examples/refcounting/Makefile
new file mode 100644
index 0000000..2d5e9d5
--- /dev/null
+++ b/python/examples/refcounting/Makefile
@@ -0,0 +1,8 @@
+test: test.c runtime/refcounting.o
+ dietcc -I $(PWD)/runtime -o $@ $^ --dietc-pass $(PWD)/dietpass
+
+runtime/refcounting.o: runtime/refcounting.c
+ gcc -c -o $@ $^
+
+clean:
+ rm -f runtime/refcounting.o test
diff --git a/python/examples/refcounting/dietc.py b/python/examples/refcounting/dietc.py
new file mode 120000
index 0000000..8cb9097
--- /dev/null
+++ b/python/examples/refcounting/dietc.py
@@ -0,0 +1 @@
+../../dietc.py \ No newline at end of file
diff --git a/python/examples/refcounting/dietpass b/python/examples/refcounting/dietpass
new file mode 100755
index 0000000..9a7b0e8
--- /dev/null
+++ b/python/examples/refcounting/dietpass
@@ -0,0 +1,76 @@
+#!/bin/python3
+import os
+import sys
+from dietc import *
+
+def eprint(*args): print(*args, file=sys.stderr)
+
+def main():
+ prog = Program(open(sys.argv[1], "rb").read())
+ prog.preamble.append("#include <refcounting.h>")
+ for function in prog.functions:
+ # (1) rewrite calloc, free, ... calls to refcount_*
+ for name in ("calloc", "free", "malloc", "realloc"):
+ function.body = [instr.replace_token(name, f"refcount_{name}")
+ for instr in function.body]
+ new_body = function.body[:function.code_start()]
+ # (2a) initialize all pointer expressions to 0
+ for local in function.locals():
+ for p in ptr_exprs(local, prog.object_types[local]):
+ new_body.append(Instruction(f"{p} = 0 ;"))
+ # (2b) whenever a pointer is written to, call refcount_write
+ for instr in function.body[function.code_start():]:
+ if instr[0] == "*" and instr[2] == "=":
+ refcount_pairs = []
+ ty = prog.object_types[instr[1]].base
+ for expr in ptr_exprs(f"(*{instr[1]})", ty):
+ backup = f"_refcount_ptr_backup_{count()}"
+ new_body.append(Instruction(f"void * {backup} = {expr} ;"))
+ refcount_pairs.append((backup, expr))
+ new_body.append(instr)
+ for old, new in refcount_pairs:
+ new_body.append(Instruction(f"refcount_write({new},{old});"))
+ elif instr[1] == "=":
+ refcount_pairs = []
+ ty = prog.object_types[instr[0]]
+ for expr in ptr_exprs(instr[0], ty):
+ backup = f"_refcount_ptr_backup_{count()}"
+ new_body.append(Instruction(f"void * {backup} = {expr} ;"))
+ refcount_pairs.append((backup, expr))
+ new_body.append(instr)
+ for old, new in refcount_pairs:
+ new_body.append(Instruction(f"refcount_write({new},{old});"))
+ elif instr[0] == "return":
+ for local in function.locals():
+ if local == instr[1]: continue
+ for expr in ptr_exprs(local, prog.object_types[local]):
+ new_body.append(Instruction(f"refcount_write(0,({expr}));"))
+ if instr[1] != ";":
+ for expr in ptr_exprs(instr[1], prog.object_types[instr[1]]):
+ new_body.append(Instruction(f"refcount_returning({expr});"))
+ new_body.append(instr)
+ else:
+ new_body.append(instr)
+ function.body = new_body
+ prog.print()
+
+# given an object and its type, returns a list of l-value expression strings,
+# representing all of the pointers in that object
+def ptr_exprs(name, ty):
+ if isinstance(ty, PointerType):
+ return [name]
+ elif isinstance(ty, AggregateType):
+ exprs = []
+ for fname, ftype in ty.fields:
+ exprs.extend(ptr_exprs(f"{name}.{fname}", ftype))
+ return exprs
+ elif isinstance(ty, (BasicType, FunctionType)):
+ return []
+ elif isinstance(ty, ArrayType):
+ # TODO
+ return []
+ else:
+ print(ty.classify(), file=sys.stderr)
+ raise NotImplementedError
+
+main()
diff --git a/python/examples/refcounting/runtime/refcounting.c b/python/examples/refcounting/runtime/refcounting.c
new file mode 100644
index 0000000..ba8cacc
--- /dev/null
+++ b/python/examples/refcounting/runtime/refcounting.c
@@ -0,0 +1,114 @@
+#include <stdio.h>
+#include <stdlib.h>
+#include <assert.h>
+
+struct data {
+ void *start;
+ size_t len;
+ size_t refcount;
+};
+
+struct node {
+ struct data data;
+
+ struct node *left;
+ struct node *right;
+};
+
+static struct node *ROOT = 0;
+static int TOTAL = 0;
+
+static struct node *lookup_interval(struct node *root, void *ptr) {
+ if (!root) return 0;
+
+ if (ptr < root->data.start)
+ return lookup_interval(root->left, ptr);
+
+ if (ptr < (root->data.start + root->data.len))
+ return root;
+
+ // our last hope is something on the right
+ return lookup_interval(root->right, ptr);
+}
+
+// exact lookup, ignoring range len
+static struct node **lookup(void *ptr) {
+ struct node **node = &ROOT;
+ while (*node) {
+ if ((*node)->data.start == ptr)
+ return node;
+ if (ptr > (*node)->data.start)
+ node = &((*node)->right);
+ else
+ node = &((*node)->left);
+ }
+ return node;
+}
+
+static struct node *insert(void *ptr, size_t len) {
+ struct node **loc = lookup(ptr);
+ assert(!(*loc));
+ *loc = calloc(1, sizeof(struct node));
+ (*loc)->data.start = ptr;
+ (*loc)->data.len = len;
+ return *loc;
+}
+
+static void remove_root(struct node **tree) {
+ if (!(*tree)) {
+ return;
+ } else if (!((*tree)->right)) {
+ *tree = (*tree)->left;
+ } else if (!((*tree)->left)) {
+ *tree = (*tree)->right;
+ } else {
+ // swap *tree and (*tree)->right, then delete the root in tree->right
+ struct data root_data = (*tree)->data;
+ (*tree)->data = (*tree)->right->data;
+ (*tree)->right->data = root_data;
+ remove_root(&((*tree)->right));
+ }
+}
+
+static struct node *remove_node(void *ptr) {
+ remove_root(lookup(ptr));
+}
+
+void *refcount_calloc(unsigned long count, unsigned long size) {
+ void *data = calloc(count, size);
+ TOTAL += 1;
+ printf("Allocated: %p ; now tracking %d regions\n", data, TOTAL);
+ insert(data, count * size);
+ return data;
+}
+
+void refcount_write(void *new, void *old) {
+ // printf("Refcount write new: %p old: %p\n", new, old);
+ if (new) {
+ struct node *new_node = lookup_interval(ROOT, new);
+ if (new_node) {
+ // printf("Incrementing refcount ...\n");
+ new_node->data.refcount++;
+ }
+ }
+ if (old) {
+ struct node *old_node = lookup_interval(ROOT, old);
+ if (old_node) {
+ // printf("Decrementing refcount ...\n");
+ old_node->data.refcount--;
+ if (!(old_node->data.refcount)) {
+ TOTAL -= 1;
+ printf("Freeing! Left: %d\n", TOTAL);
+ remove_node(old_node->data.start);
+ free(old);
+ }
+ }
+ }
+}
+
+void refcount_returning(void *ptr) {
+ if (!ptr) return;
+ struct node *ptr_node = lookup_interval(ROOT, ptr);
+ if (!ptr_node) return;
+ ptr_node->data.refcount--;
+}
diff --git a/python/examples/refcounting/runtime/refcounting.h b/python/examples/refcounting/runtime/refcounting.h
new file mode 100644
index 0000000..139f09a
--- /dev/null
+++ b/python/examples/refcounting/runtime/refcounting.h
@@ -0,0 +1,8 @@
+#pragma once
+
+// wrapper for calloc
+void *refcount_calloc(unsigned long count, unsigned long size);
+// overwrite a pointer that pointed to 'old' to now point to 'new'
+void refcount_write(void *new, void *old);
+// decrement the refcount, but do not free if the refcount hits 0
+void refcount_returning(void *ptr);
diff --git a/python/examples/refcounting/test.c b/python/examples/refcounting/test.c
new file mode 100644
index 0000000..212bb8b
--- /dev/null
+++ b/python/examples/refcounting/test.c
@@ -0,0 +1,34 @@
+#include <stdlib.h>
+#include <stdio.h>
+
+int *foo(void) {
+ int *ptr = 0;
+ for (int i = 0; i < 5; i++) {
+ ptr = calloc(1, sizeof(*ptr));
+ }
+ ptr = calloc(1, sizeof(*ptr));
+ ptr = calloc(1, sizeof(*ptr));
+ *ptr = 5;
+ return ptr;
+}
+
+struct bar {
+ int x;
+ int *y;
+};
+
+int *bar(void) {
+ struct bar bar;
+ bar.x = 5;
+ bar.y = calloc(1, sizeof(int));
+ *bar.y = bar.x;
+ return bar.y;
+}
+
+int main(void) {
+ int *x = foo();
+ printf("Result of foo(): %d\n", *x);
+ x = bar();
+ printf("Result of bar(): %d\n", *x);
+ return 0;
+}
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback