# Backport of commit 6a6c527ee287 (available in 17.0). To avoid merge conflicts the whole function was copied from patch LLVM 15.0.
# It enables better code geneartion for FP16 convert on SPR: int64/uint64 -> half (vcvtqq2ph / vcvtuqq2ph) for x8 target.
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 8bb7e81e19bb..049c68b7cc3d 100644
--- llvm/lib/Target/X86/X86ISelLowering.cpp.orig
+++ llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -53682,22 +53682,56 @@ static SDValue combineFP_ROUND(SDNode *N, SelectionDAG &DAG,
   if (!Subtarget.hasF16C() || Subtarget.useSoftFloat())
     return SDValue();
 
-  if (Subtarget.hasFP16())
-    return SDValue();
-
+  bool IsStrict = N->isStrictFPOpcode();
   EVT VT = N->getValueType(0);
-  SDValue Src = N->getOperand(0);
+  SDValue Src = N->getOperand(IsStrict ? 1 : 0);
   EVT SrcVT = Src.getValueType();
 
   if (!VT.isVector() || VT.getVectorElementType() != MVT::f16 ||
       SrcVT.getVectorElementType() != MVT::f32)
     return SDValue();
 
+  SDLoc dl(N);
+
+  SDValue Cvt, Chain;
   unsigned NumElts = VT.getVectorNumElements();
-  if (NumElts == 1 || !isPowerOf2_32(NumElts))
+  if (Subtarget.hasFP16()) {
+    // Combine (v8f16 fp_round(concat_vectors(v4f32 (xint_to_fp v4i64), ..)))
+    // into (v8f16 vector_shuffle(v8f16 (CVTXI2P v4i64), ..))
+    if (NumElts == 8 && Src.getOpcode() == ISD::CONCAT_VECTORS) {
+      SDValue Cvt0, Cvt1;
+      SDValue Op0 = Src.getOperand(0);
+      SDValue Op1 = Src.getOperand(1);
+      bool IsOp0Strict = Op0->isStrictFPOpcode();
+      if (Op0.getOpcode() != Op1.getOpcode() ||
+          Op0.getOperand(IsOp0Strict ? 1 : 0).getValueType() != MVT::v4i64 ||
+          Op1.getOperand(IsOp0Strict ? 1 : 0).getValueType() != MVT::v4i64) {
+        return SDValue();
+      }
+      int Mask[8] = {0, 1, 2, 3, 8, 9, 10, 11};
+      if (IsStrict) {
+        assert(IsOp0Strict && "Op0 must be strict node");
+        unsigned Opc = Op0.getOpcode() == ISD::STRICT_SINT_TO_FP
+                           ? X86ISD::STRICT_CVTSI2P
+                           : X86ISD::STRICT_CVTUI2P;
+        Cvt0 = DAG.getNode(Opc, dl, {MVT::v8f16, MVT::Other},
+                           {Op0.getOperand(0), Op0.getOperand(1)});
+        Cvt1 = DAG.getNode(Opc, dl, {MVT::v8f16, MVT::Other},
+                           {Op1.getOperand(0), Op1.getOperand(1)});
+        Cvt = DAG.getVectorShuffle(MVT::v8f16, dl, Cvt0, Cvt1, Mask);
+        return DAG.getMergeValues({Cvt, Cvt0.getValue(1)}, dl);
+      }
+      unsigned Opc = Op0.getOpcode() == ISD::SINT_TO_FP ? X86ISD::CVTSI2P
+                                                        : X86ISD::CVTUI2P;
+      Cvt0 = DAG.getNode(Opc, dl, MVT::v8f16, Op0.getOperand(0));
+      Cvt1 = DAG.getNode(Opc, dl, MVT::v8f16, Op1.getOperand(0));
+      return Cvt = DAG.getVectorShuffle(MVT::v8f16, dl, Cvt0, Cvt1, Mask);
+    }
     return SDValue();
+  }
 
-  SDLoc dl(N);
+  if (NumElts == 1 || !isPowerOf2_32(NumElts))
+    return SDValue();
 
   // Widen to at least 4 input elements.
   if (NumElts < 4)
@@ -53705,10 +53739,16 @@ static SDValue combineFP_ROUND(SDNode *N, SelectionDAG &DAG,
                       DAG.getConstantFP(0.0, dl, SrcVT));
 
   // Destination is v8i16 with at least 8 elements.
-  EVT CvtVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16,
-                               std::max(8U, NumElts));
-  SDValue Cvt = DAG.getNode(X86ISD::CVTPS2PH, dl, CvtVT, Src,
-                            DAG.getTargetConstant(4, dl, MVT::i32));
+  EVT CvtVT =
+      EVT::getVectorVT(*DAG.getContext(), MVT::i16, std::max(8U, NumElts));
+  SDValue Rnd = DAG.getTargetConstant(4, dl, MVT::i32);
+  if (IsStrict) {
+    Cvt = DAG.getNode(X86ISD::STRICT_CVTPS2PH, dl, {CvtVT, MVT::Other},
+                      {N->getOperand(0), Src, Rnd});
+    Chain = Cvt.getValue(1);
+  } else {
+    Cvt = DAG.getNode(X86ISD::CVTPS2PH, dl, CvtVT, Src, Rnd);
+  }
 
   // Extract down to real number of elements.
   if (NumElts < 8) {
@@ -53717,7 +53757,12 @@ static SDValue combineFP_ROUND(SDNode *N, SelectionDAG &DAG,
                       DAG.getIntPtrConstant(0, dl));
   }
 
-  return DAG.getBitcast(VT, Cvt);
+  Cvt = DAG.getBitcast(VT, Cvt);
+
+  if (IsStrict)
+    return DAG.getMergeValues({Cvt, Chain}, dl);
+
+  return Cvt;
 }
 
 static SDValue combineMOVDQ2Q(SDNode *N, SelectionDAG &DAG) {
