Skip to content

Commit

Permalink
plugin: Don't try to tag field accessor for zios
Browse files Browse the repository at this point in the history
  • Loading branch information
mschuwalow committed Jun 16, 2023
1 parent a45e9cc commit ce2df08
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 38 deletions.
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ lazy val core = project
stdSettings("zio-profiling"),
libraryDependencies ++= Seq(
"dev.zio" %% "zio" % zioVersion,
"dev.zio" %% "zio-streams" % zioVersion,
"org.scala-lang.modules" %% "scala-collection-compat" % collectionCompatVersion,
"dev.zio" %% "zio-test" % zioVersion % Test,
"dev.zio" %% "zio-test-sbt" % zioVersion % Test
Expand Down
6 changes: 3 additions & 3 deletions project/BuildHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ object BuildHelper {
val Scala213 = versions("2.13")
val Scala3 = versions("3")

val defaulScalaVersion = Scala213
val defaultScalaVersion = Scala213

def stdSettings(prjName: String) =
Seq(
name := s"$prjName",
crossScalaVersions := List(Scala212, Scala213, Scala3),
ThisBuild / scalaVersion := defaulScalaVersion,
ThisBuild / scalaVersion := defaultScalaVersion,
scalacOptions := stdOptions ++ extraOptions(scalaVersion.value, optimize = !isSnapshot.value),
libraryDependencies ++= {
if (scalaVersion.value == Scala3)
Expand All @@ -50,7 +50,7 @@ object BuildHelper {
compilerPlugin("com.github.ghik" % "silencer-plugin" % silencerVersion cross CrossVersion.full)
)
},
semanticdbEnabled := scalaVersion.value == defaulScalaVersion,
semanticdbEnabled := scalaVersion.value == defaultScalaVersion,
semanticdbOptions ++= (if (scalaVersion.value != Scala3) List("-P:semanticdb:synthetics:on") else Nil),
semanticdbVersion := scalafixSemanticdb.revision,
ThisBuild / scalafixScalaBinaryVersion := CrossVersion.binaryScalaVersion(scalaVersion.value),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ object BenchmarkUtils {
if (customRt ne null) customRt else Runtime.default
}

def getSupervisor(): Supervisor[Any] = {
val customRt = runtimeRef.get()
if (customRt ne null) customRt.environment.get[SamplingProfilerSupervisor] else Supervisor.none
}

def unsafeRun[E, A](zio: ZIO[Any, E, A]): A =
Unsafe.unsafe { implicit unsafe =>
getRuntime().unsafe.run(zio).getOrThrowFiberFailure()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,23 @@ class TaggingPlugin(val global: Global) extends Plugin {

class TaggingTransformer(unit: CompilationUnit) extends TypingTransformer(unit) {
override def transform(tree: Tree): Tree = tree match {
case valDef @ ValDef(_, _, ZioTypeTree(t1, t2, t3), rhs) if isNonAbstract(valDef) =>
val transformedRhs = tagEffectTree(descriptiveName(tree), rhs, t1, t2, t3)
case valDef @ ValDef(_, _, TaggableTypeTree(taggingTarget), rhs) if rhs.nonEmpty =>
val transformedRhs = tagEffectTree(descriptiveName(tree), rhs, taggingTarget)
val typedRhs = localTyper.typed(transformedRhs)
val updated = treeCopy.ValDef(tree, valDef.mods, valDef.name, valDef.tpt, rhs = typedRhs)

super.transform(updated)
case defDef @ DefDef(_, _, _, _, ZioTypeTree(t1, t2, t3), rhs) if isNonAbstract(defDef) =>
val transformedRhs = tagEffectTree(descriptiveName(tree), rhs, t1, t2, t3)
case defDef @ DefDef(_, _, _, _, TaggableTypeTree(taggingTarget), rhs) if rhs.nonEmpty =>
val transformedRhs = tagEffectTree(descriptiveName(tree), rhs, taggingTarget)
val typedRhs = localTyper.typed(transformedRhs)
val updated =
treeCopy.DefDef(tree, defDef.mods, defDef.name, defDef.tparams, defDef.vparamss, defDef.tpt, rhs = typedRhs)

super.transform(updated)
case _ =>
super.transform(tree)
}

private def isNonAbstract(tree: ValOrDefDef): Boolean =
!tree.mods.isDeferred

private def descriptiveName(tree: Tree): String = {
val fullName = tree.symbol.fullNameString
val sourceFile = tree.pos.source.file.name
Expand All @@ -53,21 +52,37 @@ class TaggingPlugin(val global: Global) extends Plugin {
s"$fullName($sourceFile:$sourceLine)"
}

private def tagEffectTree(name: String, tree: Tree, t1: Type, t2: Type, t3: Type): Tree = {
private def tagEffectTree(name: String, tree: Tree, taggingTarget: TaggingTarget): Tree = {
val costCenterModule = rootMirror.getRequiredModule("_root_.zio.profiling.CostCenter")
val traceModule = rootMirror.getRequiredModule("_root_.zio.Trace")

q"$costCenterModule.withChildCostCenter[$t1,$t2,$t3]($name)($tree)($traceModule.empty)"
taggingTarget match {
case ZioTaggingTarget(t1, t2, t3) =>
q"$costCenterModule.withChildCostCenter[$t1,$t2,$t3]($name)($tree)($traceModule.empty)"
case ZStreamTaggingTarget(t1, t2, t3) =>
println(name)
q"$costCenterModule.withChildCostCenterStream[$t1,$t2,$t3]($name)($tree)($traceModule.empty)"
}

}

private object ZioTypeTree {
private def zioTypeRef: Type =
rootMirror.getRequiredClass("zio.ZIO").tpe
private sealed trait TaggingTarget

private case class ZioTaggingTarget(rType: Type, eType: Type, aType: Type) extends TaggingTarget
private case class ZStreamTaggingTarget(rType: Type, eType: Type, aType: Type) extends TaggingTarget

private object TaggableTypeTree {
private def zioTypeRef: Type = rootMirror.getRequiredClass("_root_.zio.ZIO").tpe

private def zStreamTypeRef: Type = rootMirror.getRequiredClass("_root_.zio.stream.ZStream").tpe

def unapply(tpt: Tree): Option[(Type, Type, Type)] =
def unapply(tpt: Tree): Option[TaggingTarget] =
tpt.tpe.dealias match {
case TypeRef(_, sym, t1 :: t2 :: t3 :: Nil) if sym == zioTypeRef.typeSymbol => Some((t1, t2, t3))
case _ => None
case TypeRef(_, sym, t1 :: t2 :: t3 :: Nil) if sym == zioTypeRef.typeSymbol =>
Some(ZioTaggingTarget(t1, t2, t3))
case TypeRef(_, sym, t1 :: t2 :: t3 :: Nil) if sym == zStreamTypeRef.typeSymbol =>
Some(ZStreamTaggingTarget(t1, t2, t3))
case _ => None
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ object TaggingPhase extends PluginPhase {
override val runsBefore = Set(Staging.name)

override def transformValDef(tree: tpd.ValDef)(using Context): tpd.Tree = tree match {
case ValDef(_, ZioTypeTree(t1, t2, t3), _) if !tree.mods.flags.is(Flags.DeferredTerm) =>
val transformedRhs = tagEffectTree(descriptiveName(tree), tree.rhs, t1, t2, t3)
case ValDef(_, TaggableTypeTree(taggingTarget), rhs) if !tree.rhs.isEmpty =>
val transformedRhs = tagEffectTree(descriptiveName(tree), tree.rhs, taggingTarget)
cpy.ValDef(tree)(rhs = transformedRhs)
case _ =>
tree
}

override def transformDefDef(tree: tpd.DefDef)(using Context): tpd.Tree = tree match {
case DefDef(_, _, tpt @ ZioTypeTree(t1, t2, t3), _) if !tree.mods.flags.is(Flags.DeferredTerm) =>
val transformedRhs = tagEffectTree(descriptiveName(tree), tree.rhs, t1, t2, t3)
case DefDef(_, _, TaggableTypeTree(taggingTarget), rhs) if !tree.rhs.isEmpty =>
val transformedRhs = tagEffectTree(descriptiveName(tree), tree.rhs, taggingTarget)
cpy.DefDef(tree)(rhs = transformedRhs)
case _ =>
tree
Expand All @@ -47,27 +47,46 @@ object TaggingPhase extends PluginPhase {
s"$fullName($sourceFile:$sourceLine)"
}

private def tagEffectTree(name: String, tree: tpd.Tree, t1: Type, t2: Type, t3: Type)(using Context): tpd.Tree = {
val costcenterSym = requiredModule("zio.profiling.CostCenter")
val withChildCostCenterSym = costcenterSym.requiredMethod("withChildCostCenter")

val traceSym = requiredModule("zio.Trace")
private def tagEffectTree(name: String, tree: tpd.Tree, taggingTarget: TaggingTarget)(using Context): tpd.Tree = {
val costcenterSym = requiredModule("_root_.zio.profiling.CostCenter")
val traceSym = requiredModule("_root_.zio.Trace")
val emptyTraceSym = traceSym.requiredMethodRef("empty")

tpd.ref(withChildCostCenterSym)
.appliedToTypes(List(t1, t2, t3))
.appliedTo(tpd.Literal(Constant(name)))
.appliedTo(tree)
.appliedTo(tpd.ref(emptyTraceSym))
taggingTarget match {
case ZioTaggingTarget(t1, t2, t3) =>
val withChildCostCenterSym = costcenterSym.requiredMethod("withChildCostCenter")

tpd.ref(withChildCostCenterSym)
.appliedToTypes(List(t1, t2, t3))
.appliedTo(tpd.Literal(Constant(name)))
.appliedTo(tree)
.appliedTo(tpd.ref(emptyTraceSym))

case ZStreamTaggingTarget(t1, t2, t3) =>
val withChildCostCenterSym = costcenterSym.requiredMethod("withChildCostCenterStream")

tpd.ref(withChildCostCenterSym)
.appliedToTypes(List(t1, t2, t3))
.appliedTo(tpd.Literal(Constant(name)))
.appliedTo(tree)
.appliedTo(tpd.ref(emptyTraceSym))
}
}

private object ZioTypeTree {
private def zioTypeRef(using Context): TypeRef =
requiredClassRef("zio.ZIO")
private sealed trait TaggingTarget

private case class ZioTaggingTarget(rType: Type, eType: Type, aType: Type) extends TaggingTarget
private case class ZStreamTaggingTarget(rType: Type, eType: Type, aType: Type) extends TaggingTarget

private object TaggableTypeTree {
private def zioTypeRef(using Context): TypeRef = requiredClassRef("_root_.zio.ZIO")

private def zStreamTypeRef(using Context): TypeRef = requiredClassRef("_root_.stream.ZStream")

def unapply(tp: Tree[Type])(using Context): Option[(Type, Type, Type)] =
def unapply(tp: Tree[Type])(using Context): Option[TaggingTarget] =
tp.tpe.dealias match {
case AppliedType(at, t1 :: t2 :: t3 :: Nil) if at.isRef(zioTypeRef.symbol) => Some((t1, t2, t3))
case AppliedType(at, t1 :: t2 :: t3 :: Nil) if at.isRef(zioTypeRef.symbol) => Some(ZioTaggingTarget(t1, t2, t3))
case AppliedType(at, t1 :: t2 :: t3 :: Nil) if at.isRef(zioTypeRef.symbol) => Some(ZStreamTaggingTarget(t1, t2, t3))
case _ => None
}

Expand Down
9 changes: 9 additions & 0 deletions zio-profiling/src/main/scala/zio/profiling/CostCenter.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package zio.profiling

import zio._
import zio.stream.ZStream

/**
* A CostCenter allows grouping multiple source code locations into one unit for reporting and targeting purposes.
Expand Down Expand Up @@ -75,6 +76,14 @@ object CostCenter {
def withChildCostCenter[R, E, A](name: String)(zio: ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, A] =
globalRef.locallyWith(_ / name)(zio)

/**
* Run an effect with a child cost center nested under the current one.
*/
def withChildCostCenterStream[R, E, A](name: String)(stream: ZStream[R, E, A])(implicit
trace: Trace
): ZStream[R, E, A] =
ZStream.scoped[R](globalRef.locallyScopedWith(_ / name)) *> stream

private final val globalRef: FiberRef[CostCenter] =
Unsafe.unsafe(implicit u => FiberRef.unsafe.make(CostCenter.Root, identity, (old, _) => old))
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ final case class SamplingProfiler(

/**
* Create a runtime that will profile all effects executed with it. Use `runtime.environment.get` in order to get a
* reference to the supervisor. Make sure to shut down the runtime when down.
* reference to the supervisor. Make sure to shut down the runtime when done.
*/
def supervisedRuntime(implicit unsafe: Unsafe): Runtime.Scoped[SamplingProfilerSupervisor] = {
val layer = ZLayer.scoped[Any](makeSupervisor).flatMap(env => Runtime.addSupervisor(env.get).map(_ => env))
Expand Down

0 comments on commit ce2df08

Please sign in to comment.