Scala: 反射和case类

8
以下代码是有效的,但是否有更好的方法来完成相同的事情?也许是一些特定于 case classes 的东西?在以下代码中,对于我简单的 case class 中的每个 String 类型的字段,该代码会遍历我的该 case class 实例列表,并找到该字段最长字符串的长度。
case class CrmContractorRow(
                             id: Long,
                             bankCharges: String,
                             overTime: String,
                             name$id: Long,
                             mgmtFee: String,
                             contractDetails$id: Long,
                             email: String,
                             copyOfVisa: String)

object Go {
  def main(args: Array[String]) {
    val a = CrmContractorRow(1,"1","1",4444,"1",1,"1","1")
    val b = CrmContractorRow(22,"22","22",22,"55555",22,"nine long","22")
    val c = CrmContractorRow(333,"333","333",333,"333",333,"333","333")
    val rows = List(a,b,c)

    c.getClass.getDeclaredFields.filter(p => p.getType == classOf[String]).foreach{f =>
      f.setAccessible(true)
      println(f.getName + ": " + rows.map(row => f.get(row).asInstanceOf[String]).maxBy(_.length))
    }
  }
}

结果:

bankCharges: 3
overTime: 3
mgmtFee: 5
email: 9
copyOfVisa: 3
4个回答

11
如果你想使用Shapeless来完成这种操作,我强烈建议定义一个自定义类型类来处理复杂部分,并允许你将这些内容与其他逻辑分开。
在这种情况下,听起来你具体尝试做的棘手部分是获取所有case class中String成员的字段名称到字符串长度的映射。这里有一个类型类可以实现这个功能:
import shapeless._, shapeless.labelled.FieldType

trait StringFieldLengths[A] { def apply(a: A): Map[String, Int] }

object StringFieldLengths extends LowPriorityStringFieldLengths {
  implicit val hnilInstance: StringFieldLengths[HNil] =
    new StringFieldLengths[HNil] {
      def apply(a: HNil): Map[String, Int] = Map.empty
    }

  implicit def caseClassInstance[A, R <: HList](implicit
    gen: LabelledGeneric.Aux[A, R],
    sfl: StringFieldLengths[R]
  ): StringFieldLengths[A] = new StringFieldLengths[A] {
    def apply(a: A): Map[String, Int] = sfl(gen.to(a))
  }

  implicit def hconsStringInstance[K <: Symbol, T <: HList](implicit
    sfl: StringFieldLengths[T],
    key: Witness.Aux[K]
  ): StringFieldLengths[FieldType[K, String] :: T] =
    new StringFieldLengths[FieldType[K, String] :: T] {
      def apply(a: FieldType[K, String] :: T): Map[String, Int] =
        sfl(a.tail).updated(key.value.name, a.head.length)
    }
}

sealed class LowPriorityStringFieldLengths {
  implicit def hconsInstance[K, V, T <: HList](implicit
    sfl: StringFieldLengths[T]
  ): StringFieldLengths[FieldType[K, V] :: T] =
    new StringFieldLengths[FieldType[K, V] :: T] {
      def apply(a: FieldType[K, V] :: T): Map[String, Int] = sfl(a.tail)
    }
}

这看起来比较复杂,但是一旦你开始使用Shapeless,你就会学会在睡梦中编写这种代码。
现在,你可以相对简单地编写操作逻辑:
def maxStringLengths[A: StringFieldLengths](as: List[A]): Map[String, Int] =
  as.map(implicitly[StringFieldLengths[A]].apply).foldLeft(
    Map.empty[String, Int]
  ) {
    case (x, y) => x.foldLeft(y) {
      case (acc, (k, v)) =>
        acc.updated(k, acc.get(k).fold(v)(accV => math.max(accV, v)))
    }
  }

然后(假设问题中已定义了rows):

scala> maxStringLengths(rows).foreach(println)
(bankCharges,3)
(overTime,3)
(mgmtFee,5)
(email,9)
(copyOfVisa,3)

这对于任何情况类都适用。

如果这是一次性的事情,您可以使用运行时反射,或者可以使用Giovanni Caporaletti答案中的Poly1方法 - 它不太通用,并且以我不喜欢的方式混合了解决方案的不同部分,但它应该可以正常工作。 如果您经常这样做,我建议使用我在这里提供的方法。


非常感谢。当我掌握了Shapeless之后,我会回来的。 - Anthony Holland
1
不要等到那时候!我不确定有人掌握了Shapeless! - Travis Brown
很好,比我的好多了! - Giovanni Caporaletti
@GiovanniCaporaletti 不过不够简洁,可惜。 :) - Travis Brown

3
如果你想使用shapeless获取case class中的字符串字段并避免反射,可以像这样操作:
import shapeless._
import labelled._

trait lowerPriorityfilterStrings extends Poly2 {
  implicit def default[A] = at[Vector[(String, String)], A] { case (acc, _) => acc }
}

object filterStrings extends lowerPriorityfilterStrings {
  implicit def caseString[K <: Symbol](implicit w: Witness.Aux[K]) = at[Vector[(String, String)], FieldType[K, String]] {
    case (acc, x) =>  acc :+ (w.value.name -> x)
  }
}

val gen = LabelledGeneric[CrmContractorRow]


val a = CrmContractorRow(1,"1","1",4444,"1",1,"1","1")
val b = CrmContractorRow(22,"22","22",22,"55555",22,"nine long","22")
val c = CrmContractorRow(333,"333","333",333,"333",333,"333","333")
val rows = List(a,b,c)

val result = rows
  // get for each element a Vector of (fieldName -> stringField) pairs for the string fields
  .map(r => gen.to(r).foldLeft(Vector[(String, String)]())(filterStrings))
  // get the maximum for each "column"
  .reduceLeft((best, row) => best.zip(row).map {
    case (kv1@(_, v1), (_, v2)) if v1.length > v2.length => kv1
    case (_, kv2) => kv2
  })

result foreach { case (k, v) => println(s"$k: $v") }

我在这里保留了字符串,但你可以轻松地用它们的长度来替换它们。 - Giovanni Caporaletti

2

您可能需要使用Scala反射:

import scala.reflect.runtime.universe._

val rm = runtimeMirror(getClass.getClassLoader)
val instanceMirrors = rows map rm.reflect
typeOf[CrmContractorRow].members collect {
  case m: MethodSymbol if m.isCaseAccessor && m.returnType =:= typeOf[String] =>
    val maxValue = instanceMirrors map (_.reflectField(m).get.asInstanceOf[String]) maxBy (_.length)
    println(s"${m.name}$maxValue")
}

这样您就可以避免以下类似情况引起的问题:

case class CrmContractorRow(id: Long, bankCharges: String, overTime: String, name$id: Long, mgmtFee: String, contractDetails$id: Long, email: String, copyOfVisa: String) {
  val unwantedVal = "jdjd"
}

干杯


这是我可能会使用的一个!至少在我理解Shapeless之前。 - Anthony Holland

0

我已经重构了你的代码,使其更具可重用性:

import scala.reflect.ClassTag

case class CrmContractorRow(
                             id: Long,
                             bankCharges: String,
                             overTime: String,
                             name$id: Long,
                             mgmtFee: String,
                             contractDetails$id: Long,
                             email: String,
                             copyOfVisa: String)

object Go{
  def main(args: Array[String]) {
    val a = CrmContractorRow(1,"1","1",4444,"1",1,"1","1")
    val b = CrmContractorRow(22,"22","22",22,"55555",22,"nine long","22")
    val c = CrmContractorRow(333,"333","333",333,"333",333,"333","333")
    val rows = List(a,b,c)
    val initEmptyColumns = List.fill(a.productArity)(List())

    def aggregateColumns[Tin:ClassTag,Tagg](rows: Iterable[Product], aggregate: Iterable[Tin] => Tagg) = {

      val columnsWithMatchingType = (0 until rows.head.productArity).filter {
        index => rows.head.productElement(index) match {case t: Tin => true; case _ => false}
      }

      def columnIterable(col: Int) = rows.map(_.productElement(col)).asInstanceOf[Iterable[Tin]]

      columnsWithMatchingType.map(index => (index,aggregate(columnIterable(index))))
    }

    def extractCaseClassFieldNames[T: scala.reflect.ClassTag] = {
      scala.reflect.classTag[T].runtimeClass.getDeclaredFields.filter(!_.isSynthetic).map(_.getName)
    }

    val agg = aggregateColumns[String,String] (rows,_.maxBy(_.length))
    val fieldNames = extractCaseClassFieldNames[CrmContractorRow]

    agg.map{case (index,value) => fieldNames(index) + ": "+ value}.foreach(println)
  }
}

使用shapeless可以摆脱.asInstanceOf,但本质是相同的。给定代码的主要问题在于它不可重用,因为聚合逻辑与反射逻辑混合在一起以获取字段名称。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接