Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Table.write_back(), replacing documents by ids #184

Merged
merged 12 commits into from
Feb 18, 2018
40 changes: 40 additions & 0 deletions tests/test_tinydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,46 @@ def test_update_ids(db):
assert db.count(where('int') == 2) == 2


def test_write_back(db):
docs = db.search(where('int') == 1)
for doc in docs:
doc['int'] = [1, 2, 3]

db.write_back(docs)
assert db.count(where('int') == [1, 2, 3]) == 3


def test_write_back_whole_doc(db):
docs = db.search(where('int') == 1)
doc_ids = [doc.doc_id for doc in docs]
for i, doc in enumerate(docs):
docs[i] = {'newField': i}

db.write_back(docs, doc_ids)
assert db.count(where('newField') == 0) == 1
assert db.count(where('newField') == 1) == 1
assert db.count(where('newField') == 2) == 1


def test_write_back_returns_ids(db):
db.purge()
assert db.insert({'int': 1, 'char': 'a'}) == 1
assert db.insert({'int': 1, 'char': 'a'}) == 2
assert db.write_back([{'word': 'hello'}, {'word': 'world'}], [1, 2]) == [1, 2]


def test_write_back_fails(db):
with pytest.raises(ValueError):
db.write_back([{'get': 'error'}], [1, 2])


def test_write_back_id_exceed(db):
db.purge()
db.insert({'int': 1})
with pytest.raises(IndexError):
db.write_back([{'get': 'error'}], [2])


def test_upsert(db):
assert len(db) == 3

Expand Down
35 changes: 35 additions & 0 deletions tinydb/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,41 @@ def update(self, fields, cond=None, doc_ids=None, eids=None):
cond, doc_ids
)

def write_back(self, documents, doc_ids=None, eids=None):
"""
Write back documents by doc_id

:param documents: a list of document to write back
:param doc_ids: a list of documents' ID which needs to be wrote back
:returns: a list of documents' ID taht has been wrote back
"""
doc_ids = _get_doc_ids(doc_ids, eids)

if doc_ids is not None and not len(documents) == len(doc_ids):
raise ValueError(
'The length of documents and doc_ids is not match.')

if doc_ids is None:
doc_ids = [doc.doc_id for doc in documents]

# Since this function will write docs back like inserting, to ensure
# here only process existing or removed instead of inserting new,
# raise error if doc_id exceeded the last.
if sorted(doc_ids)[-1] > self._last_id:
raise IndexError(
'Id exceed table length, use existing or removed doc_id.')

data = self._read()

# Document specified by ID
documents.reverse()
for doc_id in doc_ids:
data[doc_id] = documents.pop()

self._write(data)

return doc_ids

def upsert(self, document, cond):
"""
Update a document, if it exist - insert it otherwise.
Expand Down