diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py
--- a/swh/storage/in_memory.py
+++ b/swh/storage/in_memory.py
@@ -11,7 +11,6 @@
 import functools
 import itertools
 import random
-import re
 
 from collections import defaultdict
 from datetime import timedelta
@@ -56,6 +55,7 @@
     DirectoryRow,
     DirectoryEntryRow,
     ObjectCountRow,
+    OriginRow,
     ReleaseRow,
     RevisionRow,
     RevisionParentRow,
@@ -245,6 +245,7 @@
         self._releases = Table(ReleaseRow)
         self._snapshots = Table(SnapshotRow)
         self._snapshot_branches = Table(SnapshotBranchRow)
+        self._origins = Table(OriginRow)
         self._stat_counters = defaultdict(int)
 
     def increment_counter(self, object_type: str, nb: int):
@@ -482,6 +483,40 @@
             if count >= limit:
                 break
 
+    ##########################
+    # 'origin' table
+    ##########################
+
+    def origin_add_one(self, origin: OriginRow) -> None:
+        self._origins.insert(origin)
+        self.increment_counter("origin", 1)
+
+    def origin_get_by_sha1(self, sha1: bytes) -> Iterable[OriginRow]:
+        return self._origins.get_from_partition_key((sha1,))
+
+    def origin_get_by_url(self, url: str) -> Iterable[OriginRow]:
+        return self.origin_get_by_sha1(origin_url_to_sha1(url))
+
+    def origin_list(
+        self, start_token: int, limit: int
+    ) -> Iterable[Tuple[int, OriginRow]]:
+        """Returns an iterable of (token, origin)"""
+        matches = [
+            (token, row)
+            for (token, partition) in self._origins.data.items()
+            for (clustering_key, row) in partition.items()
+            if token >= start_token
+        ]
+        matches.sort()
+        return matches[0:limit]
+
+    def origin_iter_all(self) -> Iterable[OriginRow]:
+        return (
+            row
+            for (token, partition) in self._origins.data.items()
+            for (clustering_key, row) in partition.items()
+        )
+
 
 class InMemoryStorage(CassandraStorage):
     _cql_runner: InMemoryCqlRunner  # type: ignore
@@ -492,8 +527,6 @@
 
     def reset(self):
         self._cql_runner = InMemoryCqlRunner()
-        self._origins = {}
-        self._origins_by_sha1 = {}
         self._origin_visits = {}
         self._origin_visit_statuses: Dict[Tuple[str, int], List[OriginVisitStatus]] = {}
         self._persons = {}
@@ -533,114 +566,11 @@
     def check_config(self, *, check_write: bool) -> bool:
         return True
 
-    def _convert_origin(self, t):
-        if t is None:
-            return None
-
-        return t.to_dict()
-
-    def origin_get_one(self, origin_url: str) -> Optional[Origin]:
-        return self._origins.get(origin_url)
-
-    def origin_get(self, origins: List[str]) -> Iterable[Optional[Origin]]:
-        return [self.origin_get_one(origin_url) for origin_url in origins]
-
-    def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]:
-        return [self._convert_origin(self._origins_by_sha1.get(sha1)) for sha1 in sha1s]
-
-    def origin_list(
-        self, page_token: Optional[str] = None, limit: int = 100
-    ) -> PagedResult[Origin]:
-        origin_urls = sorted(self._origins)
-        from_ = bisect.bisect_left(origin_urls, page_token) if page_token else 0
-        next_page_token = None
-
-        # Take one more origin so we can reuse it as the next page token if any
-        origins = [Origin(url=url) for url in origin_urls[from_ : from_ + limit + 1]]
-
-        if len(origins) > limit:
-            # last origin id is the next page token
-            next_page_token = str(origins[-1].url)
-            # excluding that origin from the result to respect the limit size
-            origins = origins[:limit]
-
-        assert len(origins) <= limit
-
-        return PagedResult(results=origins, next_page_token=next_page_token)
-
-    def origin_search(
-        self,
-        url_pattern: str,
-        page_token: Optional[str] = None,
-        limit: int = 50,
-        regexp: bool = False,
-        with_visit: bool = False,
-    ) -> PagedResult[Origin]:
-        next_page_token = None
-        offset = int(page_token) if page_token else 0
-
-        origins = self._origins.values()
-        if regexp:
-            pat = re.compile(url_pattern)
-            origins = [orig for orig in origins if pat.search(orig.url)]
-        else:
-            origins = [orig for orig in origins if url_pattern in orig.url]
-
-        if with_visit:
-            filtered_origins = []
-            for orig in origins:
-                visits = (
-                    self._origin_visit_get_updated(ov.origin, ov.visit)
-                    for ov in self._origin_visits[orig.url]
-                )
-                for ov in visits:
-                    snapshot = ov["snapshot"]
-                    if snapshot and not list(self.snapshot_missing([snapshot])):
-                        filtered_origins.append(orig)
-                        break
-        else:
-            filtered_origins = origins
-
-        # Take one more origin so we can reuse it as the next page token if any
-        origins = filtered_origins[offset : offset + limit + 1]
-        if len(origins) > limit:
-            # next offset
-            next_page_token = str(offset + limit)
-            # excluding that origin from the result to respect the limit size
-            origins = origins[:limit]
-
-        assert len(origins) <= limit
-        return PagedResult(results=origins, next_page_token=next_page_token)
-
-    def origin_count(
-        self, url_pattern: str, regexp: bool = False, with_visit: bool = False
-    ) -> int:
-        actual_page = self.origin_search(
-            url_pattern, regexp=regexp, with_visit=with_visit, limit=len(self._origins),
-        )
-        assert actual_page.next_page_token is None
-        return len(actual_page.results)
-
     def origin_add(self, origins: List[Origin]) -> Dict[str, int]:
-        added = 0
         for origin in origins:
-            if origin.url not in self._origins:
-                self.origin_add_one(origin)
-                added += 1
-
-        self._cql_runner.increment_counter("origin", added)
-
-        return {"origin:add": added}
-
-    def origin_add_one(self, origin: Origin) -> str:
-        if origin.url not in self._origins:
-            self.journal_writer.origin_add([origin])
-            self._origins[origin.url] = origin
-            self._origins_by_sha1[origin_url_to_sha1(origin.url)] = origin
-            self._origin_visits[origin.url] = []
-            self._objects[origin.url].append(("origin", origin.url))
-
-        return origin.url
+            if origin.url not in self._origin_visits:
+                self._origin_visits[origin.url] = []
+        return super().origin_add(origins)
 
     def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]:
         for visit in visits:
@@ -651,8 +581,7 @@
         all_visits = []
         for visit in visits:
             origin_url = visit.origin
-            if origin_url in self._origins:
-                origin = self._origins[origin_url]
+            if list(self._cql_runner.origin_get_by_url(origin_url)):
                 if visit.visit:
                     self.journal_writer.origin_visit_add([visit])
                     while len(self._origin_visits[origin_url]) < visit.visit:
@@ -800,12 +729,11 @@
                 f"{','.join(VISIT_STATUSES)} authorized"
             )
 
-        ori = self._origins.get(origin)
-        if not ori:
+        if not list(self._cql_runner.origin_get_by_url(origin)):
             return None
 
         visits = sorted(
-            self._origin_visits[ori.url], key=lambda v: (v.date, v.visit), reverse=True,
+            self._origin_visits[origin], key=lambda v: (v.date, v.visit), reverse=True,
         )
         for visit in visits:
             if type is not None and visit.type != type:
@@ -876,8 +804,7 @@
                 f"{','.join(VISIT_STATUSES)} authorized"
             )
 
-        ori = self._origins.get(origin_url)
-        if not ori:
+        if not list(self._cql_runner.origin_get_by_url(origin_url)):
             return None
 
         visit_key = (origin_url, visit)
diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py
--- a/swh/storage/tests/test_api_client.py
+++ b/swh/storage/tests/test_api_client.py
@@ -8,7 +8,9 @@
 import swh.storage.api.server as server
 import swh.storage.storage
 from swh.storage import get_storage
-from swh.storage.tests.test_storage import TestStorageGeneratedData  # noqa
+from swh.storage.tests.test_storage import (
+    TestStorageGeneratedData as _TestStorageGeneratedData,
+)
 from swh.storage.tests.test_storage import TestStorage as _TestStorage
 
 # tests are executed using imported classes (TestStorage and
@@ -60,14 +62,36 @@
     storage.journal_writer = journal_writer
 
 
-class TestStorage(_TestStorage):
-    @pytest.mark.skip("content_update is not yet implemented for Cassandra")
-    def test_content_update(self):
-        pass
-
+class TestStorageApi(_TestStorage):
     @pytest.mark.skip(
         'The "person" table of the pgsql is a legacy thing, and not '
         "supported by the cassandra backend."
     )
     def test_person_fullname_unicity(self):
         pass
+
+    @pytest.mark.skip("content_update is not yet implemented for Cassandra")
+    def test_content_update(self):
+        pass
+
+    @pytest.mark.skip("Not supported by Cassandra")
+    def test_origin_count(self):
+        pass
+
+
+class TestStorageApiGeneratedData(_TestStorageGeneratedData):
+    @pytest.mark.skip("Not supported by Cassandra")
+    def test_origin_count(self):
+        pass
+
+    @pytest.mark.skip("Not supported by Cassandra")
+    def test_origin_count_with_visit_no_visits(self):
+        pass
+
+    @pytest.mark.skip("Not supported by Cassandra")
+    def test_origin_count_with_visit_with_visits_and_snapshot(self):
+        pass
+
+    @pytest.mark.skip("Not supported by Cassandra")
+    def test_origin_count_with_visit_with_visits_no_snapshot(self):
+        pass
diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py
--- a/swh/storage/tests/test_in_memory.py
+++ b/swh/storage/tests/test_in_memory.py
@@ -10,7 +10,9 @@
 from swh.storage.cassandra.model import BaseRow
 from swh.storage.in_memory import SortedList, Table
 from swh.storage.tests.test_storage import TestStorage as _TestStorage
-from swh.storage.tests.test_storage import TestStorageGeneratedData  # noqa
+from swh.storage.tests.test_storage import (
+    TestStorageGeneratedData as _TestStorageGeneratedData,
+)
 
 
 # tests are executed using imported classes (TestStorage and
@@ -168,3 +170,25 @@
     @pytest.mark.skip("content_update is not yet implemented for Cassandra")
     def test_content_update(self):
         pass
+
+    @pytest.mark.skip("Not supported by Cassandra")
+    def test_origin_count(self):
+        pass
+
+
+class TestInMemoryStorageGeneratedData(_TestStorageGeneratedData):
+    @pytest.mark.skip("Not supported by Cassandra")
+    def test_origin_count(self):
+        pass
+
+    @pytest.mark.skip("Not supported by Cassandra")
+    def test_origin_count_with_visit_no_visits(self):
+        pass
+
+    @pytest.mark.skip("Not supported by Cassandra")
+    def test_origin_count_with_visit_with_visits_and_snapshot(self):
+        pass
+
+    @pytest.mark.skip("Not supported by Cassandra")
+    def test_origin_count_with_visit_with_visits_no_snapshot(self):
+        pass
diff --git a/swh/storage/tests/test_replay.py b/swh/storage/tests/test_replay.py
--- a/swh/storage/tests/test_replay.py
+++ b/swh/storage/tests/test_replay.py
@@ -206,7 +206,6 @@
     assert got_persons == expected_persons
 
     for attr_ in (
-        "origins",
         "origin_visits",
         "origin_visit_statuses",
     ):
@@ -223,6 +222,7 @@
         "revisions",
         "releases",
         "snapshots",
+        "origins",
     ):
         if exclude and attr_ in exclude:
             continue
@@ -379,10 +379,7 @@
     got_persons = set(dst._persons.values())
     assert got_persons == expected_persons
 
-    for attr_ in (
-        "origins",
-        "origin_visit_statuses",
-    ):
+    for attr_ in ("origin_visit_statuses",):
         expected_objects = [
             (id, maybe_anonymize(attr_, obj))
             for id, obj in sorted(getattr(src, f"_{attr_}").items())
@@ -399,6 +396,7 @@
         "revisions",
         "releases",
         "snapshots",
+        "origins",
     ):
         expected_objects = [
             (id, nullify_ctime(maybe_anonymize(attr_, obj)))