spark 를 이용 하여 tfrecord 파일 생 성

현재 데이터 가 점점 많아 지고 데 이 터 는 일반적으로 hdfs 에 저장 되 지만 현재 많은 깊이 있 는 학습 알고리즘 은 TensorFlow,pytorch 등 프레임 워 크 를 바탕 으로 이 루어 집 니 다.단기 python,자바 로 데이터 변환 을 하 는 것 이 비교적 느 립 니 다.hdfs 데 이 터 를 어떻게 대규모로 TensorFlow 에 직접 먹 입 니까?여기 서 TensorFlow 는 해결 방안 을 제공 하고 spark 를 이용 하여 tfrecord 파일 을 생 성 합 니 다.프로젝트 이름 은 spark-tensor flow-connector 이 고 GitHub 홈 페이지 는https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-connector 이 아래 는 readme 에 따라 jar 패 키 지 를 컴 파일 하고 자신의 프로젝트 에 의존 할 수 있 습 니 다.직접 jar 패 키 지 를 컴 파일 하고 싶 지 않 으 면 이 위 에 직접 다운로드 의존 도 를 추가 할 수 있 습 니 다.https://mvnrepository.com/artifact/org.tensorflow/spark-tensorflow-connector주요 원 리 는 이 프로젝트 에 은사 변환 클래스 를 쓰 고 출력 형식 을 다시 쓰 는 것 이다.상부 에서 출력 하 는 인터페이스 가 비교적 간단 하고 scala,python 의 인 터 페 이 스 를 제공 합 니 다.실제 배경 은 모두 proto 에 의존 하기 때문에 구 글 의 기술 의 강력 함 과 홍보 능력 에 탄복 할 수 밖 에 없습니다.다음은 어떻게 사용 하 는 지 보 겠 습 니 다.
 
공식 예:
package com.xxx.tfrecords
import scala.collection.JavaConversions._;
import scala.collection.JavaConverters._;
import collection.JavaConversions._
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.SparkConf

import org.apache.spark.sql.SparkSession

import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._

object TFRecordsExample {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
    Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)

    val spark = SparkSession.builder().master("local[4]").appName("tfrecords_examples").getOrCreate()

    val path = "file/test-output.tfrecord"
    val testRows: Array[Row] = Array(
      new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")),
      new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2")))
      
    val schema = StructType(List(
      StructField("id", IntegerType),
      StructField("IntegerCol", IntegerType),
      StructField("LongCol", LongType),
      StructField("FloatCol", FloatType),
      StructField("DoubleCol", DoubleType),
      StructField("VectorCol", ArrayType(DoubleType, true)),
      StructField("StringCol", StringType)))

    val rdd = spark.sparkContext.parallelize(testRows)

    //Save DataFrame as TFRecords
    val df: DataFrame = spark.createDataFrame(rdd, schema)
    df.write.format("tfrecords").option("recordType", "Example").save(path)

    //Read TFRecords into DataFrame.
    //The DataFrame schema is inferred from the TFRecords if no custom schema is provided.
    val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").load(path)
    importedDf1.show()

    //Read TFRecords into DataFrame using custom schema
    val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path)
    importedDf2.show()

  }

}

 
bert 모델 훈련 데이터 테스트 읽 기:
package com.xxx.tfrecords
import scala.collection.JavaConversions._;
import scala.collection.JavaConverters._;
import collection.JavaConversions._
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.SparkConf

import org.apache.spark.sql.SparkSession

import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._
object TFRcordsBert {
  def main(args: Array[String]): Unit = {
    
     Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
    Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)

    val spark = SparkSession.builder().master("local[4]").appName("tfrecords_examples").getOrCreate()

    val path = "/Users/shuubiasahi/Desktop/textclass/"
    val schema = StructType(List(
      StructField("input_ids",  ArrayType(IntegerType, true)),
      StructField("input_mask",  ArrayType(IntegerType, true)),
      StructField("label_ids",  ArrayType(IntegerType, true))))
     
      
   val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "SequenceExample").load(path)
    importedDf1.show()

    val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path)
    importedDf2.show()

    
    
  }
}

+--------------------+--------------------+---------+ |           input_ids|          input_mask|label_ids| +--------------------+--------------------+---------+ |[101, 4281, 3566,...|[1, 1, 1, 1, 1, 1...|     [25]| |[101, 3433, 5866,...|[1, 1, 1, 1, 0, 0...|     [40]| |[101, 6631, 5277,...|[1, 1, 1, 1, 1, 1...|      [5]|

좋은 웹페이지 즐겨찾기