Skip to content

Instantly share code, notes, and snippets.

@rintcius
Created November 13, 2018 08:41
Show Gist options
  • Save rintcius/78e22111766806cfd8577ca0aaf9d58f to your computer and use it in GitHub Desktop.
Save rintcius/78e22111766806cfd8577ca0aaf9d58f to your computer and use it in GitHub Desktop.
/*
* Copyright 2014–2018 SlamData Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package quasar.physical.mongodb
package io
import slamdata.Predef._
import quasar.fp.ski.ι
import java.util.{ArrayList => JArrayList}
import scala.Predef.classOf
import scala.collection.JavaConverters._
import scala.util.Either
import cats.effect.Async
import com.mongodb._
import com.mongodb.async._
import com.mongodb.async.client._
import fs2.{Chunk, Stream}
import org.bson.{BsonBoolean, BsonDocument, Document}
import scalaz._, Scalaz._
import shims._
final class MongoDbIO[F[_]: Async] private (client: MongoClient) {
def collectionExists(c: Collection): F[Boolean] =
collectionsIn(c.database)
.exists(_.collection === c.collection)
.compile.last.map(_.getOrElse(false))
def collectionsIn(dbName: DatabaseName): Stream[F, Collection] =
Stream.eval(database(dbName))
.flatMap(db => toStream(db.listCollectionNames))
.map(c => Collection(dbName, CollectionName(c)))
def collectionStatistics(coll: Collection): F[CollectionStatistics] = {
val cmd = Bson.Doc(ListMap("collStats" -> coll.collection.bson))
def longValue(doc: BsonDocument, field: String): String \/ Long =
\/.fromTryCatchNonFatal(Option(doc.getNumber(field)).map(_.longValue) \/>
s"expected field: $field").fold(_.getMessage.left, ι)
def booleanValue(doc: BsonDocument, field: String): Boolean =
doc.get(field, BsonBoolean.FALSE) != BsonBoolean.FALSE
runCommand(coll.database, cmd).map(doc =>
(for {
count <- longValue(doc, "count")
dataSize <- longValue(doc, "size")
sharded = booleanValue(doc, "sharded")
} yield CollectionStatistics(count, dataSize, sharded)))
.flatMap(_.fold(
err => fail(new MongoException("could not read collection statistics: " + err)),
_.point[F]))
}
def databaseExists(dbName: DatabaseName): F[Boolean] =
databaseNames
.exists(_ === dbName)
.compile.last.map(_.getOrElse(false))
def databaseNames: Stream[F, DatabaseName] =
toStream(client.listDatabaseNames).map(DatabaseName(_))
/** Set of indexes on a collection, including only simple index types and
* ignoring the rest.
*/
def indexes(coll: Collection): F[Set[Index]] = {
// TODO: split on ".", but note that MongoDB seems to treat such keys
// special anyway, at least when arrays are present.
def decodeField(s: String): BsonField = BsonField.Name(s)
val decodeType: PartialFunction[java.lang.Object, IndexType] = {
case x: java.lang.Number if x.intValue ≟ 1 => IndexType.Ascending
case x: java.lang.Number if x.intValue ≟ -1 => IndexType.Descending
case "hashed" => IndexType.Hashed
}
def decodeIndex(doc: Document): Option[Index] =
(Option(doc.get("name")).flatMap {
case s: String => s.some
case _ => None
} ⊛
Option(doc.get("key")).flatMap {
case kd: Document =>
kd.asScala.toList.toNel.flatMap(_.traverse {
case (k, v) => decodeType.lift(v).strengthL(decodeField(k))
})
case _ => None
})(Index(
_,
_,
Option(doc.get("unique")).fold(
false) {
case java.lang.Boolean.TRUE => true
case _ => false
}))
collection(coll)
.flatMap(c => collect[Document](c.listIndexes))
.map(_.flatMap(decodeIndex(_).toList).toSet)
}
////
private val F = Async[F]
private final class AsyncCallback[A](f: Either[Throwable, A] => Unit)
extends SingleResultCallback[A] {
def onResult(result: A, error: Throwable): Unit =
f(Either.cond(error == null, result, error))
}
private def async[A](f: SingleResultCallback[A] => Unit): F[A] =
F.async(cb => f(new AsyncCallback(cb)))
private def collect[A](iter: MongoIterable[A]): F[List[A]] =
async[JArrayList[A]](cb => iter.into[JArrayList[A]](new JArrayList[A], cb)).map(_.asScala.toList)
private def collection(c: Collection): F[MongoCollection[BsonDocument]] =
database(c.database).map(_.getCollection(c.collection.value, classOf[BsonDocument]))
private def database(named: DatabaseName): F[MongoDatabase] =
F.delay(client.getDatabase(named.value))
private def fail[A](t: Throwable): F[A] =
F.raiseError(t)
private def runCommand(dbName: DatabaseName, cmd: Bson.Doc): F[BsonDocument] =
database(dbName) >>= (db => async[BsonDocument](db.runCommand(cmd, classOf[BsonDocument], _)))
private def toStream[A](it: MongoIterable[A]): Stream[F, A] = {
def acquire(i: MongoIterable[A]): F[AsyncBatchCursor[A]] = async(i.batchCursor)
def release(cursor: AsyncBatchCursor[A]): F[Unit] = F.delay(cursor.close())
def next(cursor: AsyncBatchCursor[A]): F[Option[(Chunk[A], AsyncBatchCursor[A])]] =
async(cursor.next).flatMap(l =>
if (l != null)
F.delay(some((Chunk.seq(l.asScala), cursor)))
else
F.pure(none[(Chunk[A], AsyncBatchCursor[A])])
)
Stream.bracket(acquire(it))(release)
.flatMap(s => Stream.unfoldChunkEval(s)(next))
}
}
object MongoDbIO {
def apply[F[_]: Async](client: MongoClient) = new MongoDbIO[F](client)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment