spark 를 이용 하여 tfrecord 파일 생 성
공식 예:
package com.xxx.tfrecords
import scala.collection.JavaConversions._;
import scala.collection.JavaConverters._;
import collection.JavaConversions._
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._
object TFRecordsExample {
def main(args: Array[String]): Unit = {
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
val spark = SparkSession.builder().master("local[4]").appName("tfrecords_examples").getOrCreate()
val path = "file/test-output.tfrecord"
val testRows: Array[Row] = Array(
new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")),
new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2")))
val schema = StructType(List(
StructField("id", IntegerType),
StructField("IntegerCol", IntegerType),
StructField("LongCol", LongType),
StructField("FloatCol", FloatType),
StructField("DoubleCol", DoubleType),
StructField("VectorCol", ArrayType(DoubleType, true)),
StructField("StringCol", StringType)))
val rdd = spark.sparkContext.parallelize(testRows)
//Save DataFrame as TFRecords
val df: DataFrame = spark.createDataFrame(rdd, schema)
df.write.format("tfrecords").option("recordType", "Example").save(path)
//Read TFRecords into DataFrame.
//The DataFrame schema is inferred from the TFRecords if no custom schema is provided.
val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").load(path)
importedDf1.show()
//Read TFRecords into DataFrame using custom schema
val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path)
importedDf2.show()
}
}
bert 모델 훈련 데이터 테스트 읽 기:
package com.xxx.tfrecords
import scala.collection.JavaConversions._;
import scala.collection.JavaConverters._;
import collection.JavaConversions._
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._
object TFRcordsBert {
def main(args: Array[String]): Unit = {
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
val spark = SparkSession.builder().master("local[4]").appName("tfrecords_examples").getOrCreate()
val path = "/Users/shuubiasahi/Desktop/textclass/"
val schema = StructType(List(
StructField("input_ids", ArrayType(IntegerType, true)),
StructField("input_mask", ArrayType(IntegerType, true)),
StructField("label_ids", ArrayType(IntegerType, true))))
val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "SequenceExample").load(path)
importedDf1.show()
val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path)
importedDf2.show()
}
}
+--------------------+--------------------+---------+ | input_ids| input_mask|label_ids| +--------------------+--------------------+---------+ |[101, 4281, 3566,...|[1, 1, 1, 1, 1, 1...| [25]| |[101, 3433, 5866,...|[1, 1, 1, 1, 0, 0...| [40]| |[101, 6631, 5277,...|[1, 1, 1, 1, 1, 1...| [5]|
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
형태소 분석은 데스크톱을 구성하는 데 도움이?문자×기계 학습에 흥미를 가져와 개인 범위의 용도를 생각해, 폴더 정리에 사용할 수 있을까 생각해 검토를 시작했습니다. 이번 검토에서는 폴더 구성 & text의 읽기 → mecab × wordcloud를 실시하고 있...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.